xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/SpectralOps.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/cuda/CUDAContext.h>
3 #include <ATen/Config.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/cuda/detail/KernelUtils.h>
6 #include <ATen/cuda/detail/OffsetCalculator.cuh>
7 #include <ATen/detail/CUDAHooksInterface.h>
8 #include <ATen/native/SpectralOpsUtils.h>
9 
10 #include <cmath>
11 #include <vector>
12 
13 
14 namespace at::native {
15 
16 // Offset calculator for indexing in Hermitian mirrored order.
17 // In mirrored dims, maps linear index i to (n - i) % n
18 template <typename index_t>
19 struct HermitianSymmetryOffsetCalculator {
20   using offset_type = at::detail::Array<index_t, 1>;
21   using dim_type = std::remove_cv_t<decltype(MAX_DIMS)>;
22   dim_type dims;
23   at::cuda::detail::IntDivider<index_t> sizes_[MAX_DIMS];
24   index_t strides_[MAX_DIMS];
25   uint32_t mirror_dim_;  // bit mask
26   static_assert(MAX_DIMS < 32, "Need a bigger mask type");
27 
HermitianSymmetryOffsetCalculatorat::native::HermitianSymmetryOffsetCalculator28   HermitianSymmetryOffsetCalculator(
29       IntArrayRef sizes, IntArrayRef strides, IntArrayRef dim,
30       const int64_t element_size){
31     TORCH_INTERNAL_ASSERT(sizes.size() == strides.size());
32     TORCH_INTERNAL_ASSERT(sizes.size() <= MAX_DIMS);
33     dims = sizes.size();
34 
35     using at::cuda::detail::IntDivider;
36     for (dim_type i = 0; i < MAX_DIMS; ++i) {
37       if (i < dims) {
38         sizes_[i] = IntDivider<index_t>(sizes[i]);
39         strides_[i] = strides[i] / element_size;
40       } else {
41         sizes_[i] = IntDivider<index_t>(1);
42         strides_[i] = 0;
43       }
44     }
45 
46     mirror_dim_ = 0;
47     for (const auto i: c10::irange(dim.size())) {
48       mirror_dim_ |= (uint32_t{1} << dim[i]);
49     }
50   }
51 
getat::native::HermitianSymmetryOffsetCalculator52   C10_HOST_DEVICE offset_type get(index_t linear_idx) const {
53     index_t offset = 0;
54 
55     for (dim_type dim = 0; dim < dims; ++dim) {
56       auto divmod = sizes_[dim].divmod(linear_idx);
57       linear_idx = divmod.div;
58 
59       if ((mirror_dim_ & (uint32_t{1} << dim)) == 0) {
60         offset += divmod.mod * strides_[dim];
61       } else if (divmod.mod != 0) {
62         offset += (sizes_[dim].divisor - divmod.mod) * strides_[dim];
63       }
64     }
65     offset_type offsets;
66     offsets[0] = offset;
67     return offsets;
68   }
69 };
70 
71 
72 // out[:] = conj(in[:]) where in and out ordering is generalized by offset calculators
73 template <typename scalar_t, typename inp_calc_t, typename out_calc_t>
C10_LAUNCH_BOUNDS_1(cuda::detail::CUDA_NUM_THREADS)74 C10_LAUNCH_BOUNDS_1(cuda::detail::CUDA_NUM_THREADS)
75 __global__ void _fft_conjugate_copy_kernel(
76     int64_t numel, scalar_t * out_data, const scalar_t * in_data,
77     inp_calc_t ic, out_calc_t oc) {
78   CUDA_KERNEL_LOOP_TYPE(index, numel, int64_t) {
79     auto in_offset = ic.get(index)[0];
80     auto out_offset = oc.get(index)[0];
81     out_data[out_offset] = std::conj(in_data[in_offset]);
82   }
83 }
84 
85 // In real-to-complex transform, cuFFT only fills half of the values due to
86 // conjugate symmetry. See native/SpectralUtils.h for more details.
87 // The following function fills in the other half with symmetry in
88 // case of real-to-complex transform with onesided=False flag.
89 // See NOTE [ Fourier Transform Conjugate Symmetry ] in native/SpectralOpsUtils.h.
90 
91 // input should be a tensor of same size as full (twosided)
92 // signals, but only contains half (onesided) of the values.
93 // This function modifies inplace.
_fft_fill_with_conjugate_symmetry_cuda_(ScalarType dtype,IntArrayRef mirror_dims,IntArrayRef signal_half_sizes,IntArrayRef in_strides,const void * in_data,IntArrayRef out_strides,void * out_data)94 void _fft_fill_with_conjugate_symmetry_cuda_(
95     ScalarType dtype, IntArrayRef mirror_dims, IntArrayRef signal_half_sizes,
96     IntArrayRef in_strides, const void * in_data,
97     IntArrayRef out_strides, void * out_data) {
98   // Do the actual conjugate mirroring.
99   // TODO: consider adding a 32bit indexed kernel for improved performance
100   auto* in_strides_ptr = in_strides.data();
101   const int ndim = in_strides.size();
102   const int64_t element_size = scalarTypeToTypeMeta(dtype).itemsize();
103   OffsetCalculator<1, int64_t> input_offset_calculator(
104       ndim, signal_half_sizes.data(), &in_strides_ptr, &element_size);
105   HermitianSymmetryOffsetCalculator<int64_t> output_offset_calculator(
106       signal_half_sizes, out_strides, mirror_dims, element_size);
107 
108   const auto numel = c10::multiply_integers(signal_half_sizes);
109   AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "_fft_fill_with_conjugate_symmetry", [&] {
110       using namespace cuda::detail;
111       _fft_conjugate_copy_kernel<<<
112         GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
113             numel,
114             static_cast<scalar_t*>(out_data),
115             static_cast<const scalar_t*>(in_data),
116             input_offset_calculator,
117             output_offset_calculator);
118       C10_CUDA_KERNEL_LAUNCH_CHECK();
119     });
120 }
121 
122 REGISTER_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cuda_);
123 
124 } // at::native
125