1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/cuda/CUDAContext.h>
3 #include <ATen/Config.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/cuda/detail/KernelUtils.h>
6 #include <ATen/cuda/detail/OffsetCalculator.cuh>
7 #include <ATen/detail/CUDAHooksInterface.h>
8 #include <ATen/native/SpectralOpsUtils.h>
9
10 #include <cmath>
11 #include <vector>
12
13
14 namespace at::native {
15
16 // Offset calculator for indexing in Hermitian mirrored order.
17 // In mirrored dims, maps linear index i to (n - i) % n
18 template <typename index_t>
19 struct HermitianSymmetryOffsetCalculator {
20 using offset_type = at::detail::Array<index_t, 1>;
21 using dim_type = std::remove_cv_t<decltype(MAX_DIMS)>;
22 dim_type dims;
23 at::cuda::detail::IntDivider<index_t> sizes_[MAX_DIMS];
24 index_t strides_[MAX_DIMS];
25 uint32_t mirror_dim_; // bit mask
26 static_assert(MAX_DIMS < 32, "Need a bigger mask type");
27
HermitianSymmetryOffsetCalculatorat::native::HermitianSymmetryOffsetCalculator28 HermitianSymmetryOffsetCalculator(
29 IntArrayRef sizes, IntArrayRef strides, IntArrayRef dim,
30 const int64_t element_size){
31 TORCH_INTERNAL_ASSERT(sizes.size() == strides.size());
32 TORCH_INTERNAL_ASSERT(sizes.size() <= MAX_DIMS);
33 dims = sizes.size();
34
35 using at::cuda::detail::IntDivider;
36 for (dim_type i = 0; i < MAX_DIMS; ++i) {
37 if (i < dims) {
38 sizes_[i] = IntDivider<index_t>(sizes[i]);
39 strides_[i] = strides[i] / element_size;
40 } else {
41 sizes_[i] = IntDivider<index_t>(1);
42 strides_[i] = 0;
43 }
44 }
45
46 mirror_dim_ = 0;
47 for (const auto i: c10::irange(dim.size())) {
48 mirror_dim_ |= (uint32_t{1} << dim[i]);
49 }
50 }
51
getat::native::HermitianSymmetryOffsetCalculator52 C10_HOST_DEVICE offset_type get(index_t linear_idx) const {
53 index_t offset = 0;
54
55 for (dim_type dim = 0; dim < dims; ++dim) {
56 auto divmod = sizes_[dim].divmod(linear_idx);
57 linear_idx = divmod.div;
58
59 if ((mirror_dim_ & (uint32_t{1} << dim)) == 0) {
60 offset += divmod.mod * strides_[dim];
61 } else if (divmod.mod != 0) {
62 offset += (sizes_[dim].divisor - divmod.mod) * strides_[dim];
63 }
64 }
65 offset_type offsets;
66 offsets[0] = offset;
67 return offsets;
68 }
69 };
70
71
72 // out[:] = conj(in[:]) where in and out ordering is generalized by offset calculators
73 template <typename scalar_t, typename inp_calc_t, typename out_calc_t>
C10_LAUNCH_BOUNDS_1(cuda::detail::CUDA_NUM_THREADS)74 C10_LAUNCH_BOUNDS_1(cuda::detail::CUDA_NUM_THREADS)
75 __global__ void _fft_conjugate_copy_kernel(
76 int64_t numel, scalar_t * out_data, const scalar_t * in_data,
77 inp_calc_t ic, out_calc_t oc) {
78 CUDA_KERNEL_LOOP_TYPE(index, numel, int64_t) {
79 auto in_offset = ic.get(index)[0];
80 auto out_offset = oc.get(index)[0];
81 out_data[out_offset] = std::conj(in_data[in_offset]);
82 }
83 }
84
85 // In real-to-complex transform, cuFFT only fills half of the values due to
86 // conjugate symmetry. See native/SpectralUtils.h for more details.
87 // The following function fills in the other half with symmetry in
88 // case of real-to-complex transform with onesided=False flag.
89 // See NOTE [ Fourier Transform Conjugate Symmetry ] in native/SpectralOpsUtils.h.
90
91 // input should be a tensor of same size as full (twosided)
92 // signals, but only contains half (onesided) of the values.
93 // This function modifies inplace.
_fft_fill_with_conjugate_symmetry_cuda_(ScalarType dtype,IntArrayRef mirror_dims,IntArrayRef signal_half_sizes,IntArrayRef in_strides,const void * in_data,IntArrayRef out_strides,void * out_data)94 void _fft_fill_with_conjugate_symmetry_cuda_(
95 ScalarType dtype, IntArrayRef mirror_dims, IntArrayRef signal_half_sizes,
96 IntArrayRef in_strides, const void * in_data,
97 IntArrayRef out_strides, void * out_data) {
98 // Do the actual conjugate mirroring.
99 // TODO: consider adding a 32bit indexed kernel for improved performance
100 auto* in_strides_ptr = in_strides.data();
101 const int ndim = in_strides.size();
102 const int64_t element_size = scalarTypeToTypeMeta(dtype).itemsize();
103 OffsetCalculator<1, int64_t> input_offset_calculator(
104 ndim, signal_half_sizes.data(), &in_strides_ptr, &element_size);
105 HermitianSymmetryOffsetCalculator<int64_t> output_offset_calculator(
106 signal_half_sizes, out_strides, mirror_dims, element_size);
107
108 const auto numel = c10::multiply_integers(signal_half_sizes);
109 AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "_fft_fill_with_conjugate_symmetry", [&] {
110 using namespace cuda::detail;
111 _fft_conjugate_copy_kernel<<<
112 GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
113 numel,
114 static_cast<scalar_t*>(out_data),
115 static_cast<const scalar_t*>(in_data),
116 input_offset_calculator,
117 output_offset_calculator);
118 C10_CUDA_KERNEL_LAUNCH_CHECK();
119 });
120 }
121
122 REGISTER_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cuda_);
123
124 } // at::native
125