xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/debug_options_flags.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/debug_options_flags.h"
17 
18 #include <vector>
19 
20 #include "absl/base/call_once.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/node_hash_map.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_split.h"
25 #include "tensorflow/compiler/xla/debug_options_parsers.h"
26 #include "tensorflow/compiler/xla/parse_flags_from_env.h"
27 
28 namespace xla {
29 
DefaultDebugOptionsIgnoringFlags()30 DebugOptions DefaultDebugOptionsIgnoringFlags() {
31   DebugOptions opts;
32   opts.set_xla_llvm_enable_alias_scope_metadata(true);
33   opts.set_xla_llvm_enable_noalias_metadata(true);
34   opts.set_xla_llvm_enable_invariant_load_metadata(true);
35   opts.set_xla_llvm_disable_expensive_passes(false);
36   opts.set_xla_backend_optimization_level(3);
37   opts.set_xla_gpu_autotune_level(4);
38   opts.set_xla_cpu_multi_thread_eigen(true);
39   opts.set_xla_gpu_cuda_data_dir("./cuda_sdk_lib");
40   opts.set_xla_gpu_asm_extra_flags("");
41   opts.set_xla_eliminate_hlo_implicit_broadcast(true);
42   opts.set_xla_dump_hlo_as_html(false);
43   opts.set_xla_dump_fusion_visualization(false);
44   opts.set_xla_dump_include_timestamp(false);
45   opts.set_xla_dump_max_hlo_modules(-1);
46   opts.set_xla_dump_module_metadata(false);
47   opts.set_xla_dump_hlo_as_long_text(false);
48 #ifdef ENABLE_MKL
49   opts.set_xla_cpu_use_mkl_dnn(true);
50 #endif  // ENABLE_MKL
51 #ifdef XLA_CPU_USE_ACL
52   opts.set_xla_cpu_use_acl(true);
53 #endif
54   opts.set_xla_cpu_use_jitrt(false);
55   opts.set_xla_gpu_max_kernel_unroll_factor(4);
56 
57   // Run all GPU work on one stream by default. Multi-streaming support has been
58   // removed, so setting this to false has no effect.
59   // TODO(reedwm): Remove this option.
60   opts.set_xla_gpu_disable_multi_streaming(true);
61 
62   opts.set_xla_cpu_enable_fast_math(false);
63   // Disable forms of fast math that have caused users problems in the past.
64   opts.set_xla_cpu_fast_math_honor_nans(true);
65   opts.set_xla_cpu_fast_math_honor_infs(true);
66   opts.set_xla_cpu_fast_math_honor_functions(true);
67   opts.set_xla_cpu_fast_math_honor_division(true);
68 
69   // By default, copy TF's Eigen style min_max behavior with nans.
70   opts.set_xla_cpu_enable_fast_min_max(true);
71 
72   opts.set_xla_gpu_enable_cudnn_frontend(true);
73 
74   opts.set_xla_gpu_enable_cublaslt(false);
75 
76   // Despite the name, fast min/max on GPUs does not seem to be any faster, and
77   // adds very counter-intuitive "NaN-swallowing" behavior.
78   opts.set_xla_gpu_enable_fast_min_max(false);
79   opts.set_xla_gpu_strict_conv_algorithm_picker(true);
80 
81   opts.set_xla_allow_excess_precision(true);
82   opts.set_xla_force_host_platform_device_count(1);
83   opts.set_xla_gpu_all_reduce_combine_threshold_bytes(30 * 1024 * 1024);
84   opts.set_xla_gpu_enable_async_all_reduce(true);
85   opts.set_xla_cpu_enable_xprof_traceme(false);
86   opts.set_xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found(false);
87   opts.set_xla_multiheap_size_constraint_per_heap(-1);
88   opts.set_xla_detailed_logging_and_dumping(true);
89 
90   opts.set_xla_gpu_jitrt_executable(false);
91   opts.set_xla_gpu_nccl_termination_timeout_seconds(-1);
92   opts.set_xla_gpu_enable_shared_constants(true);
93 
94   // Set 4GB space limit for redzone scratch allocator.
95   opts.set_xla_gpu_redzone_scratch_max_megabytes(1LL << 12);
96   opts.set_xla_gpu_shape_checks(DebugOptions::RUNTIME);
97   opts.set_xla_cpu_enable_mlir_lowering(false);
98   opts.set_xla_gpu_enable_mlir_lowering(true);
99   opts.set_xla_gpu_normalize_layouts(false);
100   return opts;
101 }
102 
103 static absl::once_flag flags_init;
104 static DebugOptions* flag_values;
105 static std::vector<tensorflow::Flag>* flag_objects;
106 
107 // Maps pass -> initial fuel values (parsed when AllocateFlags was run).
108 static absl::flat_hash_map<std::string, int64_t>* initial_fuel;
109 
110 // Maps pass -> whether fuel was ever consumed for that pass.
111 static absl::node_hash_map<std::string, std::atomic<bool>>* fuel_ever_consumed;
112 
113 // Maps pass -> remaining fuel.
114 //
115 // All threads start off using this global fuel pool, but ResetThreadLocalFuel()
116 // switches them to a thread-local fuel pool.
117 static absl::node_hash_map<std::string, std::atomic<int64_t>>* global_fuel;
118 
119 // If we're using thread-local fuel, this stores it.
120 static thread_local std::unique_ptr<
121     absl::node_hash_map<std::string, std::atomic<int64_t>>>
122     thread_fuel;  // NOLINT (global variable with nontrivial destructor)
123 
124 // Logs a warning if a pass's fuel was never consumed, on the theory that this
125 // may be a typo in the flag value.  Called atexit.
WarnIfFuelWasNeverConsumed()126 static void WarnIfFuelWasNeverConsumed() {
127   CHECK(fuel_ever_consumed != nullptr);
128   for (const auto& kv : *fuel_ever_consumed) {
129     absl::string_view pass = kv.first;
130     bool was_consumed = kv.second;
131     if (!was_consumed) {
132       LOG(ERROR) << absl::StreamFormat(
133           "Compiler fuel for \"%s\" was never consumed. This may be a typo in "
134           "the --xla_fuel flag you passed.",
135           pass);
136     }
137   }
138 }
139 
140 // Allocates flag_values and flag_objects; this function must not be called more
141 // than once - its call done via call_once.
AllocateFlags()142 static void AllocateFlags() {
143   flag_values = new DebugOptions(DefaultDebugOptionsIgnoringFlags());
144 
145   // Returns a lambda that calls "member_setter" on "flag_values" with the
146   // argument passed in to the lambda.
147   auto bool_setter_for = [](void (DebugOptions::*member_setter)(bool)) {
148     return [member_setter](bool value) {
149       (flag_values->*member_setter)(value);
150       return true;
151     };
152   };
153 
154   // Returns a lambda that calls "member_setter" on "flag_values" with the
155   // argument passed in to the lambda.
156   auto int32_setter_for = [](void (DebugOptions::*member_setter)(int32_t)) {
157     return [member_setter](int32_t value) {
158       (flag_values->*member_setter)(value);
159       return true;
160     };
161   };
162 
163   auto int64_setter_for = [](void (DebugOptions::*member_setter)(int64_t)) {
164     return [member_setter](int64_t value) {
165       (flag_values->*member_setter)(value);
166       return true;
167     };
168   };
169 
170   auto string_setter_for =
171       [](void (DebugOptions::*member_setter)(const std::string& value)) {
172         return [member_setter](const std::string& value) {
173           (flag_values->*member_setter)(value);
174           return true;
175         };
176       };
177 
178   // Custom "sub-parser" lambda for xla_disable_hlo_passes.
179   auto setter_for_xla_disable_hlo_passes =
180       [](std::string comma_separated_values) {
181         for (const auto& passname : std::vector<std::string>(
182                  absl::StrSplit(comma_separated_values, ','))) {
183           flag_values->add_xla_disable_hlo_passes(passname);
184         }
185         return true;
186       };
187 
188   // Custom "sub-parser" lambda for xla_enable_hlo_passes_only.
189   auto setter_for_xla_enable_hlo_passes_only =
190       [](std::string comma_separated_values) {
191         for (const auto& passname : std::vector<std::string>(
192                  absl::StrSplit(comma_separated_values, ','))) {
193           flag_values->add_xla_enable_hlo_passes_only(passname);
194         }
195         return true;
196       };
197 
198   // Custom "sub-parser" lambda for xla_gpu_ptx_file.
199   auto setter_for_xla_gpu_ptx_file = [](std::string value) {
200     flag_values->add_xla_gpu_ptx_file(value);
201     return true;
202   };
203 
204   // Custom "sub-parser" lambda for xla_gpu_llvm_ir_file.
205   auto setter_for_xla_gpu_llvm_ir_file = [](const std::string& value) {
206     flag_values->add_xla_gpu_llvm_ir_file(value);
207     return true;
208   };
209 
210   // Custom "sub-parser" lambda for xla_backend_extra_options.
211   auto setter_for_xla_backend_extra_options =
212       [](std::string comma_separated_values) {
213         auto* extra_options_map =
214             flag_values->mutable_xla_backend_extra_options();
215         parse_xla_backend_extra_options(extra_options_map,
216                                         comma_separated_values);
217         return true;
218       };
219 
220   // Custom "sub-parser" for xla_fuel.  Note that ConsumeFuel does not do any
221   // locking on the fuel global variables.  This means that it's
222   // illegal/undefined behavior to modify this flag value while the compiler is
223   // running.
224   initial_fuel = new absl::flat_hash_map<std::string, int64_t>();
225   fuel_ever_consumed =
226       new absl::node_hash_map<std::string, std::atomic<bool>>();
227   global_fuel = new absl::node_hash_map<std::string, std::atomic<int64_t>>();
228   auto setter_for_xla_fuel = [](std::string xla_fuel_value) {
229     initial_fuel->clear();
230     global_fuel->clear();
231     fuel_ever_consumed->clear();
232 
233     for (const auto& kv : absl::StrSplit(xla_fuel_value, ',')) {
234       std::vector<std::string> pass_and_fuel = absl::StrSplit(kv, '=');
235       if (pass_and_fuel.size() != 2) {
236         LOG(ERROR) << absl::StreamFormat(
237             "Illegal value for --xla_fuel. Saw %s, but expected token %s to "
238             "have format X=INTEGER.",
239             xla_fuel_value, kv);
240         return false;
241       }
242       const auto& pass = pass_and_fuel[0];
243       const auto& fuel_str = pass_and_fuel[1];
244       int64_t fuel;
245       if (!absl::SimpleAtoi(fuel_str, &fuel)) {
246         LOG(ERROR) << absl::StreamFormat(
247             "Illegal value for --xla_fuel. Saw %s, but expected token %s to be "
248             "an integer.",
249             xla_fuel_value, fuel_str);
250         return false;
251       }
252       initial_fuel->emplace(pass, fuel);
253       global_fuel->emplace(pass, fuel);
254       fuel_ever_consumed->emplace(pass, false);
255     }
256 
257     // If --xla_fuel was specified, register an atexit handler which logs a
258     // warning if a pass was specified but never consumed any fuel, on the
259     // theory that this is may be a typo.
260     if (!initial_fuel->empty()) {
261       static absl::once_flag register_atexit_once;
262       absl::call_once(
263           register_atexit_once,
264           +[] { std::atexit(WarnIfFuelWasNeverConsumed); });
265     }
266     return true;
267   };
268 
269   flag_objects = new std::vector<tensorflow::Flag>();
270   // Don't use an initializer list for initializing the vector; this would
271   // create a temporary copy, and exceeds the stack space when compiling with
272   // certain configurations.
273   flag_objects->push_back(tensorflow::Flag(
274       "xla_cpu_enable_fast_math",
275       bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math),
276       flag_values->xla_cpu_enable_fast_math(),
277       "Enable unsafe fast-math optimizations in the CPU compiler; this may "
278       "produce faster code at the expense of some accuracy."));
279   flag_objects->push_back(tensorflow::Flag(
280       "xla_cpu_fast_math_honor_nans",
281       bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_nans),
282       flag_values->xla_cpu_fast_math_honor_nans(),
283       "When xla_cpu_enable_fast_math is true then this controls whether we "
284       "allow operations to produce NaNs.  Ignored when "
285       "xla_cpu_enable_fast_math is false."));
286   flag_objects->push_back(tensorflow::Flag(
287       "xla_cpu_fast_math_honor_infs",
288       bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_infs),
289       flag_values->xla_cpu_fast_math_honor_infs(),
290       "When xla_cpu_enable_fast_math is true then this controls whether we "
291       "allow operations to produce infinites.  Ignored when "
292       "xla_cpu_enable_fast_math is false."));
293   flag_objects->push_back(tensorflow::Flag(
294       "xla_cpu_fast_math_honor_division",
295       bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_division),
296       flag_values->xla_cpu_fast_math_honor_division(),
297       "When xla_cpu_enable_fast_math is true then this controls whether we "
298       "forbid to use multiplication by the reciprocal instead of division. "
299       "Ignored when xla_cpu_enable_fast_math is false."));
300   flag_objects->push_back(tensorflow::Flag(
301       "xla_cpu_fast_math_honor_functions",
302       bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_functions),
303       flag_values->xla_cpu_fast_math_honor_functions(),
304       "When xla_cpu_enable_fast_math is true then this controls whether we "
305       "forbid to approximate calculations for functions. Ignored when "
306       "xla_cpu_enable_fast_math is false."));
307   flag_objects->push_back(tensorflow::Flag(
308       "xla_cpu_enable_fast_min_max",
309       bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_min_max),
310       flag_values->xla_cpu_enable_fast_min_max(),
311       "Enable fast floating point min/max lowering that always propagates "
312       "NaNs."));
313   flag_objects->push_back(tensorflow::Flag(
314       "xla_gpu_enable_fast_min_max",
315       bool_setter_for(&DebugOptions::set_xla_gpu_enable_fast_min_max),
316       flag_values->xla_gpu_enable_fast_min_max(),
317       "Enable fast floating point min/max lowering that does not propagate "
318       "NaNs."));
319   flag_objects->push_back(tensorflow::Flag(
320       "xla_llvm_enable_alias_scope_metadata",
321       bool_setter_for(&DebugOptions::set_xla_llvm_enable_alias_scope_metadata),
322       flag_values->xla_llvm_enable_alias_scope_metadata(),
323       "In LLVM-based backends, enable the emission of !alias.scope metadata in "
324       "the generated IR."));
325   flag_objects->push_back(tensorflow::Flag(
326       "xla_llvm_enable_noalias_metadata",
327       bool_setter_for(&DebugOptions::set_xla_llvm_enable_noalias_metadata),
328       flag_values->xla_llvm_enable_noalias_metadata(),
329       "In LLVM-based backends, enable the emission of !noalias metadata in the "
330       "generated IR."));
331   flag_objects->push_back(tensorflow::Flag(
332       "xla_llvm_enable_invariant_load_metadata",
333       bool_setter_for(
334           &DebugOptions::set_xla_llvm_enable_invariant_load_metadata),
335       flag_values->xla_llvm_enable_invariant_load_metadata(),
336       "In LLVM-based backends, enable the emission of !invariant.load metadata "
337       "in the generated IR."));
338   flag_objects->push_back(tensorflow::Flag(
339       "xla_llvm_disable_expensive_passes",
340       bool_setter_for(&DebugOptions::set_xla_llvm_disable_expensive_passes),
341       flag_values->xla_llvm_disable_expensive_passes(),
342       "In LLVM-based backends, disable a custom set of expensive optimization "
343       "passes."));
344   flag_objects->push_back(tensorflow::Flag(
345       "xla_backend_optimization_level",
346       int32_setter_for(&DebugOptions::set_xla_backend_optimization_level),
347       flag_values->xla_backend_optimization_level(),
348       "Numerical optimization level for the XLA compiler backend."));
349   flag_objects->push_back(tensorflow::Flag(
350       "xla_disable_hlo_passes", setter_for_xla_disable_hlo_passes, "",
351       "Comma-separated list of hlo passes to be disabled. These names must "
352       "exactly match the passes' names; no whitespace around commas."));
353   flag_objects->push_back(tensorflow::Flag(
354       "xla_enable_hlo_passes_only", setter_for_xla_enable_hlo_passes_only, "",
355       "Comma-separated list of hlo passes to be enabled. These names must "
356       "exactly match the passes' names; no whitespace around commas. The "
357       "unspecified passes are all disabled."));
358   flag_objects->push_back(tensorflow::Flag(
359       "xla_disable_all_hlo_passes",
360       bool_setter_for(&DebugOptions::set_xla_disable_all_hlo_passes), false,
361       "Disables all HLO passes.  Notes that some passes are necessary for "
362       "correctness and the invariants that must be satisfied by 'fully "
363       "optimized' HLO are different for different devices and may change "
364       "over time.  The only 'guarantee', such as it is, is that if you compile "
365       "XLA and dump the optimized HLO for some graph, you should be able to "
366       "run it again on the same device with the same build of XLA."));
367   flag_objects->push_back(tensorflow::Flag(
368       "xla_embed_ir_in_executable",
369       bool_setter_for(&DebugOptions::set_xla_embed_ir_in_executable),
370       flag_values->xla_embed_ir_in_executable(),
371       "Embed the compiler IR as a string in the executable."));
372   flag_objects->push_back(tensorflow::Flag(
373       "xla_eliminate_hlo_implicit_broadcast",
374       bool_setter_for(&DebugOptions::set_xla_eliminate_hlo_implicit_broadcast),
375       flag_values->xla_eliminate_hlo_implicit_broadcast(),
376       "Eliminate implicit broadcasts when lowering user computations to HLO "
377       "instructions; use explicit broadcast instead."));
378   flag_objects->push_back(tensorflow::Flag(
379       "xla_cpu_multi_thread_eigen",
380       bool_setter_for(&DebugOptions::set_xla_cpu_multi_thread_eigen),
381       flag_values->xla_cpu_multi_thread_eigen(),
382       "When generating calls to Eigen in the CPU backend, use multi-threaded "
383       "Eigen mode."));
384   flag_objects->push_back(tensorflow::Flag(
385       "xla_gpu_cuda_data_dir", flag_values->mutable_xla_gpu_cuda_data_dir(),
386       "If non-empty, specifies a local directory containing ptxas and nvvm "
387       "libdevice files; otherwise we use those from runfile directories."));
388   flag_objects->push_back(tensorflow::Flag(
389       "xla_gpu_ftz", bool_setter_for(&DebugOptions::set_xla_gpu_ftz),
390       flag_values->xla_gpu_ftz(),
391       "If true, flush-to-zero semantics are enabled in the code generated for "
392       "GPUs."));
393   flag_objects->push_back(tensorflow::Flag(
394       "xla_gpu_disable_multi_streaming",
395       bool_setter_for(&DebugOptions::set_xla_gpu_disable_multi_streaming),
396       flag_values->xla_gpu_disable_multi_streaming(),
397       "Has no impact. Multi-streaming support has been removed from XLA GPU so "
398       "it is always disabled."));
399   flag_objects->push_back(tensorflow::Flag(
400       "xla_gpu_max_kernel_unroll_factor",
401       int32_setter_for(&DebugOptions::set_xla_gpu_max_kernel_unroll_factor),
402       flag_values->xla_gpu_max_kernel_unroll_factor(),
403       "Specify the maximum kernel unroll factor for the GPU backend."));
404   flag_objects->push_back(tensorflow::Flag(
405       "xla_gpu_ptx_file", setter_for_xla_gpu_ptx_file, "",
406       "If non-empty, specifies a file containing ptx to use. The filename "
407       "prefix must have the same pattern as PTX dumped by XLA. This allows to "
408       "match one specific module. General workflow. Get the generated module "
409       "ptx from XLA, modify it, then pass it back via this option."));
410   flag_objects->push_back(tensorflow::Flag(
411       "xla_gpu_llvm_ir_file", setter_for_xla_gpu_llvm_ir_file, "",
412       "If non-empty, specifies a file containing textual LLVM IR to use. The "
413       "filename prefix must have the same pattern as LLVM dumped by XLA "
414       "(i.e. module_0001.ir-no-opt.ll -> module_0001.MY_NEW_FILE.ll). This "
415       "allows to match one specific module. General workflow. Get the not "
416       "optimized LLVM IR from XLA, modify it, then pass it back via this "
417       "option."));
418   flag_objects->push_back(tensorflow::Flag(
419       "xla_test_all_output_layouts",
420       bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts),
421       flag_values->xla_test_all_output_layouts(),
422       "Let ClientLibraryTestBase::ComputeAndCompare* test all permutations of "
423       "output layouts. For example, with a 3D shape, all permutations of the "
424       "set {0, 1, 2} are tried."));
425   flag_objects->push_back(tensorflow::Flag(
426       "xla_test_all_input_layouts",
427       bool_setter_for(&DebugOptions::set_xla_test_all_input_layouts),
428       flag_values->xla_test_all_input_layouts(),
429       "Let ClientLibraryTestBase::ComputeAndCompare* test all permutations of "
430       "*input* layouts. For example, for 2 input arguments with 2D shape and "
431       "4D shape, the computation will run 2! * 4! times for every possible "
432       "layouts"));
433   flag_objects->push_back(tensorflow::Flag(
434       "xla_hlo_profile", bool_setter_for(&DebugOptions::set_xla_hlo_profile),
435       flag_values->xla_hlo_profile(),
436       "Instrument the computation to collect per-HLO cycle counts"));
437   flag_objects->push_back(tensorflow::Flag(
438       "xla_backend_extra_options", setter_for_xla_backend_extra_options, "",
439       "Extra options to pass to a backend; comma-separated list of 'key=val' "
440       "strings (=val may be omitted); no whitespace around commas."));
441   flag_objects->push_back(
442       tensorflow::Flag("xla_cpu_use_mkl_dnn",
443                        bool_setter_for(&DebugOptions::set_xla_cpu_use_mkl_dnn),
444                        flag_values->xla_cpu_use_mkl_dnn(),
445                        "Generate calls to MKL-DNN in the CPU backend."));
446   flag_objects->push_back(tensorflow::Flag(
447       "xla_cpu_use_acl", bool_setter_for(&DebugOptions::set_xla_cpu_use_acl),
448       flag_values->xla_cpu_use_acl(),
449       "Generate calls to ACL (Arm Compute Library) in the CPU backend."));
450   flag_objects->push_back(tensorflow::Flag(
451       "xla_cpu_use_jitrt",
452       bool_setter_for(&DebugOptions::set_xla_cpu_use_jitrt),
453       flag_values->xla_cpu_use_jitrt(), "Enable JitRt in the CPU backend."));
454   flag_objects->push_back(tensorflow::Flag(
455       "xla_gpu_crash_on_verification_failures",
456       bool_setter_for(
457           &DebugOptions::set_xla_gpu_crash_on_verification_failures),
458       flag_values->xla_gpu_crash_on_verification_failures(),
459       "Crashes the program on extra verification failures, e.g. cuDNN cross "
460       "checking failures"));
461   flag_objects->push_back(tensorflow::Flag(
462       "xla_gpu_strict_conv_algorithm_picker",
463       bool_setter_for(&DebugOptions::set_xla_gpu_strict_conv_algorithm_picker),
464       flag_values->xla_gpu_strict_conv_algorithm_picker(),
465       "Upgrades warnings to failures when all algorithms fail conv "
466       "autotuning."));
467   flag_objects->push_back(tensorflow::Flag(
468       "xla_gpu_autotune_level",
469       int32_setter_for(&DebugOptions::set_xla_gpu_autotune_level),
470       flag_values->xla_gpu_autotune_level(),
471       "Set GEMM and Convolution auto-tuning level. 0 = off; 1 = on; 2 = "
472       "on+init; 3 = on+init+reinit; 4 = on+init+reinit+check."));
473   flag_objects->push_back(tensorflow::Flag(
474       "xla_force_host_platform_device_count",
475       int32_setter_for(&DebugOptions::set_xla_force_host_platform_device_count),
476       flag_values->xla_force_host_platform_device_count(),
477       "Force the host platform to pretend that there are these many host "
478       "\"devices\". All of these host devices are backed by the same "
479       "threadpool. Setting this to anything other than 1 can increase overhead "
480       "from context switching but we let the user override this behavior to "
481       "help run tests on the host that run models in parallel across multiple "
482       "devices."));
483   flag_objects->push_back(tensorflow::Flag(
484       "xla_gpu_disable_gpuasm_optimizations",
485       bool_setter_for(&DebugOptions::set_xla_gpu_disable_gpuasm_optimizations),
486       flag_values->xla_gpu_disable_gpuasm_optimizations(),
487       "In XLA:GPU run ptxas in -O0 (default is -O3)."));
488   flag_objects->push_back(tensorflow::Flag(
489       "xla_gpu_asm_extra_flags",
490       string_setter_for(&DebugOptions::set_xla_gpu_asm_extra_flags), "",
491       "Pass extra parameters to the GPU assembler tool (i.e., ptxas for CUDA). "
492       "If multiple parameters, separate them by comma."));
493   flag_objects->push_back(tensorflow::Flag(
494       "xla_fuel", setter_for_xla_fuel, /*default_value_for_display=*/"",
495       "Sets compiler fuel, useful for bisecting bugs in passes.  Format "
496       "--xla_fuel=PASS1=NUM1,PASS2=NUM2,..."));
497   flag_objects->push_back(tensorflow::Flag(
498       "xla_dump_to", string_setter_for(&DebugOptions::set_xla_dump_to),
499       flag_values->xla_dump_to(),
500       "Directory into which debugging data is written. If not specified but "
501       "another dumping flag is passed, data will be written to stdout. To "
502       "explicitly write to stdout, set this to \"-\". The values \"sponge\" "
503       "and \"test_undeclared_outputs_dir\" have a special meaning: They cause "
504       "us to dump into the directory specified by the environment variable "
505       "TEST_UNDECLARED_OUTPUTS_DIR."));
506   flag_objects->push_back(tensorflow::Flag(
507       "xla_dump_hlo_as_text",
508       bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_text),
509       flag_values->xla_dump_hlo_as_text(),
510       "Dumps HLO modules as text before and after optimizations. Results are "
511       "written to the --xla_dump_to dir, or, if no dir is specified, to "
512       "stdout."));
513   flag_objects->push_back(tensorflow::Flag(
514       "xla_dump_hlo_as_long_text",
515       bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_long_text),
516       flag_values->xla_dump_hlo_as_long_text(),
517       "Dumps HLO modules as long text before and after optimizations. Results "
518       "are written to the --xla_dump_to dir, or, if no dir is specified, to "
519       "stdout. Ignored unless xla_dump_hlo_as_text is true."));
520   flag_objects->push_back(tensorflow::Flag(
521       "xla_dump_hlo_as_proto",
522       bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_proto),
523       flag_values->xla_dump_hlo_as_proto(),
524       "Dumps HLO modules as HloProtos to the directory specified by "
525       "--xla_dump_to."));
526   flag_objects->push_back(
527       tensorflow::Flag("xla_dump_hlo_as_dot",
528                        bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_dot),
529                        flag_values->xla_dump_hlo_as_dot(),
530                        "Dumps HLO modules rendered as dot files to the "
531                        "directory specified by --xla_dump_to."));
532   flag_objects->push_back(
533       tensorflow::Flag("xla_dump_hlo_as_html",
534                        bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_html),
535                        flag_values->xla_dump_hlo_as_html(),
536                        "Dumps HLO modules rendered as HTML files to the "
537                        "directory specified by --xla_dump_to."));
538   flag_objects->push_back(tensorflow::Flag(
539       "xla_dump_hlo_as_url",
540       bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_url),
541       flag_values->xla_dump_hlo_as_url(),
542       "Tries to dump HLO modules rendered as URLs to stdout (and also to the "
543       "directory specified by --xla_dump_to). This is not implemented by "
544       "default; you need to add a plugin which calls "
545       "RegisterGraphToURLRenderer()."));
546   flag_objects->push_back(tensorflow::Flag(
547       "xla_dump_fusion_visualization",
548       bool_setter_for(&DebugOptions::set_xla_dump_fusion_visualization),
549       flag_values->xla_dump_fusion_visualization(),
550       "Tries to generate HLO fusion visualization as an HTML page to the "
551       "directory specified by --xla_dump_to). This is not implemented by "
552       "default; you need to add a plugin which calls "
553       "RegisterGraphToURLRenderer(). Generates a file per computation. "
554       "Currently only implemented for the GPU backend."));
555   flag_objects->push_back(tensorflow::Flag(
556       "xla_dump_hlo_snapshots",
557       bool_setter_for(&DebugOptions::set_xla_dump_hlo_snapshots),
558       flag_values->xla_dump_hlo_snapshots(),
559       "Every time an HLO module is run, dumps an HloSnapshot to the directory "
560       "specified by --xla_dump_to."));
561   flag_objects->push_back(tensorflow::Flag(
562       "xla_dump_hlo_module_re",
563       string_setter_for(&DebugOptions::set_xla_dump_hlo_module_re),
564       flag_values->xla_dump_hlo_module_re(),
565       "Limits dumping only to modules which match this regular expression. "
566       "Default is to dump all modules."));
567   flag_objects->push_back(tensorflow::Flag(
568       "xla_dump_hlo_pass_re",
569       string_setter_for(&DebugOptions::set_xla_dump_hlo_pass_re),
570       flag_values->xla_dump_hlo_pass_re(),
571       "If specified, dumps HLO before and after optimization passes which "
572       "match this regular expression, in addition to dumping at the very "
573       "beginning and end of compilation."));
574   flag_objects->push_back(tensorflow::Flag(
575       "xla_dump_include_timestamp",
576       bool_setter_for(&DebugOptions::set_xla_dump_include_timestamp),
577       flag_values->xla_dump_include_timestamp(),
578       "If specified, includes a timestamp in the dumped filenames."));
579   flag_objects->push_back(tensorflow::Flag(
580       "xla_dump_max_hlo_modules",
581       int32_setter_for(&DebugOptions::set_xla_dump_max_hlo_modules),
582       flag_values->xla_dump_max_hlo_modules(),
583       "Max number of hlo module dumps in a directory. Set to < 0 for "
584       "unbounded."));
585   flag_objects->push_back(tensorflow::Flag(
586       "xla_dump_module_metadata",
587       bool_setter_for(&DebugOptions::set_xla_dump_module_metadata),
588       flag_values->xla_dump_module_metadata(),
589       "Dumps HloModuleMetadata as text protos to the directory specified "
590       "by --xla_dump_to."));
591   flag_objects->push_back(tensorflow::Flag(
592       "xla_dump_compress_protos",
593       bool_setter_for(&DebugOptions::set_xla_dump_compress_protos),
594       flag_values->xla_dump_compress_protos(),
595       "Gzip-compress protos dumped by --xla_dump_hlo_as_proto."));
596   flag_objects->push_back(tensorflow::Flag(
597       "xla_hlo_graph_addresses",
598       bool_setter_for(&DebugOptions::set_xla_hlo_graph_addresses),
599       flag_values->xla_hlo_graph_addresses(),
600       "When rendering graphs (--xla_dump_hlo_as_{dot,html,url}), displays "
601       "the address in memory of each HloInstruction object."));
602   flag_objects->push_back(tensorflow::Flag(
603       "xla_hlo_graph_sharding_color",
604       bool_setter_for(&DebugOptions::set_xla_hlo_graph_sharding_color),
605       flag_values->xla_hlo_graph_sharding_color(),
606       "Assign colors based on sharding assignments when generating the HLO "
607       "graphs."));
608   flag_objects->push_back(tensorflow::Flag(
609       "xla_allow_excess_precision",
610       bool_setter_for(&DebugOptions::set_xla_allow_excess_precision),
611       flag_values->xla_allow_excess_precision(),
612       "Allow xla to increase the output precision of an instruction."));
613   flag_objects->push_back(tensorflow::Flag(
614       "xla_gpu_force_conv_nchw",
615       bool_setter_for(&DebugOptions::set_xla_gpu_force_conv_nchw),
616       flag_values->xla_gpu_force_conv_nchw(),
617       "For cuDNN convolutions, always use NCHW layouts."));
618   flag_objects->push_back(tensorflow::Flag(
619       "xla_gpu_force_conv_nhwc",
620       bool_setter_for(&DebugOptions::set_xla_gpu_force_conv_nhwc),
621       flag_values->xla_gpu_force_conv_nhwc(),
622       "For cuDNN convolutions, always use NHWC layouts."));
623   flag_objects->push_back(tensorflow::Flag(
624       "xla_gpu_algorithm_denylist_path",
625       string_setter_for(&DebugOptions::set_xla_gpu_algorithm_denylist_path),
626       flag_values->xla_gpu_algorithm_denylist_path(),
627       "An AlgorithmDenylist text proto file as a denylist of convolutions to "
628       "avoid to use."));
629   flag_objects->push_back(tensorflow::Flag(
630       "xla_tpu_detect_nan",
631       bool_setter_for(&DebugOptions::set_xla_tpu_detect_nan),
632       flag_values->xla_tpu_detect_nan(),
633       "Trigger error on execution on TPU if a NAN value is detected"));
634   flag_objects->push_back(tensorflow::Flag(
635       "xla_tpu_detect_inf",
636       bool_setter_for(&DebugOptions::set_xla_tpu_detect_inf),
637       flag_values->xla_tpu_detect_inf(),
638       "Trigger error on execution on TPU if a INF value is detected"));
639   flag_objects->push_back(tensorflow::Flag(
640       "xla_cpu_enable_xprof_traceme",
641       bool_setter_for(&DebugOptions::set_xla_cpu_enable_xprof_traceme),
642       flag_values->xla_cpu_enable_xprof_traceme(),
643       "If true, XLA CPU generates code to call "
644       "TraceMe::Activity{Start|End} around HLO operations."));
645   flag_objects->push_back(tensorflow::Flag(
646       "xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found",
647       bool_setter_for(
648           &DebugOptions::
649               set_xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found),
650       flag_values->xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found(),
651       "If true, XLA GPU falls back to the driver if ptxas is not found. Note "
652       "that falling back to the driver can have drawbacks like using more "
653       "memory and/or other bugs during compilation, so we recommend setting "
654       "this flag to false."));
655   flag_objects->push_back(tensorflow::Flag(
656       "xla_multiheap_size_constraint_per_heap",
657       int32_setter_for(
658           &DebugOptions::set_xla_multiheap_size_constraint_per_heap),
659       flag_values->xla_multiheap_size_constraint_per_heap(),
660       "Generates multiple heaps (i.e., temp buffers) with a size "
661       "constraint on each heap to avoid Out-of-Memory due to memory "
662       "fragmentation. The constraint is soft, so it works with tensors "
663       "larger than the given constraint size. -1 corresponds to no "
664       "constraints."));
665   flag_objects->push_back(tensorflow::Flag(
666       "xla_gpu_force_compilation_parallelism",
667       int32_setter_for(
668           &DebugOptions::set_xla_gpu_force_compilation_parallelism),
669       flag_values->xla_gpu_force_compilation_parallelism(),
670       "Overrides normal multi-threaded compilation settting to use this many "
671       "threads. Setting to 0 (the default value) means no enforcement."));
672   flag_objects->push_back(tensorflow::Flag(
673       "xla_gpu_deterministic_ops",
674       bool_setter_for(&DebugOptions::set_xla_gpu_deterministic_ops),
675       flag_values->xla_gpu_deterministic_ops(),
676       "Guarantees run-to-run determinism on GPU."));
677   flag_objects->push_back(tensorflow::Flag(
678       "xla_gpu_enable_async_all_reduce",
679       bool_setter_for(&DebugOptions::set_xla_gpu_enable_async_all_reduce),
680       flag_values->xla_gpu_enable_async_all_reduce(),
681       "Converts synchronous all-reduce ops into asynchronous."));
682   flag_objects->push_back(tensorflow::Flag(
683       "xla_gpu_all_reduce_combine_threshold_bytes",
684       int64_setter_for(
685           &DebugOptions::set_xla_gpu_all_reduce_combine_threshold_bytes),
686       flag_values->xla_gpu_all_reduce_combine_threshold_bytes(),
687       "Size threshold (in bytes) for the GPU all-reduce combiner."));
688   flag_objects->push_back(tensorflow::Flag(
689       "xla_gpu_all_reduce_contiguous",
690       bool_setter_for(&DebugOptions::set_xla_gpu_all_reduce_contiguous),
691       flag_values->xla_gpu_all_reduce_contiguous(),
692       "Combine all-reduces into a single operation over a contiguous buffer."));
693   flag_objects->push_back(tensorflow::Flag(
694       "xla_gpu_all_reduce_blueconnect_num_devices_per_host",
695       int32_setter_for(
696           &DebugOptions::
697               set_xla_gpu_all_reduce_blueconnect_num_devices_per_host),
698       flag_values->xla_gpu_all_reduce_blueconnect_num_devices_per_host(),
699       "Number of devices per host for first stage of BlueConnect decomposition "
700       "pass. The pass will attempt to decompose all-reduces ops into a "
701       "ReduceScatter-AllReduce-AllGather sequence, with the initial "
702       "ReduceScatter being performed over all of the devices in the same host. "
703       "Set to < 1 to disable all-reduce decomposition."));
704   flag_objects->push_back(
705       tensorflow::Flag("xla_gpu_dump_llvmir",
706                        bool_setter_for(&DebugOptions::set_xla_gpu_dump_llvmir),
707                        flag_values->xla_gpu_dump_llvmir(), "Dump LLVM IR."));
708   flag_objects->push_back(tensorflow::Flag(
709       "xla_gpu_enable_cudnn_frontend",
710       bool_setter_for(&DebugOptions::set_xla_gpu_enable_cudnn_frontend),
711       flag_values->xla_gpu_enable_cudnn_frontend(),
712       "Use the cuDNN frontend API for convolutions when possible."));
713   flag_objects->push_back(tensorflow::Flag(
714       "xla_gpu_enable_cublaslt",
715       bool_setter_for(&DebugOptions::set_xla_gpu_enable_cublaslt),
716       flag_values->xla_gpu_enable_cublaslt(),
717       "Use cuBLASLt for GEMMs when possible."));
718   flag_objects->push_back(tensorflow::Flag(
719       "xla_dump_disable_metadata",
720       bool_setter_for(&DebugOptions::set_xla_dump_disable_metadata),
721       flag_values->xla_dump_disable_metadata(),
722       "Disable dumping HLO metadata in HLO dumps."));
723   flag_objects->push_back(tensorflow::Flag(
724       "xla_dump_hlo_pipeline_re",
725       string_setter_for(&DebugOptions::set_xla_dump_hlo_pipeline_re),
726       flag_values->xla_dump_hlo_pipeline_re(),
727       "If specified, dumps HLO before and after optimization passes in the "
728       "pass pipelines that match this regular expression."));
729   flag_objects->push_back(tensorflow::Flag(
730       "xla_gpu_jitrt_executable",
731       bool_setter_for(&DebugOptions::set_xla_gpu_jitrt_executable),
732       flag_values->xla_gpu_jitrt_executable(),
733       "Whether to enable XLIR to compile gpu programs to JitRt."));
734   flag_objects->push_back(tensorflow::Flag(
735       "xla_gpu_nccl_termination_timeout_seconds",
736       int64_setter_for(
737           &DebugOptions::set_xla_gpu_nccl_termination_timeout_seconds),
738       flag_values->xla_gpu_nccl_termination_timeout_seconds(),
739       "Timeout in seconds before terminating jobs stuck in NCCL Rendezvous."));
740   flag_objects->push_back(tensorflow::Flag(
741       "xla_gpu_enable_shared_constants",
742       bool_setter_for(&DebugOptions::set_xla_gpu_enable_shared_constants),
743       flag_values->xla_gpu_enable_shared_constants(),
744       "Enable constant sharing between GPU executables"));
745   flag_objects->push_back(tensorflow::Flag(
746       "xla_gpu_redzone_scratch_max_megabytes",
747       int64_setter_for(
748           &DebugOptions::set_xla_gpu_redzone_scratch_max_megabytes),
749       flag_values->xla_gpu_redzone_scratch_max_megabytes(),
750       "Max size (in megabytes) for the GPU redzone scratch allocator."));
751   flag_objects->push_back(tensorflow::Flag(
752       "xla_gpu_simplify_all_fp_conversions",
753       bool_setter_for(&DebugOptions::set_xla_gpu_simplify_all_fp_conversions),
754       flag_values->xla_gpu_simplify_all_fp_conversions(),
755       "Allows any chain of floating-point conversions to be simplified."));
756   flag_objects->push_back(tensorflow::Flag(
757       "xla_cpu_enable_mlir_lowering",
758       bool_setter_for(&DebugOptions::set_xla_cpu_enable_mlir_lowering),
759       flag_values->xla_cpu_enable_mlir_lowering(),
760       "Enable MLIR-based lowering in XLA:CPU instead of LLVM emitters."));
761   flag_objects->push_back(tensorflow::Flag(
762       "xla_gpu_enable_mlir_lowering",
763       bool_setter_for(&DebugOptions::set_xla_gpu_enable_mlir_lowering),
764       flag_values->xla_gpu_enable_mlir_lowering(),
765       "Enable MLIR-based lowering in XLA:GPU instead of LLVM emitters."));
766   flag_objects->push_back(tensorflow::Flag(
767       "xla_gpu_normalize_layouts",
768       bool_setter_for(&DebugOptions::set_xla_gpu_normalize_layouts),
769       flag_values->xla_gpu_normalize_layouts(),
770       "An experimental option to force all layouts present in the "
771       "after-optimizations HLO to be descending"));
772   flag_objects->push_back(tensorflow::Flag(
773       "xla_cpu_strict_dot_conv_math",
774       bool_setter_for(&DebugOptions::set_xla_cpu_strict_dot_conv_math),
775       flag_values->xla_cpu_strict_dot_conv_math(),
776       "By default, XLA:CPU will run fp16 dot/conv as fp32, as this is "
777       "generally (much) faster on our hardware.  Set this flag to true to "
778       "disable this behavior."));
779 
780   ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", *flag_objects);
781 }  // NOLINT(readability/fn_size)
782 
AppendDebugOptionsFlags(std::vector<tensorflow::Flag> * flag_list)783 void AppendDebugOptionsFlags(std::vector<tensorflow::Flag>* flag_list) {
784   absl::call_once(flags_init, &AllocateFlags);
785   flag_list->insert(flag_list->end(), flag_objects->begin(),
786                     flag_objects->end());
787 }
788 
GetDebugOptionsFromFlags()789 xla::DebugOptions GetDebugOptionsFromFlags() {
790   absl::call_once(flags_init, &AllocateFlags);
791   return *flag_values;
792 }
793 
ResetThreadLocalFuel()794 void ResetThreadLocalFuel() {
795   absl::call_once(flags_init, &AllocateFlags);
796 
797   thread_fuel.reset(
798       new absl::node_hash_map<std::string, std::atomic<int64_t>>());
799   CHECK(initial_fuel != nullptr);
800   for (const auto& kv : *initial_fuel) {
801     thread_fuel->emplace(kv.first, kv.second);
802   }
803 }
804 
ConsumeFuel(absl::string_view pass,bool * just_ran_out)805 bool ConsumeFuel(absl::string_view pass, bool* just_ran_out) {
806   absl::call_once(flags_init, &AllocateFlags);
807   if (just_ran_out != nullptr) {
808     *just_ran_out = false;
809   }
810   auto* fuel_pool = thread_fuel ? thread_fuel.get() : global_fuel;
811   if (fuel_pool->empty()) {
812     return true;
813   }
814   auto it = fuel_pool->find(pass);
815   if (it == fuel_pool->end()) {
816     return true;
817   }
818   std::atomic<int64_t>& remaining_fuel = it->second;
819   std::atomic<bool>& fuel_has_been_consumed = fuel_ever_consumed->at(pass);
820   fuel_has_been_consumed = true;
821 
822   int64_t remaining = remaining_fuel.fetch_sub(1);
823   if (just_ran_out != nullptr) {
824     *just_ran_out = remaining == 0;
825   }
826   return remaining > 0;
827 }
828 
829 }  // namespace xla
830