1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
17
18 #include <memory>
19 #include <random>
20 #include <utility>
21
22 #include "tensorflow/compiler/xla/layout_util.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/util.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/platform/regexp.h"
29 #include "tensorflow/core/profiler/lib/traceme.h"
30 #include "tensorflow/core/util/determinism.h"
31 #include "tensorflow/core/util/env_var.h"
32 #include "tensorflow/core/util/proto/proto_utils.h"
33 #include "tensorflow/stream_executor/kernel_spec.h"
34
35 namespace xla {
36 namespace gpu {
37
38 namespace {
39
40 using se::dnn::DataLayout;
41 using se::dnn::DataLayoutString;
42 using se::dnn::FilterLayout;
43 using se::dnn::FilterLayoutString;
44 using tensorflow::AutotuneResult;
45
46 // Returns the smallest integer >= 0 that's not in the given set of numbers.
47 //
48 // For example, FindMissingDnum({1, 0, 3, 4}) returns 2.
49 //
50 // This is useful for handling DataLayout::kBatchDepthYX4, which repesents a
51 // layout [N, C/k, H, W, k] for some constant k, usually 4 or 32.
52 // ConvolutionDimensionNumbers doesn't explicitly say which dimension is `k`,
53 // but we can infer it by finding the first dnum that isn't otherwise mentioned
54 // in the dnums.
FindMissingDnum(absl::Span<const int64_t> vals)55 int64_t FindMissingDnum(absl::Span<const int64_t> vals) {
56 for (int i = 0; i < vals.size(); i++) {
57 if (!absl::c_linear_search(vals, i)) {
58 return i;
59 }
60 }
61 return vals.size();
62 }
63
64 } // anonymous namespace
65
66 StatusOr<std::tuple<Layout, Layout, Layout>>
StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers & dnums,DataLayout input,FilterLayout filter,DataLayout output)67 StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums,
68 DataLayout input, FilterLayout filter,
69 DataLayout output) {
70 std::vector<int64_t> input_layout;
71 switch (input) {
72 case DataLayout::kBatchDepthYX: // NCHW
73 input_layout.push_back(dnums.input_batch_dimension());
74 input_layout.push_back(dnums.input_feature_dimension());
75 input_layout.insert(input_layout.end(),
76 dnums.input_spatial_dimensions().begin(),
77 dnums.input_spatial_dimensions().end());
78 break;
79 case DataLayout::kBatchDepthYX4: // NCHW_VECT_C
80 case DataLayout::kBatchDepthYX32: // NCHW_VECT_C
81 input_layout.push_back(dnums.input_batch_dimension());
82 input_layout.push_back(dnums.input_feature_dimension());
83 input_layout.insert(input_layout.end(),
84 dnums.input_spatial_dimensions().begin(),
85 dnums.input_spatial_dimensions().end());
86 input_layout.push_back(FindMissingDnum(input_layout));
87 break;
88 case DataLayout::kBatchYXDepth: // NHWC
89 input_layout.push_back(dnums.input_batch_dimension());
90 input_layout.insert(input_layout.end(),
91 dnums.input_spatial_dimensions().begin(),
92 dnums.input_spatial_dimensions().end());
93 input_layout.push_back(dnums.input_feature_dimension());
94 break;
95 default:
96 return InternalError("Invalid input layout %s for conv with dnums %s",
97 DataLayoutString(input),
98 ConvolutionDimensionNumbersToString(dnums));
99 }
100
101 std::vector<int64_t> filter_layout;
102 switch (filter) {
103 case FilterLayout::kOutputInputYX: // OIHW
104 filter_layout.push_back(dnums.kernel_output_feature_dimension());
105 filter_layout.push_back(dnums.kernel_input_feature_dimension());
106 filter_layout.insert(filter_layout.end(),
107 dnums.kernel_spatial_dimensions().begin(),
108 dnums.kernel_spatial_dimensions().end());
109 break;
110 case FilterLayout::kOutputInputYX4: // OIHW_VECT_C
111 case FilterLayout::kOutputInputYX32: // OIHW_VECT_C
112 filter_layout.push_back(dnums.kernel_output_feature_dimension());
113 filter_layout.push_back(dnums.kernel_input_feature_dimension());
114 filter_layout.insert(filter_layout.end(),
115 dnums.kernel_spatial_dimensions().begin(),
116 dnums.kernel_spatial_dimensions().end());
117 filter_layout.push_back(FindMissingDnum(filter_layout));
118 break;
119 case FilterLayout::kOutputYXInput: // OHWI
120 filter_layout.push_back(dnums.kernel_output_feature_dimension());
121 filter_layout.insert(filter_layout.end(),
122 dnums.kernel_spatial_dimensions().begin(),
123 dnums.kernel_spatial_dimensions().end());
124 filter_layout.push_back(dnums.kernel_input_feature_dimension());
125 break;
126 default:
127 return InternalError("Invalid filter layout %s for conv with dnums %s",
128 FilterLayoutString(filter),
129 ConvolutionDimensionNumbersToString(dnums));
130 }
131
132 std::vector<int64_t> output_layout;
133 switch (output) {
134 case DataLayout::kBatchDepthYX: // NCHW
135 output_layout.push_back(dnums.output_batch_dimension());
136 output_layout.push_back(dnums.output_feature_dimension());
137 output_layout.insert(output_layout.end(),
138 dnums.output_spatial_dimensions().begin(),
139 dnums.output_spatial_dimensions().end());
140 break;
141 case DataLayout::kBatchDepthYX4: // NCHW_VECT_C
142 case DataLayout::kBatchDepthYX32: // NCHW_VECT_C
143 output_layout.push_back(dnums.output_batch_dimension());
144 output_layout.push_back(dnums.output_feature_dimension());
145 output_layout.insert(output_layout.end(),
146 dnums.output_spatial_dimensions().begin(),
147 dnums.output_spatial_dimensions().end());
148 output_layout.push_back(FindMissingDnum(output_layout));
149 break;
150 case DataLayout::kBatchYXDepth: // NHWC
151 output_layout.push_back(dnums.output_batch_dimension());
152 output_layout.insert(output_layout.end(),
153 dnums.output_spatial_dimensions().begin(),
154 dnums.output_spatial_dimensions().end());
155 output_layout.push_back(dnums.output_feature_dimension());
156 break;
157 default:
158 return InternalError("Invalid output layout %s for conv with dnums %s",
159 DataLayoutString(output),
160 ConvolutionDimensionNumbersToString(dnums));
161 }
162
163 return std::make_tuple(LayoutUtil::MakeLayoutFromMajorToMinor(input_layout),
164 LayoutUtil::MakeLayoutFromMajorToMinor(filter_layout),
165 LayoutUtil::MakeLayoutFromMajorToMinor(output_layout));
166 }
167
168 StatusOr<std::tuple<DataLayout, FilterLayout, DataLayout>>
XlaConvShapesToStreamExecutorLayouts(const ConvolutionDimensionNumbers & dnums,const Shape & input,const Shape & filter,const Shape & output)169 XlaConvShapesToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
170 const Shape& input, const Shape& filter,
171 const Shape& output) {
172 CHECK(input.has_layout());
173 CHECK(filter.has_layout());
174 CHECK(output.has_layout());
175
176 Layout nchw_input, nchw_filter, nchw_output;
177 std::tie(nchw_input, nchw_filter, nchw_output) =
178 StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchDepthYX,
179 FilterLayout::kOutputInputYX,
180 DataLayout::kBatchDepthYX)
181 .value();
182
183 // NCHW4 and NCHW32 have the same Layout; we disambiguate them below.
184 Layout nchw_vect_input, nchw_vect_filter, nchw_vect_output;
185 std::tie(nchw_vect_input, nchw_vect_filter, nchw_vect_output) =
186 StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchDepthYX4,
187 FilterLayout::kOutputInputYX4,
188 DataLayout::kBatchDepthYX4)
189 .value();
190
191 Layout nhwc_input, nhwc_filter, nhwc_output;
192 std::tie(nhwc_input, nhwc_filter, nhwc_output) =
193 StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchYXDepth,
194 FilterLayout::kOutputYXInput,
195 DataLayout::kBatchYXDepth)
196 .value();
197
198 DataLayout input_layout;
199 if (LayoutUtil::Equal(input.layout(), nchw_input)) {
200 input_layout = DataLayout::kBatchDepthYX;
201 } else if (LayoutUtil::Equal(input.layout(), nchw_vect_input)) {
202 // Differentiate between VECT_4 and VECT_32 by looking at the input shape.
203 int64_t vect_size = input.dimensions(input.layout().minor_to_major(0));
204 if (vect_size == 4) {
205 input_layout = DataLayout::kBatchDepthYX4;
206 } else if (vect_size == 32) {
207 input_layout = DataLayout::kBatchDepthYX32;
208 } else {
209 return InternalError(
210 "Invalid input shape %s for conv with dnums %s. Most-minor dim "
211 "should be 4 or 32, but was %d.",
212 ShapeUtil::HumanStringWithLayout(input),
213 ConvolutionDimensionNumbersToString(dnums), vect_size);
214 }
215 } else if (LayoutUtil::Equal(input.layout(), nhwc_input)) {
216 input_layout = DataLayout::kBatchYXDepth;
217 } else {
218 return InternalError("Invalid input layout %s for conv with dnums %s",
219 LayoutUtil::HumanString(input.layout()),
220 ConvolutionDimensionNumbersToString(dnums));
221 }
222
223 FilterLayout filter_layout;
224 if (LayoutUtil::Equal(filter.layout(), nchw_filter)) {
225 filter_layout = FilterLayout::kOutputInputYX;
226 } else if (LayoutUtil::Equal(filter.layout(), nchw_vect_filter)) {
227 int64_t vect_size = filter.dimensions(filter.layout().minor_to_major(0));
228 if (vect_size == 4) {
229 filter_layout = FilterLayout::kOutputInputYX4;
230 } else if (vect_size == 32) {
231 filter_layout = FilterLayout::kOutputInputYX32;
232 } else {
233 return InternalError(
234 "Invalid filter shape %s for conv with dnums %s. Most-minor dim "
235 "should be 4 or 32, but was %d.",
236 ShapeUtil::HumanStringWithLayout(filter),
237 ConvolutionDimensionNumbersToString(dnums), vect_size);
238 }
239 } else if (LayoutUtil::Equal(filter.layout(), nhwc_filter)) {
240 filter_layout = FilterLayout::kOutputYXInput;
241 } else {
242 return InternalError("Invalid filter layout %s for conv with dnums %s",
243 LayoutUtil::HumanString(filter.layout()),
244 ConvolutionDimensionNumbersToString(dnums));
245 }
246
247 DataLayout output_layout;
248 if (LayoutUtil::Equal(output.layout(), nchw_output)) {
249 output_layout = DataLayout::kBatchDepthYX;
250 } else if (LayoutUtil::Equal(output.layout(), nchw_vect_output)) {
251 int64_t vect_size = output.dimensions(output.layout().minor_to_major(0));
252 if (vect_size == 4) {
253 output_layout = DataLayout::kBatchDepthYX4;
254 } else if (vect_size == 32) {
255 output_layout = DataLayout::kBatchDepthYX32;
256 } else {
257 return InternalError(
258 "Invalid output shape %s for conv with dnums %s. Most-minor dim "
259 "should be 4 or 32, but was %d.",
260 ShapeUtil::HumanStringWithLayout(output),
261 ConvolutionDimensionNumbersToString(dnums), vect_size);
262 }
263 } else if (LayoutUtil::Equal(output.layout(), nhwc_output)) {
264 output_layout = DataLayout::kBatchYXDepth;
265 } else {
266 return InternalError("Invalid output layout %s for conv with dnums %s",
267 LayoutUtil::HumanString(output.layout()),
268 ConvolutionDimensionNumbersToString(dnums));
269 }
270
271 return std::make_tuple(input_layout, filter_layout, output_layout);
272 }
273
274 // Given unique integers D = {d0, d1, ds...}, finds the first integer less than
275 // `rank` which is not in D. If there is no such number (because all the values
276 // in [0, rank) appear), returns nullopt.
277 //
278 // When D is the set of dimensions in a ConvolutionDimensionNumbers, this finds
279 // the dimension number that corresponds to the vectorized-features dimension in
280 // the convolution.
FindVectorizedDim(int64_t rank,int64_t d0,int64_t d1,absl::Span<const int64_t> ds)281 static std::optional<int64_t> FindVectorizedDim(int64_t rank, int64_t d0,
282 int64_t d1,
283 absl::Span<const int64_t> ds) {
284 for (int64_t i = 0; i < rank; i++) {
285 if (i == d0 || i == d1 || absl::c_linear_search(ds, i)) {
286 continue;
287 }
288 return i;
289 }
290 return std::nullopt;
291 }
292
293 std::tuple<std::optional<int64_t>, std::optional<int64_t>,
294 std::optional<int64_t>>
FindVectorizedFeatureDims(const ConvolutionDimensionNumbers & dnums,const Shape & input,const Shape & filter,const Shape & output)295 FindVectorizedFeatureDims(const ConvolutionDimensionNumbers& dnums,
296 const Shape& input, const Shape& filter,
297 const Shape& output) {
298 return {
299 FindVectorizedDim(input.dimensions_size(), dnums.input_batch_dimension(),
300 dnums.input_feature_dimension(),
301 dnums.input_spatial_dimensions()),
302 FindVectorizedDim(filter.dimensions_size(),
303 dnums.kernel_input_feature_dimension(),
304 dnums.kernel_output_feature_dimension(),
305 dnums.kernel_spatial_dimensions()),
306 FindVectorizedDim(
307 output.dimensions_size(), dnums.output_batch_dimension(),
308 dnums.output_feature_dimension(), dnums.output_spatial_dimensions()),
309 };
310 }
311
312 // Returns a mutex that can be used to lock the given stream executor.
GetGpuMutex(const se::StreamExecutor * stream_exec)313 absl::Mutex& GetGpuMutex(const se::StreamExecutor* stream_exec) {
314 static absl::Mutex mu(absl::kConstInit);
315 // se::Platform*s are global singletons guaranteed to live forever.
316 static auto* mutexes =
317 new std::map<std::pair<const se::Platform*, /*device_ordinal*/ int64_t>,
318 absl::Mutex>();
319
320 absl::MutexLock global_lock(&mu);
321 auto it = mutexes
322 ->emplace(std::piecewise_construct,
323 std::make_tuple(stream_exec->platform(),
324 stream_exec->device_ordinal()),
325 std::make_tuple())
326 .first;
327
328 return it->second;
329 }
330
CreateKernel(absl::string_view kernel_name,uint64_t num_args,absl::string_view ptx,absl::Span<const uint8_t> cubin_data,se::StreamExecutor * stream_exec)331 StatusOr<std::unique_ptr<se::KernelBase>> CreateKernel(
332 absl::string_view kernel_name, uint64_t num_args, absl::string_view ptx,
333 absl::Span<const uint8_t> cubin_data, se::StreamExecutor* stream_exec) {
334 se::MultiKernelLoaderSpec loader_spec(num_args);
335 loader_spec.AddCudaPtxInMemory(ptx, kernel_name);
336
337 if (!cubin_data.empty()) {
338 loader_spec.AddCudaCubinInMemory(
339 reinterpret_cast<const char*>(cubin_data.data()), kernel_name);
340 }
341
342 auto kernel_base = std::make_unique<se::KernelBase>(stream_exec);
343 TF_RETURN_IF_ERROR(stream_exec->GetKernel(loader_spec, kernel_base.get()));
344 return std::move(kernel_base);
345 }
346
347 template <int n>
MakeKernelArgs(absl::Span<const se::DeviceMemoryBase> args)348 static std::unique_ptr<se::KernelArgsArrayBase> MakeKernelArgs(
349 absl::Span<const se::DeviceMemoryBase> args) {
350 auto kernel_args = std::make_unique<se::KernelArgsArray<n>>();
351 for (const se::DeviceMemoryBase& buf : args) {
352 kernel_args->add_device_memory_argument(buf);
353 }
354 return kernel_args;
355 }
356
ExecuteKernelOnStream(const se::KernelBase & kernel,absl::Span<const se::DeviceMemoryBase> args,const LaunchDimensions & dims,se::Stream * stream)357 Status ExecuteKernelOnStream(const se::KernelBase& kernel,
358 absl::Span<const se::DeviceMemoryBase> args,
359 const LaunchDimensions& dims, se::Stream* stream) {
360 static constexpr int kKernelArgsLimit = 1024;
361 std::unique_ptr<se::KernelArgsArrayBase> kernel_args;
362 // The KernelArgsArray structure requires at a minimum 48 * args.size()
363 // bytes. It can be expensive to allocate, say, 48KiB, so we add
364 // specializations for smaller sizes. 64 arguments are likely to fit in a
365 // 4KiB page.
366 if (args.size() <= 64) {
367 kernel_args = MakeKernelArgs<64>(args);
368 } else if (args.size() <= 256) {
369 kernel_args = MakeKernelArgs<256>(args);
370 } else {
371 kernel_args = MakeKernelArgs<kKernelArgsLimit>(args);
372 }
373
374 LaunchDimensions::Dim3D thread_counts = dims.thread_counts_per_block();
375 LaunchDimensions::Dim3D block_counts = dims.block_counts();
376 return stream->parent()->Launch(
377 stream, se::ThreadDim(thread_counts.x, thread_counts.y, thread_counts.z),
378 se::BlockDim(block_counts.x, block_counts.y, block_counts.z), kernel,
379 *kernel_args);
380 }
381
382 // Unimplemented for integers yet.
383 template <typename T, typename Generator>
384 typename std::enable_if<std::is_integral<T>::value,
385 T>::type static UniformDistribution(T lhs, T rhs,
386 Generator* gen) =
387 delete;
388
389 template <typename T, typename Generator>
390 typename std::enable_if<std::is_floating_point<T>::value,
UniformDistribution(T lhs,T rhs,Generator * gen)391 T>::type static UniformDistribution(T lhs, T rhs,
392 Generator* gen) {
393 return std::uniform_real_distribution<T>(lhs, rhs)(*gen);
394 }
395
396 template <typename T>
InitializeTypedBuffer(se::Stream * stream,se::DeviceMemoryBase buffer,int64_t * rng_state)397 static void InitializeTypedBuffer(se::Stream* stream,
398 se::DeviceMemoryBase buffer,
399 int64_t* rng_state) {
400 // Accesses to static variables are not locked, since the caller is already
401 // in a critical section.
402 static std::vector<T>* host_buffer = [] {
403 // Use a large prime number to fragment the accesses.
404 auto* ret = new std::vector<T>(10069);
405 // Default-seeded random numbers.
406 std::mt19937 gen;
407 for (auto& element : *ret) {
408 // Only double gets random values in double. Other data types get random
409 // values in float then cast them to the target data types.
410 using RandomFloatingPointType =
411 typename std::conditional<std::is_same<T, Eigen::half>::value, float,
412 T>::type;
413 using RandomType =
414 typename std::conditional<std::is_integral<T>::value, float,
415 RandomFloatingPointType>::type;
416 // Scale down the values for fp16 to have less overflows.
417 auto upper_bound =
418 RandomType(std::is_same<T, Eigen::half>::value ? 0.1 : 1.0);
419 auto rand_val = UniformDistribution(RandomType(0), upper_bound, &gen);
420 // For float or double, it is between [0,1].
421 // For fp16, it ranges between [0, 0.1].
422 // For integer types, element is either 0 or 1 for less overflows
423 // especially for int8_t.
424 element = T(std::is_integral<T>::value ? rand_val + 0.5 : rand_val);
425 }
426 return ret;
427 }();
428
429 int64_t& host_index = *rng_state;
430
431 char* current_addr = static_cast<char*>(buffer.opaque());
432 CHECK_EQ(0, buffer.size() % sizeof(T));
433 int64_t elements_left = buffer.size() / sizeof(T);
434 while (elements_left > 0) {
435 CHECK_LE(host_index, host_buffer->size());
436 if (host_buffer->size() == host_index) {
437 host_index = 0;
438 }
439 int64_t elements_copied =
440 std::min<int64_t>(host_buffer->size() - host_index, elements_left);
441 se::DeviceMemoryBase mem(current_addr, elements_copied * sizeof(T));
442 stream->ThenMemcpy(&mem, host_buffer->data() + host_index,
443 elements_copied * sizeof(T));
444 current_addr += elements_copied * sizeof(T);
445 elements_left -= elements_copied;
446 host_index += elements_copied;
447 }
448 }
449
InitializeBuffer(se::Stream * stream,PrimitiveType buffer_type,int64_t * rng_state,se::DeviceMemoryBase buffer)450 void InitializeBuffer(se::Stream* stream, PrimitiveType buffer_type,
451 int64_t* rng_state, se::DeviceMemoryBase buffer) {
452 switch (buffer_type) {
453 case xla::F16:
454 case xla::BF16:
455 // Using F16 for BF16 initialization: it's fine since we only need some
456 // random number there, and random generator is not working for BF16 (not
457 // all required overloads are there).
458 return InitializeTypedBuffer<Eigen::half>(stream, buffer, rng_state);
459 case xla::F32:
460 case xla::C64:
461 return InitializeTypedBuffer<float>(stream, buffer, rng_state);
462 case xla::F64:
463 case xla::C128:
464 return InitializeTypedBuffer<double>(stream, buffer, rng_state);
465 case xla::PRED:
466 // Using S8 for PRED initialization, as vector<bool> has different
467 // semantics and cannot be used as a buffer.
468 case xla::S8:
469 return InitializeTypedBuffer<int8_t>(stream, buffer, rng_state);
470 case xla::S32:
471 return InitializeTypedBuffer<int32_t>(stream, buffer, rng_state);
472 default:
473 LOG(FATAL) << "Unexpected type: "
474 << primitive_util::LowercasePrimitiveTypeName(buffer_type);
475 }
476 }
477
GetDNNConvKindFromCudnnConvKind(CudnnConvKind kind)478 StatusOr<se::dnn::ConvolutionKind> GetDNNConvKindFromCudnnConvKind(
479 CudnnConvKind kind) {
480 switch (kind) {
481 case CudnnConvKind::kBackwardFilter:
482 return se::dnn::BACKWARD_FILTER;
483 case CudnnConvKind::kBackwardInput:
484 return se::dnn::BACKWARD_DATA;
485 case CudnnConvKind::kForward:
486 return se::dnn::FORWARD;
487 case CudnnConvKind::kForwardActivation:
488 return se::dnn::FORWARD_BIAS_ACTIVATION;
489 default:
490 break;
491 }
492 return InternalError("Unexpected convolution kind");
493 }
494
GetDNNDataTypeFromPrimitiveType(PrimitiveType type)495 StatusOr<se::dnn::DataType> GetDNNDataTypeFromPrimitiveType(
496 PrimitiveType type) {
497 switch (type) {
498 case F16:
499 return se::dnn::ToDataType<Eigen::half>::value;
500 case F32:
501 return se::dnn::ToDataType<float>::value;
502 case F64:
503 return se::dnn::ToDataType<double>::value;
504 case S8:
505 return se::dnn::ToDataType<int8_t>::value;
506 case S32:
507 return se::dnn::ToDataType<int32_t>::value;
508 case BF16:
509 return se::dnn::ToDataType<Eigen::bfloat16>::value;
510 default:
511 break;
512 }
513 return InternalError("Unsupported convolution datatype");
514 }
515
RequireDeterminism(const HloModuleConfig & config)516 bool RequireDeterminism(const HloModuleConfig& config) {
517 static bool require_cudnn_determinism = [] {
518 // TODO(reedwm): Remove the TF_CUDNN_DETERMINISTIC env var.
519 bool cudnn_deterministic = false;
520 TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC",
521 /*default_val=*/false,
522 &cudnn_deterministic));
523 return cudnn_deterministic;
524 }();
525 return tensorflow::OpDeterminismRequired() || require_cudnn_determinism ||
526 config.debug_options().xla_gpu_deterministic_ops();
527 }
528
PickBestResult(absl::Span<AutotuneResult const> profile_results,const HloInstruction & instr)529 StatusOr<AutotuneResult> PickBestResult(
530 absl::Span<AutotuneResult const> profile_results,
531 const HloInstruction& instr) {
532 std::vector<AutotuneResult> filtered_results;
533
534 // For now, we ignore WRONG_RESULT failures because false-positives are
535 // possible (e.g. perhaps the reference algorithm is the one that's
536 // incorrect!). But we don't ignore REDZONE_MODIFIED failures because they're
537 // quite severe and can be detected with high accuracy.
538 absl::c_copy_if(
539 profile_results, std::back_inserter(filtered_results),
540 [](const AutotuneResult& r) {
541 return !(r.has_failure() &&
542 r.failure().kind() != AutotuneResult::WRONG_RESULT);
543 });
544
545 if (filtered_results.empty()) {
546 std::ostringstream msg;
547 msg << "All algorithms tried for " << instr.ToString()
548 << " failed. Falling back to default algorithm. Per-algorithm errors:";
549 for (const auto& result : profile_results) {
550 msg << "\n " << result.failure().msg();
551 }
552 return InternalError("%s", msg.str());
553 }
554
555 auto selected_result = filtered_results.begin();
556 if (!RequireDeterminism(instr.parent()->parent()->config())) {
557 selected_result = absl::c_min_element(
558 filtered_results,
559 [](const AutotuneResult& lhs, const AutotuneResult& rhs) {
560 return tensorflow::proto_utils::FromDurationProto(lhs.run_time()) <
561 tensorflow::proto_utils::FromDurationProto(rhs.run_time());
562 });
563 }
564 return *selected_result;
565 }
566
567 } // namespace gpu
568 } // namespace xla
569