xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/SpectralOps.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Config.h>
4 #include <ATen/TensorSubclassLikeUtils.h>
5 #include <ATen/detail/CUDAHooksInterface.h>
6 #include <ATen/native/SpectralOpsUtils.h>
7 #include <ATen/TensorIterator.h>
8 #include <ATen/TensorOperators.h>
9 #include <ATen/WrapDimUtils.h>
10 #include <c10/util/irange.h>
11 
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #include <ATen/NativeFunctions.h>
15 #else
16 #include <ATen/ops/_cufft_clear_plan_cache_native.h>
17 #include <ATen/ops/_cufft_get_plan_cache_max_size_native.h>
18 #include <ATen/ops/_cufft_get_plan_cache_size_native.h>
19 #include <ATen/ops/_cufft_set_plan_cache_max_size_native.h>
20 #include <ATen/ops/_fft_c2c.h>
21 #include <ATen/ops/_fft_c2r.h>
22 #include <ATen/ops/_fft_r2c.h>
23 #include <ATen/ops/arange.h>
24 #include <ATen/ops/arange_native.h>
25 #include <ATen/ops/conj.h>
26 #include <ATen/ops/conj_physical.h>
27 #include <ATen/ops/constant_pad_nd.h>
28 #include <ATen/ops/empty.h>
29 #include <ATen/ops/fft_fft2_native.h>
30 #include <ATen/ops/fft_fft_native.h>
31 #include <ATen/ops/fft_fftfreq_native.h>
32 #include <ATen/ops/fft_fftn_native.h>
33 #include <ATen/ops/fft_fftshift_native.h>
34 #include <ATen/ops/fft_hfft2_native.h>
35 #include <ATen/ops/fft_hfft_native.h>
36 #include <ATen/ops/fft_hfftn_native.h>
37 #include <ATen/ops/fft_ifft2_native.h>
38 #include <ATen/ops/fft_ifft_native.h>
39 #include <ATen/ops/fft_ifftn_native.h>
40 #include <ATen/ops/fft_ifftshift_native.h>
41 #include <ATen/ops/fft_ihfft2_native.h>
42 #include <ATen/ops/fft_ihfft_native.h>
43 #include <ATen/ops/fft_ihfftn_native.h>
44 #include <ATen/ops/fft_irfft2_native.h>
45 #include <ATen/ops/fft_irfft_native.h>
46 #include <ATen/ops/fft_irfftn_native.h>
47 #include <ATen/ops/fft_rfft2_native.h>
48 #include <ATen/ops/fft_rfft_native.h>
49 #include <ATen/ops/fft_rfftfreq_native.h>
50 #include <ATen/ops/fft_rfftn_native.h>
51 #include <ATen/ops/istft_native.h>
52 #include <ATen/ops/ones.h>
53 #include <ATen/ops/pad.h>
54 #include <ATen/ops/roll.h>
55 #include <ATen/ops/stft.h>
56 #include <ATen/ops/stft_native.h>
57 #include <ATen/ops/unfold_backward.h>
58 #include <ATen/ops/view_as_complex.h>
59 #include <ATen/ops/view_as_real.h>
60 #include <ATen/ops/zeros.h>
61 #include <ATen/ops/zeros_like_ops.h>
62 #endif
63 
64 #include <algorithm>
65 
66 namespace at::native {
67 
68 namespace {
69 
70 // Promote inputs to FFT functions
71 // * Integers are promoted to the default floating type
72 // * If require_complex=True, all types are promoted to complex
73 // * Raises an error for half-precision dtypes to allow future support
promote_type_fft(ScalarType type,bool require_complex,Device device)74 ScalarType promote_type_fft(ScalarType type, bool require_complex, Device device) {
75   if (at::isComplexType(type)) {
76     return type;
77   }
78   // Promote integral to default float type
79   if (!at::isFloatingType(type)) {
80     type = c10::typeMetaToScalarType(c10::get_default_dtype());
81   }
82 
83   const bool maybe_support_half = (
84     // Only CUDA supports half precision, but since meta tensors don't have a
85     // device we err on the side of accepting it
86     device.is_cuda() || device.is_meta()
87   );
88   if (maybe_support_half) {
89     TORCH_CHECK(type == kHalf || type == kFloat || type == kDouble, "Unsupported dtype ", type);
90   } else {
91     TORCH_CHECK(type == kFloat || type == kDouble, "Unsupported dtype ", type);
92   }
93 
94   if (!require_complex) {
95     return type;
96   }
97 
98   // Promote to complex
99   switch (type) {
100   case kHalf: return kComplexHalf;
101   case kFloat: return kComplexFloat;
102   case kDouble: return kComplexDouble;
103   default: TORCH_INTERNAL_ASSERT(false, "Unhandled dtype");
104   }
105 }
106 
107 // Promote a tensor's dtype according to promote_type_fft
promote_tensor_fft(const Tensor & t,bool require_complex=false)108 Tensor promote_tensor_fft(const Tensor& t, bool require_complex=false) {
109   auto cur_type = t.scalar_type();
110   auto new_type = promote_type_fft(cur_type, require_complex, t.device());
111   return (cur_type == new_type) ? t : t.to(new_type);
112 }
113 
114 // Convert NumPy compatible normalization mode string to enum values
115 // NOTE: NumPy's normalization modes have direction-specific meanings. For example,
116 // "forward" translates to `by_n` for a forward transform and `none` for backward.
norm_from_string(std::optional<c10::string_view> norm,bool forward)117 fft_norm_mode norm_from_string(std::optional<c10::string_view> norm, bool forward) {
118   if (!norm || *norm == "backward") {
119     return forward ? fft_norm_mode::none : fft_norm_mode::by_n;
120   }
121 
122   if (*norm == "forward") {
123     return forward ? fft_norm_mode::by_n : fft_norm_mode::none;
124   }
125 
126   if (*norm == "ortho") {
127     return fft_norm_mode::by_root_n;
128   }
129 
130   TORCH_CHECK(false, "Invalid normalization mode: \"", *norm, "\"")
131 }
132 
133 // Fixes the shape of x such that x.size(dims[i]) == sizes[i],
134 // either by zero-padding, or by slicing x starting from 0.
resize_fft_input(Tensor x,IntArrayRef dims,SymIntArrayRef sizes)135 Tensor resize_fft_input(Tensor x, IntArrayRef dims, SymIntArrayRef sizes) {
136   TORCH_INTERNAL_ASSERT(dims.size() == sizes.size());
137   bool must_copy = false;
138   auto x_sizes = x.sym_sizes();
139   SymDimVector pad_amount(x_sizes.size() * 2);
140   for (const auto i : c10::irange(dims.size())) {
141     if (sizes[i] == -1) {
142       continue;
143     }
144 
145     if (x_sizes[dims[i]] < sizes[i]) {
146       must_copy = true;
147       auto pad_idx = pad_amount.size() - 2 * dims[i] - 1;
148       pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]];
149     }
150 
151     if (x_sizes[dims[i]] > sizes[i]) {
152       x = x.slice_symint(dims[i], 0, sizes[i]);
153     }
154   }
155 
156   // Only call pad if necessary since pad copies the entire tensor
157   return must_copy ? at::constant_pad_nd_symint(x, pad_amount) : x;
158 }
159 
fft_r2c_maybe_out(c10::string_view fname,const Tensor & out,const Tensor & input,IntArrayRef dim,int64_t norm,bool onesided)160 Tensor fft_r2c_maybe_out(
161     c10::string_view fname, const Tensor& out, const Tensor& input,
162     IntArrayRef dim, int64_t norm, bool onesided) {
163   if (out.defined()) {
164     TORCH_CHECK(out.is_complex(), fname,
165                 " expects a complex output tensor, but got ", out.scalar_type());
166     auto out_mut = out;
167     return at::_fft_r2c_outf(input, dim, norm, onesided, out_mut);
168   }
169   return at::_fft_r2c(input, dim, norm, onesided);
170 }
171 
fft_c2r_maybe_out(c10::string_view fname,const Tensor & out,const Tensor & input,IntArrayRef dim,int64_t norm,SymInt last_dim_size)172 Tensor fft_c2r_maybe_out(
173     c10::string_view fname, const Tensor& out, const Tensor& input,
174     IntArrayRef dim, int64_t norm, SymInt last_dim_size) {
175   // Support out argument if defined, otherwise call functional
176   // variant so autograd works properly.
177   if (out.defined()) {
178     TORCH_CHECK(out.is_floating_point(), fname,
179                 " expects a floating point output tensor, but got ", out.scalar_type());
180     auto out_mut = out;
181     return at::_fft_c2r_symint_outf(input, dim, norm, last_dim_size, out_mut);
182   }
183   return at::_fft_c2r_symint(input, dim, norm, last_dim_size);
184 }
185 
fft_c2c_maybe_out(c10::string_view fname,const Tensor & out,const Tensor & input,IntArrayRef dim,int64_t norm,bool forward)186 Tensor fft_c2c_maybe_out(
187     c10::string_view fname, const Tensor& out, const Tensor& input,
188     IntArrayRef dim, int64_t norm, bool forward) {
189   if (out.defined()) {
190     TORCH_CHECK(out.is_complex(), fname,
191                 " expects a complex output tensor, but got ", out.scalar_type());
192     auto out_mut = out;
193     return at::_fft_c2c_outf(input, dim, norm, forward, out_mut);
194   }
195   return at::_fft_c2c(input, dim, norm, forward);
196 }
197 
198 // Complex to real FFT
fft_c2r(c10::string_view function_name,Tensor out,Tensor input,std::optional<SymInt> n_opt,int64_t unwrapped_dim,std::optional<c10::string_view> norm_str,bool forward)199 Tensor fft_c2r(c10::string_view function_name,
200                Tensor out, Tensor input, std::optional<SymInt> n_opt,
201                int64_t unwrapped_dim, std::optional<c10::string_view> norm_str,
202                bool forward) {
203   TORCH_CHECK(!out.defined() || out.is_floating_point(), function_name,
204               " expects a floating point output tensor, but got ", out.scalar_type());
205   input = promote_tensor_fft(input, /*require_complex=*/true);
206   const auto input_dim = input.dim();
207   const auto dim = maybe_wrap_dim(unwrapped_dim, input_dim, /*wrap_scalar=*/false);
208   const auto n = n_opt.value_or(2*(input.sym_sizes()[dim] - 1));
209   TORCH_CHECK(n >= 1, "Invalid number of data points (", n, ") specified");
210   if (n_opt) {
211     input = resize_fft_input(input, dim, n/2 + 1);
212   }
213   const auto norm = norm_from_string(norm_str, forward);
214   if (forward) {
215     // FIXME: _fft does not support complex_output=false with inverse=false
216     input = input.conj();
217   }
218   return fft_c2r_maybe_out(
219       function_name, out, input, dim, static_cast<int64_t>(norm), n);
220 }
221 
222 // Real to complex FFT
fft_r2c(c10::string_view function_name,Tensor out,Tensor input,std::optional<SymInt> n_opt,int64_t unwrapped_dim,std::optional<c10::string_view> norm_str,bool forward,bool onesided)223 Tensor fft_r2c(c10::string_view function_name,
224                Tensor out, Tensor input, std::optional<SymInt> n_opt,
225                int64_t unwrapped_dim, std::optional<c10::string_view> norm_str,
226                bool forward, bool onesided) {
227   TORCH_CHECK(!input.is_complex(), function_name,
228               " expects a real input tensor, but got ", input.scalar_type());
229   TORCH_CHECK(!out.defined() || out.is_complex(), function_name,
230               " expects a complex output tensor, but got ", out.scalar_type());
231   input = promote_tensor_fft(input);
232   const auto input_dim = input.dim();
233   const auto dim = maybe_wrap_dim(unwrapped_dim, input_dim, /*wrap_scalar=*/false);
234   const auto n = n_opt.value_or(input.sym_sizes()[dim]);
235   TORCH_CHECK(n >= 1, "Invalid number of data points (", n, ") specified");
236   if (n_opt) {
237     input = resize_fft_input(input, dim, n);
238   }
239 
240   const auto norm = norm_from_string(norm_str, forward);
241 
242   Tensor ret;
243   if (out.defined() && forward) {
244     ret = at::_fft_r2c_out(out, input, dim, static_cast<int64_t>(norm), onesided);
245   } else {
246     ret = at::_fft_r2c(input, dim, static_cast<int64_t>(norm), onesided);
247   }
248 
249   if (!forward) {
250     // FIXME: _fft_r2c doesn't support native r2c IFFT
251     return out.defined() ? at::conj_physical_out(out, ret) : ret.conj();
252   } else {
253     return ret;
254   }
255 }
256 
257 // Complex to complex FFT
fft_c2c(c10::string_view function_name,Tensor out,Tensor input,std::optional<SymInt> n_opt,int64_t unwrapped_dim,std::optional<c10::string_view> norm_str,bool forward)258 Tensor fft_c2c(c10::string_view function_name,
259                Tensor out, Tensor input, std::optional<SymInt> n_opt,
260                int64_t unwrapped_dim, std::optional<c10::string_view> norm_str,
261                bool forward) {
262   TORCH_CHECK(input.is_complex(), function_name,
263               " expects a complex input tensor, but got ", input.scalar_type());
264   const auto input_dim = input.dim();
265   const auto dim = maybe_wrap_dim(unwrapped_dim, input_dim, /*wrap_scalar=*/false);
266   const auto n = n_opt.value_or(input.sym_sizes()[dim]);
267   TORCH_CHECK(n >= 1, "Invalid number of data points (", n, ") specified");
268   if (n_opt) {
269     input = resize_fft_input(input, dim, n);
270   }
271   const auto norm = static_cast<int64_t>(norm_from_string(norm_str, forward));
272   return fft_c2c_maybe_out(function_name, out, input, dim, norm, forward);
273 }
274 
275 // Dimensions to transform, and the signal shape in those dimensions
276 struct ShapeAndDims {
277   SymDimVector shape;
278   DimVector dim;
279 };
280 
281 // Pre-process n-dimensional fft's `s` and `dim` arguments.
282 // Wraps dimensions and applies defaulting behavior.
283 // Also checks transform dims are unique and transform shape is non-empty.
canonicalize_fft_shape_and_dim_args(Tensor input,at::OptionalSymIntArrayRef shape,at::OptionalIntArrayRef dim)284 ShapeAndDims canonicalize_fft_shape_and_dim_args(
285     Tensor input, at::OptionalSymIntArrayRef shape, at::OptionalIntArrayRef dim) {
286   const int64_t input_dim = input.dim();
287   const SymIntArrayRef input_sizes = input.sym_sizes();
288   ShapeAndDims ret;
289 
290   if (dim) {
291     ret.dim.resize(dim->size());
292     std::copy(dim->begin(), dim->end(), ret.dim.begin());
293     maybe_wrap_dims(ret.dim, input_dim, /*wrap_scalars=*/false);
294 
295     // Check dims are unique
296     DimVector copy = ret.dim;
297     std::sort(copy.begin(), copy.end());
298     auto duplicate = std::adjacent_find(copy.begin(), copy.end());
299     TORCH_CHECK(duplicate == copy.end(), "FFT dims must be unique");
300   }
301 
302   if (shape) {
303     // Has shape, may have dim
304     TORCH_CHECK(!dim ||
305                 dim->size() == shape->size(),
306                 "When given, dim and shape arguments must have the same length");
307     TORCH_CHECK(static_cast<int64_t>(shape->size()) <= input_dim,
308                 "Got shape with ", shape->size(), " values but input tensor "
309                 "only has ", input_dim, " dimensions.");
310     const int64_t transform_ndim = shape->size();
311     // If shape is given, dims defaults to the last shape.size() dimensions
312     if (!dim) {
313       ret.dim.resize(transform_ndim);
314       std::iota(ret.dim.begin(), ret.dim.end(), input_dim - transform_ndim);
315     }
316 
317     // Translate shape of -1 to the default length
318     ret.shape.resize(transform_ndim);
319     for (const auto i : c10::irange(transform_ndim)) {
320       const auto n = (*shape)[i];
321       ret.shape[i] = n == -1 ? input_sizes[ret.dim[i]] : n;
322     }
323   } else if (!dim) {
324     // No shape, no dim
325     ret.dim.resize(input_dim);
326     std::iota(ret.dim.begin(), ret.dim.end(), int64_t{0});
327     ret.shape.resize(input_dim);
328     std::copy(input_sizes.begin(), input_sizes.end(), ret.shape.begin());
329   } else {
330     // No shape, has dim
331     ret.shape.resize(ret.dim.size());
332     for (const auto i : c10::irange(ret.dim.size())) {
333       ret.shape[i] = input_sizes[ret.dim[i]];
334     }
335   }
336 
337   for (const auto & shape : ret.shape) {
338     TORCH_CHECK(shape > 0,
339                 "Invalid number of data points (", shape, ") specified");
340   }
341 
342   return ret;
343 }
344 
345 // Complex to complex n-dimensional fft
fftn_c2c(c10::string_view function_name,Tensor out,const Tensor & input,SymIntArrayRef shape,IntArrayRef dim,std::optional<c10::string_view> norm_str,bool forward)346 Tensor fftn_c2c(
347     c10::string_view function_name,
348     Tensor out, const Tensor& input, SymIntArrayRef shape,
349     IntArrayRef dim, std::optional<c10::string_view> norm_str, bool forward) {
350   TORCH_CHECK(input.is_complex(), function_name, " expects a complex input tensor, but got", input.scalar_type());
351   Tensor x = resize_fft_input(input, dim, shape);
352   const auto norm = static_cast<int64_t>(norm_from_string(norm_str, forward));
353   constexpr c10::string_view fname = "fftn";
354   return fft_c2c_maybe_out(fname, out, x, dim, norm, forward);
355 }
356 
357 }  // namespace (anonymous)
358 
359 // torch.fft.fft, analogous to NumPy's numpy.fft.fft
fft_fft_symint(const Tensor & self,std::optional<SymInt> n,int64_t dim,std::optional<c10::string_view> norm)360 Tensor fft_fft_symint(const Tensor& self, std::optional<SymInt> n, int64_t dim,
361                std::optional<c10::string_view> norm) {
362   return self.is_complex() ?
363     fft_c2c("fft", {}, self, n, dim, norm, /*forward=*/true) :
364     fft_r2c("fft", {}, self, n, dim, norm, /*forward=*/true, /*onesided=*/false);
365 }
366 
fft_fft_symint_out(const Tensor & self,std::optional<SymInt> n,int64_t dim,std::optional<c10::string_view> norm,Tensor & out)367 Tensor& fft_fft_symint_out(const Tensor& self, std::optional<SymInt> n,
368                     int64_t dim, std::optional<c10::string_view> norm, Tensor& out) {
369   if (self.is_complex()) {
370     fft_c2c("fft", out, self, n, dim, norm, /*forward=*/true);
371   } else {
372     fft_r2c("fft", out, self, n, dim, norm, /*forward=*/true, /*onesided=*/false);
373   }
374   return out;
375 }
376 
fft_ifft_symint(const Tensor & self,std::optional<SymInt> n,int64_t dim,std::optional<c10::string_view> norm)377 Tensor fft_ifft_symint(const Tensor& self, std::optional<SymInt> n, int64_t dim,
378                 std::optional<c10::string_view> norm) {
379   return self.is_complex() ?
380     fft_c2c("ifft", {}, self, n, dim, norm, /*forward=*/false) :
381     fft_r2c("ifft", {}, self, n, dim, norm, /*forward=*/false, /*onesided=*/false);
382 }
383 
fft_ifft_symint_out(const Tensor & self,std::optional<SymInt> n,int64_t dim,std::optional<c10::string_view> norm,Tensor & out)384 Tensor& fft_ifft_symint_out(const Tensor& self, std::optional<SymInt> n,
385                      int64_t dim, std::optional<c10::string_view> norm, Tensor& out) {
386   if (self.is_complex()) {
387     fft_c2c("ifft", out, self, n, dim, norm, /*forward=*/false);
388   } else {
389     fft_r2c("ifft", out, self, n, dim, norm, /*forward=*/false, /*onesided=*/false);
390   }
391   return out;
392 }
393 
fft_rfft_symint(const Tensor & self,std::optional<SymInt> n,int64_t dim,std::optional<c10::string_view> norm)394 Tensor fft_rfft_symint(const Tensor& self, std::optional<SymInt> n, int64_t dim,
395                 std::optional<c10::string_view> norm) {
396   return fft_r2c("rfft", {}, self, n, dim, norm, /*forward=*/true, /*onesided=*/true);
397 }
398 
fft_rfft_symint_out(const Tensor & self,std::optional<SymInt> n,int64_t dim,std::optional<c10::string_view> norm,Tensor & out)399 Tensor& fft_rfft_symint_out(const Tensor& self, std::optional<SymInt> n,
400                      int64_t dim, std::optional<c10::string_view> norm, Tensor& out) {
401   fft_r2c("rfft", out, self, n, dim, norm, /*forward=*/true, /*onesided=*/true);
402   return out;
403 }
404 
fft_irfft_symint(const Tensor & self,std::optional<SymInt> n,int64_t dim,std::optional<c10::string_view> norm)405 Tensor fft_irfft_symint(const Tensor& self, std::optional<SymInt> n, int64_t dim,
406                  std::optional<c10::string_view> norm) {
407   return fft_c2r("irfft", {}, self, n, dim, norm, /*forward=*/false);
408 }
409 
fft_irfft_symint_out(const Tensor & self,std::optional<SymInt> n,int64_t dim,std::optional<c10::string_view> norm,Tensor & out)410 Tensor& fft_irfft_symint_out(const Tensor& self, std::optional<SymInt> n,
411                   int64_t dim, std::optional<c10::string_view> norm, Tensor& out) {
412   fft_c2r("irfft", out, self, n, dim, norm, /*forward=*/false);
413   return out;
414 }
415 
fft_hfft_symint(const Tensor & self,std::optional<SymInt> n,int64_t dim,std::optional<c10::string_view> norm)416 Tensor fft_hfft_symint(const Tensor& self, std::optional<SymInt> n, int64_t dim,
417                 std::optional<c10::string_view> norm) {
418   return fft_c2r("hfft", {}, self, n, dim, norm, /*forward=*/true);
419 }
420 
fft_hfft_symint_out(const Tensor & self,std::optional<SymInt> n,int64_t dim,std::optional<c10::string_view> norm,Tensor & out)421 Tensor& fft_hfft_symint_out(const Tensor& self, std::optional<SymInt> n,
422                      int64_t dim, std::optional<c10::string_view> norm, Tensor& out) {
423   fft_c2r("hfft", out, self, n, dim, norm, /*forward=*/true);
424   return out;
425 }
426 
fft_ihfft_symint(const Tensor & self,std::optional<SymInt> n,int64_t dim,std::optional<c10::string_view> norm)427 Tensor fft_ihfft_symint(const Tensor& self, std::optional<SymInt> n, int64_t dim,
428                  std::optional<c10::string_view> norm) {
429   return fft_r2c("ihfft", {}, self, n, dim, norm, /*forward=*/false, /*onesided=*/true);
430 }
431 
fft_ihfft_symint_out(const Tensor & self,std::optional<SymInt> n,int64_t dim,std::optional<c10::string_view> norm,Tensor & out)432 Tensor& fft_ihfft_symint_out(const Tensor& self, std::optional<SymInt> n,
433                      int64_t dim, std::optional<c10::string_view> norm, Tensor& out) {
434   fft_r2c("ihfft", out, self, n, dim, norm, /*forward=*/false, /*onesided=*/true);
435   return out;
436 }
437 
fft_fftn_symint(const Tensor & self,at::OptionalSymIntArrayRef s,at::OptionalIntArrayRef dim,std::optional<c10::string_view> norm)438 Tensor fft_fftn_symint(const Tensor& self, at::OptionalSymIntArrayRef s,
439                 at::OptionalIntArrayRef dim,
440                 std::optional<c10::string_view> norm) {
441   auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim);
442   // TODO: For real input, perform rfftn then mirror with conjugate symmetry
443   Tensor input = promote_tensor_fft(self, /*require_complex=*/true);
444   return fftn_c2c("fftn", {}, input, desc.shape, desc.dim, norm, /*forward=*/true);
445 }
446 
fft_fftn_symint_out(const Tensor & self,at::OptionalSymIntArrayRef s,at::OptionalIntArrayRef dim,std::optional<c10::string_view> norm,Tensor & out)447 Tensor& fft_fftn_symint_out(const Tensor& self,
448                      at::OptionalSymIntArrayRef s,
449                      at::OptionalIntArrayRef dim,
450                      std::optional<c10::string_view> norm, Tensor& out) {
451   auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim);
452   // TODO: For real input, perform rfftn then mirror with conjugate symmetry
453   Tensor input = promote_tensor_fft(self, /*require_complex=*/true);
454   fftn_c2c("fftn", out, input, desc.shape, desc.dim, norm, /*forward=*/true);
455   return out;
456 }
457 
fft_ifftn_symint(const Tensor & self,at::OptionalSymIntArrayRef s,at::OptionalIntArrayRef dim,std::optional<c10::string_view> norm)458 Tensor fft_ifftn_symint(const Tensor& self, at::OptionalSymIntArrayRef s,
459                 at::OptionalIntArrayRef dim,
460                 std::optional<c10::string_view> norm) {
461   auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim);
462   Tensor input = promote_tensor_fft(self, /*require_complex=*/true);
463   return fftn_c2c("ifftn", {}, input, desc.shape, desc.dim, norm, /*forward=*/false);
464 }
465 
fft_ifftn_symint_out(const Tensor & self,at::OptionalSymIntArrayRef s,at::OptionalIntArrayRef dim,std::optional<c10::string_view> norm,Tensor & out)466 Tensor& fft_ifftn_symint_out(const Tensor& self,
467                       at::OptionalSymIntArrayRef s,
468                       at::OptionalIntArrayRef dim,
469                       std::optional<c10::string_view> norm, Tensor& out) {
470   auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim);
471   Tensor input = promote_tensor_fft(self, /*require_complex=*/true);
472   fftn_c2c("ifftn", out, input, desc.shape, desc.dim, norm, /*forward=*/false);
473   return out;
474 }
475 
fft_rfftn_impl(Tensor out,const Tensor & self,at::OptionalSymIntArrayRef s,at::OptionalIntArrayRef dim,const std::optional<c10::string_view> & norm_str)476 static Tensor fft_rfftn_impl(Tensor out, const Tensor& self,
477                              at::OptionalSymIntArrayRef s,
478                              at::OptionalIntArrayRef dim,
479                              const std::optional<c10::string_view>& norm_str) {
480   TORCH_CHECK(!self.is_complex(), "rfftn expects a real-valued input tensor, but got ", self.scalar_type());
481   auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim);
482   TORCH_CHECK(!desc.shape.empty(), "rfftn must transform at least one axis");
483   Tensor input = promote_tensor_fft(self, /*require_complex=*/false);
484   Tensor x = resize_fft_input(input, desc.dim, desc.shape);
485   const auto norm = static_cast<int64_t>(norm_from_string(norm_str, /*forward=*/true));
486   constexpr c10::string_view fname = "rfftn";
487   return fft_r2c_maybe_out(fname, out, x, desc.dim, norm, /*onesided=*/true);
488 }
489 
fft_rfftn_symint(const Tensor & self,at::OptionalSymIntArrayRef s,at::OptionalIntArrayRef dim,std::optional<c10::string_view> norm_str)490 Tensor fft_rfftn_symint(const Tensor& self, at::OptionalSymIntArrayRef s,
491                 at::OptionalIntArrayRef dim,
492                 std::optional<c10::string_view> norm_str) {
493   return fft_rfftn_impl({}, self, s, dim, norm_str);
494 }
495 
fft_rfftn_symint_out(const Tensor & self,at::OptionalSymIntArrayRef s,at::OptionalIntArrayRef dim,std::optional<c10::string_view> norm_str,Tensor & out)496 Tensor& fft_rfftn_symint_out(const Tensor& self,
497                       at::OptionalSymIntArrayRef s,
498                       at::OptionalIntArrayRef dim,
499                       std::optional<c10::string_view> norm_str, Tensor& out) {
500   fft_rfftn_impl(out, self, s, dim, norm_str);
501   return out;
502 }
503 
canonicalize_fft_c2r_shape_and_dim_args(c10::string_view fname,const Tensor & self,const at::OptionalSymIntArrayRef & s,const at::OptionalIntArrayRef & dims,SymInt & last_dim_size)504 static ShapeAndDims canonicalize_fft_c2r_shape_and_dim_args(
505     c10::string_view fname, const Tensor& self,
506     const at::OptionalSymIntArrayRef& s,
507     const at::OptionalIntArrayRef& dims,
508     SymInt& last_dim_size) {
509   auto desc = canonicalize_fft_shape_and_dim_args(self, s, dims);
510   TORCH_CHECK(!desc.shape.empty(), fname, " must transform at least one axis");
511 
512   // Expected output size of the hermitian-symmetric dimension
513   last_dim_size = [&] {
514     // Fixup default shape handling in the last dimension,
515     if (!s.has_value() || (s->back() == -1)) {
516       const auto last_dim = desc.dim.back();
517       return 2 * (self.sym_sizes()[last_dim] - 1);
518     }
519     return desc.shape.back();
520   }();
521   TORCH_CHECK(last_dim_size >= 1, "Invalid number of data points (", last_dim_size, ") specified");
522 
523   // Expected input size of the complex-hermitian data
524   desc.shape.back() = last_dim_size / 2 + 1;
525   return desc;
526 }
527 
fft_irfftn_impl(Tensor out,const Tensor & self,at::OptionalSymIntArrayRef s,at::OptionalIntArrayRef dim,const std::optional<c10::string_view> & norm_str)528 static Tensor fft_irfftn_impl(Tensor out, const Tensor& self,
529                               at::OptionalSymIntArrayRef s,
530                               at::OptionalIntArrayRef dim,
531                               const std::optional<c10::string_view>& norm_str) {
532   SymInt last_dim_size = 0;
533   auto desc = canonicalize_fft_c2r_shape_and_dim_args(
534       "irfftn", self, s, dim, last_dim_size);
535   Tensor input = promote_tensor_fft(self, /*require_complex=*/true);
536   Tensor x = resize_fft_input(input, desc.dim, desc.shape);
537   const auto norm = static_cast<int64_t>(norm_from_string(norm_str, /*forward=*/false));
538   constexpr c10::string_view fname = "irfftn";
539   return fft_c2r_maybe_out(fname, out, x, desc.dim, norm, last_dim_size);
540 }
541 
fft_irfftn_symint(const Tensor & self,at::OptionalSymIntArrayRef s,at::OptionalIntArrayRef dim,std::optional<c10::string_view> norm_str)542 Tensor fft_irfftn_symint(const Tensor& self,
543                   at::OptionalSymIntArrayRef s,
544                   at::OptionalIntArrayRef dim,
545                   std::optional<c10::string_view> norm_str) {
546   return fft_irfftn_impl({}, self, s, dim, norm_str);
547 }
548 
fft_irfftn_symint_out(const Tensor & self,at::OptionalSymIntArrayRef s,at::OptionalIntArrayRef dim,std::optional<c10::string_view> norm_str,Tensor & out)549 Tensor& fft_irfftn_symint_out(const Tensor& self,
550                        at::OptionalSymIntArrayRef s,
551                        at::OptionalIntArrayRef dim,
552                        std::optional<c10::string_view> norm_str, Tensor& out) {
553   fft_irfftn_impl(out, self, s, dim, norm_str);
554   return out;
555 }
556 
fft_hfftn_impl(const Tensor & self,at::OptionalSymIntArrayRef s,at::OptionalIntArrayRef dim,std::optional<c10::string_view> norm_str,const Tensor & out)557 static Tensor fft_hfftn_impl(
558     const Tensor& self,
559     at::OptionalSymIntArrayRef s,
560     at::OptionalIntArrayRef dim,
561     std::optional<c10::string_view> norm_str,
562     const Tensor& out) {
563   constexpr c10::string_view fname = "hfftn";
564   SymInt last_dim_size = 0;
565   auto desc = canonicalize_fft_c2r_shape_and_dim_args(
566       fname, self, s, dim, last_dim_size);
567   auto input = promote_tensor_fft(self, /*require_complex=*/true);
568   auto x = resize_fft_input(input, desc.dim, desc.shape);
569   const auto norm = static_cast<int64_t>(
570       norm_from_string(norm_str, /*forward=*/true));
571 
572   Tensor tmp;
573   if (desc.dim.size() > 1) {
574     auto c2c_dims = IntArrayRef(desc.dim).slice(0, desc.dim.size() - 1);
575     tmp = at::_fft_c2c(x, c2c_dims, norm, /*forward=*/true);
576   } else {
577     tmp = x;
578   }
579 
580   const auto last_dim = desc.dim.back();
581   tmp = tmp.conj();
582   return fft_c2r_maybe_out(fname, out, tmp, last_dim, norm, last_dim_size);
583 }
584 
fft_hfftn_symint(const Tensor & self,at::OptionalSymIntArrayRef s,at::OptionalIntArrayRef dim,std::optional<c10::string_view> norm)585 Tensor fft_hfftn_symint(
586     const Tensor& self,
587     at::OptionalSymIntArrayRef s,
588     at::OptionalIntArrayRef dim,
589     std::optional<c10::string_view> norm) {
590   return fft_hfftn_impl(self, s, dim, norm, {});
591 }
592 
fft_hfftn_symint_out(const Tensor & self,at::OptionalSymIntArrayRef s,at::OptionalIntArrayRef dim,std::optional<c10::string_view> norm,const Tensor & out)593 const Tensor& fft_hfftn_symint_out(
594     const Tensor& self,
595     at::OptionalSymIntArrayRef s,
596     at::OptionalIntArrayRef dim, std::optional<c10::string_view> norm,
597     const Tensor& out) {
598   fft_hfftn_impl(self, s, dim, norm, out);
599   return out;
600 }
601 
fft_ihfftn_impl(const Tensor & self,const at::OptionalSymIntArrayRef & s,const at::OptionalIntArrayRef & dim,const std::optional<c10::string_view> & norm_str,const Tensor & out)602 static Tensor fft_ihfftn_impl(
603     const Tensor& self,
604     const at::OptionalSymIntArrayRef& s,
605     const at::OptionalIntArrayRef& dim,
606     const std::optional<c10::string_view>& norm_str,
607     const Tensor& out) {
608   constexpr c10::string_view fname = "ihfftn";
609   auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim);
610   TORCH_CHECK(!desc.shape.empty(), "ihfftn must transform at least one axis");
611   auto input = promote_tensor_fft(self, /*require_complex=*/false);
612   auto x = resize_fft_input(input, desc.dim, desc.shape);
613   const auto norm = static_cast<int64_t>(
614       norm_from_string(norm_str, /*forward=*/false));
615 
616   const auto last_dim = desc.dim.back();
617   auto tmp = at::_fft_r2c(x, last_dim, norm, /*onesided=*/true);
618   if (desc.dim.size() == 1) {
619     return out.defined() ? at::conj_physical_out(tmp, out) : tmp.conj();
620   }
621 
622   tmp = at::conj_physical(tmp);
623   auto c2c_dims = IntArrayRef(desc.dim).slice(0, desc.dim.size() - 1);
624   return fft_c2c_maybe_out(fname, out, tmp, c2c_dims, norm, /*forward=*/false);
625 }
626 
fft_ihfftn_symint(const Tensor & self,at::OptionalSymIntArrayRef s,at::OptionalIntArrayRef dim,std::optional<c10::string_view> norm)627 Tensor fft_ihfftn_symint(
628     const Tensor& self,
629     at::OptionalSymIntArrayRef s,
630     at::OptionalIntArrayRef dim,
631     std::optional<c10::string_view> norm) {
632   return fft_ihfftn_impl(self, s, dim, norm, {});
633 }
634 
fft_ihfftn_symint_out(const Tensor & self,at::OptionalSymIntArrayRef s,at::OptionalIntArrayRef dim,std::optional<c10::string_view> norm,const Tensor & out)635 const Tensor& fft_ihfftn_symint_out(
636     const Tensor& self,
637     at::OptionalSymIntArrayRef s,
638     at::OptionalIntArrayRef dim,
639     std::optional<c10::string_view> norm,
640     const Tensor& out) {
641   fft_ihfftn_impl(self, s, dim, norm, out);
642   return out;
643 }
644 
fft_fft2_symint(const Tensor & self,at::OptionalSymIntArrayRef s,IntArrayRef dim,std::optional<c10::string_view> norm)645 Tensor fft_fft2_symint(const Tensor& self, at::OptionalSymIntArrayRef s,
646                 IntArrayRef dim, std::optional<c10::string_view> norm) {
647   return native::fft_fftn_symint(self, s, dim, std::move(norm));
648 }
649 
fft_fft2_symint_out(const Tensor & self,at::OptionalSymIntArrayRef s,IntArrayRef dim,std::optional<c10::string_view> norm,Tensor & out)650 Tensor& fft_fft2_symint_out(const Tensor& self, at::OptionalSymIntArrayRef s,
651                      IntArrayRef dim, std::optional<c10::string_view> norm, Tensor& out) {
652   return native::fft_fftn_symint_out(self, s, dim, std::move(norm), out);
653 }
654 
fft_ifft2_symint(const Tensor & self,at::OptionalSymIntArrayRef s,IntArrayRef dim,std::optional<c10::string_view> norm)655 Tensor fft_ifft2_symint(const Tensor& self, at::OptionalSymIntArrayRef s,
656                 IntArrayRef dim, std::optional<c10::string_view> norm) {
657   return native::fft_ifftn_symint(self, s, dim, std::move(norm));
658 }
659 
fft_ifft2_symint_out(const Tensor & self,at::OptionalSymIntArrayRef s,IntArrayRef dim,std::optional<c10::string_view> norm,Tensor & out)660 Tensor& fft_ifft2_symint_out(const Tensor& self, at::OptionalSymIntArrayRef s,
661                       IntArrayRef dim, std::optional<c10::string_view> norm, Tensor& out) {
662   return native::fft_ifftn_symint_out(self, s, dim, std::move(norm), out);
663 }
664 
fft_rfft2_symint(const Tensor & self,at::OptionalSymIntArrayRef s,IntArrayRef dim,std::optional<c10::string_view> norm)665 Tensor fft_rfft2_symint(const Tensor& self, at::OptionalSymIntArrayRef s,
666                 IntArrayRef dim, std::optional<c10::string_view> norm) {
667   return native::fft_rfftn_symint(self, s, dim, std::move(norm));
668 }
669 
fft_rfft2_symint_out(const Tensor & self,at::OptionalSymIntArrayRef s,IntArrayRef dim,std::optional<c10::string_view> norm,Tensor & out)670 Tensor& fft_rfft2_symint_out(const Tensor& self, at::OptionalSymIntArrayRef s,
671                       IntArrayRef dim, std::optional<c10::string_view> norm, Tensor& out) {
672   return native::fft_rfftn_symint_out(self, s, dim, std::move(norm), out);
673 }
674 
fft_irfft2_symint(const Tensor & self,at::OptionalSymIntArrayRef s,IntArrayRef dim,std::optional<c10::string_view> norm)675 Tensor fft_irfft2_symint(const Tensor& self, at::OptionalSymIntArrayRef s,
676                   IntArrayRef dim, std::optional<c10::string_view> norm) {
677   return native::fft_irfftn_symint(self, s, dim, std::move(norm));
678 }
679 
fft_irfft2_symint_out(const Tensor & self,at::OptionalSymIntArrayRef s,IntArrayRef dim,std::optional<c10::string_view> norm,Tensor & out)680 Tensor& fft_irfft2_symint_out(const Tensor& self, at::OptionalSymIntArrayRef s,
681                        IntArrayRef dim, std::optional<c10::string_view> norm, Tensor& out) {
682   return native::fft_irfftn_symint_out(self, s, dim, std::move(norm), out);
683 }
684 
fft_hfft2_symint_out(const Tensor & self,at::OptionalSymIntArrayRef s,IntArrayRef dim,std::optional<c10::string_view> norm,const Tensor & out)685 const Tensor& fft_hfft2_symint_out(
686     const Tensor& self, at::OptionalSymIntArrayRef s, IntArrayRef dim,
687     std::optional<c10::string_view> norm, const Tensor& out) {
688   return native::fft_hfftn_symint_out(self, s, dim, std::move(norm), out);
689 }
690 
fft_hfft2_symint(const Tensor & self,at::OptionalSymIntArrayRef s,IntArrayRef dim,std::optional<c10::string_view> norm)691 Tensor fft_hfft2_symint(const Tensor& self, at::OptionalSymIntArrayRef s,
692                  IntArrayRef dim, std::optional<c10::string_view> norm) {
693   return native::fft_hfftn_symint(self, s, dim, std::move(norm));
694 }
695 
fft_ihfft2_symint_out(const Tensor & self,at::OptionalSymIntArrayRef s,IntArrayRef dim,std::optional<c10::string_view> norm,const Tensor & out)696 const Tensor& fft_ihfft2_symint_out(
697     const Tensor& self, at::OptionalSymIntArrayRef s, IntArrayRef dim,
698     std::optional<c10::string_view> norm, const Tensor& out) {
699   return native::fft_ihfftn_symint_out(self, s, dim, std::move(norm), out);
700 }
701 
fft_ihfft2_symint(const Tensor & self,at::OptionalSymIntArrayRef s,IntArrayRef dim,std::optional<c10::string_view> norm)702 Tensor fft_ihfft2_symint(const Tensor& self, at::OptionalSymIntArrayRef s,
703                   IntArrayRef dim, std::optional<c10::string_view> norm) {
704   return native::fft_ihfftn_symint(self, s, dim, std::move(norm));
705 }
706 
fft_fftfreq_out(int64_t n,double d,Tensor & out)707 Tensor& fft_fftfreq_out(int64_t n, double d, Tensor& out) {
708   ScalarType dtype = out.scalar_type();
709   TORCH_CHECK(at::isFloatingType(dtype) || at::isComplexType(dtype),
710               "fftfreq requires a floating point or complex dtype");
711   // TODO: arange doesn't have complex support
712   at::arange_out(out, n);
713   auto right_slice = out.slice(0, (n + 1) / 2, 0);
714   at::arange_out(right_slice, -(n/2), 0, 1);
715   return out.mul_(1.0 / (n * d));  // Slightly faster than div_(n*d)
716 }
717 
fft_fftfreq(int64_t n,double d,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)718 Tensor fft_fftfreq(int64_t n, double d,
719     std::optional<ScalarType> dtype,
720     std::optional<Layout> layout,
721     std::optional<Device> device,
722     std::optional<bool> pin_memory) {
723   // See [Note: hacky wrapper removal for TensorOptions]
724   TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
725 
726   auto out = at::empty({n}, options);
727   return native::fft_fftfreq_out(n, d, out);
728 }
729 
fft_rfftfreq_out(int64_t n,double d,Tensor & out)730 Tensor& fft_rfftfreq_out(int64_t n, double d, Tensor& out) {
731   ScalarType dtype = out.scalar_type();
732   TORCH_CHECK(at::isFloatingType(dtype) || at::isComplexType(dtype),
733               "rfftfreq requires a floating point or complex dtype");
734   // TODO: arange doesn't have complex support
735   native::arange_out(n/2 + 1, out);
736   return out.mul_(1.0 / (n * d));  // Slightly faster than div_(n*d)
737 }
738 
fft_rfftfreq(int64_t n,double d,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)739 Tensor fft_rfftfreq(int64_t n, double d,
740     std::optional<ScalarType> dtype,
741     std::optional<Layout> layout,
742     std::optional<Device> device,
743     std::optional<bool> pin_memory) {
744   // See [Note: hacky wrapper removal for TensorOptions]
745   TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
746 
747   auto out = at::empty({n/2 + 1}, options);
748   return native::fft_rfftfreq_out(n, d, out);
749 }
750 
751 // If an array dim is specified, wraps them according to self.dim().
752 // Otherwise returns a vector of all dims.
default_alldims(const Tensor & self,at::OptionalIntArrayRef dim_opt)753 static DimVector default_alldims(const Tensor& self, at::OptionalIntArrayRef dim_opt) {
754   DimVector dim;
755   if (dim_opt) {
756     IntArrayRef dim_unwrapped = *dim_opt;
757     dim.resize(dim_unwrapped.size());
758     for (const auto i : c10::irange(dim.size())) {
759       dim[i] = maybe_wrap_dim(dim_unwrapped[i], self.dim(), /*wrap_scalars=*/false);
760     }
761   } else {
762     dim.resize(self.dim());
763     std::iota(dim.begin(), dim.end(), 0);
764   }
765   return dim;
766 }
767 
fft_fftshift(const Tensor & x,at::OptionalIntArrayRef dim_opt)768 Tensor fft_fftshift(const Tensor& x, at::OptionalIntArrayRef dim_opt) {
769   auto dim = default_alldims(x, dim_opt);
770 
771   SymIntArrayRef x_sizes = x.sym_sizes();
772   SymDimVector shift(dim.size());
773   for (const auto i : c10::irange(dim.size())) {
774     shift[i] = x_sizes[dim[i]] / 2;
775   }
776 
777   return at::roll_symint(x, shift, dim);
778 }
779 
fft_ifftshift(const Tensor & x,at::OptionalIntArrayRef dim_opt)780 Tensor fft_ifftshift(const Tensor& x, at::OptionalIntArrayRef dim_opt) {
781   auto dim = default_alldims(x, dim_opt);
782 
783   SymIntArrayRef x_sizes = x.sym_sizes();
784   SymDimVector shift(dim.size());
785   for (const auto i : c10::irange(dim.size())) {
786     shift[i] = (x_sizes[dim[i]] + 1) / 2;
787   }
788 
789   return at::roll_symint(x, shift, dim);
790 }
791 
792 
793 // We call the following methods via CUDA hooks because they are really only
794 // valid when CUDA is available. See native/cuda/CuFFTPlanCache.h for more details.
_cufft_get_plan_cache_max_size(DeviceIndex device_index)795 int64_t _cufft_get_plan_cache_max_size(DeviceIndex device_index) {
796   return detail::getCUDAHooks().cuFFTGetPlanCacheMaxSize(device_index);
797 }
798 
_cufft_set_plan_cache_max_size(DeviceIndex device_index,int64_t max_size)799 void _cufft_set_plan_cache_max_size(DeviceIndex device_index, int64_t max_size) {
800   detail::getCUDAHooks().cuFFTSetPlanCacheMaxSize(device_index, max_size);
801 }
802 
_cufft_get_plan_cache_size(DeviceIndex device_index)803 int64_t _cufft_get_plan_cache_size(DeviceIndex device_index) {
804   return detail::getCUDAHooks().cuFFTGetPlanCacheSize(device_index);
805 }
806 
_cufft_clear_plan_cache(DeviceIndex device_index)807 void _cufft_clear_plan_cache(DeviceIndex device_index) {
808   detail::getCUDAHooks().cuFFTClearPlanCache(device_index);
809 }
810 
811 template <typename Stream, typename T>
write_opt(Stream & SS,const std::optional<T> & value)812 static Stream& write_opt(Stream& SS, const std::optional<T>& value) {
813   if (value) {
814     SS << *value;
815   } else {
816     SS << "None";
817   }
818   return SS;
819 }
820 
821 /* Short-time Fourier Transform, for signal analysis.
822  *
823  * This is modeled after librosa but with support for complex time-domain
824  * signals and complex windows.
825  */
stft(const Tensor & self,const int64_t n_fft,const std::optional<int64_t> hop_lengthOpt,const std::optional<int64_t> win_lengthOpt,const std::optional<Tensor> & window_opt,const bool center,c10::string_view mode,const bool normalized,const std::optional<bool> onesidedOpt,const std::optional<bool> return_complexOpt)826 Tensor stft(const Tensor& self, const int64_t n_fft, const std::optional<int64_t> hop_lengthOpt,
827             const std::optional<int64_t> win_lengthOpt, const std::optional<Tensor>& window_opt,
828             const bool center, c10::string_view mode, const bool normalized,
829             const std::optional<bool> onesidedOpt, const std::optional<bool> return_complexOpt) {
830   // See [Note: hacky wrapper removal for optional tensor]
831   c10::MaybeOwned<Tensor> window_maybe_owned = at::borrow_from_optional_tensor(window_opt);
832   const Tensor& window = *window_maybe_owned;
833 
834   // Warn if window is not provided
835   if (!window.defined()) {
836     TORCH_WARN_ONCE(
837         "A window was not provided. A rectangular window will be applied,"
838         "which is known to cause spectral leakage. "
839         "Other windows such as torch.hann_window or torch.hamming_window "
840         "can are recommended to reduce spectral leakage."
841         "To suppress this warning and use a rectangular window, explicitly set "
842         "`window=torch.ones(n_fft, device=<device>)`.");
843   }
844 
845   #define REPR(SS) \
846     SS << "stft(" << self.toString() << self.sizes() << ", n_fft=" << n_fft \
847        << ", hop_length=" << hop_length << ", win_length=" << win_length \
848        << ", window="; \
849     if (window.defined()) { \
850       SS << window.toString() << "{" << window.sizes() << "}"; \
851     } else { \
852       SS << "None"; \
853     } \
854     SS << ", normalized=" << normalized << ", onesided="; \
855     write_opt(SS, onesidedOpt) << ", return_complex="; \
856     write_opt(SS, return_complexOpt) << ") "
857 
858   TORCH_CHECK(!window.defined() || window.device() == self.device(),
859               "stft input and window must be on the same device but got self on ",
860               self.device(), " and window on ", window.device())
861 
862   // default_init hop_length and win_length
863   auto hop_length = hop_lengthOpt.value_or(n_fft >> 2);
864   auto win_length = win_lengthOpt.value_or(n_fft);
865   const bool return_complex = return_complexOpt.value_or(
866       self.is_complex() || (window.defined() && window.is_complex()));
867   if (!return_complex) {
868     TORCH_CHECK(return_complexOpt.has_value(),
869         "stft requires the return_complex parameter be given for real inputs, "
870         "and will further require that return_complex=True in a future PyTorch release.");
871 
872 
873     TORCH_WARN_ONCE(
874         "stft with return_complex=False is deprecated. In a future pytorch "
875         "release, stft will return complex tensors for all inputs, and "
876         "return_complex=False will raise an error.\n"
877         "Note: you can still call torch.view_as_real on the complex output to "
878         "recover the old return format.");
879   }
880 
881   if (!at::isFloatingType(self.scalar_type()) && !at::isComplexType(self.scalar_type())) {
882     std::ostringstream ss;
883     REPR(ss) << ": expected a tensor of floating point or complex values";
884     AT_ERROR(ss.str());
885   }
886   if (self.dim() > 2 || self.dim() < 1) {
887     std::ostringstream ss;
888     REPR(ss) << ": expected a 1D or 2D tensor";
889     AT_ERROR(ss.str());
890   }
891   Tensor input = self;
892   if (self.dim() == 1) {
893     input = input.unsqueeze(0);
894   }
895 
896   if (center) {
897     const auto input_shape = input.sizes();
898     const auto input_dim = input_shape.size();
899     const auto extra_dims = std::max(size_t{3}, input_dim) - input_dim;
900     const auto pad_amount = n_fft / 2;
901 
902     DimVector extended_shape(extra_dims, 1);
903     extended_shape.append(input_shape.begin(), input_shape.end());
904     input = at::pad(input.view(extended_shape), {pad_amount, pad_amount}, mode);
905     input = input.view(IntArrayRef(input.sizes()).slice(extra_dims));
906   }
907 
908   int64_t batch = input.size(0);
909   int64_t len = input.size(1);
910   if (n_fft <= 0 || n_fft > len) {
911     std::ostringstream ss;
912     REPR(ss) << ": expected 0 < n_fft < " << len
913              << ", but got n_fft=" << win_length;
914     AT_ERROR(ss.str());
915   }
916   if (hop_length <= 0) {
917     std::ostringstream ss;
918     REPR(ss) << ": expected hop_length > 0, but got hop_length=" << hop_length;
919     AT_ERROR(ss.str());
920   }
921   if (win_length <= 0 || win_length > n_fft) {
922     std::ostringstream ss;
923     REPR(ss) << ": expected 0 < win_length <= n_fft, but got win_length="
924              << win_length;
925     AT_ERROR(ss.str());
926   }
927   if (window.defined() && (window.dim() != 1 || window.size(0) != win_length)) {
928     std::ostringstream ss;
929     REPR(ss) << ": expected a 1D window tensor of size equal to win_length="
930              << win_length << ", but got window with size " << window.sizes();
931     AT_ERROR(ss.str());
932   }
933   #undef REPR
934   auto window_ = window;
935   if (win_length < n_fft) {
936     // pad center
937     auto left = (n_fft - win_length) / 2;
938     if (window.defined()) {
939       window_ = at::zeros({n_fft}, window.options());
940       window_.narrow(0, left, win_length).copy_(window);
941     } else {
942       window_ = at::zeros({n_fft}, self.options());
943       window_.narrow(0, left, win_length).fill_(1);
944     }
945   }
946   int64_t n_frames = 1 + (len - n_fft) / hop_length;
947   // time2col
948   input = input.as_strided(
949     {batch, n_frames, n_fft},
950     {input.stride(0), hop_length * input.stride(1), input.stride(1)}
951   );
952   if (window_.defined()) {
953     input = input.mul(window_);
954   }
955 
956   // FFT and transpose to get (batch x fft_size x num_frames)
957   const bool complex_fft = input.is_complex();
958   const auto onesided = onesidedOpt.value_or(!complex_fft);
959 
960   const fft_norm_mode norm = normalized ? fft_norm_mode::by_root_n : fft_norm_mode::none;
961   Tensor out;
962   if (complex_fft) {
963     TORCH_CHECK(!onesided, "Cannot have onesided output if window or input is complex");
964     out = at::_fft_c2c(input, input.dim() - 1, static_cast<int64_t>(norm), /*forward=*/true);
965   } else {
966     out = at::_fft_r2c(input, input.dim() - 1, static_cast<int64_t>(norm), onesided);
967   }
968   out.transpose_(1, 2);
969 
970   if (self.dim() == 1) {
971     out.squeeze_(0);
972   }
973 
974   if (return_complex) {
975     return out;
976   } else {
977     return at::view_as_real(out);
978   }
979 }
980 
stft(const Tensor & self,const int64_t n_fft,const std::optional<int64_t> hop_lengthOpt,const std::optional<int64_t> win_lengthOpt,const std::optional<Tensor> & window_opt,const bool normalized,const std::optional<bool> onesidedOpt,const std::optional<bool> return_complexOpt)981 Tensor stft(
982     const Tensor& self, const int64_t n_fft, const std::optional<int64_t> hop_lengthOpt,
983     const std::optional<int64_t> win_lengthOpt, const std::optional<Tensor>& window_opt,
984     const bool normalized,
985     const std::optional<bool> onesidedOpt, const std::optional<bool> return_complexOpt) {
986   return at::stft(
987       self, n_fft, hop_lengthOpt, win_lengthOpt, window_opt,
988       /*center=*/false, /*mode=*/"constant", normalized, onesidedOpt,
989       return_complexOpt);
990 }
991 
992 // Create complex tensor from the old style of real tensor with size=(..., 2)
993 // This is to support istft in the transition to requiring complex input.
994 // NOTE: This may return a view of the input tensor, or might clone if necessary
as_complex(const Tensor & self)995 static Tensor as_complex(const Tensor& self) {
996   const bool can_view_as_complex = [&]{
997     auto strides = self.strides();
998     for (const auto i : c10::irange(static_cast<int64_t>(strides.size()) - 1)) {
999       if (strides[i] % 2 != 0) {
1000         return false;
1001       }
1002     }
1003     return strides.back() == 1 && self.storage_offset() % 2 == 0;
1004   }();
1005   return at::view_as_complex(can_view_as_complex ? self : self.clone(MemoryFormat::Contiguous));
1006 }
1007 
1008 /* Inverse Short-time Fourier Transform
1009  *
1010  * This is modeled after librosa but with support for complex time-domain
1011  * signals and complex windows.
1012  */
istft(const Tensor & self,const int64_t n_fft,const std::optional<int64_t> hop_lengthOpt,const std::optional<int64_t> win_lengthOpt,const std::optional<Tensor> & window_opt,const bool center,const bool normalized,const std::optional<bool> onesidedOpt,const std::optional<int64_t> lengthOpt,const bool return_complex)1013 Tensor istft(const Tensor& self, const int64_t n_fft, const std::optional<int64_t> hop_lengthOpt,
1014              const std::optional<int64_t> win_lengthOpt, const std::optional<Tensor>& window_opt,
1015              const bool center, const bool normalized, const std::optional<bool> onesidedOpt,
1016              const std::optional<int64_t> lengthOpt, const bool return_complex) {
1017   // See [Note: hacky wrapper removal for optional tensor]
1018   c10::MaybeOwned<Tensor> window_maybe_owned = at::borrow_from_optional_tensor(window_opt);
1019   const Tensor& window = *window_maybe_owned;
1020 
1021   // Warn if window is not provided
1022   if (!window.defined()) {
1023     TORCH_WARN_ONCE(
1024         "A window was not provided. A rectangular window will be applied."
1025         "Please provide the same window used by stft to make the inversion "
1026         "lossless."
1027         "To suppress this warning and use a rectangular window, explicitly set "
1028         "`window=torch.ones(n_fft, device=<device>)`.");
1029   }
1030 
1031   #define REPR(SS) \
1032     SS << "istft(" << self.toString() << self.sizes() << ", n_fft=" << n_fft \
1033        << ", hop_length=" << hop_length << ", win_length=" << win_length \
1034        << ", window="; \
1035     if (window.defined()) { \
1036       SS << window.toString() << "{" << window.sizes() << "}"; \
1037     } else { \
1038       SS << "None"; \
1039     } \
1040     SS << ", center=" << center << ", normalized=" << normalized << ", onesided="; \
1041     write_opt(SS, onesidedOpt) << ", length="; \
1042     write_opt(SS, lengthOpt) << ", return_complex=" << return_complex << ") "
1043 
1044   TORCH_CHECK(!window.defined() || window.device() == self.device(),
1045               "istft input and window must be on the same device but got self on ",
1046               self.device(), " and window on ", window.device())
1047 
1048   // default_init hop_length and win_length
1049   const auto hop_length = hop_lengthOpt.value_or(n_fft >> 2);
1050   const auto win_length = win_lengthOpt.value_or(n_fft);
1051 
1052   TORCH_CHECK(self.is_complex(),
1053               "istft requires a complex-valued input tensor matching the "
1054               "output from stft with return_complex=True.");
1055   Tensor input = at::view_as_real(self.resolve_conj());
1056   const auto input_dim = input.dim();
1057   const auto n_frames = input.size(-2);
1058   const auto fft_size = input.size(-3);
1059 
1060   const auto expected_output_signal_len = n_fft + hop_length * (n_frames - 1);
1061 
1062   const auto options = at::device(input.device()).dtype(input.dtype());
1063   if (input.numel() == 0) {
1064     std::ostringstream ss;
1065     REPR(ss) << ": input tensor cannot be empty.";
1066     AT_ERROR(ss.str());
1067   }
1068   if (input_dim != 3 && input_dim != 4) {
1069     std::ostringstream ss;
1070     REPR(ss) << ": expected a tensor with 3 or 4 dimensions, but got " << input_dim;
1071     AT_ERROR(ss.str());
1072   }
1073   if (input.size(-1) != 2) {
1074     std::ostringstream ss;
1075     REPR(ss) << ": expected the last dimension to be 2 (corresponding to real and imaginary parts), but got " << self.size(-1);
1076     AT_ERROR(ss.str());
1077   }
1078 
1079   const bool onesided = onesidedOpt.value_or(fft_size != n_fft);
1080   if (onesided) {
1081     if (n_fft / 2 + 1 != fft_size) {
1082       std::ostringstream ss;
1083       REPR(ss) << ": expected the frequency dimension (3rd to the last) of the input tensor to match n_fft / 2 + 1 when onesided=True, but got " << fft_size;
1084       AT_ERROR(ss.str());
1085     }
1086   } else {
1087     if (n_fft != fft_size) {
1088       std::ostringstream ss;
1089       REPR(ss) << ": expected the frequency dimension (3rd to the last) of the input tensor to match n_fft when onesided=False, but got " << fft_size;
1090       AT_ERROR(ss.str());
1091     }
1092   }
1093 
1094   if (!(0 < hop_length && hop_length <= win_length)) {
1095     std::ostringstream ss;
1096     REPR(ss) << ": expected 0 < hop_length <= win_length";
1097     AT_ERROR(ss.str());
1098   }
1099 
1100   if (!(0 < win_length && win_length <= n_fft)) {
1101     std::ostringstream ss;
1102     REPR(ss) << ": expected 0 < win_length <= n_fft";
1103     AT_ERROR(ss.str());
1104   }
1105   if (window.defined()) {
1106     if (window.dim() != 1 || window.size(0) != win_length) {
1107       std::ostringstream ss;
1108       REPR(ss) << ": Invalid window shape. window has to be 1D and length of `win_length`";
1109       AT_ERROR(ss.str());
1110     }
1111   }
1112 
1113   Tensor window_tmp = window.defined() ? window : at::ones({win_length,}, options);
1114   if (win_length != n_fft) {
1115     // center window by padding zeros on right and left side
1116     int64_t left = (n_fft - win_length) / 2;
1117     window_tmp = at::constant_pad_nd(window_tmp, {left, n_fft - win_length - left}, 0);
1118     TORCH_INTERNAL_ASSERT(window_tmp.size(0) == n_fft);
1119   }
1120 
1121   if (input_dim == 3) {
1122     input = input.unsqueeze(0);
1123   }
1124 
1125   input = as_complex(input.transpose(1, 2));  // size: (channel, n_frames, fft_size)
1126 
1127   const fft_norm_mode norm = normalized ? fft_norm_mode::by_root_n : fft_norm_mode::by_n;
1128   if (return_complex) {
1129     TORCH_CHECK(!onesided, "Cannot have onesided output if window or input is complex");
1130     input = at::_fft_c2c(input, input.dim() - 1, static_cast<int64_t>(norm), /*forward=*/false);  // size: (channel, n_frames, n_fft)
1131   } else {
1132     TORCH_CHECK(!window.defined() || !window.is_complex(),
1133                 "Complex windows are incompatible with return_complex=False");
1134     if (!onesided) {
1135       input = input.slice(-1, 0, n_fft / 2 + 1);
1136     }
1137     input = at::_fft_c2r(input, input.dim() - 1, static_cast<int64_t>(norm), n_fft);  // size: (channel, n_frames, n_fft)
1138   }
1139   TORCH_INTERNAL_ASSERT(input.size(2) == n_fft);
1140 
1141   Tensor y_tmp = input * window_tmp.view({1, 1, n_fft});  // size: (channel, n_frames, n_fft)
1142 
1143   Tensor y = at::unfold_backward(
1144     y_tmp,
1145     /*input_sizes=*/{y_tmp.size(0), expected_output_signal_len},
1146     /*dim=*/1,
1147     /*size=*/n_fft,
1148     /*step=*/hop_length);
1149   window_tmp = window_tmp.pow(2).expand({1, n_frames, n_fft});  // size: (1, n_frames, n_fft)
1150   Tensor window_envelop = at::unfold_backward(
1151     window_tmp,
1152     /*input_sizes=*/{1, expected_output_signal_len},
1153     /*dim=*/1,
1154     /*size=*/n_fft,
1155     /*step=*/hop_length); // size: (1, expected_output_signal_len)
1156 
1157   TORCH_INTERNAL_ASSERT(expected_output_signal_len == y.size(1));
1158   TORCH_INTERNAL_ASSERT(expected_output_signal_len == window_envelop.size(1));
1159 
1160   // We need to trim the front padding away if centered
1161   const auto start = center ? n_fft / 2 : 0;
1162   const auto end = [&] () -> int64_t {
1163     if (lengthOpt.has_value()) {
1164       return start + *lengthOpt;
1165     }
1166     if (center) {
1167       return -(n_fft / 2);
1168     }
1169     return expected_output_signal_len;
1170   }();
1171 
1172   y = y.slice(1, start, end, 1);
1173   window_envelop = window_envelop.slice(1, start, end, 1);
1174   const auto window_envelop_lowest = window_envelop.abs().min().lt(1e-11);
1175   if (at::is_scalar_tensor_true(window_envelop_lowest)) {
1176     std::ostringstream ss;
1177     REPR(ss) << "window overlap add min: " << window_envelop_lowest;
1178     AT_ERROR(ss.str());
1179   }
1180 
1181   y = (y / window_envelop);  // size: (channel, expected_output_signal_len)
1182   if (input_dim == 3) {
1183     y = y.squeeze(0);
1184   }
1185   // zero padding if the given lengthOpt is longer than expected
1186   if(end > expected_output_signal_len) {
1187     TORCH_WARN_ONCE(
1188       "The length of signal is shorter than the length parameter. Result is being padded with zeros in the tail. "
1189       "Please check your center and hop_length settings."
1190     );
1191     y = at::constant_pad_nd(y, {0, end - expected_output_signal_len}, 0);
1192   }
1193   return y;
1194 
1195 #undef REPR
1196 }
1197 
_fft_fill_with_conjugate_symmetry_(const Tensor & input,IntArrayRef dim_)1198 void _fft_fill_with_conjugate_symmetry_(const Tensor& input, IntArrayRef dim_) {
1199   const auto input_sizes = input.sizes();
1200   const auto input_strides = input.strides();
1201   TORCH_CHECK(!dim_.empty());
1202   DimVector dim(dim_.begin(), dim_.end());
1203   at::maybe_wrap_dims(dim, input_strides.size(), /*wrap_scalars=*/false);
1204 
1205   if (input.numel() == 0 || input_sizes[dim.back()] <= 2) {
1206     return;  // No elements need writing
1207   }
1208 
1209   // Small dimensions may be treated as batch dims since they don't get mirrored
1210   dim.erase(
1211       std::remove_if(dim.begin(), dim.end(), [&](int64_t dim) {
1212         return (input_sizes[dim] <= 2);
1213       }),
1214       dim.end());
1215 
1216   // Use TensorIterator to coalesce batch dimensions
1217   // NOTE: Can't use TensorIterator loops because we need negative strides
1218   auto iter = TensorIteratorConfig()
1219       .add_output(input)
1220       .add_input(input)
1221       .resize_outputs(false)
1222       .declare_static_shape(input_sizes, dim)
1223       .build();
1224 
1225   const auto iter_strides = iter.strides(0);
1226   const auto iter_sizes = iter.shape();
1227   const auto ndim = static_cast<int64_t>(iter_strides.size() + dim.size());
1228   DimVector in_strides(ndim), signal_half_sizes(ndim);
1229   // Take coalesced batch dimensions from TensorIterator
1230   std::copy(iter_strides.begin(), iter_strides.end(), in_strides.begin());
1231   std::copy(iter_sizes.begin(), iter_sizes.end(), signal_half_sizes.begin());
1232 
1233   // Take transformed dimensions directly from the input
1234   const auto element_size = iter.element_size(0);
1235   for (const auto i : c10::irange(dim.size())) {
1236     // Convert to byte strides to match TensorIterator
1237     in_strides[iter_strides.size() + i] = input_strides[dim[i]] * element_size;
1238     signal_half_sizes[iter_strides.size() + i] = input_sizes[dim[i]];
1239   }
1240 
1241   // For the last dimension, use negative strides to perform the mirroring
1242   signal_half_sizes.back() = (input_sizes[dim.back()] - 1) / 2;
1243   auto out_strides = in_strides;
1244   out_strides.back() *= -1;
1245 
1246   auto* data_ptr = static_cast<char*>(input.data_ptr());
1247   const auto* in_data = data_ptr + input_strides[dim.back()] * element_size;
1248   auto* out_data = data_ptr + (
1249       input_strides[dim.back()] * (input_sizes[dim.back()] - 1) * element_size);
1250 
1251   // Reorder dimensions by stride to maximize data locality
1252   DimVector dim_permute(ndim);
1253   std::iota(dim_permute.begin(), dim_permute.end(), 0);
1254   std::sort(dim_permute.begin(), dim_permute.end(),
1255       [&](auto dim1, auto dim2) {
1256         return in_strides[dim1] < in_strides[dim2];
1257       });
1258 
1259   DimVector temp(ndim);
1260   auto apply_permutation = [&] (DimVector & vec) {
1261     // Do permuted index copy into a temporary, then copy back
1262     for (const auto i : c10::irange(ndim)) {
1263       temp[i] = vec[dim_permute[i]];
1264     }
1265     vec = temp;
1266   };
1267   apply_permutation(in_strides);
1268   apply_permutation(out_strides);
1269   apply_permutation(signal_half_sizes);
1270 
1271   // Find dims.slice(dims.size() - 1) in the new permuted order.
1272   // These are the dimensions that need explicit Hermitian mirroring
1273   DimVector mirror_dims;
1274   mirror_dims.reserve(dim.size() - 1);
1275   for (const auto i : c10::irange(ndim)) {
1276     if (dim_permute[i] >= static_cast<int64_t>(iter_strides.size()) &&  // Not a batch dimension
1277         dim_permute[i] != ndim - 1) {  // Not the last dim, which is mirrored separately with negative strides
1278       mirror_dims.push_back(i);
1279     }
1280   }
1281   TORCH_INTERNAL_ASSERT(mirror_dims.size() == dim.size() - 1);
1282 
1283   // Dispatch to CPU or CUDA kernel to do the actual conjugate mirroring
1284   fft_fill_with_conjugate_symmetry_stub(
1285       input.device().type(), input.scalar_type(),
1286       mirror_dims, signal_half_sizes, in_strides, in_data, out_strides, out_data);
1287 }
1288 
1289 DEFINE_DISPATCH(fft_fill_with_conjugate_symmetry_stub);
1290 
1291 } // namespace at::native
1292