xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkl/SpectralOps.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Config.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/native/Resize.h>
6 #include <ATen/native/SpectralOpsUtils.h>
7 #include <c10/util/accumulate.h>
8 #include <c10/util/irange.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/_fft_c2c_native.h>
15 #include <ATen/ops/_fft_c2r_native.h>
16 #include <ATen/ops/_fft_r2c_native.h>
17 #include <ATen/ops/empty.h>
18 #endif
19 
20 #if AT_MKL_ENABLED() || AT_POCKETFFT_ENABLED()
21 #include <ATen/Parallel.h>
22 #include <ATen/TensorIterator.h>
23 
24 namespace at { namespace native {
25 // In real-to-complex transform, MKL FFT only fills half of the values due to
26 // conjugate symmetry. See native/SpectralUtils.h for more details.
27 // The following structs are used to fill in the other half with symmetry in
28 // case of real-to-complex transform with onesided=False flag.
29 // See NOTE [ Fourier Transform Conjugate Symmetry ] in native/SpectralOpsUtils.h.
30 
31 template <typename scalar_t>
32 static __ubsan_ignore_undefined__  // UBSAN gives false positives on using negative indexes with a pointer
_fft_fill_with_conjugate_symmetry_slice(Range range,at::ArrayRef<bool> is_mirrored_dim,IntArrayRef signal_half_sizes,IntArrayRef in_strides,const scalar_t * in_ptr,IntArrayRef out_strides,scalar_t * out_ptr)33 void _fft_fill_with_conjugate_symmetry_slice(
34     Range range, at::ArrayRef<bool> is_mirrored_dim, IntArrayRef signal_half_sizes,
35     IntArrayRef in_strides, const scalar_t * in_ptr,
36     IntArrayRef out_strides, scalar_t * out_ptr) {
37   const auto ndim = signal_half_sizes.size();
38   DimVector iter_index(ndim, 0);
39 
40   // We explicitly loop over one row, then use this lambda to iterate over
41   // n-dimensions. This advances iter_index by one row, while updating in_ptr
42   // and out_ptr to point to the new row of data.
43   auto advance_index = [&] () __ubsan_ignore_undefined__ {
44     for (const auto i : c10::irange(1, iter_index.size())) {
45       if (iter_index[i] + 1 < signal_half_sizes[i]) {
46         ++iter_index[i];
47         in_ptr += in_strides[i];
48         if (is_mirrored_dim[i]) {
49           if (iter_index[i] == 1) {
50             out_ptr += (signal_half_sizes[i] - 1) * out_strides[i];
51           } else {
52             out_ptr -= out_strides[i];
53           }
54         } else {
55           out_ptr += out_strides[i];
56         }
57         return;
58       }
59 
60       in_ptr -= in_strides[i] * iter_index[i];
61       if (is_mirrored_dim[i]) {
62         out_ptr -= out_strides[i];
63       } else {
64         out_ptr -= out_strides[i] * iter_index[i];
65       }
66       iter_index[i] = 0;
67     }
68   };
69 
70   // The data slice we operate on may start part-way into the data
71   // Update iter_index and pointers to reference the start of the slice
72   if (range.begin > 0) {
73     iter_index[0] = range.begin % signal_half_sizes[0];
74     auto linear_idx = range.begin / signal_half_sizes[0];
75 
76     for (size_t i = 1; i < ndim && linear_idx > 0; ++i) {
77       iter_index[i] = linear_idx % signal_half_sizes[i];
78       linear_idx = linear_idx / signal_half_sizes[i];
79 
80       if (iter_index[i] > 0) {
81         in_ptr += in_strides[i] * iter_index[i];
82         if (is_mirrored_dim[i]) {
83           out_ptr += out_strides[i] * (signal_half_sizes[i] - iter_index[i]);
84         } else {
85           out_ptr += out_strides[i] * iter_index[i];
86         }
87       }
88     }
89   }
90 
91   auto numel_remaining = range.end - range.begin;
92 
93   if (is_mirrored_dim[0]) {
94     // Explicitly loop over a Hermitian mirrored dimension
95     if (iter_index[0] > 0) {
96       auto end = std::min(signal_half_sizes[0], iter_index[0] + numel_remaining);
97       for (const auto i : c10::irange(iter_index[0], end)) {
98         out_ptr[(signal_half_sizes[0] - i) * out_strides[0]] = std::conj(in_ptr[i * in_strides[0]]);
99       }
100       numel_remaining -= (end - iter_index[0]);
101       iter_index[0] = 0;
102       advance_index();
103     }
104 
105     while (numel_remaining > 0) {
106       auto end = std::min(signal_half_sizes[0], numel_remaining);
107       out_ptr[0] = std::conj(in_ptr[0]);
108       for (const auto i : c10::irange(1, end)) {
109         out_ptr[(signal_half_sizes[0] - i) * out_strides[0]] = std::conj(in_ptr[i * in_strides[0]]);
110       }
111       numel_remaining -= end;
112       advance_index();
113     }
114   } else {
115     // Explicit loop over a non-mirrored dimension, so just a simple conjugated copy
116     while (numel_remaining > 0) {
117       auto end = std::min(signal_half_sizes[0], iter_index[0] + numel_remaining);
118       for (int64_t i = iter_index[0]; i != end; ++i) {
119         out_ptr[i * out_strides[0]] = std::conj(in_ptr[i * in_strides[0]]);
120       }
121       numel_remaining -= (end - iter_index[0]);
122       iter_index[0] = 0;
123       advance_index();
124     }
125   }
126 }
127 
_fft_fill_with_conjugate_symmetry_cpu_(ScalarType dtype,IntArrayRef mirror_dims,IntArrayRef signal_half_sizes,IntArrayRef in_strides_bytes,const void * in_data,IntArrayRef out_strides_bytes,void * out_data)128 static void _fft_fill_with_conjugate_symmetry_cpu_(
129     ScalarType dtype, IntArrayRef mirror_dims, IntArrayRef signal_half_sizes,
130     IntArrayRef in_strides_bytes, const void * in_data,
131     IntArrayRef out_strides_bytes, void * out_data) {
132 
133   // Convert strides from bytes to elements
134   const auto element_size = scalarTypeToTypeMeta(dtype).itemsize();
135   const auto ndim = signal_half_sizes.size();
136   DimVector in_strides(ndim), out_strides(ndim);
137   for (const auto i : c10::irange(ndim)) {
138     TORCH_INTERNAL_ASSERT(in_strides_bytes[i] % element_size == 0);
139     in_strides[i] = in_strides_bytes[i] / element_size;
140     TORCH_INTERNAL_ASSERT(out_strides_bytes[i] % element_size == 0);
141     out_strides[i] = out_strides_bytes[i] / element_size;
142   }
143 
144   // Construct boolean mask for mirrored dims
145   c10::SmallVector<bool, at::kDimVectorStaticSize> is_mirrored_dim(ndim, false);
146   for (const auto& dim : mirror_dims) {
147     is_mirrored_dim[dim] = true;
148   }
149 
150   const auto numel = c10::multiply_integers(signal_half_sizes);
151   AT_DISPATCH_COMPLEX_TYPES(dtype, "_fft_fill_with_conjugate_symmetry", [&] {
152     at::parallel_for(0, numel, at::internal::GRAIN_SIZE,
153         [&](int64_t begin, int64_t end) {
154           _fft_fill_with_conjugate_symmetry_slice(
155               {begin, end}, is_mirrored_dim, signal_half_sizes,
156               in_strides, static_cast<const scalar_t*>(in_data),
157               out_strides, static_cast<scalar_t*>(out_data));
158         });
159   });
160 }
161 
162 // Register this one implementation for all cpu types instead of compiling multiple times
163 REGISTER_ARCH_DISPATCH(fft_fill_with_conjugate_symmetry_stub, DEFAULT, &_fft_fill_with_conjugate_symmetry_cpu_)
164 REGISTER_AVX2_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
165 REGISTER_AVX512_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
166 REGISTER_ZVECTOR_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
167 REGISTER_VSX_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
168 
169 // _out variants can be shared between PocketFFT and MKL
_fft_r2c_mkl_out(const Tensor & self,IntArrayRef dim,int64_t normalization,bool onesided,Tensor & out)170 Tensor& _fft_r2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
171                          bool onesided, Tensor& out) {
172   auto result = _fft_r2c_mkl(self, dim, normalization, /*onesided=*/true);
173   if (onesided) {
174     resize_output(out, result.sizes());
175     return out.copy_(result);
176   }
177 
178   resize_output(out, self.sizes());
179 
180   auto last_dim = dim.back();
181   auto last_dim_halfsize = result.sizes()[last_dim];
182   auto out_slice = out.slice(last_dim, 0, last_dim_halfsize);
183   out_slice.copy_(result);
184   at::native::_fft_fill_with_conjugate_symmetry_(out, dim);
185   return out;
186 }
187 
_fft_c2r_mkl_out(const Tensor & self,IntArrayRef dim,int64_t normalization,int64_t last_dim_size,Tensor & out)188 Tensor& _fft_c2r_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
189                          int64_t last_dim_size, Tensor& out) {
190   auto result = _fft_c2r_mkl(self, dim, normalization, last_dim_size);
191   resize_output(out, result.sizes());
192   return out.copy_(result);
193 }
194 
_fft_c2c_mkl_out(const Tensor & self,IntArrayRef dim,int64_t normalization,bool forward,Tensor & out)195 Tensor& _fft_c2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
196                          bool forward, Tensor& out) {
197   auto result = _fft_c2c_mkl(self, dim, normalization, forward);
198   resize_output(out, result.sizes());
199   return out.copy_(result);
200 }
201 
202 }} // namespace at::native
203 #endif /* AT_MKL_ENABLED() || AT_POCKETFFT_ENABLED() */
204 
205 #if AT_POCKETFFT_ENABLED()
206 #include <pocketfft_hdronly.h>
207 
208 namespace at { namespace native {
209 
210 namespace {
211 using namespace pocketfft;
212 
stride_from_tensor(const Tensor & t)213 stride_t stride_from_tensor(const Tensor& t) {
214   stride_t stride(t.strides().begin(), t.strides().end());
215   for(auto& s: stride) {
216    s *= t.element_size();
217   }
218   return stride;
219 }
220 
shape_from_tensor(const Tensor & t)221 inline shape_t shape_from_tensor(const Tensor& t) {
222   return shape_t(t.sizes().begin(), t.sizes().end());
223 }
224 
225 template<typename T>
tensor_cdata(Tensor & t)226 inline std::complex<T> *tensor_cdata(Tensor& t) {
227   return reinterpret_cast<std::complex<T>*>(t.data_ptr<c10::complex<T>>());
228 }
229 
230 template<typename T>
tensor_cdata(const Tensor & t)231 inline const std::complex<T> *tensor_cdata(const Tensor& t) {
232   return reinterpret_cast<const std::complex<T>*>(t.const_data_ptr<c10::complex<T>>());
233 }
234 
235 template<typename T>
compute_fct(int64_t size,int64_t normalization)236 T compute_fct(int64_t size, int64_t normalization) {
237   constexpr auto one = static_cast<T>(1);
238   switch (static_cast<fft_norm_mode>(normalization)) {
239     case fft_norm_mode::none: return one;
240     case fft_norm_mode::by_n: return one / static_cast<T>(size);
241     case fft_norm_mode::by_root_n: return one / std::sqrt(static_cast<T>(size));
242   }
243   AT_ERROR("Unsupported normalization type", normalization);
244 }
245 
246 template<typename T>
compute_fct(const Tensor & t,IntArrayRef dim,int64_t normalization)247 T compute_fct(const Tensor& t, IntArrayRef dim, int64_t normalization) {
248   if (static_cast<fft_norm_mode>(normalization) == fft_norm_mode::none) {
249     return static_cast<T>(1);
250   }
251   const auto& sizes = t.sizes();
252   int64_t n = 1;
253   for(auto idx: dim) {
254     n *= sizes[idx];
255   }
256   return compute_fct<T>(n, normalization);
257 }
258 
259 } // anonymous namespace
260 
_fft_c2r_mkl(const Tensor & self,IntArrayRef dim,int64_t normalization,int64_t last_dim_size)261 Tensor _fft_c2r_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) {
262   auto in_sizes = self.sizes();
263   DimVector out_sizes(in_sizes.begin(), in_sizes.end());
264   out_sizes[dim.back()] = last_dim_size;
265   auto out = at::empty(out_sizes, self.options().dtype(c10::toRealValueType(self.scalar_type())));
266   pocketfft::shape_t axes(dim.begin(), dim.end());
267   if (self.scalar_type() == kComplexFloat) {
268     pocketfft::c2r(shape_from_tensor(out), stride_from_tensor(self), stride_from_tensor(out), axes, false,
269                    tensor_cdata<float>(self),
270                    out.data_ptr<float>(), compute_fct<float>(out, dim, normalization));
271   } else {
272     pocketfft::c2r(shape_from_tensor(out), stride_from_tensor(self), stride_from_tensor(out), axes, false,
273                    tensor_cdata<double>(self),
274                    out.data_ptr<double>(), compute_fct<double>(out, dim, normalization));
275     }
276   return out;
277 }
278 
279 
_fft_r2c_mkl(const Tensor & self,IntArrayRef dim,int64_t normalization,bool onesided)280 Tensor _fft_r2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) {
281   TORCH_CHECK(self.is_floating_point());
282   auto input_sizes = self.sizes();
283   DimVector out_sizes(input_sizes.begin(), input_sizes.end());
284   auto last_dim = dim.back();
285   auto last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1;
286   if (onesided) {
287     out_sizes[last_dim] = last_dim_halfsize;
288   }
289 
290   auto out = at::empty(out_sizes, self.options().dtype(c10::toComplexType(self.scalar_type())));
291   pocketfft::shape_t axes(dim.begin(), dim.end());
292   if (self.scalar_type() == kFloat) {
293     pocketfft::r2c(shape_from_tensor(self), stride_from_tensor(self), stride_from_tensor(out), axes, true,
294                    self.const_data_ptr<float>(),
295                    tensor_cdata<float>(out), compute_fct<float>(self, dim, normalization));
296   } else {
297     pocketfft::r2c(shape_from_tensor(self), stride_from_tensor(self), stride_from_tensor(out), axes, true,
298                    self.const_data_ptr<double>(),
299                    tensor_cdata<double>(out), compute_fct<double>(self, dim, normalization));
300   }
301 
302   if (!onesided) {
303     at::native::_fft_fill_with_conjugate_symmetry_(out, dim);
304   }
305   return out;
306 }
307 
_fft_c2c_mkl(const Tensor & self,IntArrayRef dim,int64_t normalization,bool forward)308 Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) {
309   TORCH_CHECK(self.is_complex());
310   if (dim.empty()) {
311     return self.clone();
312   }
313 
314   auto out = at::empty(self.sizes(), self.options());
315   pocketfft::shape_t axes(dim.begin(), dim.end());
316   if (self.scalar_type() == kComplexFloat) {
317     pocketfft::c2c(shape_from_tensor(self), stride_from_tensor(self), stride_from_tensor(out), axes, forward,
318                    tensor_cdata<float>(self),
319                    tensor_cdata<float>(out), compute_fct<float>(self, dim, normalization));
320   } else {
321     pocketfft::c2c(shape_from_tensor(self), stride_from_tensor(self), stride_from_tensor(out), axes, forward,
322                    tensor_cdata<double>(self),
323                    tensor_cdata<double>(out), compute_fct<double>(self, dim, normalization));
324   }
325 
326   return out;
327 }
328 
329 }}
330 
331 #elif AT_MKL_ENABLED()
332 #include <ATen/Dispatch.h>
333 
334 #include <algorithm>
335 #include <numeric>
336 #include <cmath>
337 
338 #include <mkl_dfti.h>
339 #include <ATen/mkl/Exceptions.h>
340 #include <ATen/mkl/Descriptors.h>
341 #include <ATen/mkl/Limits.h>
342 
343 
344 namespace at { namespace native {
345 
346 // Constructs an mkl-fft plan descriptor representing the desired transform
347 // For complex types, strides are in units of 2 * element_size(dtype)
348 // sizes are for the full signal, including batch size and always two-sided
_plan_mkl_fft(IntArrayRef in_strides,IntArrayRef out_strides,IntArrayRef sizes,bool complex_input,bool complex_output,int64_t normalization,bool forward,ScalarType dtype)349 static DftiDescriptor _plan_mkl_fft(
350     IntArrayRef in_strides, IntArrayRef out_strides, IntArrayRef sizes,
351     bool complex_input, bool complex_output,
352     int64_t normalization, bool forward, ScalarType dtype) {
353   const int64_t signal_ndim = sizes.size() - 1;
354   TORCH_INTERNAL_ASSERT(in_strides.size() == sizes.size());
355   TORCH_INTERNAL_ASSERT(out_strides.size() == sizes.size());
356 
357   // precision
358   const DFTI_CONFIG_VALUE prec = [&]{
359     switch (c10::toRealValueType(dtype)) {
360       case ScalarType::Float: return DFTI_SINGLE;
361       case ScalarType::Double: return DFTI_DOUBLE;
362       default: TORCH_CHECK(false, "MKL FFT doesn't support tensors of type: ", dtype);
363     }
364   }();
365   // signal type
366   const DFTI_CONFIG_VALUE signal_type = [&]{
367     if (forward) {
368       return complex_input ? DFTI_COMPLEX : DFTI_REAL;
369     } else {
370       return complex_output ? DFTI_COMPLEX : DFTI_REAL;
371     }
372   }();
373   // create descriptor with signal size
374   using MklDimVector = c10::SmallVector<MKL_LONG, at::kDimVectorStaticSize>;
375   MklDimVector mkl_signal_sizes(sizes.begin() + 1, sizes.end());
376   DftiDescriptor descriptor;
377   descriptor.init(prec, signal_type, signal_ndim, mkl_signal_sizes.data());
378   // out of place FFT
379   MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE));
380   // batch mode
381   MKL_LONG mkl_batch_size = sizes[0];
382   MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, mkl_batch_size));
383 
384   // batch dim stride, i.e., dist between each data
385   TORCH_CHECK(in_strides[0] <= MKL_LONG_MAX && out_strides[0] <= MKL_LONG_MAX);
386   MKL_LONG idist = in_strides[0];
387   MKL_LONG odist = out_strides[0];
388   MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_INPUT_DISTANCE, idist));
389   MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_DISTANCE, odist));
390 
391   // signal strides
392   // first val is offset, set to zero (ignored)
393   MklDimVector mkl_istrides(1 + signal_ndim, 0), mkl_ostrides(1 + signal_ndim, 0);
394   for (int64_t i = 1; i <= signal_ndim; i++) {
395     TORCH_CHECK(in_strides[i] <= MKL_LONG_MAX && out_strides[i] <= MKL_LONG_MAX);
396     mkl_istrides[i] = in_strides[i];
397     mkl_ostrides[i] = out_strides[i];
398   }
399   MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_INPUT_STRIDES, mkl_istrides.data()));
400   MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_STRIDES, mkl_ostrides.data()));
401   // if conjugate domain of real is involved, set standard CCE storage type
402   // this will become default in MKL in future
403   if (!complex_input || !complex_output) {
404     MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX));
405   }
406   // rescale if requested
407   const auto norm = static_cast<fft_norm_mode>(normalization);
408   int64_t signal_numel = c10::multiply_integers(IntArrayRef(sizes.data() + 1, signal_ndim));
409   if (norm != fft_norm_mode::none) {
410     const double scale = (
411       (norm == fft_norm_mode::by_root_n) ?
412       1.0 / std::sqrt(static_cast<double>(signal_numel)) :
413       1.0 / static_cast<double>(signal_numel));
414     const auto scale_direction = forward ? DFTI_FORWARD_SCALE : DFTI_BACKWARD_SCALE;
415     MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), scale_direction, scale));
416   }
417 
418   if (sizeof(MKL_LONG) < sizeof(int64_t)) {
419     TORCH_CHECK(signal_numel <= MKL_LONG_MAX,
420                 "MKL FFT: input signal numel exceeds allowed range [1, ", MKL_LONG_MAX, "]");
421   }
422 
423   // finalize
424   MKL_DFTI_CHECK(DftiCommitDescriptor(descriptor.get()));
425 
426   return descriptor;
427 }
428 
429 // Execute a general fft operation (can be c2c, onesided r2c or onesided c2r)
_exec_fft(Tensor & out,const Tensor & self,IntArrayRef out_sizes,IntArrayRef dim,int64_t normalization,bool forward)430 static Tensor& _exec_fft(Tensor& out, const Tensor& self, IntArrayRef out_sizes,
431                          IntArrayRef dim, int64_t normalization, bool forward) {
432   const auto ndim = self.dim();
433   const int64_t signal_ndim = dim.size();
434   const auto batch_dims = ndim - signal_ndim;
435 
436   // Permute dimensions so batch dimensions come first, and in stride order
437   // This maximizes data locality when collapsing to a single batch dimension
438   DimVector dim_permute(ndim);
439   std::iota(dim_permute.begin(), dim_permute.end(), int64_t{0});
440 
441   c10::SmallVector<bool, kDimVectorStaticSize> is_transformed_dim(ndim);
442   for (const auto& d : dim) {
443     is_transformed_dim[d] = true;
444   }
445   auto batch_end = std::partition(dim_permute.begin(), dim_permute.end(),
446                                   [&](int64_t d) {return !is_transformed_dim[d]; });
447   auto self_strides = self.strides();
448   std::sort(dim_permute.begin(), batch_end,
449             [&](int64_t a, int64_t b) { return self_strides[a] > self_strides[b]; });
450   std::copy(dim.cbegin(), dim.cend(), batch_end);
451   auto input = self.permute(dim_permute);
452 
453   // Collapse batch dimensions into a single dimension
454   DimVector batched_sizes(signal_ndim + 1);
455   batched_sizes[0] = -1;
456   std::copy(input.sizes().cbegin() + batch_dims, input.sizes().cend(), batched_sizes.begin() + 1);
457   input = input.reshape(batched_sizes);
458 
459   const auto batch_size = input.sizes()[0];
460   DimVector signal_size(signal_ndim + 1);
461   signal_size[0] = batch_size;
462   for (const auto i : c10::irange(signal_ndim)) {
463     auto in_size = input.sizes()[i + 1];
464     auto out_size = out_sizes[dim[i]];
465     signal_size[i + 1] = std::max(in_size, out_size);
466     TORCH_INTERNAL_ASSERT(in_size == signal_size[i + 1] ||
467                           in_size == (signal_size[i + 1] / 2) + 1);
468     TORCH_INTERNAL_ASSERT(out_size == signal_size[i + 1] ||
469                           out_size == (signal_size[i + 1] / 2) + 1);
470   }
471 
472   batched_sizes[0] = batch_size;
473   DimVector batched_out_sizes(batched_sizes.begin(), batched_sizes.end());
474   for (const auto i : c10::irange(dim.size())) {
475     batched_out_sizes[i + 1] = out_sizes[dim[i]];
476   }
477 
478   const auto value_type = c10::toRealValueType(input.scalar_type());
479   out.resize_(batched_out_sizes, MemoryFormat::Contiguous);
480 
481   auto descriptor = _plan_mkl_fft(
482       input.strides(), out.strides(), signal_size, input.is_complex(),
483       out.is_complex(), normalization, forward, value_type);
484 
485   // run the FFT
486   if (forward) {
487     MKL_DFTI_CHECK(DftiComputeForward(descriptor.get(), const_cast<void*>(input.const_data_ptr()), out.data_ptr()));
488   } else {
489     MKL_DFTI_CHECK(DftiComputeBackward(descriptor.get(), const_cast<void*>(input.const_data_ptr()), out.data_ptr()));
490   }
491 
492   // Inplace reshaping to original batch shape and inverting the dimension permutation
493   DimVector out_strides(ndim);
494   int64_t batch_numel = 1;
495   for (int64_t i = batch_dims - 1; i >= 0; --i) {
496     out_strides[dim_permute[i]] = batch_numel * out.strides()[0];
497     batch_numel *= out_sizes[dim_permute[i]];
498   }
499   for (const auto i : c10::irange(batch_dims, ndim)) {
500     out_strides[dim_permute[i]] = out.strides()[1 + (i - batch_dims)];
501   }
502   out.as_strided_(out_sizes, out_strides, out.storage_offset());
503   return out;
504 }
505 
506 // Sort transform dimensions by input layout, for best performance
507 // exclude_last is for onesided transforms where the last dimension cannot be reordered
_sort_dims(const Tensor & self,IntArrayRef dim,bool exclude_last=false)508 static DimVector _sort_dims(const Tensor& self, IntArrayRef dim, bool exclude_last=false) {
509   DimVector sorted_dims(dim.begin(), dim.end());
510   auto self_strides = self.strides();
511   std::sort(sorted_dims.begin(), sorted_dims.end() - exclude_last,
512             [&](int64_t a, int64_t b) { return self_strides[a] > self_strides[b]; });
513   return sorted_dims;
514 }
515 
516 // n-dimensional complex to real IFFT
_fft_c2r_mkl(const Tensor & self,IntArrayRef dim,int64_t normalization,int64_t last_dim_size)517 Tensor _fft_c2r_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) {
518   TORCH_CHECK(self.is_complex());
519   // NOTE: Multi-dimensional C2R transforms don't agree with numpy in cases
520   // where the input isn't strictly Hermitian-symmetric. Instead, we use a
521   // multi-dim C2C transform followed by a 1D C2R transform.
522   //
523   // Such inputs are technically out of contract though, so maybe a disagreement
524   // is okay.
525   auto input = self;
526   if (dim.size() > 1) {
527     auto c2c_dims = dim.slice(0, dim.size() - 1);
528     input = _fft_c2c_mkl(self, c2c_dims, normalization, /*forward=*/false);
529     dim = dim.slice(dim.size() - 1);
530   }
531 
532   auto in_sizes = input.sizes();
533   DimVector out_sizes(in_sizes.begin(), in_sizes.end());
534   out_sizes[dim.back()] = last_dim_size;
535   auto out = at::empty(out_sizes, self.options().dtype(c10::toRealValueType(self.scalar_type())));
536   return _exec_fft(out, input, out_sizes, dim, normalization, /*forward=*/false);
537 }
538 
539 // n-dimensional real to complex FFT
_fft_r2c_mkl(const Tensor & self,IntArrayRef dim,int64_t normalization,bool onesided)540 Tensor _fft_r2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) {
541   TORCH_CHECK(self.is_floating_point());
542   auto input_sizes = self.sizes();
543   DimVector out_sizes(input_sizes.begin(), input_sizes.end());
544   auto last_dim = dim.back();
545   auto last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1;
546   if (onesided) {
547     out_sizes[last_dim] = last_dim_halfsize;
548   }
549 
550   auto sorted_dims = _sort_dims(self, dim, /*exclude_last=*/true);
551   auto out = at::empty(out_sizes, self.options().dtype(c10::toComplexType(self.scalar_type())));
552   _exec_fft(out, self, out_sizes, sorted_dims, normalization, /*forward=*/true);
553 
554   if (!onesided) {
555     at::native::_fft_fill_with_conjugate_symmetry_(out, dim);
556   }
557   return out;
558 }
559 
560 // n-dimensional complex to complex FFT/IFFT
_fft_c2c_mkl(const Tensor & self,IntArrayRef dim,int64_t normalization,bool forward)561 Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) {
562   TORCH_CHECK(self.is_complex());
563   if (dim.empty()) {
564     return self.clone();
565   }
566 
567   const auto sorted_dims = _sort_dims(self, dim);
568   auto out = at::empty(self.sizes(), self.options());
569   return _exec_fft(out, self, self.sizes(), sorted_dims, normalization, forward);
570 }
571 
572 }} // namespace at::native
573 
574 #else
575 
576 namespace at { namespace native {
577 REGISTER_NO_CPU_DISPATCH(fft_fill_with_conjugate_symmetry_stub);
578 
_fft_c2r_mkl(const Tensor & self,IntArrayRef dim,int64_t normalization,int64_t last_dim_size)579 Tensor _fft_c2r_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) {
580   AT_ERROR("fft: ATen not compiled with FFT support");
581 }
582 
_fft_r2c_mkl(const Tensor & self,IntArrayRef dim,int64_t normalization,bool onesided)583 Tensor _fft_r2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) {
584   AT_ERROR("fft: ATen not compiled with FFT support");
585 }
586 
_fft_c2c_mkl(const Tensor & self,IntArrayRef dim,int64_t normalization,bool forward)587 Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) {
588   AT_ERROR("fft: ATen not compiled with FFT support");
589 }
590 
_fft_r2c_mkl_out(const Tensor & self,IntArrayRef dim,int64_t normalization,bool onesided,Tensor & out)591 Tensor& _fft_r2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
592                          bool onesided, Tensor& out) {
593   AT_ERROR("fft: ATen not compiled with FFT support");
594 }
595 
_fft_c2r_mkl_out(const Tensor & self,IntArrayRef dim,int64_t normalization,int64_t last_dim_size,Tensor & out)596 Tensor& _fft_c2r_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
597                          int64_t last_dim_size, Tensor& out) {
598   AT_ERROR("fft: ATen not compiled with FFT support");
599 }
600 
_fft_c2c_mkl_out(const Tensor & self,IntArrayRef dim,int64_t normalization,bool forward,Tensor & out)601 Tensor& _fft_c2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
602                          bool forward, Tensor& out) {
603   AT_ERROR("fft: ATen not compiled with FFT support");
604 }
605 
606 }} // namespace at::native
607 #endif
608