xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/SpectralOps.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/cuda/CUDAContext.h>
4 #include <ATen/Config.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/ScalarOps.h>
7 #include <ATen/TensorIterator.h>
8 #include <ATen/detail/CUDAHooksInterface.h>
9 #include <ATen/native/Resize.h>
10 #include <ATen/native/SpectralOpsUtils.h>
11 #include <ATen/native/cuda/CuFFTUtils.h>
12 #include <ATen/native/cuda/CuFFTPlanCache.h>
13 #include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
14 #include <c10/util/irange.h>
15 
16 #ifndef AT_PER_OPERATOR_HEADERS
17 #include <ATen/Functions.h>
18 #include <ATen/NativeFunctions.h>
19 #else
20 #include <ATen/ops/_fft_c2c_native.h>
21 #include <ATen/ops/_fft_c2r_native.h>
22 #include <ATen/ops/_fft_r2c_native.h>
23 #include <ATen/ops/empty.h>
24 #include <ATen/ops/mul.h>
25 #endif
26 
27 #include <cufft.h>
28 #include <cufftXt.h>
29 
30 #include <cmath>
31 
32 
33 namespace at::native {
34 
35 using namespace at::native::detail;
36 
37 // Execute a pre-planned transform
exec_cufft_plan(const CuFFTConfig & config,void * in_data,void * out_data,bool forward)38 static void exec_cufft_plan(
39     const CuFFTConfig &config, void* in_data, void* out_data, bool forward) {
40   auto& plan = config.plan();
41   CUFFT_CHECK(cufftXtExec(plan, in_data, out_data,
42                           forward ? CUFFT_FORWARD : CUFFT_INVERSE));
43 }
44 
45 
46 // NOTE [ cuFFT Embedded Strides ]
47 //
48 // cuFFT supports a subset of arbitrary strides via their "advanced data layout"
49 // option (http://docs.nvidia.com/cuda/cufft/index.html#advanced-data-layout).
50 // Specifically, these are tensors that can be viewed as subtensors resulted
51 // from slicing a larger contiguous tensors. For such input tensors, let the
52 // sizes of the enclosing tensor be `inembed`, and we can have in 3d case:
53 //
54 //     input[x, y, z] = input[((x * inembed[1] + y) * inembed[2] + z)]
55 //
56 // Above is the simplified formula ignoring the batch dimension. In fact, the
57 // last dimension of the enclosing tensor doesn't have to be contiguous, i.e.,
58 // it can be greater than 1. Then one can set the base stride for the enclosing
59 // tensor with `istride`. Then we have
60 //
61 //     input[x, y, z] = input[((x * inembed[1] + y) * inembed[2] + z) * istride]
62 //
63 // For example, consider
64 //
65 //     enclosing = torch.zeros(6, 8, 10)  # contiguous
66 //     input = enclosing[:4, 2:6, 6:]
67 //     input.size()                       # [ 4,  4,  4]
68 //     input.stride()                     # [80, 10,  1]
69 //     # inembed = [6, 8, 10]
70 //     input[2, 1, 3] = input[((2 * 8) + 1) * 10 + 3]   # using above formula
71 //                    = input[173]
72 //                    = input[2 * 80 + 1 * 10 + 1 * 3]  # using strides directly
73 //
74 // Generally, the embedded strides can be computed as
75 //
76 //     embed[i] = stride[i - 1] / stride[i].
77 //
78 // Note that the value of embed[0] isn't used to compute indices and doesn't
79 // matter.
80 //
81 // Contrary to advanced data layout, simple layout means that *embeds have
82 // unit-strides. In particular, unit-stride refers to that the input and output
83 // tensors being contiguous, and that the strides at the innermost signal
84 // dimension being unit (1) w.r.t. the corresponding data type.
85 
86 // The cuFFT plan cache
87 // unique_ptr for nullability and to avoid reference invalidation on vector resize
88 static std::vector<std::unique_ptr<CuFFTParamsLRUCache>> plan_caches;
89 static std::mutex plan_caches_mutex;
90 
91 static inline
cufft_get_plan_cache(DeviceIndex device_index)92 CuFFTParamsLRUCache &cufft_get_plan_cache(DeviceIndex device_index) {
93   std::lock_guard<std::mutex> guard(plan_caches_mutex);
94 
95   AT_ASSERT(device_index >= 0);
96 
97   if (device_index >= static_cast<int64_t>(plan_caches.size())) {
98     plan_caches.resize(device_index + 1);
99   }
100 
101   if (!plan_caches[device_index]) {
102     plan_caches[device_index] = std::make_unique<CuFFTParamsLRUCache>();
103   }
104 
105   return *plan_caches[device_index];
106 }
107 
108 
109 namespace detail {
110 
cufft_get_plan_cache_max_size_impl(DeviceIndex device_index)111 int64_t cufft_get_plan_cache_max_size_impl(DeviceIndex device_index) {
112   TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
113     "cufft_get_plan_cache_max_size: expected 0 <= device_index < ",
114     at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
115     device_index);
116   return cufft_get_plan_cache(device_index).max_size();
117 }
118 
cufft_set_plan_cache_max_size_impl(DeviceIndex device_index,int64_t max_size)119 void cufft_set_plan_cache_max_size_impl(DeviceIndex device_index, int64_t max_size) {
120   TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
121     "cufft_set_plan_cache_max_size: expected 0 <= device_index < ",
122     at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
123     device_index);
124   return cufft_get_plan_cache(device_index).resize(max_size);
125 }
126 
cufft_get_plan_cache_size_impl(DeviceIndex device_index)127 int64_t cufft_get_plan_cache_size_impl(DeviceIndex device_index) {
128   TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
129     "cufft_get_plan_cache_size: expected 0 <= device_index < ",
130     at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
131     device_index);
132   return cufft_get_plan_cache(device_index).size();
133 }
134 
cufft_clear_plan_cache_impl(DeviceIndex device_index)135 void cufft_clear_plan_cache_impl(DeviceIndex device_index) {
136   TORCH_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
137     "cufft_clear_plan_cache: expected 0 <= device_index < ",
138     at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
139     device_index);
140   return cufft_get_plan_cache(device_index).clear();
141 }
142 
143 } // namespace at::native::detail
144 
145 namespace {
146 constexpr int64_t cufft_max_ndim = 3;
147 
148 // "Large" here means a prime factor not special-cased by cuFFT
149 // Ref: https://docs.nvidia.com/cuda/cufft/index.html#accuracy-and-performance
has_large_prime_factor(int64_t n)150 bool has_large_prime_factor(int64_t n) {
151   constexpr int64_t first_large_prime = 11;
152   const std::array<int64_t, 4> prime_radices{{2, 3, 5, 7}};
153   for (auto prime : prime_radices) {
154     if (n < first_large_prime) {
155         return false;
156     }
157 
158     while (n % prime == 0) {
159       n /= prime;
160     }
161   }
162   return n != 1;
163 }
164 
165 // Execute a general fft operation (can be c2c, onesided r2c or onesided c2r)
_exec_fft(Tensor & out,const Tensor & self,IntArrayRef out_sizes,IntArrayRef dim,bool forward)166 static const Tensor& _exec_fft(Tensor& out, const Tensor& self, IntArrayRef out_sizes,
167                          IntArrayRef dim, bool forward) {
168   const auto ndim = self.dim();
169   const int64_t signal_ndim = dim.size();
170   const auto batch_dims = ndim - signal_ndim;
171 
172   // Permute dimensions so batch dimensions come first, and in stride order
173   // This maximizes data locality when collapsing to a single batch dimension
174   DimVector dim_permute(ndim);
175   std::iota(dim_permute.begin(), dim_permute.end(), int64_t{0});
176 
177   c10::SmallVector<bool, kDimVectorStaticSize> is_transformed_dim(ndim);
178   for (const auto& d : dim) {
179     is_transformed_dim[d] = true;
180   }
181   auto batch_end = std::partition(dim_permute.begin(), dim_permute.end(),
182                                   [&](int64_t d) {return !is_transformed_dim[d]; });
183   auto self_strides = self.strides();
184   std::sort(dim_permute.begin(), batch_end,
185             [&](int64_t a, int64_t b) { return self_strides[a] > self_strides[b]; });
186   std::copy(dim.cbegin(), dim.cend(), batch_end);
187   auto input = self.permute(dim_permute);
188 
189   // Collapse batch dimensions into a single dimension
190   DimVector batched_sizes(signal_ndim + 1);
191   batched_sizes[0] = -1;
192   std::copy(input.sizes().cbegin() + batch_dims, input.sizes().cend(), batched_sizes.begin() + 1);
193   input = input.reshape(batched_sizes);
194 
195   const auto batch_size = input.sizes()[0];
196   DimVector signal_size(signal_ndim + 1);
197   signal_size[0] = batch_size;
198   for (const auto i : c10::irange(signal_ndim)) {
199     auto in_size = input.sizes()[i + 1];
200     auto out_size = out_sizes[dim[i]];
201     signal_size[i + 1] = std::max(in_size, out_size);
202     TORCH_INTERNAL_ASSERT(in_size == signal_size[i + 1] ||
203                           in_size == (signal_size[i + 1] / 2) + 1);
204     TORCH_INTERNAL_ASSERT(out_size == signal_size[i + 1] ||
205                           out_size == (signal_size[i + 1] / 2) + 1);
206   }
207 
208   batched_sizes[0] = batch_size;
209   DimVector batched_out_sizes(batched_sizes.begin(), batched_sizes.end());
210   for (const auto i : c10::irange(dim.size())) {
211     batched_out_sizes[i + 1] = out_sizes[dim[i]];
212   }
213   out.resize_(batched_out_sizes, MemoryFormat::Contiguous);
214 
215   // Create the transform plan (either from cache or locally)
216   const auto value_type = c10::toRealValueType(input.scalar_type());
217   auto fft_type = GetCuFFTTransformType(input.is_complex(), out.is_complex());
218   CuFFTParams Params(input.strides(), out.strides(), signal_size, fft_type, value_type);
219   CuFFTParamsLRUCache& plan_cache = cufft_get_plan_cache(input.device().index());
220   std::unique_lock<std::mutex> guard(plan_cache.mutex, std::defer_lock);
221   std::optional<CuFFTConfig> uncached_plan;
222   const CuFFTConfig * config = nullptr;
223 
224   // Workaround for gh-63152, gh-58724
225   // Bluestein plans in CUDA 11.1 (cufft 10.3) cannot be re-used
226   // Bluestein's algorithm is only used when a size has large prime factors,
227   // sizes with only small prime factors can still be cached
228   bool use_caching = true;
229 #ifdef CUFFT_VERSION
230   if constexpr (10300 <= CUFFT_VERSION && CUFFT_VERSION < 10400) {
231     // Only cache plans for transforms with small prime factors
232     use_caching = std::none_of(
233         signal_size.begin() + 1, signal_size.end(), [](int64_t dim_size) {
234       return has_large_prime_factor(dim_size);
235     });
236   }
237 #endif
238 
239   if (use_caching && plan_cache.max_size() > 0) {
240     guard.lock();
241     if (plan_cache.max_size() > 0) {  // check again after acquiring the lock
242       config = &plan_cache.lookup(Params);
243     }
244   }
245 
246   if (config == nullptr) {
247     uncached_plan.emplace(Params);
248     config = &uncached_plan.value();
249   }
250 
251   auto & plan = config->plan();
252 
253   if (config->should_clone_input()) {
254     input = input.clone(MemoryFormat::Contiguous);
255   }
256 
257   // prepare cufft for execution
258   CUFFT_CHECK(cufftSetStream(plan, at::cuda::getCurrentCUDAStream()));
259   auto workspace = at::empty({ config->workspace_size() }, at::device(at::kCUDA).dtype(at::kByte));
260   CUFFT_CHECK(cufftSetWorkArea(plan, workspace.mutable_data_ptr()));
261 
262   // execute transform plan
263 #if !defined(USE_ROCM)
264   CUcontext pctx = nullptr;
265   at::globalContext().getNVRTC().cuCtxGetCurrent(&pctx);
266   if (C10_UNLIKELY(!pctx)) {
267     // workaround for corner case where a primary context exists but is not
268     // the current context
269     TORCH_WARN_ONCE("Attempting to run cuFFT, but there was no current CUDA context! Attempting to set the primary context...");
270     at::globalContext().getNVRTC().cuDevicePrimaryCtxRetain(&pctx, 0);
271     at::globalContext().getNVRTC().cuCtxSetCurrent(pctx);
272   }
273 #endif /* !defined(USE_ROCM) */
274   exec_cufft_plan(*config, const_cast<void*>(input.const_data_ptr()), out.data_ptr(), forward);
275 
276   // Inplace reshaping to original batch shape and inverting the dimension permutation
277   DimVector out_strides(ndim);
278   int64_t batch_numel = 1;
279   for (int64_t i = batch_dims - 1; i >= 0; --i) {
280     out_strides[dim_permute[i]] = batch_numel * out.strides()[0];
281     batch_numel *= out_sizes[dim_permute[i]];
282   }
283   for (const auto i : c10::irange(batch_dims, ndim)) {
284     out_strides[dim_permute[i]] = out.strides()[1 + (i - batch_dims)];
285   }
286   return out.as_strided_(out_sizes, out_strides, out.storage_offset());
287 }
288 
289 // Calculates the normalization constant and applies it in-place to self
290 // sizes is the sizes of a twosided tensor and dims are all transformed dims
_fft_normalization_scale(int64_t normalization,IntArrayRef sizes,IntArrayRef dims)291 double _fft_normalization_scale(int64_t normalization, IntArrayRef sizes, IntArrayRef dims) {
292   auto norm = static_cast<fft_norm_mode>(normalization);
293   if (norm == fft_norm_mode::none) {
294     return 1.0;
295   }
296 
297   int64_t signal_numel = 1;
298   for (auto dim : dims) {
299     signal_numel *= sizes[dim];
300   }
301   const double scale_denom = (norm == fft_norm_mode::by_root_n) ?
302     std::sqrt(signal_numel) : static_cast<double>(signal_numel);
303   return 1.0 / scale_denom;
304 }
305 
_fft_apply_normalization(const Tensor & self,int64_t normalization,IntArrayRef sizes,IntArrayRef dims)306 const Tensor& _fft_apply_normalization(const Tensor& self, int64_t normalization, IntArrayRef sizes, IntArrayRef dims) {
307   auto scale = _fft_normalization_scale(normalization, sizes, dims);
308   return (scale == 1.0) ? self : self.mul_(scale);
309 }
310 
_fft_apply_normalization_out(Tensor & out,const Tensor & self,int64_t normalization,IntArrayRef sizes,IntArrayRef dims)311 Tensor& _fft_apply_normalization_out(Tensor& out, const Tensor& self, int64_t normalization, IntArrayRef sizes, IntArrayRef dims) {
312   auto scale = _fft_normalization_scale(normalization, sizes, dims);
313   return at::mul_out(out, self, c10::scalar_to_tensor(scale));
314 }
315 
316 }  // namespace (anonymous)
317 
318 // Use the optimized path to perform single R2C or C2R if transformation dim is supported by cuFFT
use_optimized_cufft_path(IntArrayRef dim)319 bool use_optimized_cufft_path(IntArrayRef dim) {
320   // For performance reason, when dim starts with (0, 1), do not use the optimized path.
321   if (dim.size() > cufft_max_ndim || (
322     dim.size() >= 2 && dim[0] == 0 && dim[1] == 1
323   )) {
324     return false;
325   } else {
326     return true;
327   }
328 }
329 
330 // n-dimensional real to complex FFT
_fft_r2c_cufft(const Tensor & self,IntArrayRef dim,int64_t normalization,bool onesided)331 Tensor _fft_r2c_cufft(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) {
332   TORCH_CHECK(self.is_floating_point());
333   auto input_sizes = self.sizes();
334   DimVector onesided_sizes(input_sizes.begin(), input_sizes.end());
335   auto last_dim = dim.back();
336   auto last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1;
337   onesided_sizes[last_dim] = last_dim_halfsize;
338   IntArrayRef out_sizes = onesided ? onesided_sizes : input_sizes;
339 
340   const auto out_options = self.options().dtype(c10::toComplexType(self.scalar_type()));
341   auto output = at::empty(out_sizes, out_options);
342 
343   // CuFFT requires real input to be over-aligned, as if it were complex
344   const auto complex_size = 2 * self.element_size();
345   const bool complex_aligned = (
346       reinterpret_cast<std::uintptr_t>(self.const_data_ptr()) % complex_size == 0);
347   auto working_tensor = self;
348   if (!complex_aligned) {
349     working_tensor = self.movedim(last_dim, -1)
350                          .clone(MemoryFormat::Contiguous)
351                          .movedim(-1, last_dim);
352   }
353 
354   if (use_optimized_cufft_path(dim)) {
355     _exec_fft(output, working_tensor, out_sizes, dim, /*forward=*/true);
356   } else {
357     // First do the R2C transform on the last dimension
358     {
359       auto target_sizes = dim.size() == 1 ? out_sizes : onesided_sizes;
360       _exec_fft(output, working_tensor, target_sizes, last_dim, /*forward=*/true);
361       if (dim.size() > 1) {
362         working_tensor = at::empty(out_sizes, out_options);
363       }
364     }
365 
366     // Then any remaining C2C transforms
367     DimVector sorted_dims(dim.begin(), dim.end() - 1);
368     while (!sorted_dims.empty()) {
369       std::swap(output, working_tensor);
370 
371       // Resort dimensions every time as _exec_fft re-strides the output
372       auto strides = working_tensor.strides();
373       std::sort(sorted_dims.begin(), sorted_dims.end(),
374                 [&](int64_t a, int64_t b) { return strides[a] > strides[b]; });
375 
376       const auto max_dims = std::min(static_cast<size_t>(cufft_max_ndim), sorted_dims.size());
377       auto last_dims = IntArrayRef(sorted_dims).slice(sorted_dims.size() - max_dims, max_dims);
378 
379       // Intermediate results are always onesided
380       _exec_fft(output, working_tensor, onesided_sizes, last_dims, /*forward=*/true);
381       sorted_dims.resize(sorted_dims.size() - max_dims);
382     }
383   }
384 
385   // Only need to normalize the onesided slice since data in the other half is overwritten
386   auto out_slice = output.slice(last_dim, 0, last_dim_halfsize);
387   _fft_apply_normalization(out_slice, normalization, input_sizes, dim);
388 
389   if (!onesided) {
390     if (output.sizes()[last_dim] != out_sizes[last_dim]) {
391       working_tensor.resize_(out_sizes, MemoryFormat::Contiguous);
392       working_tensor.slice(last_dim, 0, last_dim_halfsize).copy_(output);
393       output = std::move(working_tensor);
394     }
395     at::native::_fft_fill_with_conjugate_symmetry_(output, dim);
396   }
397   return output;
398 }
399 
_fft_r2c_cufft_out(const Tensor & self,IntArrayRef dim,int64_t normalization,bool onesided,Tensor & out)400 Tensor& _fft_r2c_cufft_out(const Tensor& self, IntArrayRef dim,
401                            int64_t normalization, bool onesided, Tensor& out) {
402   auto result = _fft_r2c_cufft(self, dim, static_cast<int64_t>(fft_norm_mode::none), /*onesided=*/true);
403   if (onesided) {
404     return _fft_apply_normalization_out(out, result, normalization, self.sizes(), dim);
405   }
406 
407   resize_output(out, self.sizes());
408 
409   auto last_dim = dim.back();
410   auto last_dim_halfsize = result.sizes()[last_dim];
411   auto out_slice = out.slice(last_dim, 0, last_dim_halfsize);
412   _fft_apply_normalization_out(out_slice, result, normalization, self.sizes(), dim);
413   at::native::_fft_fill_with_conjugate_symmetry_(out, dim);
414   return out;
415 }
416 
417 // n-dimensional complex to real IFFT
_fft_c2r_cufft(const Tensor & self,IntArrayRef dim,int64_t normalization,int64_t lastdim)418 Tensor _fft_c2r_cufft(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t lastdim) {
419   TORCH_CHECK(self.is_complex());
420   auto in_sizes = self.sizes();
421   DimVector out_sizes(in_sizes.begin(), in_sizes.end());
422   out_sizes[dim.back()] = lastdim;
423 
424   auto output = at::empty(out_sizes, self.options().dtype(c10::toRealValueType(self.scalar_type())));
425 
426   if (use_optimized_cufft_path(dim)) {
427     Tensor temp;
428     // Complex to real FFTs may overwrite the input buffer, so must always clone (gh-34551)
429     temp = self.clone(MemoryFormat::Contiguous);
430     _exec_fft(output, temp, out_sizes, dim, /*forward=*/false);
431   } else {
432     // First complete any C2C transforms
433     Tensor temp;
434     if (dim.size() > 1) {
435       temp = _fft_c2c_cufft(
436           self, dim.slice(0, dim.size() - 1),
437           static_cast<int64_t>(fft_norm_mode::none), /*forward=*/false);
438     } else {
439       // Complex to real FFTs may overwrite the input buffer, so must always clone (gh-34551)
440       temp = self.clone(MemoryFormat::Contiguous);
441     }
442 
443     // Finally, do a 1D C2R transform
444     // TODO: could transform up to 2 other dims in the same cuFFT operation
445     _exec_fft(output, temp, out_sizes, dim.back(), /*forward=*/false);
446   }
447 
448   return _fft_apply_normalization(output, normalization, out_sizes, dim);
449 }
450 
_fft_c2r_cufft_out(const Tensor & self,IntArrayRef dim,int64_t normalization,int64_t lastdim,Tensor & out)451 Tensor& _fft_c2r_cufft_out(const Tensor& self, IntArrayRef dim,
452                            int64_t normalization, int64_t lastdim, Tensor& out) {
453   auto result = _fft_c2r_cufft(self, dim, static_cast<int64_t>(fft_norm_mode::none), lastdim);
454   return _fft_apply_normalization_out(out, result, normalization, result.sizes(), dim);
455 }
456 
457 // n-dimensional complex to complex FFT/IFFT
_fft_c2c_cufft(const Tensor & self,IntArrayRef dim,int64_t normalization,bool forward)458 Tensor _fft_c2c_cufft(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) {
459   TORCH_CHECK(self.is_complex());
460   if (dim.empty()) {
461     return self.clone();
462   }
463 
464   auto out_sizes = self.sizes();
465   auto output = at::empty(out_sizes, self.options());
466 
467   // Perform any number of C2C transforms
468   DimVector sorted_dims(dim.begin(), dim.end());
469   auto working_tensor = self;
470   while (true) {
471     // Sort dimensions every time as _exec_fft re-strides the output
472     auto strides = working_tensor.strides();
473     std::sort(sorted_dims.begin(), sorted_dims.end(),
474               [&](int64_t a, int64_t b) { return strides[a] > strides[b]; });
475 
476     const auto max_dims = std::min(static_cast<size_t>(cufft_max_ndim), sorted_dims.size());
477     auto first_dims = IntArrayRef(sorted_dims).slice(sorted_dims.size() - max_dims, max_dims);
478 
479     _exec_fft(output, working_tensor, out_sizes, first_dims, forward);
480     sorted_dims.resize(sorted_dims.size() - max_dims);
481 
482     if (sorted_dims.empty()) {
483       break;
484     }
485 
486     if (working_tensor.is_same(self)) {
487       working_tensor = std::move(output);
488       output = at::empty(out_sizes, self.options());
489     } else {
490       std::swap(output, working_tensor);
491     }
492   }
493 
494   return _fft_apply_normalization(output, normalization, out_sizes, dim);
495 }
496 
_fft_c2c_cufft_out(const Tensor & self,IntArrayRef dim,int64_t normalization,bool forward,Tensor & out)497 Tensor& _fft_c2c_cufft_out(const Tensor& self, IntArrayRef dim,
498                            int64_t normalization, bool forward, Tensor& out) {
499   auto result = _fft_c2c_cufft(self, dim, static_cast<int64_t>(fft_norm_mode::none), forward);
500   return _fft_apply_normalization_out(out, result, normalization, result.sizes(), dim);
501 }
502 
503 
504 } // at::native
505