xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/stream_executor/fft.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Exposes the family of FFT routines as pre-canned high performance calls for
17 // use in conjunction with the StreamExecutor abstraction.
18 //
19 // Note that this interface is optionally supported by platforms; see
20 // StreamExecutor::SupportsFft() for details.
21 //
22 // This abstraction makes it simple to entrain FFT operations on GPU data into
23 // a Stream -- users typically will not use this API directly, but will use the
24 // Stream builder methods to entrain these operations "under the hood". For
25 // example:
26 //
27 //  DeviceMemory<std::complex<float>> x =
28 //    stream_exec->AllocateArray<std::complex<float>>(1024);
29 //  DeviceMemory<std::complex<float>> y =
30 //    stream_exec->AllocateArray<std::complex<float>>(1024);
31 //  // ... populate x and y ...
32 //  Stream stream{stream_exec};
33 //  std::unique_ptr<Plan> plan =
34 //     stream_exec.AsFft()->Create1dPlan(&stream, 1024, Type::kC2CForward);
35 //  stream
36 //    .Init()
37 //    .ThenFft(plan.get(), x, &y);
38 //  SE_CHECK_OK(stream.BlockHostUntilDone());
39 //
40 // By using stream operations in this manner the user can easily intermix custom
41 // kernel launches (via StreamExecutor::ThenLaunch()) with these pre-canned FFT
42 // routines.
43 
44 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_FFT_H_
45 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_FFT_H_
46 
47 #include <complex>
48 #include <memory>
49 
50 #include "tensorflow/compiler/xla/stream_executor/platform/port.h"
51 
52 namespace stream_executor {
53 
54 class Stream;
55 template <typename ElemT>
56 class DeviceMemory;
57 class ScratchAllocator;
58 
59 namespace fft {
60 
61 // Specifies FFT input and output types, and the direction.
62 // R, D, C, and Z stand for SP real, DP real, SP complex, and DP complex.
63 enum class Type {
64   kInvalid,
65   kC2CForward,
66   kC2CInverse,
67   kC2R,
68   kR2C,
69   kZ2ZForward,
70   kZ2ZInverse,
71   kZ2D,
72   kD2Z
73 };
74 
75 // FFT plan class. Each FFT implementation should define a plan class that is
76 // derived from this class. It does not provide any interface but serves
77 // as a common type that is used to execute the plan.
78 class Plan {
79  public:
~Plan()80   virtual ~Plan() {}
81 };
82 
83 // FFT support interface -- this can be derived from a GPU executor when the
84 // underlying platform has an FFT library implementation available. See
85 // StreamExecutor::AsFft().
86 //
87 // This support interface is not generally thread-safe; it is only thread-safe
88 // for the CUDA platform (cuFFT) usage; host side FFT support is known
89 // thread-compatible, but not thread-safe.
90 class FftSupport {
91  public:
~FftSupport()92   virtual ~FftSupport() {}
93 
94   // Creates a 1d FFT plan.
95   virtual std::unique_ptr<Plan> Create1dPlan(Stream *stream, uint64_t num_x,
96                                              Type type, bool in_place_fft) = 0;
97 
98   // Creates a 2d FFT plan.
99   virtual std::unique_ptr<Plan> Create2dPlan(Stream *stream, uint64_t num_x,
100                                              uint64_t num_y, Type type,
101                                              bool in_place_fft) = 0;
102 
103   // Creates a 3d FFT plan.
104   virtual std::unique_ptr<Plan> Create3dPlan(Stream *stream, uint64_t num_x,
105                                              uint64_t num_y, uint64 num_z,
106                                              Type type, bool in_place_fft) = 0;
107 
108   // Creates a 1d FFT plan with scratch allocator.
109   virtual std::unique_ptr<Plan> Create1dPlanWithScratchAllocator(
110       Stream *stream, uint64_t num_x, Type type, bool in_place_fft,
111       ScratchAllocator *scratch_allocator) = 0;
112 
113   // Creates a 2d FFT plan with scratch allocator.
114   virtual std::unique_ptr<Plan> Create2dPlanWithScratchAllocator(
115       Stream *stream, uint64_t num_x, uint64 num_y, Type type,
116       bool in_place_fft, ScratchAllocator *scratch_allocator) = 0;
117 
118   // Creates a 3d FFT plan with scratch allocator.
119   virtual std::unique_ptr<Plan> Create3dPlanWithScratchAllocator(
120       Stream *stream, uint64_t num_x, uint64 num_y, uint64 num_z, Type type,
121       bool in_place_fft, ScratchAllocator *scratch_allocator) = 0;
122 
123   // Creates a batched FFT plan.
124   //
125   // stream:          The GPU stream in which the FFT runs.
126   // rank:            Dimensionality of the transform (1, 2, or 3).
127   // elem_count:      Array of size rank, describing the size of each dimension.
128   // input_embed, output_embed:
129   //                  Pointer of size rank that indicates the storage dimensions
130   //                  of the input/output data in memory. If set to null_ptr all
131   //                  other advanced data layout parameters are ignored.
132   // input_stride:    Indicates the distance (number of elements; same below)
133   //                  between two successive input elements.
134   // input_distance:  Indicates the distance between the first element of two
135   //                  consecutive signals in a batch of the input data.
136   // output_stride:   Indicates the distance between two successive output
137   //                  elements.
138   // output_distance: Indicates the distance between the first element of two
139   //                  consecutive signals in a batch of the output data.
140   virtual std::unique_ptr<Plan> CreateBatchedPlan(
141       Stream *stream, int rank, uint64_t *elem_count, uint64 *input_embed,
142       uint64_t input_stride, uint64 input_distance, uint64 *output_embed,
143       uint64_t output_stride, uint64 output_distance, Type type,
144       bool in_place_fft, int batch_count) = 0;
145 
146   // Creates a batched FFT plan with scratch allocator.
147   //
148   // stream:          The GPU stream in which the FFT runs.
149   // rank:            Dimensionality of the transform (1, 2, or 3).
150   // elem_count:      Array of size rank, describing the size of each dimension.
151   // input_embed, output_embed:
152   //                  Pointer of size rank that indicates the storage dimensions
153   //                  of the input/output data in memory. If set to null_ptr all
154   //                  other advanced data layout parameters are ignored.
155   // input_stride:    Indicates the distance (number of elements; same below)
156   //                  between two successive input elements.
157   // input_distance:  Indicates the distance between the first element of two
158   //                  consecutive signals in a batch of the input data.
159   // output_stride:   Indicates the distance between two successive output
160   //                  elements.
161   // output_distance: Indicates the distance between the first element of two
162   //                  consecutive signals in a batch of the output data.
163   virtual std::unique_ptr<Plan> CreateBatchedPlanWithScratchAllocator(
164       Stream *stream, int rank, uint64_t *elem_count, uint64 *input_embed,
165       uint64_t input_stride, uint64 input_distance, uint64 *output_embed,
166       uint64_t output_stride, uint64 output_distance, Type type,
167       bool in_place_fft, int batch_count,
168       ScratchAllocator *scratch_allocator) = 0;
169 
170   // Updates the plan's work area with space allocated by a new scratch
171   // allocator. This facilitates plan reuse with scratch allocators.
172   //
173   // This requires that the plan was originally created using a scratch
174   // allocator, as otherwise scratch space will have been allocated internally
175   // by cuFFT.
176   virtual void UpdatePlanWithScratchAllocator(
177       Stream *stream, Plan *plan, ScratchAllocator *scratch_allocator) = 0;
178 
179   // Computes complex-to-complex FFT in the transform direction as specified
180   // by direction parameter.
181   virtual bool DoFft(Stream *stream, Plan *plan,
182                      const DeviceMemory<std::complex<float>> &input,
183                      DeviceMemory<std::complex<float>> *output) = 0;
184   virtual bool DoFft(Stream *stream, Plan *plan,
185                      const DeviceMemory<std::complex<double>> &input,
186                      DeviceMemory<std::complex<double>> *output) = 0;
187 
188   // Computes real-to-complex FFT in forward direction.
189   virtual bool DoFft(Stream *stream, Plan *plan,
190                      const DeviceMemory<float> &input,
191                      DeviceMemory<std::complex<float>> *output) = 0;
192   virtual bool DoFft(Stream *stream, Plan *plan,
193                      const DeviceMemory<double> &input,
194                      DeviceMemory<std::complex<double>> *output) = 0;
195 
196   // Computes complex-to-real FFT in inverse direction.
197   virtual bool DoFft(Stream *stream, Plan *plan,
198                      const DeviceMemory<std::complex<float>> &input,
199                      DeviceMemory<float> *output) = 0;
200   virtual bool DoFft(Stream *stream, Plan *plan,
201                      const DeviceMemory<std::complex<double>> &input,
202                      DeviceMemory<double> *output) = 0;
203 
204  protected:
FftSupport()205   FftSupport() {}
206 
207  private:
208   SE_DISALLOW_COPY_AND_ASSIGN(FftSupport);
209 };
210 
211 // Macro used to quickly declare overrides for abstract virtuals in the
212 // fft::FftSupport base class. Assumes that it's emitted somewhere inside the
213 // ::stream_executor namespace.
214 #define TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES                   \
215   std::unique_ptr<fft::Plan> Create1dPlan(Stream *stream, uint64_t num_x,      \
216                                           fft::Type type, bool in_place_fft)   \
217       override;                                                                \
218   std::unique_ptr<fft::Plan> Create2dPlan(Stream *stream, uint64_t num_x,      \
219                                           uint64_t num_y, fft::Type type,      \
220                                           bool in_place_fft) override;         \
221   std::unique_ptr<fft::Plan> Create3dPlan(                                     \
222       Stream *stream, uint64_t num_x, uint64 num_y, uint64 num_z,              \
223       fft::Type type, bool in_place_fft) override;                             \
224   std::unique_ptr<fft::Plan> Create1dPlanWithScratchAllocator(                 \
225       Stream *stream, uint64_t num_x, fft::Type type, bool in_place_fft,       \
226       ScratchAllocator *scratch_allocator) override;                           \
227   std::unique_ptr<fft::Plan> Create2dPlanWithScratchAllocator(                 \
228       Stream *stream, uint64_t num_x, uint64 num_y, fft::Type type,            \
229       bool in_place_fft, ScratchAllocator *scratch_allocator) override;        \
230   std::unique_ptr<fft::Plan> Create3dPlanWithScratchAllocator(                 \
231       Stream *stream, uint64_t num_x, uint64 num_y, uint64 num_z,              \
232       fft::Type type, bool in_place_fft, ScratchAllocator *scratch_allocator)  \
233       override;                                                                \
234   std::unique_ptr<fft::Plan> CreateBatchedPlan(                                \
235       Stream *stream, int rank, uint64_t *elem_count, uint64 *input_embed,     \
236       uint64_t input_stride, uint64 input_distance, uint64 *output_embed,      \
237       uint64_t output_stride, uint64 output_distance, fft::Type type,          \
238       bool in_place_fft, int batch_count) override;                            \
239   std::unique_ptr<fft::Plan> CreateBatchedPlanWithScratchAllocator(            \
240       Stream *stream, int rank, uint64_t *elem_count, uint64 *input_embed,     \
241       uint64_t input_stride, uint64 input_distance, uint64 *output_embed,      \
242       uint64_t output_stride, uint64 output_distance, fft::Type type,          \
243       bool in_place_fft, int batch_count, ScratchAllocator *scratch_allocator) \
244       override;                                                                \
245   void UpdatePlanWithScratchAllocator(Stream *stream, fft::Plan *plan,         \
246                                       ScratchAllocator *scratch_allocator)     \
247       override;                                                                \
248   bool DoFft(Stream *stream, fft::Plan *plan,                                  \
249              const DeviceMemory<std::complex<float>> &input,                   \
250              DeviceMemory<std::complex<float>> *output) override;              \
251   bool DoFft(Stream *stream, fft::Plan *plan,                                  \
252              const DeviceMemory<std::complex<double>> &input,                  \
253              DeviceMemory<std::complex<double>> *output) override;             \
254   bool DoFft(Stream *stream, fft::Plan *plan,                                  \
255              const DeviceMemory<float> &input,                                 \
256              DeviceMemory<std::complex<float>> *output) override;              \
257   bool DoFft(Stream *stream, fft::Plan *plan,                                  \
258              const DeviceMemory<double> &input,                                \
259              DeviceMemory<std::complex<double>> *output) override;             \
260   bool DoFft(Stream *stream, fft::Plan *plan,                                  \
261              const DeviceMemory<std::complex<float>> &input,                   \
262              DeviceMemory<float> *output) override;                            \
263   bool DoFft(Stream *stream, fft::Plan *plan,                                  \
264              const DeviceMemory<std::complex<double>> &input,                  \
265              DeviceMemory<double> *output) override;
266 
267 }  // namespace fft
268 }  // namespace stream_executor
269 
270 #endif  // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_FFT_H_
271