1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Config.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/native/Resize.h>
6 #include <ATen/native/SpectralOpsUtils.h>
7 #include <c10/util/accumulate.h>
8 #include <c10/util/irange.h>
9
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/_fft_c2c_native.h>
15 #include <ATen/ops/_fft_c2r_native.h>
16 #include <ATen/ops/_fft_r2c_native.h>
17 #include <ATen/ops/empty.h>
18 #endif
19
20 #if AT_MKL_ENABLED() || AT_POCKETFFT_ENABLED()
21 #include <ATen/Parallel.h>
22 #include <ATen/TensorIterator.h>
23
24 namespace at { namespace native {
25 // In real-to-complex transform, MKL FFT only fills half of the values due to
26 // conjugate symmetry. See native/SpectralUtils.h for more details.
27 // The following structs are used to fill in the other half with symmetry in
28 // case of real-to-complex transform with onesided=False flag.
29 // See NOTE [ Fourier Transform Conjugate Symmetry ] in native/SpectralOpsUtils.h.
30
31 template <typename scalar_t>
32 static __ubsan_ignore_undefined__ // UBSAN gives false positives on using negative indexes with a pointer
_fft_fill_with_conjugate_symmetry_slice(Range range,at::ArrayRef<bool> is_mirrored_dim,IntArrayRef signal_half_sizes,IntArrayRef in_strides,const scalar_t * in_ptr,IntArrayRef out_strides,scalar_t * out_ptr)33 void _fft_fill_with_conjugate_symmetry_slice(
34 Range range, at::ArrayRef<bool> is_mirrored_dim, IntArrayRef signal_half_sizes,
35 IntArrayRef in_strides, const scalar_t * in_ptr,
36 IntArrayRef out_strides, scalar_t * out_ptr) {
37 const auto ndim = signal_half_sizes.size();
38 DimVector iter_index(ndim, 0);
39
40 // We explicitly loop over one row, then use this lambda to iterate over
41 // n-dimensions. This advances iter_index by one row, while updating in_ptr
42 // and out_ptr to point to the new row of data.
43 auto advance_index = [&] () __ubsan_ignore_undefined__ {
44 for (const auto i : c10::irange(1, iter_index.size())) {
45 if (iter_index[i] + 1 < signal_half_sizes[i]) {
46 ++iter_index[i];
47 in_ptr += in_strides[i];
48 if (is_mirrored_dim[i]) {
49 if (iter_index[i] == 1) {
50 out_ptr += (signal_half_sizes[i] - 1) * out_strides[i];
51 } else {
52 out_ptr -= out_strides[i];
53 }
54 } else {
55 out_ptr += out_strides[i];
56 }
57 return;
58 }
59
60 in_ptr -= in_strides[i] * iter_index[i];
61 if (is_mirrored_dim[i]) {
62 out_ptr -= out_strides[i];
63 } else {
64 out_ptr -= out_strides[i] * iter_index[i];
65 }
66 iter_index[i] = 0;
67 }
68 };
69
70 // The data slice we operate on may start part-way into the data
71 // Update iter_index and pointers to reference the start of the slice
72 if (range.begin > 0) {
73 iter_index[0] = range.begin % signal_half_sizes[0];
74 auto linear_idx = range.begin / signal_half_sizes[0];
75
76 for (size_t i = 1; i < ndim && linear_idx > 0; ++i) {
77 iter_index[i] = linear_idx % signal_half_sizes[i];
78 linear_idx = linear_idx / signal_half_sizes[i];
79
80 if (iter_index[i] > 0) {
81 in_ptr += in_strides[i] * iter_index[i];
82 if (is_mirrored_dim[i]) {
83 out_ptr += out_strides[i] * (signal_half_sizes[i] - iter_index[i]);
84 } else {
85 out_ptr += out_strides[i] * iter_index[i];
86 }
87 }
88 }
89 }
90
91 auto numel_remaining = range.end - range.begin;
92
93 if (is_mirrored_dim[0]) {
94 // Explicitly loop over a Hermitian mirrored dimension
95 if (iter_index[0] > 0) {
96 auto end = std::min(signal_half_sizes[0], iter_index[0] + numel_remaining);
97 for (const auto i : c10::irange(iter_index[0], end)) {
98 out_ptr[(signal_half_sizes[0] - i) * out_strides[0]] = std::conj(in_ptr[i * in_strides[0]]);
99 }
100 numel_remaining -= (end - iter_index[0]);
101 iter_index[0] = 0;
102 advance_index();
103 }
104
105 while (numel_remaining > 0) {
106 auto end = std::min(signal_half_sizes[0], numel_remaining);
107 out_ptr[0] = std::conj(in_ptr[0]);
108 for (const auto i : c10::irange(1, end)) {
109 out_ptr[(signal_half_sizes[0] - i) * out_strides[0]] = std::conj(in_ptr[i * in_strides[0]]);
110 }
111 numel_remaining -= end;
112 advance_index();
113 }
114 } else {
115 // Explicit loop over a non-mirrored dimension, so just a simple conjugated copy
116 while (numel_remaining > 0) {
117 auto end = std::min(signal_half_sizes[0], iter_index[0] + numel_remaining);
118 for (int64_t i = iter_index[0]; i != end; ++i) {
119 out_ptr[i * out_strides[0]] = std::conj(in_ptr[i * in_strides[0]]);
120 }
121 numel_remaining -= (end - iter_index[0]);
122 iter_index[0] = 0;
123 advance_index();
124 }
125 }
126 }
127
_fft_fill_with_conjugate_symmetry_cpu_(ScalarType dtype,IntArrayRef mirror_dims,IntArrayRef signal_half_sizes,IntArrayRef in_strides_bytes,const void * in_data,IntArrayRef out_strides_bytes,void * out_data)128 static void _fft_fill_with_conjugate_symmetry_cpu_(
129 ScalarType dtype, IntArrayRef mirror_dims, IntArrayRef signal_half_sizes,
130 IntArrayRef in_strides_bytes, const void * in_data,
131 IntArrayRef out_strides_bytes, void * out_data) {
132
133 // Convert strides from bytes to elements
134 const auto element_size = scalarTypeToTypeMeta(dtype).itemsize();
135 const auto ndim = signal_half_sizes.size();
136 DimVector in_strides(ndim), out_strides(ndim);
137 for (const auto i : c10::irange(ndim)) {
138 TORCH_INTERNAL_ASSERT(in_strides_bytes[i] % element_size == 0);
139 in_strides[i] = in_strides_bytes[i] / element_size;
140 TORCH_INTERNAL_ASSERT(out_strides_bytes[i] % element_size == 0);
141 out_strides[i] = out_strides_bytes[i] / element_size;
142 }
143
144 // Construct boolean mask for mirrored dims
145 c10::SmallVector<bool, at::kDimVectorStaticSize> is_mirrored_dim(ndim, false);
146 for (const auto& dim : mirror_dims) {
147 is_mirrored_dim[dim] = true;
148 }
149
150 const auto numel = c10::multiply_integers(signal_half_sizes);
151 AT_DISPATCH_COMPLEX_TYPES(dtype, "_fft_fill_with_conjugate_symmetry", [&] {
152 at::parallel_for(0, numel, at::internal::GRAIN_SIZE,
153 [&](int64_t begin, int64_t end) {
154 _fft_fill_with_conjugate_symmetry_slice(
155 {begin, end}, is_mirrored_dim, signal_half_sizes,
156 in_strides, static_cast<const scalar_t*>(in_data),
157 out_strides, static_cast<scalar_t*>(out_data));
158 });
159 });
160 }
161
162 // Register this one implementation for all cpu types instead of compiling multiple times
163 REGISTER_ARCH_DISPATCH(fft_fill_with_conjugate_symmetry_stub, DEFAULT, &_fft_fill_with_conjugate_symmetry_cpu_)
164 REGISTER_AVX2_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
165 REGISTER_AVX512_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
166 REGISTER_ZVECTOR_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
167 REGISTER_VSX_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
168
169 // _out variants can be shared between PocketFFT and MKL
_fft_r2c_mkl_out(const Tensor & self,IntArrayRef dim,int64_t normalization,bool onesided,Tensor & out)170 Tensor& _fft_r2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
171 bool onesided, Tensor& out) {
172 auto result = _fft_r2c_mkl(self, dim, normalization, /*onesided=*/true);
173 if (onesided) {
174 resize_output(out, result.sizes());
175 return out.copy_(result);
176 }
177
178 resize_output(out, self.sizes());
179
180 auto last_dim = dim.back();
181 auto last_dim_halfsize = result.sizes()[last_dim];
182 auto out_slice = out.slice(last_dim, 0, last_dim_halfsize);
183 out_slice.copy_(result);
184 at::native::_fft_fill_with_conjugate_symmetry_(out, dim);
185 return out;
186 }
187
_fft_c2r_mkl_out(const Tensor & self,IntArrayRef dim,int64_t normalization,int64_t last_dim_size,Tensor & out)188 Tensor& _fft_c2r_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
189 int64_t last_dim_size, Tensor& out) {
190 auto result = _fft_c2r_mkl(self, dim, normalization, last_dim_size);
191 resize_output(out, result.sizes());
192 return out.copy_(result);
193 }
194
_fft_c2c_mkl_out(const Tensor & self,IntArrayRef dim,int64_t normalization,bool forward,Tensor & out)195 Tensor& _fft_c2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
196 bool forward, Tensor& out) {
197 auto result = _fft_c2c_mkl(self, dim, normalization, forward);
198 resize_output(out, result.sizes());
199 return out.copy_(result);
200 }
201
202 }} // namespace at::native
203 #endif /* AT_MKL_ENABLED() || AT_POCKETFFT_ENABLED() */
204
205 #if AT_POCKETFFT_ENABLED()
206 #include <pocketfft_hdronly.h>
207
208 namespace at { namespace native {
209
210 namespace {
211 using namespace pocketfft;
212
stride_from_tensor(const Tensor & t)213 stride_t stride_from_tensor(const Tensor& t) {
214 stride_t stride(t.strides().begin(), t.strides().end());
215 for(auto& s: stride) {
216 s *= t.element_size();
217 }
218 return stride;
219 }
220
shape_from_tensor(const Tensor & t)221 inline shape_t shape_from_tensor(const Tensor& t) {
222 return shape_t(t.sizes().begin(), t.sizes().end());
223 }
224
225 template<typename T>
tensor_cdata(Tensor & t)226 inline std::complex<T> *tensor_cdata(Tensor& t) {
227 return reinterpret_cast<std::complex<T>*>(t.data_ptr<c10::complex<T>>());
228 }
229
230 template<typename T>
tensor_cdata(const Tensor & t)231 inline const std::complex<T> *tensor_cdata(const Tensor& t) {
232 return reinterpret_cast<const std::complex<T>*>(t.const_data_ptr<c10::complex<T>>());
233 }
234
235 template<typename T>
compute_fct(int64_t size,int64_t normalization)236 T compute_fct(int64_t size, int64_t normalization) {
237 constexpr auto one = static_cast<T>(1);
238 switch (static_cast<fft_norm_mode>(normalization)) {
239 case fft_norm_mode::none: return one;
240 case fft_norm_mode::by_n: return one / static_cast<T>(size);
241 case fft_norm_mode::by_root_n: return one / std::sqrt(static_cast<T>(size));
242 }
243 AT_ERROR("Unsupported normalization type", normalization);
244 }
245
246 template<typename T>
compute_fct(const Tensor & t,IntArrayRef dim,int64_t normalization)247 T compute_fct(const Tensor& t, IntArrayRef dim, int64_t normalization) {
248 if (static_cast<fft_norm_mode>(normalization) == fft_norm_mode::none) {
249 return static_cast<T>(1);
250 }
251 const auto& sizes = t.sizes();
252 int64_t n = 1;
253 for(auto idx: dim) {
254 n *= sizes[idx];
255 }
256 return compute_fct<T>(n, normalization);
257 }
258
259 } // anonymous namespace
260
_fft_c2r_mkl(const Tensor & self,IntArrayRef dim,int64_t normalization,int64_t last_dim_size)261 Tensor _fft_c2r_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) {
262 auto in_sizes = self.sizes();
263 DimVector out_sizes(in_sizes.begin(), in_sizes.end());
264 out_sizes[dim.back()] = last_dim_size;
265 auto out = at::empty(out_sizes, self.options().dtype(c10::toRealValueType(self.scalar_type())));
266 pocketfft::shape_t axes(dim.begin(), dim.end());
267 if (self.scalar_type() == kComplexFloat) {
268 pocketfft::c2r(shape_from_tensor(out), stride_from_tensor(self), stride_from_tensor(out), axes, false,
269 tensor_cdata<float>(self),
270 out.data_ptr<float>(), compute_fct<float>(out, dim, normalization));
271 } else {
272 pocketfft::c2r(shape_from_tensor(out), stride_from_tensor(self), stride_from_tensor(out), axes, false,
273 tensor_cdata<double>(self),
274 out.data_ptr<double>(), compute_fct<double>(out, dim, normalization));
275 }
276 return out;
277 }
278
279
_fft_r2c_mkl(const Tensor & self,IntArrayRef dim,int64_t normalization,bool onesided)280 Tensor _fft_r2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) {
281 TORCH_CHECK(self.is_floating_point());
282 auto input_sizes = self.sizes();
283 DimVector out_sizes(input_sizes.begin(), input_sizes.end());
284 auto last_dim = dim.back();
285 auto last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1;
286 if (onesided) {
287 out_sizes[last_dim] = last_dim_halfsize;
288 }
289
290 auto out = at::empty(out_sizes, self.options().dtype(c10::toComplexType(self.scalar_type())));
291 pocketfft::shape_t axes(dim.begin(), dim.end());
292 if (self.scalar_type() == kFloat) {
293 pocketfft::r2c(shape_from_tensor(self), stride_from_tensor(self), stride_from_tensor(out), axes, true,
294 self.const_data_ptr<float>(),
295 tensor_cdata<float>(out), compute_fct<float>(self, dim, normalization));
296 } else {
297 pocketfft::r2c(shape_from_tensor(self), stride_from_tensor(self), stride_from_tensor(out), axes, true,
298 self.const_data_ptr<double>(),
299 tensor_cdata<double>(out), compute_fct<double>(self, dim, normalization));
300 }
301
302 if (!onesided) {
303 at::native::_fft_fill_with_conjugate_symmetry_(out, dim);
304 }
305 return out;
306 }
307
_fft_c2c_mkl(const Tensor & self,IntArrayRef dim,int64_t normalization,bool forward)308 Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) {
309 TORCH_CHECK(self.is_complex());
310 if (dim.empty()) {
311 return self.clone();
312 }
313
314 auto out = at::empty(self.sizes(), self.options());
315 pocketfft::shape_t axes(dim.begin(), dim.end());
316 if (self.scalar_type() == kComplexFloat) {
317 pocketfft::c2c(shape_from_tensor(self), stride_from_tensor(self), stride_from_tensor(out), axes, forward,
318 tensor_cdata<float>(self),
319 tensor_cdata<float>(out), compute_fct<float>(self, dim, normalization));
320 } else {
321 pocketfft::c2c(shape_from_tensor(self), stride_from_tensor(self), stride_from_tensor(out), axes, forward,
322 tensor_cdata<double>(self),
323 tensor_cdata<double>(out), compute_fct<double>(self, dim, normalization));
324 }
325
326 return out;
327 }
328
329 }}
330
331 #elif AT_MKL_ENABLED()
332 #include <ATen/Dispatch.h>
333
334 #include <algorithm>
335 #include <numeric>
336 #include <cmath>
337
338 #include <mkl_dfti.h>
339 #include <ATen/mkl/Exceptions.h>
340 #include <ATen/mkl/Descriptors.h>
341 #include <ATen/mkl/Limits.h>
342
343
344 namespace at { namespace native {
345
346 // Constructs an mkl-fft plan descriptor representing the desired transform
347 // For complex types, strides are in units of 2 * element_size(dtype)
348 // sizes are for the full signal, including batch size and always two-sided
_plan_mkl_fft(IntArrayRef in_strides,IntArrayRef out_strides,IntArrayRef sizes,bool complex_input,bool complex_output,int64_t normalization,bool forward,ScalarType dtype)349 static DftiDescriptor _plan_mkl_fft(
350 IntArrayRef in_strides, IntArrayRef out_strides, IntArrayRef sizes,
351 bool complex_input, bool complex_output,
352 int64_t normalization, bool forward, ScalarType dtype) {
353 const int64_t signal_ndim = sizes.size() - 1;
354 TORCH_INTERNAL_ASSERT(in_strides.size() == sizes.size());
355 TORCH_INTERNAL_ASSERT(out_strides.size() == sizes.size());
356
357 // precision
358 const DFTI_CONFIG_VALUE prec = [&]{
359 switch (c10::toRealValueType(dtype)) {
360 case ScalarType::Float: return DFTI_SINGLE;
361 case ScalarType::Double: return DFTI_DOUBLE;
362 default: TORCH_CHECK(false, "MKL FFT doesn't support tensors of type: ", dtype);
363 }
364 }();
365 // signal type
366 const DFTI_CONFIG_VALUE signal_type = [&]{
367 if (forward) {
368 return complex_input ? DFTI_COMPLEX : DFTI_REAL;
369 } else {
370 return complex_output ? DFTI_COMPLEX : DFTI_REAL;
371 }
372 }();
373 // create descriptor with signal size
374 using MklDimVector = c10::SmallVector<MKL_LONG, at::kDimVectorStaticSize>;
375 MklDimVector mkl_signal_sizes(sizes.begin() + 1, sizes.end());
376 DftiDescriptor descriptor;
377 descriptor.init(prec, signal_type, signal_ndim, mkl_signal_sizes.data());
378 // out of place FFT
379 MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE));
380 // batch mode
381 MKL_LONG mkl_batch_size = sizes[0];
382 MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, mkl_batch_size));
383
384 // batch dim stride, i.e., dist between each data
385 TORCH_CHECK(in_strides[0] <= MKL_LONG_MAX && out_strides[0] <= MKL_LONG_MAX);
386 MKL_LONG idist = in_strides[0];
387 MKL_LONG odist = out_strides[0];
388 MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_INPUT_DISTANCE, idist));
389 MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_DISTANCE, odist));
390
391 // signal strides
392 // first val is offset, set to zero (ignored)
393 MklDimVector mkl_istrides(1 + signal_ndim, 0), mkl_ostrides(1 + signal_ndim, 0);
394 for (int64_t i = 1; i <= signal_ndim; i++) {
395 TORCH_CHECK(in_strides[i] <= MKL_LONG_MAX && out_strides[i] <= MKL_LONG_MAX);
396 mkl_istrides[i] = in_strides[i];
397 mkl_ostrides[i] = out_strides[i];
398 }
399 MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_INPUT_STRIDES, mkl_istrides.data()));
400 MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_STRIDES, mkl_ostrides.data()));
401 // if conjugate domain of real is involved, set standard CCE storage type
402 // this will become default in MKL in future
403 if (!complex_input || !complex_output) {
404 MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX));
405 }
406 // rescale if requested
407 const auto norm = static_cast<fft_norm_mode>(normalization);
408 int64_t signal_numel = c10::multiply_integers(IntArrayRef(sizes.data() + 1, signal_ndim));
409 if (norm != fft_norm_mode::none) {
410 const double scale = (
411 (norm == fft_norm_mode::by_root_n) ?
412 1.0 / std::sqrt(static_cast<double>(signal_numel)) :
413 1.0 / static_cast<double>(signal_numel));
414 const auto scale_direction = forward ? DFTI_FORWARD_SCALE : DFTI_BACKWARD_SCALE;
415 MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), scale_direction, scale));
416 }
417
418 if (sizeof(MKL_LONG) < sizeof(int64_t)) {
419 TORCH_CHECK(signal_numel <= MKL_LONG_MAX,
420 "MKL FFT: input signal numel exceeds allowed range [1, ", MKL_LONG_MAX, "]");
421 }
422
423 // finalize
424 MKL_DFTI_CHECK(DftiCommitDescriptor(descriptor.get()));
425
426 return descriptor;
427 }
428
429 // Execute a general fft operation (can be c2c, onesided r2c or onesided c2r)
_exec_fft(Tensor & out,const Tensor & self,IntArrayRef out_sizes,IntArrayRef dim,int64_t normalization,bool forward)430 static Tensor& _exec_fft(Tensor& out, const Tensor& self, IntArrayRef out_sizes,
431 IntArrayRef dim, int64_t normalization, bool forward) {
432 const auto ndim = self.dim();
433 const int64_t signal_ndim = dim.size();
434 const auto batch_dims = ndim - signal_ndim;
435
436 // Permute dimensions so batch dimensions come first, and in stride order
437 // This maximizes data locality when collapsing to a single batch dimension
438 DimVector dim_permute(ndim);
439 std::iota(dim_permute.begin(), dim_permute.end(), int64_t{0});
440
441 c10::SmallVector<bool, kDimVectorStaticSize> is_transformed_dim(ndim);
442 for (const auto& d : dim) {
443 is_transformed_dim[d] = true;
444 }
445 auto batch_end = std::partition(dim_permute.begin(), dim_permute.end(),
446 [&](int64_t d) {return !is_transformed_dim[d]; });
447 auto self_strides = self.strides();
448 std::sort(dim_permute.begin(), batch_end,
449 [&](int64_t a, int64_t b) { return self_strides[a] > self_strides[b]; });
450 std::copy(dim.cbegin(), dim.cend(), batch_end);
451 auto input = self.permute(dim_permute);
452
453 // Collapse batch dimensions into a single dimension
454 DimVector batched_sizes(signal_ndim + 1);
455 batched_sizes[0] = -1;
456 std::copy(input.sizes().cbegin() + batch_dims, input.sizes().cend(), batched_sizes.begin() + 1);
457 input = input.reshape(batched_sizes);
458
459 const auto batch_size = input.sizes()[0];
460 DimVector signal_size(signal_ndim + 1);
461 signal_size[0] = batch_size;
462 for (const auto i : c10::irange(signal_ndim)) {
463 auto in_size = input.sizes()[i + 1];
464 auto out_size = out_sizes[dim[i]];
465 signal_size[i + 1] = std::max(in_size, out_size);
466 TORCH_INTERNAL_ASSERT(in_size == signal_size[i + 1] ||
467 in_size == (signal_size[i + 1] / 2) + 1);
468 TORCH_INTERNAL_ASSERT(out_size == signal_size[i + 1] ||
469 out_size == (signal_size[i + 1] / 2) + 1);
470 }
471
472 batched_sizes[0] = batch_size;
473 DimVector batched_out_sizes(batched_sizes.begin(), batched_sizes.end());
474 for (const auto i : c10::irange(dim.size())) {
475 batched_out_sizes[i + 1] = out_sizes[dim[i]];
476 }
477
478 const auto value_type = c10::toRealValueType(input.scalar_type());
479 out.resize_(batched_out_sizes, MemoryFormat::Contiguous);
480
481 auto descriptor = _plan_mkl_fft(
482 input.strides(), out.strides(), signal_size, input.is_complex(),
483 out.is_complex(), normalization, forward, value_type);
484
485 // run the FFT
486 if (forward) {
487 MKL_DFTI_CHECK(DftiComputeForward(descriptor.get(), const_cast<void*>(input.const_data_ptr()), out.data_ptr()));
488 } else {
489 MKL_DFTI_CHECK(DftiComputeBackward(descriptor.get(), const_cast<void*>(input.const_data_ptr()), out.data_ptr()));
490 }
491
492 // Inplace reshaping to original batch shape and inverting the dimension permutation
493 DimVector out_strides(ndim);
494 int64_t batch_numel = 1;
495 for (int64_t i = batch_dims - 1; i >= 0; --i) {
496 out_strides[dim_permute[i]] = batch_numel * out.strides()[0];
497 batch_numel *= out_sizes[dim_permute[i]];
498 }
499 for (const auto i : c10::irange(batch_dims, ndim)) {
500 out_strides[dim_permute[i]] = out.strides()[1 + (i - batch_dims)];
501 }
502 out.as_strided_(out_sizes, out_strides, out.storage_offset());
503 return out;
504 }
505
506 // Sort transform dimensions by input layout, for best performance
507 // exclude_last is for onesided transforms where the last dimension cannot be reordered
_sort_dims(const Tensor & self,IntArrayRef dim,bool exclude_last=false)508 static DimVector _sort_dims(const Tensor& self, IntArrayRef dim, bool exclude_last=false) {
509 DimVector sorted_dims(dim.begin(), dim.end());
510 auto self_strides = self.strides();
511 std::sort(sorted_dims.begin(), sorted_dims.end() - exclude_last,
512 [&](int64_t a, int64_t b) { return self_strides[a] > self_strides[b]; });
513 return sorted_dims;
514 }
515
516 // n-dimensional complex to real IFFT
_fft_c2r_mkl(const Tensor & self,IntArrayRef dim,int64_t normalization,int64_t last_dim_size)517 Tensor _fft_c2r_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) {
518 TORCH_CHECK(self.is_complex());
519 // NOTE: Multi-dimensional C2R transforms don't agree with numpy in cases
520 // where the input isn't strictly Hermitian-symmetric. Instead, we use a
521 // multi-dim C2C transform followed by a 1D C2R transform.
522 //
523 // Such inputs are technically out of contract though, so maybe a disagreement
524 // is okay.
525 auto input = self;
526 if (dim.size() > 1) {
527 auto c2c_dims = dim.slice(0, dim.size() - 1);
528 input = _fft_c2c_mkl(self, c2c_dims, normalization, /*forward=*/false);
529 dim = dim.slice(dim.size() - 1);
530 }
531
532 auto in_sizes = input.sizes();
533 DimVector out_sizes(in_sizes.begin(), in_sizes.end());
534 out_sizes[dim.back()] = last_dim_size;
535 auto out = at::empty(out_sizes, self.options().dtype(c10::toRealValueType(self.scalar_type())));
536 return _exec_fft(out, input, out_sizes, dim, normalization, /*forward=*/false);
537 }
538
539 // n-dimensional real to complex FFT
_fft_r2c_mkl(const Tensor & self,IntArrayRef dim,int64_t normalization,bool onesided)540 Tensor _fft_r2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) {
541 TORCH_CHECK(self.is_floating_point());
542 auto input_sizes = self.sizes();
543 DimVector out_sizes(input_sizes.begin(), input_sizes.end());
544 auto last_dim = dim.back();
545 auto last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1;
546 if (onesided) {
547 out_sizes[last_dim] = last_dim_halfsize;
548 }
549
550 auto sorted_dims = _sort_dims(self, dim, /*exclude_last=*/true);
551 auto out = at::empty(out_sizes, self.options().dtype(c10::toComplexType(self.scalar_type())));
552 _exec_fft(out, self, out_sizes, sorted_dims, normalization, /*forward=*/true);
553
554 if (!onesided) {
555 at::native::_fft_fill_with_conjugate_symmetry_(out, dim);
556 }
557 return out;
558 }
559
560 // n-dimensional complex to complex FFT/IFFT
_fft_c2c_mkl(const Tensor & self,IntArrayRef dim,int64_t normalization,bool forward)561 Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) {
562 TORCH_CHECK(self.is_complex());
563 if (dim.empty()) {
564 return self.clone();
565 }
566
567 const auto sorted_dims = _sort_dims(self, dim);
568 auto out = at::empty(self.sizes(), self.options());
569 return _exec_fft(out, self, self.sizes(), sorted_dims, normalization, forward);
570 }
571
572 }} // namespace at::native
573
574 #else
575
576 namespace at { namespace native {
577 REGISTER_NO_CPU_DISPATCH(fft_fill_with_conjugate_symmetry_stub);
578
_fft_c2r_mkl(const Tensor & self,IntArrayRef dim,int64_t normalization,int64_t last_dim_size)579 Tensor _fft_c2r_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) {
580 AT_ERROR("fft: ATen not compiled with FFT support");
581 }
582
_fft_r2c_mkl(const Tensor & self,IntArrayRef dim,int64_t normalization,bool onesided)583 Tensor _fft_r2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) {
584 AT_ERROR("fft: ATen not compiled with FFT support");
585 }
586
_fft_c2c_mkl(const Tensor & self,IntArrayRef dim,int64_t normalization,bool forward)587 Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) {
588 AT_ERROR("fft: ATen not compiled with FFT support");
589 }
590
_fft_r2c_mkl_out(const Tensor & self,IntArrayRef dim,int64_t normalization,bool onesided,Tensor & out)591 Tensor& _fft_r2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
592 bool onesided, Tensor& out) {
593 AT_ERROR("fft: ATen not compiled with FFT support");
594 }
595
_fft_c2r_mkl_out(const Tensor & self,IntArrayRef dim,int64_t normalization,int64_t last_dim_size,Tensor & out)596 Tensor& _fft_c2r_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
597 int64_t last_dim_size, Tensor& out) {
598 AT_ERROR("fft: ATen not compiled with FFT support");
599 }
600
_fft_c2c_mkl_out(const Tensor & self,IntArrayRef dim,int64_t normalization,bool forward,Tensor & out)601 Tensor& _fft_c2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
602 bool forward, Tensor& out) {
603 AT_ERROR("fft: ATen not compiled with FFT support");
604 }
605
606 }} // namespace at::native
607 #endif
608