1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/cuda/CUDAContext.h>
4 #include <ATen/Config.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/ScalarOps.h>
7 #include <ATen/TensorIterator.h>
8 #include <ATen/detail/CUDAHooksInterface.h>
9 #include <ATen/native/Resize.h>
10 #include <ATen/native/SpectralOpsUtils.h>
11 #include <ATen/native/cuda/CuFFTUtils.h>
12 #include <ATen/native/cuda/CuFFTPlanCache.h>
13 #include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
14 #include <c10/util/irange.h>
15
16 #ifndef AT_PER_OPERATOR_HEADERS
17 #include <ATen/Functions.h>
18 #include <ATen/NativeFunctions.h>
19 #else
20 #include <ATen/ops/_fft_c2c_native.h>
21 #include <ATen/ops/_fft_c2r_native.h>
22 #include <ATen/ops/_fft_r2c_native.h>
23 #include <ATen/ops/empty.h>
24 #include <ATen/ops/mul.h>
25 #endif
26
27 #include <cufft.h>
28 #include <cufftXt.h>
29
30 #include <cmath>
31
32
33 namespace at::native {
34
35 using namespace at::native::detail;
36
37 // Execute a pre-planned transform
exec_cufft_plan(const CuFFTConfig & config,void * in_data,void * out_data,bool forward)38 static void exec_cufft_plan(
39 const CuFFTConfig &config, void* in_data, void* out_data, bool forward) {
40 auto& plan = config.plan();
41 CUFFT_CHECK(cufftXtExec(plan, in_data, out_data,
42 forward ? CUFFT_FORWARD : CUFFT_INVERSE));
43 }
44
45
46 // NOTE [ cuFFT Embedded Strides ]
47 //
48 // cuFFT supports a subset of arbitrary strides via their "advanced data layout"
49 // option (http://docs.nvidia.com/cuda/cufft/index.html#advanced-data-layout).
50 // Specifically, these are tensors that can be viewed as subtensors resulted
51 // from slicing a larger contiguous tensors. For such input tensors, let the
52 // sizes of the enclosing tensor be `inembed`, and we can have in 3d case:
53 //
54 // input[x, y, z] = input[((x * inembed[1] + y) * inembed[2] + z)]
55 //
56 // Above is the simplified formula ignoring the batch dimension. In fact, the
57 // last dimension of the enclosing tensor doesn't have to be contiguous, i.e.,
58 // it can be greater than 1. Then one can set the base stride for the enclosing
59 // tensor with `istride`. Then we have
60 //
61 // input[x, y, z] = input[((x * inembed[1] + y) * inembed[2] + z) * istride]
62 //
63 // For example, consider
64 //
65 // enclosing = torch.zeros(6, 8, 10) # contiguous
66 // input = enclosing[:4, 2:6, 6:]
67 // input.size() # [ 4, 4, 4]
68 // input.stride() # [80, 10, 1]
69 // # inembed = [6, 8, 10]
70 // input[2, 1, 3] = input[((2 * 8) + 1) * 10 + 3] # using above formula
71 // = input[173]
72 // = input[2 * 80 + 1 * 10 + 1 * 3] # using strides directly
73 //
74 // Generally, the embedded strides can be computed as
75 //
76 // embed[i] = stride[i - 1] / stride[i].
77 //
78 // Note that the value of embed[0] isn't used to compute indices and doesn't
79 // matter.
80 //
81 // Contrary to advanced data layout, simple layout means that *embeds have
82 // unit-strides. In particular, unit-stride refers to that the input and output
83 // tensors being contiguous, and that the strides at the innermost signal
84 // dimension being unit (1) w.r.t. the corresponding data type.
85
86 // The cuFFT plan cache
87 // unique_ptr for nullability and to avoid reference invalidation on vector resize
88 static std::vector<std::unique_ptr<CuFFTParamsLRUCache>> plan_caches;
89 static std::mutex plan_caches_mutex;
90
91 static inline
cufft_get_plan_cache(DeviceIndex device_index)92 CuFFTParamsLRUCache &cufft_get_plan_cache(DeviceIndex device_index) {
93 std::lock_guard<std::mutex> guard(plan_caches_mutex);
94
95 AT_ASSERT(device_index >= 0);
96
97 if (device_index >= static_cast<int64_t>(plan_caches.size())) {
98 plan_caches.resize(device_index + 1);
99 }
100
101 if (!plan_caches[device_index]) {
102 plan_caches[device_index] = std::make_unique<CuFFTParamsLRUCache>();
103 }
104
105 return *plan_caches[device_index];
106 }
107
108
109 namespace detail {
110
cufft_get_plan_cache_max_size_impl(DeviceIndex device_index)111 int64_t cufft_get_plan_cache_max_size_impl(DeviceIndex device_index) {
112 TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
113 "cufft_get_plan_cache_max_size: expected 0 <= device_index < ",
114 at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
115 device_index);
116 return cufft_get_plan_cache(device_index).max_size();
117 }
118
cufft_set_plan_cache_max_size_impl(DeviceIndex device_index,int64_t max_size)119 void cufft_set_plan_cache_max_size_impl(DeviceIndex device_index, int64_t max_size) {
120 TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
121 "cufft_set_plan_cache_max_size: expected 0 <= device_index < ",
122 at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
123 device_index);
124 return cufft_get_plan_cache(device_index).resize(max_size);
125 }
126
cufft_get_plan_cache_size_impl(DeviceIndex device_index)127 int64_t cufft_get_plan_cache_size_impl(DeviceIndex device_index) {
128 TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
129 "cufft_get_plan_cache_size: expected 0 <= device_index < ",
130 at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
131 device_index);
132 return cufft_get_plan_cache(device_index).size();
133 }
134
cufft_clear_plan_cache_impl(DeviceIndex device_index)135 void cufft_clear_plan_cache_impl(DeviceIndex device_index) {
136 TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
137 "cufft_clear_plan_cache: expected 0 <= device_index < ",
138 at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
139 device_index);
140 return cufft_get_plan_cache(device_index).clear();
141 }
142
143 } // namespace at::native::detail
144
145 namespace {
146 constexpr int64_t cufft_max_ndim = 3;
147
148 // "Large" here means a prime factor not special-cased by cuFFT
149 // Ref: https://docs.nvidia.com/cuda/cufft/index.html#accuracy-and-performance
has_large_prime_factor(int64_t n)150 bool has_large_prime_factor(int64_t n) {
151 constexpr int64_t first_large_prime = 11;
152 const std::array<int64_t, 4> prime_radices{{2, 3, 5, 7}};
153 for (auto prime : prime_radices) {
154 if (n < first_large_prime) {
155 return false;
156 }
157
158 while (n % prime == 0) {
159 n /= prime;
160 }
161 }
162 return n != 1;
163 }
164
165 // Execute a general fft operation (can be c2c, onesided r2c or onesided c2r)
_exec_fft(Tensor & out,const Tensor & self,IntArrayRef out_sizes,IntArrayRef dim,bool forward)166 static const Tensor& _exec_fft(Tensor& out, const Tensor& self, IntArrayRef out_sizes,
167 IntArrayRef dim, bool forward) {
168 const auto ndim = self.dim();
169 const int64_t signal_ndim = dim.size();
170 const auto batch_dims = ndim - signal_ndim;
171
172 // Permute dimensions so batch dimensions come first, and in stride order
173 // This maximizes data locality when collapsing to a single batch dimension
174 DimVector dim_permute(ndim);
175 std::iota(dim_permute.begin(), dim_permute.end(), int64_t{0});
176
177 c10::SmallVector<bool, kDimVectorStaticSize> is_transformed_dim(ndim);
178 for (const auto& d : dim) {
179 is_transformed_dim[d] = true;
180 }
181 auto batch_end = std::partition(dim_permute.begin(), dim_permute.end(),
182 [&](int64_t d) {return !is_transformed_dim[d]; });
183 auto self_strides = self.strides();
184 std::sort(dim_permute.begin(), batch_end,
185 [&](int64_t a, int64_t b) { return self_strides[a] > self_strides[b]; });
186 std::copy(dim.cbegin(), dim.cend(), batch_end);
187 auto input = self.permute(dim_permute);
188
189 // Collapse batch dimensions into a single dimension
190 DimVector batched_sizes(signal_ndim + 1);
191 batched_sizes[0] = -1;
192 std::copy(input.sizes().cbegin() + batch_dims, input.sizes().cend(), batched_sizes.begin() + 1);
193 input = input.reshape(batched_sizes);
194
195 const auto batch_size = input.sizes()[0];
196 DimVector signal_size(signal_ndim + 1);
197 signal_size[0] = batch_size;
198 for (const auto i : c10::irange(signal_ndim)) {
199 auto in_size = input.sizes()[i + 1];
200 auto out_size = out_sizes[dim[i]];
201 signal_size[i + 1] = std::max(in_size, out_size);
202 TORCH_INTERNAL_ASSERT(in_size == signal_size[i + 1] ||
203 in_size == (signal_size[i + 1] / 2) + 1);
204 TORCH_INTERNAL_ASSERT(out_size == signal_size[i + 1] ||
205 out_size == (signal_size[i + 1] / 2) + 1);
206 }
207
208 batched_sizes[0] = batch_size;
209 DimVector batched_out_sizes(batched_sizes.begin(), batched_sizes.end());
210 for (const auto i : c10::irange(dim.size())) {
211 batched_out_sizes[i + 1] = out_sizes[dim[i]];
212 }
213 out.resize_(batched_out_sizes, MemoryFormat::Contiguous);
214
215 // Create the transform plan (either from cache or locally)
216 const auto value_type = c10::toRealValueType(input.scalar_type());
217 auto fft_type = GetCuFFTTransformType(input.is_complex(), out.is_complex());
218 CuFFTParams Params(input.strides(), out.strides(), signal_size, fft_type, value_type);
219 CuFFTParamsLRUCache& plan_cache = cufft_get_plan_cache(input.device().index());
220 std::unique_lock<std::mutex> guard(plan_cache.mutex, std::defer_lock);
221 std::optional<CuFFTConfig> uncached_plan;
222 const CuFFTConfig * config = nullptr;
223
224 // Workaround for gh-63152, gh-58724
225 // Bluestein plans in CUDA 11.1 (cufft 10.3) cannot be re-used
226 // Bluestein's algorithm is only used when a size has large prime factors,
227 // sizes with only small prime factors can still be cached
228 bool use_caching = true;
229 #ifdef CUFFT_VERSION
230 if constexpr (10300 <= CUFFT_VERSION && CUFFT_VERSION < 10400) {
231 // Only cache plans for transforms with small prime factors
232 use_caching = std::none_of(
233 signal_size.begin() + 1, signal_size.end(), [](int64_t dim_size) {
234 return has_large_prime_factor(dim_size);
235 });
236 }
237 #endif
238
239 if (use_caching && plan_cache.max_size() > 0) {
240 guard.lock();
241 if (plan_cache.max_size() > 0) { // check again after acquiring the lock
242 config = &plan_cache.lookup(Params);
243 }
244 }
245
246 if (config == nullptr) {
247 uncached_plan.emplace(Params);
248 config = &uncached_plan.value();
249 }
250
251 auto & plan = config->plan();
252
253 if (config->should_clone_input()) {
254 input = input.clone(MemoryFormat::Contiguous);
255 }
256
257 // prepare cufft for execution
258 CUFFT_CHECK(cufftSetStream(plan, at::cuda::getCurrentCUDAStream()));
259 auto workspace = at::empty({ config->workspace_size() }, at::device(at::kCUDA).dtype(at::kByte));
260 CUFFT_CHECK(cufftSetWorkArea(plan, workspace.mutable_data_ptr()));
261
262 // execute transform plan
263 #if !defined(USE_ROCM)
264 CUcontext pctx = nullptr;
265 at::globalContext().getNVRTC().cuCtxGetCurrent(&pctx);
266 if (C10_UNLIKELY(!pctx)) {
267 // workaround for corner case where a primary context exists but is not
268 // the current context
269 TORCH_WARN_ONCE("Attempting to run cuFFT, but there was no current CUDA context! Attempting to set the primary context...");
270 at::globalContext().getNVRTC().cuDevicePrimaryCtxRetain(&pctx, 0);
271 at::globalContext().getNVRTC().cuCtxSetCurrent(pctx);
272 }
273 #endif /* !defined(USE_ROCM) */
274 exec_cufft_plan(*config, const_cast<void*>(input.const_data_ptr()), out.data_ptr(), forward);
275
276 // Inplace reshaping to original batch shape and inverting the dimension permutation
277 DimVector out_strides(ndim);
278 int64_t batch_numel = 1;
279 for (int64_t i = batch_dims - 1; i >= 0; --i) {
280 out_strides[dim_permute[i]] = batch_numel * out.strides()[0];
281 batch_numel *= out_sizes[dim_permute[i]];
282 }
283 for (const auto i : c10::irange(batch_dims, ndim)) {
284 out_strides[dim_permute[i]] = out.strides()[1 + (i - batch_dims)];
285 }
286 return out.as_strided_(out_sizes, out_strides, out.storage_offset());
287 }
288
289 // Calculates the normalization constant and applies it in-place to self
290 // sizes is the sizes of a twosided tensor and dims are all transformed dims
_fft_normalization_scale(int64_t normalization,IntArrayRef sizes,IntArrayRef dims)291 double _fft_normalization_scale(int64_t normalization, IntArrayRef sizes, IntArrayRef dims) {
292 auto norm = static_cast<fft_norm_mode>(normalization);
293 if (norm == fft_norm_mode::none) {
294 return 1.0;
295 }
296
297 int64_t signal_numel = 1;
298 for (auto dim : dims) {
299 signal_numel *= sizes[dim];
300 }
301 const double scale_denom = (norm == fft_norm_mode::by_root_n) ?
302 std::sqrt(signal_numel) : static_cast<double>(signal_numel);
303 return 1.0 / scale_denom;
304 }
305
_fft_apply_normalization(const Tensor & self,int64_t normalization,IntArrayRef sizes,IntArrayRef dims)306 const Tensor& _fft_apply_normalization(const Tensor& self, int64_t normalization, IntArrayRef sizes, IntArrayRef dims) {
307 auto scale = _fft_normalization_scale(normalization, sizes, dims);
308 return (scale == 1.0) ? self : self.mul_(scale);
309 }
310
_fft_apply_normalization_out(Tensor & out,const Tensor & self,int64_t normalization,IntArrayRef sizes,IntArrayRef dims)311 Tensor& _fft_apply_normalization_out(Tensor& out, const Tensor& self, int64_t normalization, IntArrayRef sizes, IntArrayRef dims) {
312 auto scale = _fft_normalization_scale(normalization, sizes, dims);
313 return at::mul_out(out, self, c10::scalar_to_tensor(scale));
314 }
315
316 } // namespace (anonymous)
317
318 // Use the optimized path to perform single R2C or C2R if transformation dim is supported by cuFFT
use_optimized_cufft_path(IntArrayRef dim)319 bool use_optimized_cufft_path(IntArrayRef dim) {
320 // For performance reason, when dim starts with (0, 1), do not use the optimized path.
321 if (dim.size() > cufft_max_ndim || (
322 dim.size() >= 2 && dim[0] == 0 && dim[1] == 1
323 )) {
324 return false;
325 } else {
326 return true;
327 }
328 }
329
330 // n-dimensional real to complex FFT
_fft_r2c_cufft(const Tensor & self,IntArrayRef dim,int64_t normalization,bool onesided)331 Tensor _fft_r2c_cufft(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) {
332 TORCH_CHECK(self.is_floating_point());
333 auto input_sizes = self.sizes();
334 DimVector onesided_sizes(input_sizes.begin(), input_sizes.end());
335 auto last_dim = dim.back();
336 auto last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1;
337 onesided_sizes[last_dim] = last_dim_halfsize;
338 IntArrayRef out_sizes = onesided ? onesided_sizes : input_sizes;
339
340 const auto out_options = self.options().dtype(c10::toComplexType(self.scalar_type()));
341 auto output = at::empty(out_sizes, out_options);
342
343 // CuFFT requires real input to be over-aligned, as if it were complex
344 const auto complex_size = 2 * self.element_size();
345 const bool complex_aligned = (
346 reinterpret_cast<std::uintptr_t>(self.const_data_ptr()) % complex_size == 0);
347 auto working_tensor = self;
348 if (!complex_aligned) {
349 working_tensor = self.movedim(last_dim, -1)
350 .clone(MemoryFormat::Contiguous)
351 .movedim(-1, last_dim);
352 }
353
354 if (use_optimized_cufft_path(dim)) {
355 _exec_fft(output, working_tensor, out_sizes, dim, /*forward=*/true);
356 } else {
357 // First do the R2C transform on the last dimension
358 {
359 auto target_sizes = dim.size() == 1 ? out_sizes : onesided_sizes;
360 _exec_fft(output, working_tensor, target_sizes, last_dim, /*forward=*/true);
361 if (dim.size() > 1) {
362 working_tensor = at::empty(out_sizes, out_options);
363 }
364 }
365
366 // Then any remaining C2C transforms
367 DimVector sorted_dims(dim.begin(), dim.end() - 1);
368 while (!sorted_dims.empty()) {
369 std::swap(output, working_tensor);
370
371 // Resort dimensions every time as _exec_fft re-strides the output
372 auto strides = working_tensor.strides();
373 std::sort(sorted_dims.begin(), sorted_dims.end(),
374 [&](int64_t a, int64_t b) { return strides[a] > strides[b]; });
375
376 const auto max_dims = std::min(static_cast<size_t>(cufft_max_ndim), sorted_dims.size());
377 auto last_dims = IntArrayRef(sorted_dims).slice(sorted_dims.size() - max_dims, max_dims);
378
379 // Intermediate results are always onesided
380 _exec_fft(output, working_tensor, onesided_sizes, last_dims, /*forward=*/true);
381 sorted_dims.resize(sorted_dims.size() - max_dims);
382 }
383 }
384
385 // Only need to normalize the onesided slice since data in the other half is overwritten
386 auto out_slice = output.slice(last_dim, 0, last_dim_halfsize);
387 _fft_apply_normalization(out_slice, normalization, input_sizes, dim);
388
389 if (!onesided) {
390 if (output.sizes()[last_dim] != out_sizes[last_dim]) {
391 working_tensor.resize_(out_sizes, MemoryFormat::Contiguous);
392 working_tensor.slice(last_dim, 0, last_dim_halfsize).copy_(output);
393 output = std::move(working_tensor);
394 }
395 at::native::_fft_fill_with_conjugate_symmetry_(output, dim);
396 }
397 return output;
398 }
399
_fft_r2c_cufft_out(const Tensor & self,IntArrayRef dim,int64_t normalization,bool onesided,Tensor & out)400 Tensor& _fft_r2c_cufft_out(const Tensor& self, IntArrayRef dim,
401 int64_t normalization, bool onesided, Tensor& out) {
402 auto result = _fft_r2c_cufft(self, dim, static_cast<int64_t>(fft_norm_mode::none), /*onesided=*/true);
403 if (onesided) {
404 return _fft_apply_normalization_out(out, result, normalization, self.sizes(), dim);
405 }
406
407 resize_output(out, self.sizes());
408
409 auto last_dim = dim.back();
410 auto last_dim_halfsize = result.sizes()[last_dim];
411 auto out_slice = out.slice(last_dim, 0, last_dim_halfsize);
412 _fft_apply_normalization_out(out_slice, result, normalization, self.sizes(), dim);
413 at::native::_fft_fill_with_conjugate_symmetry_(out, dim);
414 return out;
415 }
416
417 // n-dimensional complex to real IFFT
_fft_c2r_cufft(const Tensor & self,IntArrayRef dim,int64_t normalization,int64_t lastdim)418 Tensor _fft_c2r_cufft(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t lastdim) {
419 TORCH_CHECK(self.is_complex());
420 auto in_sizes = self.sizes();
421 DimVector out_sizes(in_sizes.begin(), in_sizes.end());
422 out_sizes[dim.back()] = lastdim;
423
424 auto output = at::empty(out_sizes, self.options().dtype(c10::toRealValueType(self.scalar_type())));
425
426 if (use_optimized_cufft_path(dim)) {
427 Tensor temp;
428 // Complex to real FFTs may overwrite the input buffer, so must always clone (gh-34551)
429 temp = self.clone(MemoryFormat::Contiguous);
430 _exec_fft(output, temp, out_sizes, dim, /*forward=*/false);
431 } else {
432 // First complete any C2C transforms
433 Tensor temp;
434 if (dim.size() > 1) {
435 temp = _fft_c2c_cufft(
436 self, dim.slice(0, dim.size() - 1),
437 static_cast<int64_t>(fft_norm_mode::none), /*forward=*/false);
438 } else {
439 // Complex to real FFTs may overwrite the input buffer, so must always clone (gh-34551)
440 temp = self.clone(MemoryFormat::Contiguous);
441 }
442
443 // Finally, do a 1D C2R transform
444 // TODO: could transform up to 2 other dims in the same cuFFT operation
445 _exec_fft(output, temp, out_sizes, dim.back(), /*forward=*/false);
446 }
447
448 return _fft_apply_normalization(output, normalization, out_sizes, dim);
449 }
450
_fft_c2r_cufft_out(const Tensor & self,IntArrayRef dim,int64_t normalization,int64_t lastdim,Tensor & out)451 Tensor& _fft_c2r_cufft_out(const Tensor& self, IntArrayRef dim,
452 int64_t normalization, int64_t lastdim, Tensor& out) {
453 auto result = _fft_c2r_cufft(self, dim, static_cast<int64_t>(fft_norm_mode::none), lastdim);
454 return _fft_apply_normalization_out(out, result, normalization, result.sizes(), dim);
455 }
456
457 // n-dimensional complex to complex FFT/IFFT
_fft_c2c_cufft(const Tensor & self,IntArrayRef dim,int64_t normalization,bool forward)458 Tensor _fft_c2c_cufft(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) {
459 TORCH_CHECK(self.is_complex());
460 if (dim.empty()) {
461 return self.clone();
462 }
463
464 auto out_sizes = self.sizes();
465 auto output = at::empty(out_sizes, self.options());
466
467 // Perform any number of C2C transforms
468 DimVector sorted_dims(dim.begin(), dim.end());
469 auto working_tensor = self;
470 while (true) {
471 // Sort dimensions every time as _exec_fft re-strides the output
472 auto strides = working_tensor.strides();
473 std::sort(sorted_dims.begin(), sorted_dims.end(),
474 [&](int64_t a, int64_t b) { return strides[a] > strides[b]; });
475
476 const auto max_dims = std::min(static_cast<size_t>(cufft_max_ndim), sorted_dims.size());
477 auto first_dims = IntArrayRef(sorted_dims).slice(sorted_dims.size() - max_dims, max_dims);
478
479 _exec_fft(output, working_tensor, out_sizes, first_dims, forward);
480 sorted_dims.resize(sorted_dims.size() - max_dims);
481
482 if (sorted_dims.empty()) {
483 break;
484 }
485
486 if (working_tensor.is_same(self)) {
487 working_tensor = std::move(output);
488 output = at::empty(out_sizes, self.options());
489 } else {
490 std::swap(output, working_tensor);
491 }
492 }
493
494 return _fft_apply_normalization(output, normalization, out_sizes, dim);
495 }
496
_fft_c2c_cufft_out(const Tensor & self,IntArrayRef dim,int64_t normalization,bool forward,Tensor & out)497 Tensor& _fft_c2c_cufft_out(const Tensor& self, IntArrayRef dim,
498 int64_t normalization, bool forward, Tensor& out) {
499 auto result = _fft_c2c_cufft(self, dim, static_cast<int64_t>(fft_norm_mode::none), forward);
500 return _fft_apply_normalization_out(out, result, normalization, result.sizes(), dim);
501 }
502
503
504 } // at::native
505