1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
17
18 #include <string>
19
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_format.h"
22 #include "tensorflow/compiler/xla/types.h"
23 #include "tensorflow/compiler/xla/util.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
26 #include "tensorflow/stream_executor/scratch_allocator.h"
27
28 namespace xla {
29 namespace gpu {
30 namespace {
31
FftTypeToSeType(FftType type,bool double_precision)32 se::fft::Type FftTypeToSeType(FftType type, bool double_precision) {
33 switch (type) {
34 case FftType::FFT:
35 return double_precision ? se::fft::Type::kZ2ZForward
36 : se::fft::Type::kC2CForward;
37 case FftType::IFFT:
38 return double_precision ? se::fft::Type::kZ2ZInverse
39 : se::fft::Type::kC2CInverse;
40 case FftType::IRFFT:
41 return double_precision ? se::fft::Type::kZ2D : se::fft::Type::kC2R;
42 case FftType::RFFT:
43 return double_precision ? se::fft::Type::kD2Z : se::fft::Type::kR2C;
44 default:
45 LOG(FATAL) << "unsupported fft type";
46 }
47 }
48
FftTypeToString(se::fft::Type type)49 std::string FftTypeToString(se::fft::Type type) {
50 switch (type) {
51 case se::fft::Type::kC2CForward:
52 case se::fft::Type::kZ2ZForward:
53 return "FFT";
54 case se::fft::Type::kC2CInverse:
55 case se::fft::Type::kZ2ZInverse:
56 return "IFFT";
57 case se::fft::Type::kC2R:
58 case se::fft::Type::kZ2D:
59 return "IRFFT";
60 case se::fft::Type::kR2C:
61 case se::fft::Type::kD2Z:
62 return "RFFT";
63 default:
64 LOG(FATAL) << "unknown fft type";
65 }
66 }
67
68 } // namespace
69
FftThunk(ThunkInfo thunk_info,FftType fft_type,absl::Span<const int64_t> fft_length,const BufferAllocation::Slice & input_buffer,const BufferAllocation::Slice & output_buffer,const Shape & input_shape,const Shape & output_shape)70 FftThunk::FftThunk(ThunkInfo thunk_info, FftType fft_type,
71 absl::Span<const int64_t> fft_length,
72 const BufferAllocation::Slice& input_buffer,
73 const BufferAllocation::Slice& output_buffer,
74 const Shape& input_shape, const Shape& output_shape)
75 : Thunk(Kind::kFft, thunk_info),
76 fft_type_(
77 FftTypeToSeType(fft_type, input_shape.element_type() == F64 ||
78 input_shape.element_type() == C128)),
79 fft_length_(fft_length.begin(), fft_length.end()),
80 input_buffer_(input_buffer),
81 output_buffer_(output_buffer),
82 input_shape_(input_shape),
83 output_shape_(output_shape) {}
84
ExecuteOnStream(const ExecuteParams & params)85 Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
86 auto& buffer_allocations = *params.buffer_allocations;
87
88 return RunFft(
89 buffer_allocations.GetDeviceAddress(input_buffer_), input_shape_,
90 buffer_allocations.GetDeviceAddress(output_buffer_), output_shape_,
91 fft_type_, fft_length_, buffer_allocations.device_ordinal(),
92 &fft_plan_cache_, params.stream, buffer_allocations.memory_allocator());
93 }
94
RunFft(se::DeviceMemoryBase input,const Shape & input_shape,se::DeviceMemoryBase output,const Shape & output_shape,se::fft::Type fft_type,absl::Span<const int64_t> fft_len,int device_ordinal,FftPlanCache * fft_plan_cache,se::Stream * stream,se::DeviceMemoryAllocator * memory_allocator)95 Status RunFft(se::DeviceMemoryBase input, const Shape& input_shape,
96 se::DeviceMemoryBase output, const Shape& output_shape,
97 se::fft::Type fft_type, absl::Span<const int64_t> fft_len,
98 int device_ordinal, FftPlanCache* fft_plan_cache,
99 se::Stream* stream, se::DeviceMemoryAllocator* memory_allocator) {
100 VLOG(3) << "FFT type: " << FftTypeToString(fft_type);
101 VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape);
102 VLOG(3) << "Output shape: " << ShapeUtil::HumanStringWithLayout(output_shape);
103
104 se::OwningScratchAllocator<2> scratch_allocator(device_ordinal,
105 memory_allocator);
106
107 // Get the Fft plan for the given device ordinal.
108 FftPlan* fft_plan_ptr = fft_plan_cache->GetOrCreate(device_ordinal);
109
110 // CuFFT thread-safety requires that separate host threads not share plans;
111 // protect each plan with a mutex.
112 absl::MutexLock lock(&fft_plan_ptr->mu);
113 std::unique_ptr<se::fft::Plan>& fft_plan = fft_plan_ptr->plan;
114 if (fft_plan == nullptr) {
115 const int64_t fft_rank = fft_len.size();
116 CHECK_LE(fft_rank, 3);
117 int batch_size = 1;
118 for (int i = 0; i < input_shape.dimensions_size() - fft_rank; ++i) {
119 batch_size *= input_shape.dimensions(i);
120 }
121 uint64_t fft_length[3];
122 uint64_t input_embed[3];
123 const uint64_t input_stride = 1;
124 uint64_t input_distance = 1;
125 uint64_t output_embed[3];
126 const uint64_t output_stride = 1;
127 uint64_t output_distance = 1;
128
129 for (int i = 0; i < fft_rank; ++i) {
130 auto dim_offset = input_shape.dimensions_size() - fft_rank + i;
131 fft_length[i] = static_cast<uint64_t>(fft_len[i]);
132 input_embed[i] = input_shape.dimensions(dim_offset);
133 input_distance *= input_shape.dimensions(dim_offset);
134 output_embed[i] = output_shape.dimensions(dim_offset);
135 output_distance *= output_shape.dimensions(dim_offset);
136 }
137
138 constexpr bool kInPlaceFft = false;
139 fft_plan = stream->parent()->AsFft()->CreateBatchedPlanWithScratchAllocator(
140 stream, fft_rank, fft_length, input_embed, input_stride, input_distance,
141 output_embed, output_stride, output_distance, fft_type, kInPlaceFft,
142 batch_size, &scratch_allocator);
143 TF_RET_CHECK(fft_plan != nullptr)
144 << "Failed to create cuFFT batched plan with scratch allocator";
145 fft_plan_ptr->scale_factor = 1.0f / output_distance;
146 } else {
147 stream->parent()->AsFft()->UpdatePlanWithScratchAllocator(
148 stream, fft_plan.get(), &scratch_allocator);
149 }
150
151 float scale_factor = fft_plan_ptr->scale_factor;
152
153 bool launch_ok;
154 switch (fft_type) {
155 case se::fft::Type::kC2CForward: {
156 se::DeviceMemory<complex64> input_data(input);
157 se::DeviceMemory<complex64> output_data(output);
158 launch_ok =
159 stream->ThenFft(fft_plan.get(), input_data, &output_data).ok();
160 break;
161 }
162 case se::fft::Type::kZ2ZForward: {
163 se::DeviceMemory<complex128> input_data(input);
164 se::DeviceMemory<complex128> output_data(output);
165 launch_ok =
166 stream->ThenFft(fft_plan.get(), input_data, &output_data).ok();
167 break;
168 }
169 case se::fft::Type::kC2CInverse: {
170 se::DeviceMemory<complex64> input_data(input);
171 se::DeviceMemory<complex64> output_data(output);
172 launch_ok =
173 stream->ThenFft(fft_plan.get(), input_data, &output_data).ok();
174 if (launch_ok) {
175 launch_ok = stream
176 ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape),
177 complex64(scale_factor), &output_data, 1)
178 .ok();
179 }
180 break;
181 }
182 case se::fft::Type::kZ2ZInverse: {
183 se::DeviceMemory<complex128> input_data(input);
184 se::DeviceMemory<complex128> output_data(output);
185 launch_ok =
186 stream->ThenFft(fft_plan.get(), input_data, &output_data).ok();
187 if (launch_ok) {
188 launch_ok =
189 stream
190 ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape),
191 complex128(scale_factor), &output_data, 1)
192 .ok();
193 }
194 break;
195 }
196 case se::fft::Type::kR2C: {
197 se::DeviceMemory<float> input_data(input);
198 se::DeviceMemory<complex64> output_data(output);
199 launch_ok =
200 stream->ThenFft(fft_plan.get(), input_data, &output_data).ok();
201 break;
202 }
203 case se::fft::Type::kD2Z: {
204 se::DeviceMemory<double> input_data(input);
205 se::DeviceMemory<complex128> output_data(output);
206 launch_ok =
207 stream->ThenFft(fft_plan.get(), input_data, &output_data).ok();
208 break;
209 }
210 case se::fft::Type::kC2R: {
211 se::DeviceMemory<complex64> input_data(input);
212 se::DeviceMemory<float> output_data(output);
213 launch_ok =
214 stream->ThenFft(fft_plan.get(), input_data, &output_data).ok();
215 if (launch_ok) {
216 launch_ok = stream
217 ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape),
218 scale_factor, &output_data, 1)
219 .ok();
220 }
221 break;
222 }
223 case se::fft::Type::kZ2D: {
224 se::DeviceMemory<complex128> input_data(input);
225 se::DeviceMemory<double> output_data(output);
226 launch_ok =
227 stream->ThenFft(fft_plan.get(), input_data, &output_data).ok();
228 if (launch_ok) {
229 launch_ok = stream
230 ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape),
231 scale_factor, &output_data, 1)
232 .ok();
233 }
234 break;
235 }
236 default:
237 LOG(FATAL) << "unsupported fft type";
238 }
239 if (launch_ok) {
240 return OkStatus();
241 }
242 return InternalError("Unable to launch fft with type %s",
243 FftTypeToString(fft_type));
244 }
245
246 } // namespace gpu
247 } // namespace xla
248