xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
17 
18 #include <memory>
19 #include <random>
20 #include <utility>
21 
22 #include "tensorflow/compiler/xla/layout_util.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/util.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/platform/regexp.h"
29 #include "tensorflow/core/profiler/lib/traceme.h"
30 #include "tensorflow/core/util/determinism.h"
31 #include "tensorflow/core/util/env_var.h"
32 #include "tensorflow/core/util/proto/proto_utils.h"
33 #include "tensorflow/stream_executor/kernel_spec.h"
34 
35 namespace xla {
36 namespace gpu {
37 
38 namespace {
39 
40 using se::dnn::DataLayout;
41 using se::dnn::DataLayoutString;
42 using se::dnn::FilterLayout;
43 using se::dnn::FilterLayoutString;
44 using tensorflow::AutotuneResult;
45 
46 // Returns the smallest integer >= 0 that's not in the given set of numbers.
47 //
48 // For example, FindMissingDnum({1, 0, 3, 4}) returns 2.
49 //
50 // This is useful for handling DataLayout::kBatchDepthYX4, which repesents a
51 // layout [N, C/k, H, W, k] for some constant k, usually 4 or 32.
52 // ConvolutionDimensionNumbers doesn't explicitly say which dimension is `k`,
53 // but we can infer it by finding the first dnum that isn't otherwise mentioned
54 // in the dnums.
FindMissingDnum(absl::Span<const int64_t> vals)55 int64_t FindMissingDnum(absl::Span<const int64_t> vals) {
56   for (int i = 0; i < vals.size(); i++) {
57     if (!absl::c_linear_search(vals, i)) {
58       return i;
59     }
60   }
61   return vals.size();
62 }
63 
64 }  // anonymous namespace
65 
66 StatusOr<std::tuple<Layout, Layout, Layout>>
StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers & dnums,DataLayout input,FilterLayout filter,DataLayout output)67 StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums,
68                                       DataLayout input, FilterLayout filter,
69                                       DataLayout output) {
70   std::vector<int64_t> input_layout;
71   switch (input) {
72     case DataLayout::kBatchDepthYX:  // NCHW
73       input_layout.push_back(dnums.input_batch_dimension());
74       input_layout.push_back(dnums.input_feature_dimension());
75       input_layout.insert(input_layout.end(),
76                           dnums.input_spatial_dimensions().begin(),
77                           dnums.input_spatial_dimensions().end());
78       break;
79     case DataLayout::kBatchDepthYX4:   // NCHW_VECT_C
80     case DataLayout::kBatchDepthYX32:  // NCHW_VECT_C
81       input_layout.push_back(dnums.input_batch_dimension());
82       input_layout.push_back(dnums.input_feature_dimension());
83       input_layout.insert(input_layout.end(),
84                           dnums.input_spatial_dimensions().begin(),
85                           dnums.input_spatial_dimensions().end());
86       input_layout.push_back(FindMissingDnum(input_layout));
87       break;
88     case DataLayout::kBatchYXDepth:  // NHWC
89       input_layout.push_back(dnums.input_batch_dimension());
90       input_layout.insert(input_layout.end(),
91                           dnums.input_spatial_dimensions().begin(),
92                           dnums.input_spatial_dimensions().end());
93       input_layout.push_back(dnums.input_feature_dimension());
94       break;
95     default:
96       return InternalError("Invalid input layout %s for conv with dnums %s",
97                            DataLayoutString(input),
98                            ConvolutionDimensionNumbersToString(dnums));
99   }
100 
101   std::vector<int64_t> filter_layout;
102   switch (filter) {
103     case FilterLayout::kOutputInputYX:  // OIHW
104       filter_layout.push_back(dnums.kernel_output_feature_dimension());
105       filter_layout.push_back(dnums.kernel_input_feature_dimension());
106       filter_layout.insert(filter_layout.end(),
107                            dnums.kernel_spatial_dimensions().begin(),
108                            dnums.kernel_spatial_dimensions().end());
109       break;
110     case FilterLayout::kOutputInputYX4:   // OIHW_VECT_C
111     case FilterLayout::kOutputInputYX32:  // OIHW_VECT_C
112       filter_layout.push_back(dnums.kernel_output_feature_dimension());
113       filter_layout.push_back(dnums.kernel_input_feature_dimension());
114       filter_layout.insert(filter_layout.end(),
115                            dnums.kernel_spatial_dimensions().begin(),
116                            dnums.kernel_spatial_dimensions().end());
117       filter_layout.push_back(FindMissingDnum(filter_layout));
118       break;
119     case FilterLayout::kOutputYXInput:  // OHWI
120       filter_layout.push_back(dnums.kernel_output_feature_dimension());
121       filter_layout.insert(filter_layout.end(),
122                            dnums.kernel_spatial_dimensions().begin(),
123                            dnums.kernel_spatial_dimensions().end());
124       filter_layout.push_back(dnums.kernel_input_feature_dimension());
125       break;
126     default:
127       return InternalError("Invalid filter layout %s for conv with dnums %s",
128                            FilterLayoutString(filter),
129                            ConvolutionDimensionNumbersToString(dnums));
130   }
131 
132   std::vector<int64_t> output_layout;
133   switch (output) {
134     case DataLayout::kBatchDepthYX:  // NCHW
135       output_layout.push_back(dnums.output_batch_dimension());
136       output_layout.push_back(dnums.output_feature_dimension());
137       output_layout.insert(output_layout.end(),
138                            dnums.output_spatial_dimensions().begin(),
139                            dnums.output_spatial_dimensions().end());
140       break;
141     case DataLayout::kBatchDepthYX4:   // NCHW_VECT_C
142     case DataLayout::kBatchDepthYX32:  // NCHW_VECT_C
143       output_layout.push_back(dnums.output_batch_dimension());
144       output_layout.push_back(dnums.output_feature_dimension());
145       output_layout.insert(output_layout.end(),
146                            dnums.output_spatial_dimensions().begin(),
147                            dnums.output_spatial_dimensions().end());
148       output_layout.push_back(FindMissingDnum(output_layout));
149       break;
150     case DataLayout::kBatchYXDepth:  // NHWC
151       output_layout.push_back(dnums.output_batch_dimension());
152       output_layout.insert(output_layout.end(),
153                            dnums.output_spatial_dimensions().begin(),
154                            dnums.output_spatial_dimensions().end());
155       output_layout.push_back(dnums.output_feature_dimension());
156       break;
157     default:
158       return InternalError("Invalid output layout %s for conv with dnums %s",
159                            DataLayoutString(output),
160                            ConvolutionDimensionNumbersToString(dnums));
161   }
162 
163   return std::make_tuple(LayoutUtil::MakeLayoutFromMajorToMinor(input_layout),
164                          LayoutUtil::MakeLayoutFromMajorToMinor(filter_layout),
165                          LayoutUtil::MakeLayoutFromMajorToMinor(output_layout));
166 }
167 
168 StatusOr<std::tuple<DataLayout, FilterLayout, DataLayout>>
XlaConvShapesToStreamExecutorLayouts(const ConvolutionDimensionNumbers & dnums,const Shape & input,const Shape & filter,const Shape & output)169 XlaConvShapesToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
170                                      const Shape& input, const Shape& filter,
171                                      const Shape& output) {
172   CHECK(input.has_layout());
173   CHECK(filter.has_layout());
174   CHECK(output.has_layout());
175 
176   Layout nchw_input, nchw_filter, nchw_output;
177   std::tie(nchw_input, nchw_filter, nchw_output) =
178       StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchDepthYX,
179                                             FilterLayout::kOutputInputYX,
180                                             DataLayout::kBatchDepthYX)
181           .value();
182 
183   // NCHW4 and NCHW32 have the same Layout; we disambiguate them below.
184   Layout nchw_vect_input, nchw_vect_filter, nchw_vect_output;
185   std::tie(nchw_vect_input, nchw_vect_filter, nchw_vect_output) =
186       StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchDepthYX4,
187                                             FilterLayout::kOutputInputYX4,
188                                             DataLayout::kBatchDepthYX4)
189           .value();
190 
191   Layout nhwc_input, nhwc_filter, nhwc_output;
192   std::tie(nhwc_input, nhwc_filter, nhwc_output) =
193       StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchYXDepth,
194                                             FilterLayout::kOutputYXInput,
195                                             DataLayout::kBatchYXDepth)
196           .value();
197 
198   DataLayout input_layout;
199   if (LayoutUtil::Equal(input.layout(), nchw_input)) {
200     input_layout = DataLayout::kBatchDepthYX;
201   } else if (LayoutUtil::Equal(input.layout(), nchw_vect_input)) {
202     // Differentiate between VECT_4 and VECT_32 by looking at the input shape.
203     int64_t vect_size = input.dimensions(input.layout().minor_to_major(0));
204     if (vect_size == 4) {
205       input_layout = DataLayout::kBatchDepthYX4;
206     } else if (vect_size == 32) {
207       input_layout = DataLayout::kBatchDepthYX32;
208     } else {
209       return InternalError(
210           "Invalid input shape %s for conv with dnums %s.  Most-minor dim "
211           "should be 4 or 32, but was %d.",
212           ShapeUtil::HumanStringWithLayout(input),
213           ConvolutionDimensionNumbersToString(dnums), vect_size);
214     }
215   } else if (LayoutUtil::Equal(input.layout(), nhwc_input)) {
216     input_layout = DataLayout::kBatchYXDepth;
217   } else {
218     return InternalError("Invalid input layout %s for conv with dnums %s",
219                          LayoutUtil::HumanString(input.layout()),
220                          ConvolutionDimensionNumbersToString(dnums));
221   }
222 
223   FilterLayout filter_layout;
224   if (LayoutUtil::Equal(filter.layout(), nchw_filter)) {
225     filter_layout = FilterLayout::kOutputInputYX;
226   } else if (LayoutUtil::Equal(filter.layout(), nchw_vect_filter)) {
227     int64_t vect_size = filter.dimensions(filter.layout().minor_to_major(0));
228     if (vect_size == 4) {
229       filter_layout = FilterLayout::kOutputInputYX4;
230     } else if (vect_size == 32) {
231       filter_layout = FilterLayout::kOutputInputYX32;
232     } else {
233       return InternalError(
234           "Invalid filter shape %s for conv with dnums %s.  Most-minor dim "
235           "should be 4 or 32, but was %d.",
236           ShapeUtil::HumanStringWithLayout(filter),
237           ConvolutionDimensionNumbersToString(dnums), vect_size);
238     }
239   } else if (LayoutUtil::Equal(filter.layout(), nhwc_filter)) {
240     filter_layout = FilterLayout::kOutputYXInput;
241   } else {
242     return InternalError("Invalid filter layout %s for conv with dnums %s",
243                          LayoutUtil::HumanString(filter.layout()),
244                          ConvolutionDimensionNumbersToString(dnums));
245   }
246 
247   DataLayout output_layout;
248   if (LayoutUtil::Equal(output.layout(), nchw_output)) {
249     output_layout = DataLayout::kBatchDepthYX;
250   } else if (LayoutUtil::Equal(output.layout(), nchw_vect_output)) {
251     int64_t vect_size = output.dimensions(output.layout().minor_to_major(0));
252     if (vect_size == 4) {
253       output_layout = DataLayout::kBatchDepthYX4;
254     } else if (vect_size == 32) {
255       output_layout = DataLayout::kBatchDepthYX32;
256     } else {
257       return InternalError(
258           "Invalid output shape %s for conv with dnums %s.  Most-minor dim "
259           "should be 4 or 32, but was %d.",
260           ShapeUtil::HumanStringWithLayout(output),
261           ConvolutionDimensionNumbersToString(dnums), vect_size);
262     }
263   } else if (LayoutUtil::Equal(output.layout(), nhwc_output)) {
264     output_layout = DataLayout::kBatchYXDepth;
265   } else {
266     return InternalError("Invalid output layout %s for conv with dnums %s",
267                          LayoutUtil::HumanString(output.layout()),
268                          ConvolutionDimensionNumbersToString(dnums));
269   }
270 
271   return std::make_tuple(input_layout, filter_layout, output_layout);
272 }
273 
274 // Given unique integers D = {d0, d1, ds...}, finds the first integer less than
275 // `rank` which is not in D.  If there is no such number (because all the values
276 // in [0, rank) appear), returns nullopt.
277 //
278 // When D is the set of dimensions in a ConvolutionDimensionNumbers, this finds
279 // the dimension number that corresponds to the vectorized-features dimension in
280 // the convolution.
FindVectorizedDim(int64_t rank,int64_t d0,int64_t d1,absl::Span<const int64_t> ds)281 static std::optional<int64_t> FindVectorizedDim(int64_t rank, int64_t d0,
282                                                 int64_t d1,
283                                                 absl::Span<const int64_t> ds) {
284   for (int64_t i = 0; i < rank; i++) {
285     if (i == d0 || i == d1 || absl::c_linear_search(ds, i)) {
286       continue;
287     }
288     return i;
289   }
290   return std::nullopt;
291 }
292 
293 std::tuple<std::optional<int64_t>, std::optional<int64_t>,
294            std::optional<int64_t>>
FindVectorizedFeatureDims(const ConvolutionDimensionNumbers & dnums,const Shape & input,const Shape & filter,const Shape & output)295 FindVectorizedFeatureDims(const ConvolutionDimensionNumbers& dnums,
296                           const Shape& input, const Shape& filter,
297                           const Shape& output) {
298   return {
299       FindVectorizedDim(input.dimensions_size(), dnums.input_batch_dimension(),
300                         dnums.input_feature_dimension(),
301                         dnums.input_spatial_dimensions()),
302       FindVectorizedDim(filter.dimensions_size(),
303                         dnums.kernel_input_feature_dimension(),
304                         dnums.kernel_output_feature_dimension(),
305                         dnums.kernel_spatial_dimensions()),
306       FindVectorizedDim(
307           output.dimensions_size(), dnums.output_batch_dimension(),
308           dnums.output_feature_dimension(), dnums.output_spatial_dimensions()),
309   };
310 }
311 
312 // Returns a mutex that can be used to lock the given stream executor.
GetGpuMutex(const se::StreamExecutor * stream_exec)313 absl::Mutex& GetGpuMutex(const se::StreamExecutor* stream_exec) {
314   static absl::Mutex mu(absl::kConstInit);
315   // se::Platform*s are global singletons guaranteed to live forever.
316   static auto* mutexes =
317       new std::map<std::pair<const se::Platform*, /*device_ordinal*/ int64_t>,
318                    absl::Mutex>();
319 
320   absl::MutexLock global_lock(&mu);
321   auto it = mutexes
322                 ->emplace(std::piecewise_construct,
323                           std::make_tuple(stream_exec->platform(),
324                                           stream_exec->device_ordinal()),
325                           std::make_tuple())
326                 .first;
327 
328   return it->second;
329 }
330 
CreateKernel(absl::string_view kernel_name,uint64_t num_args,absl::string_view ptx,absl::Span<const uint8_t> cubin_data,se::StreamExecutor * stream_exec)331 StatusOr<std::unique_ptr<se::KernelBase>> CreateKernel(
332     absl::string_view kernel_name, uint64_t num_args, absl::string_view ptx,
333     absl::Span<const uint8_t> cubin_data, se::StreamExecutor* stream_exec) {
334   se::MultiKernelLoaderSpec loader_spec(num_args);
335   loader_spec.AddCudaPtxInMemory(ptx, kernel_name);
336 
337   if (!cubin_data.empty()) {
338     loader_spec.AddCudaCubinInMemory(
339         reinterpret_cast<const char*>(cubin_data.data()), kernel_name);
340   }
341 
342   auto kernel_base = std::make_unique<se::KernelBase>(stream_exec);
343   TF_RETURN_IF_ERROR(stream_exec->GetKernel(loader_spec, kernel_base.get()));
344   return std::move(kernel_base);
345 }
346 
347 template <int n>
MakeKernelArgs(absl::Span<const se::DeviceMemoryBase> args)348 static std::unique_ptr<se::KernelArgsArrayBase> MakeKernelArgs(
349     absl::Span<const se::DeviceMemoryBase> args) {
350   auto kernel_args = std::make_unique<se::KernelArgsArray<n>>();
351   for (const se::DeviceMemoryBase& buf : args) {
352     kernel_args->add_device_memory_argument(buf);
353   }
354   return kernel_args;
355 }
356 
ExecuteKernelOnStream(const se::KernelBase & kernel,absl::Span<const se::DeviceMemoryBase> args,const LaunchDimensions & dims,se::Stream * stream)357 Status ExecuteKernelOnStream(const se::KernelBase& kernel,
358                              absl::Span<const se::DeviceMemoryBase> args,
359                              const LaunchDimensions& dims, se::Stream* stream) {
360   static constexpr int kKernelArgsLimit = 1024;
361   std::unique_ptr<se::KernelArgsArrayBase> kernel_args;
362   // The KernelArgsArray structure requires at a minimum 48 * args.size()
363   // bytes. It can be expensive to allocate, say, 48KiB, so we add
364   // specializations for smaller sizes. 64 arguments are likely to fit in a
365   // 4KiB page.
366   if (args.size() <= 64) {
367     kernel_args = MakeKernelArgs<64>(args);
368   } else if (args.size() <= 256) {
369     kernel_args = MakeKernelArgs<256>(args);
370   } else {
371     kernel_args = MakeKernelArgs<kKernelArgsLimit>(args);
372   }
373 
374   LaunchDimensions::Dim3D thread_counts = dims.thread_counts_per_block();
375   LaunchDimensions::Dim3D block_counts = dims.block_counts();
376   return stream->parent()->Launch(
377       stream, se::ThreadDim(thread_counts.x, thread_counts.y, thread_counts.z),
378       se::BlockDim(block_counts.x, block_counts.y, block_counts.z), kernel,
379       *kernel_args);
380 }
381 
382 // Unimplemented for integers yet.
383 template <typename T, typename Generator>
384 typename std::enable_if<std::is_integral<T>::value,
385                         T>::type static UniformDistribution(T lhs, T rhs,
386                                                             Generator* gen) =
387     delete;
388 
389 template <typename T, typename Generator>
390 typename std::enable_if<std::is_floating_point<T>::value,
UniformDistribution(T lhs,T rhs,Generator * gen)391                         T>::type static UniformDistribution(T lhs, T rhs,
392                                                             Generator* gen) {
393   return std::uniform_real_distribution<T>(lhs, rhs)(*gen);
394 }
395 
396 template <typename T>
InitializeTypedBuffer(se::Stream * stream,se::DeviceMemoryBase buffer,int64_t * rng_state)397 static void InitializeTypedBuffer(se::Stream* stream,
398                                   se::DeviceMemoryBase buffer,
399                                   int64_t* rng_state) {
400   // Accesses to static variables are not locked, since the caller is already
401   // in a critical section.
402   static std::vector<T>* host_buffer = [] {
403     // Use a large prime number to fragment the accesses.
404     auto* ret = new std::vector<T>(10069);
405     // Default-seeded random numbers.
406     std::mt19937 gen;
407     for (auto& element : *ret) {
408       // Only double gets random values in double.  Other data types get random
409       // values in float then cast them to the target data types.
410       using RandomFloatingPointType =
411           typename std::conditional<std::is_same<T, Eigen::half>::value, float,
412                                     T>::type;
413       using RandomType =
414           typename std::conditional<std::is_integral<T>::value, float,
415                                     RandomFloatingPointType>::type;
416       // Scale down the values for fp16 to have less overflows.
417       auto upper_bound =
418           RandomType(std::is_same<T, Eigen::half>::value ? 0.1 : 1.0);
419       auto rand_val = UniformDistribution(RandomType(0), upper_bound, &gen);
420       // For float or double, it is between [0,1].
421       // For fp16, it ranges between [0, 0.1].
422       // For integer types, element is either 0 or 1 for less overflows
423       // especially for int8_t.
424       element = T(std::is_integral<T>::value ? rand_val + 0.5 : rand_val);
425     }
426     return ret;
427   }();
428 
429   int64_t& host_index = *rng_state;
430 
431   char* current_addr = static_cast<char*>(buffer.opaque());
432   CHECK_EQ(0, buffer.size() % sizeof(T));
433   int64_t elements_left = buffer.size() / sizeof(T);
434   while (elements_left > 0) {
435     CHECK_LE(host_index, host_buffer->size());
436     if (host_buffer->size() == host_index) {
437       host_index = 0;
438     }
439     int64_t elements_copied =
440         std::min<int64_t>(host_buffer->size() - host_index, elements_left);
441     se::DeviceMemoryBase mem(current_addr, elements_copied * sizeof(T));
442     stream->ThenMemcpy(&mem, host_buffer->data() + host_index,
443                        elements_copied * sizeof(T));
444     current_addr += elements_copied * sizeof(T);
445     elements_left -= elements_copied;
446     host_index += elements_copied;
447   }
448 }
449 
InitializeBuffer(se::Stream * stream,PrimitiveType buffer_type,int64_t * rng_state,se::DeviceMemoryBase buffer)450 void InitializeBuffer(se::Stream* stream, PrimitiveType buffer_type,
451                       int64_t* rng_state, se::DeviceMemoryBase buffer) {
452   switch (buffer_type) {
453     case xla::F16:
454     case xla::BF16:
455       // Using F16 for BF16 initialization: it's fine since we only need some
456       // random number there, and random generator is not working for BF16 (not
457       // all required overloads are there).
458       return InitializeTypedBuffer<Eigen::half>(stream, buffer, rng_state);
459     case xla::F32:
460     case xla::C64:
461       return InitializeTypedBuffer<float>(stream, buffer, rng_state);
462     case xla::F64:
463     case xla::C128:
464       return InitializeTypedBuffer<double>(stream, buffer, rng_state);
465     case xla::PRED:
466       // Using S8 for PRED initialization, as vector<bool> has different
467       // semantics and cannot be used as a buffer.
468     case xla::S8:
469       return InitializeTypedBuffer<int8_t>(stream, buffer, rng_state);
470     case xla::S32:
471       return InitializeTypedBuffer<int32_t>(stream, buffer, rng_state);
472     default:
473       LOG(FATAL) << "Unexpected type: "
474                  << primitive_util::LowercasePrimitiveTypeName(buffer_type);
475   }
476 }
477 
GetDNNConvKindFromCudnnConvKind(CudnnConvKind kind)478 StatusOr<se::dnn::ConvolutionKind> GetDNNConvKindFromCudnnConvKind(
479     CudnnConvKind kind) {
480   switch (kind) {
481     case CudnnConvKind::kBackwardFilter:
482       return se::dnn::BACKWARD_FILTER;
483     case CudnnConvKind::kBackwardInput:
484       return se::dnn::BACKWARD_DATA;
485     case CudnnConvKind::kForward:
486       return se::dnn::FORWARD;
487     case CudnnConvKind::kForwardActivation:
488       return se::dnn::FORWARD_BIAS_ACTIVATION;
489     default:
490       break;
491   }
492   return InternalError("Unexpected convolution kind");
493 }
494 
GetDNNDataTypeFromPrimitiveType(PrimitiveType type)495 StatusOr<se::dnn::DataType> GetDNNDataTypeFromPrimitiveType(
496     PrimitiveType type) {
497   switch (type) {
498     case F16:
499       return se::dnn::ToDataType<Eigen::half>::value;
500     case F32:
501       return se::dnn::ToDataType<float>::value;
502     case F64:
503       return se::dnn::ToDataType<double>::value;
504     case S8:
505       return se::dnn::ToDataType<int8_t>::value;
506     case S32:
507       return se::dnn::ToDataType<int32_t>::value;
508     case BF16:
509       return se::dnn::ToDataType<Eigen::bfloat16>::value;
510     default:
511       break;
512   }
513   return InternalError("Unsupported convolution datatype");
514 }
515 
RequireDeterminism(const HloModuleConfig & config)516 bool RequireDeterminism(const HloModuleConfig& config) {
517   static bool require_cudnn_determinism = [] {
518     // TODO(reedwm): Remove the TF_CUDNN_DETERMINISTIC env var.
519     bool cudnn_deterministic = false;
520     TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC",
521                                                /*default_val=*/false,
522                                                &cudnn_deterministic));
523     return cudnn_deterministic;
524   }();
525   return tensorflow::OpDeterminismRequired() || require_cudnn_determinism ||
526          config.debug_options().xla_gpu_deterministic_ops();
527 }
528 
PickBestResult(absl::Span<AutotuneResult const> profile_results,const HloInstruction & instr)529 StatusOr<AutotuneResult> PickBestResult(
530     absl::Span<AutotuneResult const> profile_results,
531     const HloInstruction& instr) {
532   std::vector<AutotuneResult> filtered_results;
533 
534   // For now, we ignore WRONG_RESULT failures because false-positives are
535   // possible (e.g. perhaps the reference algorithm is the one that's
536   // incorrect!).  But we don't ignore REDZONE_MODIFIED failures because they're
537   // quite severe and can be detected with high accuracy.
538   absl::c_copy_if(
539       profile_results, std::back_inserter(filtered_results),
540       [](const AutotuneResult& r) {
541         return !(r.has_failure() &&
542                  r.failure().kind() != AutotuneResult::WRONG_RESULT);
543       });
544 
545   if (filtered_results.empty()) {
546     std::ostringstream msg;
547     msg << "All algorithms tried for " << instr.ToString()
548         << " failed. Falling back to default algorithm.  Per-algorithm errors:";
549     for (const auto& result : profile_results) {
550       msg << "\n  " << result.failure().msg();
551     }
552     return InternalError("%s", msg.str());
553   }
554 
555   auto selected_result = filtered_results.begin();
556   if (!RequireDeterminism(instr.parent()->parent()->config())) {
557     selected_result = absl::c_min_element(
558         filtered_results,
559         [](const AutotuneResult& lhs, const AutotuneResult& rhs) {
560           return tensorflow::proto_utils::FromDurationProto(lhs.run_time()) <
561                  tensorflow::proto_utils::FromDurationProto(rhs.run_time());
562         });
563   }
564   return *selected_result;
565 }
566 
567 }  // namespace gpu
568 }  // namespace xla
569