xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/fft_thunk.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
17 
18 #include <string>
19 
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_format.h"
22 #include "tensorflow/compiler/xla/types.h"
23 #include "tensorflow/compiler/xla/util.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
26 #include "tensorflow/stream_executor/scratch_allocator.h"
27 
28 namespace xla {
29 namespace gpu {
30 namespace {
31 
FftTypeToSeType(FftType type,bool double_precision)32 se::fft::Type FftTypeToSeType(FftType type, bool double_precision) {
33   switch (type) {
34     case FftType::FFT:
35       return double_precision ? se::fft::Type::kZ2ZForward
36                               : se::fft::Type::kC2CForward;
37     case FftType::IFFT:
38       return double_precision ? se::fft::Type::kZ2ZInverse
39                               : se::fft::Type::kC2CInverse;
40     case FftType::IRFFT:
41       return double_precision ? se::fft::Type::kZ2D : se::fft::Type::kC2R;
42     case FftType::RFFT:
43       return double_precision ? se::fft::Type::kD2Z : se::fft::Type::kR2C;
44     default:
45       LOG(FATAL) << "unsupported fft type";
46   }
47 }
48 
FftTypeToString(se::fft::Type type)49 std::string FftTypeToString(se::fft::Type type) {
50   switch (type) {
51     case se::fft::Type::kC2CForward:
52     case se::fft::Type::kZ2ZForward:
53       return "FFT";
54     case se::fft::Type::kC2CInverse:
55     case se::fft::Type::kZ2ZInverse:
56       return "IFFT";
57     case se::fft::Type::kC2R:
58     case se::fft::Type::kZ2D:
59       return "IRFFT";
60     case se::fft::Type::kR2C:
61     case se::fft::Type::kD2Z:
62       return "RFFT";
63     default:
64       LOG(FATAL) << "unknown fft type";
65   }
66 }
67 
68 }  // namespace
69 
FftThunk(ThunkInfo thunk_info,FftType fft_type,absl::Span<const int64_t> fft_length,const BufferAllocation::Slice & input_buffer,const BufferAllocation::Slice & output_buffer,const Shape & input_shape,const Shape & output_shape)70 FftThunk::FftThunk(ThunkInfo thunk_info, FftType fft_type,
71                    absl::Span<const int64_t> fft_length,
72                    const BufferAllocation::Slice& input_buffer,
73                    const BufferAllocation::Slice& output_buffer,
74                    const Shape& input_shape, const Shape& output_shape)
75     : Thunk(Kind::kFft, thunk_info),
76       fft_type_(
77           FftTypeToSeType(fft_type, input_shape.element_type() == F64 ||
78                                         input_shape.element_type() == C128)),
79       fft_length_(fft_length.begin(), fft_length.end()),
80       input_buffer_(input_buffer),
81       output_buffer_(output_buffer),
82       input_shape_(input_shape),
83       output_shape_(output_shape) {}
84 
ExecuteOnStream(const ExecuteParams & params)85 Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
86   auto& buffer_allocations = *params.buffer_allocations;
87 
88   return RunFft(
89       buffer_allocations.GetDeviceAddress(input_buffer_), input_shape_,
90       buffer_allocations.GetDeviceAddress(output_buffer_), output_shape_,
91       fft_type_, fft_length_, buffer_allocations.device_ordinal(),
92       &fft_plan_cache_, params.stream, buffer_allocations.memory_allocator());
93 }
94 
RunFft(se::DeviceMemoryBase input,const Shape & input_shape,se::DeviceMemoryBase output,const Shape & output_shape,se::fft::Type fft_type,absl::Span<const int64_t> fft_len,int device_ordinal,FftPlanCache * fft_plan_cache,se::Stream * stream,se::DeviceMemoryAllocator * memory_allocator)95 Status RunFft(se::DeviceMemoryBase input, const Shape& input_shape,
96               se::DeviceMemoryBase output, const Shape& output_shape,
97               se::fft::Type fft_type, absl::Span<const int64_t> fft_len,
98               int device_ordinal, FftPlanCache* fft_plan_cache,
99               se::Stream* stream, se::DeviceMemoryAllocator* memory_allocator) {
100   VLOG(3) << "FFT type: " << FftTypeToString(fft_type);
101   VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape);
102   VLOG(3) << "Output shape: " << ShapeUtil::HumanStringWithLayout(output_shape);
103 
104   se::OwningScratchAllocator<2> scratch_allocator(device_ordinal,
105                                                   memory_allocator);
106 
107   // Get the Fft plan for the given device ordinal.
108   FftPlan* fft_plan_ptr = fft_plan_cache->GetOrCreate(device_ordinal);
109 
110   // CuFFT thread-safety requires that separate host threads not share plans;
111   // protect each plan with a mutex.
112   absl::MutexLock lock(&fft_plan_ptr->mu);
113   std::unique_ptr<se::fft::Plan>& fft_plan = fft_plan_ptr->plan;
114   if (fft_plan == nullptr) {
115     const int64_t fft_rank = fft_len.size();
116     CHECK_LE(fft_rank, 3);
117     int batch_size = 1;
118     for (int i = 0; i < input_shape.dimensions_size() - fft_rank; ++i) {
119       batch_size *= input_shape.dimensions(i);
120     }
121     uint64_t fft_length[3];
122     uint64_t input_embed[3];
123     const uint64_t input_stride = 1;
124     uint64_t input_distance = 1;
125     uint64_t output_embed[3];
126     const uint64_t output_stride = 1;
127     uint64_t output_distance = 1;
128 
129     for (int i = 0; i < fft_rank; ++i) {
130       auto dim_offset = input_shape.dimensions_size() - fft_rank + i;
131       fft_length[i] = static_cast<uint64_t>(fft_len[i]);
132       input_embed[i] = input_shape.dimensions(dim_offset);
133       input_distance *= input_shape.dimensions(dim_offset);
134       output_embed[i] = output_shape.dimensions(dim_offset);
135       output_distance *= output_shape.dimensions(dim_offset);
136     }
137 
138     constexpr bool kInPlaceFft = false;
139     fft_plan = stream->parent()->AsFft()->CreateBatchedPlanWithScratchAllocator(
140         stream, fft_rank, fft_length, input_embed, input_stride, input_distance,
141         output_embed, output_stride, output_distance, fft_type, kInPlaceFft,
142         batch_size, &scratch_allocator);
143     TF_RET_CHECK(fft_plan != nullptr)
144         << "Failed to create cuFFT batched plan with scratch allocator";
145     fft_plan_ptr->scale_factor = 1.0f / output_distance;
146   } else {
147     stream->parent()->AsFft()->UpdatePlanWithScratchAllocator(
148         stream, fft_plan.get(), &scratch_allocator);
149   }
150 
151   float scale_factor = fft_plan_ptr->scale_factor;
152 
153   bool launch_ok;
154   switch (fft_type) {
155     case se::fft::Type::kC2CForward: {
156       se::DeviceMemory<complex64> input_data(input);
157       se::DeviceMemory<complex64> output_data(output);
158       launch_ok =
159           stream->ThenFft(fft_plan.get(), input_data, &output_data).ok();
160       break;
161     }
162     case se::fft::Type::kZ2ZForward: {
163       se::DeviceMemory<complex128> input_data(input);
164       se::DeviceMemory<complex128> output_data(output);
165       launch_ok =
166           stream->ThenFft(fft_plan.get(), input_data, &output_data).ok();
167       break;
168     }
169     case se::fft::Type::kC2CInverse: {
170       se::DeviceMemory<complex64> input_data(input);
171       se::DeviceMemory<complex64> output_data(output);
172       launch_ok =
173           stream->ThenFft(fft_plan.get(), input_data, &output_data).ok();
174       if (launch_ok) {
175         launch_ok = stream
176                         ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape),
177                                        complex64(scale_factor), &output_data, 1)
178                         .ok();
179       }
180       break;
181     }
182     case se::fft::Type::kZ2ZInverse: {
183       se::DeviceMemory<complex128> input_data(input);
184       se::DeviceMemory<complex128> output_data(output);
185       launch_ok =
186           stream->ThenFft(fft_plan.get(), input_data, &output_data).ok();
187       if (launch_ok) {
188         launch_ok =
189             stream
190                 ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape),
191                                complex128(scale_factor), &output_data, 1)
192                 .ok();
193       }
194       break;
195     }
196     case se::fft::Type::kR2C: {
197       se::DeviceMemory<float> input_data(input);
198       se::DeviceMemory<complex64> output_data(output);
199       launch_ok =
200           stream->ThenFft(fft_plan.get(), input_data, &output_data).ok();
201       break;
202     }
203     case se::fft::Type::kD2Z: {
204       se::DeviceMemory<double> input_data(input);
205       se::DeviceMemory<complex128> output_data(output);
206       launch_ok =
207           stream->ThenFft(fft_plan.get(), input_data, &output_data).ok();
208       break;
209     }
210     case se::fft::Type::kC2R: {
211       se::DeviceMemory<complex64> input_data(input);
212       se::DeviceMemory<float> output_data(output);
213       launch_ok =
214           stream->ThenFft(fft_plan.get(), input_data, &output_data).ok();
215       if (launch_ok) {
216         launch_ok = stream
217                         ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape),
218                                        scale_factor, &output_data, 1)
219                         .ok();
220       }
221       break;
222     }
223     case se::fft::Type::kZ2D: {
224       se::DeviceMemory<complex128> input_data(input);
225       se::DeviceMemory<double> output_data(output);
226       launch_ok =
227           stream->ThenFft(fft_plan.get(), input_data, &output_data).ok();
228       if (launch_ok) {
229         launch_ok = stream
230                         ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape),
231                                        scale_factor, &output_data, 1)
232                         .ok();
233       }
234       break;
235     }
236     default:
237       LOG(FATAL) << "unsupported fft type";
238   }
239   if (launch_ok) {
240     return OkStatus();
241   }
242   return InternalError("Unable to launch fft with type %s",
243                        FftTypeToString(fft_type));
244 }
245 
246 }  // namespace gpu
247 }  // namespace xla
248