xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_module_config.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_
18 
19 #include <optional>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/compiler/xla/debug_options_flags.h"
26 #include "tensorflow/compiler/xla/service/computation_layout.h"
27 #include "tensorflow/compiler/xla/service/computation_placer.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/xla.pb.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 
32 namespace xla {
33 
34 enum class FusionConfigCollection {
35   kOff,      // Do not collect configuration.
36   kPerEdge,  // Collect per-edge configuration.
37   kPerNode,  // Collect per-node configuration.
38 };
39 
40 // This class gathers all settings and values which affect the compiled
41 // executable outside of the HLO code itself. This include layouts of inputs and
42 // outputs to the module and settings such as HLO profiling. Together the
43 // HloModule and HloModuleConfig unambiguously determine a particular
44 // executable.
45 class HloModuleConfig {
46  public:
47   // Represents a pair of input and output of the entry computation that can be
48   // considered as the original and updated values of a variable maintained by
49   // the caller, and that can be transparently sharded by XLA as an internal
50   // optimization. If sharded, XLA will create separate sharding/unsharding
51   // programs, and the caller is responsible to call the XLA-generated
52   // sharding/unsharding programs before and after the sharded main program.
53   //
54   // If the variable is not updated and there is not a corresponding output, use
55   // {-1} as the output_shape_index.
56   //
57   // The sharding/unsharding programs will include all the input/output pairs in
58   // shardable_value_update_pairs() as a flat tuple in their inputs/outputs,
59   // sorted by (input_parameter_number, parameter_shape_index).
60   //
61   // A typical usage pattern is to shard the variables first, then repeatedly
62   // invoke the main program, and finally invoke the unsharding program before
63   // they are used in full-shape.
64   struct ShardableValueUpdatePair {
65     int64_t input_parameter_number;
66     ShapeIndex parameter_shape_index;
67     ShapeIndex output_shape_index;
68   };
69 
70   // A configuration can be created either with, or without an entry
71   // ComputationLayout. The default ctor creates it without -- in this case
72   // accessing entry_computation_layout will CHECK-fail. The ctor accepting a
73   // ProgramShape creates a computation layout using this shape.
74   // The layouts in the ProgramShape will be reset to default unless
75   // ignore_layouts is set to false.
HloModuleConfig()76   HloModuleConfig() { debug_options_ = DefaultDebugOptionsIgnoringFlags(); }
77 
78   explicit HloModuleConfig(const ProgramShape& program_shape,
79                            bool ignore_layouts = true);
80 
81   explicit HloModuleConfig(ComputationLayout entry_computation_layout);
82 
83   // Checks if this config has an entry computation layout already.
has_entry_computation_layout()84   bool has_entry_computation_layout() const {
85     return entry_computation_layout_.has_value();
86   }
87 
88   // Sets the entry_computation_layout's parameter and result shapes for this
89   // config, according to the given program shape. The parameters and result
90   // are set to default layout.
91   void SetDefaultComputationLayout(const ProgramShape& program_shape);
92 
93   // Same as above but if the given program contains layout for parameters or
94   // result, the entry_computation_layout's layout is updated accordingly.
95   void SetComputationLayoutIfExists(const ProgramShape& program_shape);
96 
97   // Returns a constant reference to the layout of the entry computation.
98   // Assumes the layout was set.
entry_computation_layout()99   const ComputationLayout& entry_computation_layout() const {
100     CHECK(entry_computation_layout_.has_value());
101     return *entry_computation_layout_;
102   }
103 
104   // Returns a mutable pointer to the layout of the entry computation.
105   // Assumes the layout was set.
mutable_entry_computation_layout()106   ComputationLayout* mutable_entry_computation_layout() {
107     CHECK(entry_computation_layout_.has_value());
108     return &(*entry_computation_layout_);
109   }
110 
111   // Clears the entry computation layout.
clear_entry_computation_layout()112   void clear_entry_computation_layout() {
113     entry_computation_layout_ = std::nullopt;
114   }
115 
116   // Returns whether to enable HLO-level profiling.
hlo_profiling_enabled()117   bool hlo_profiling_enabled() const {
118     return debug_options_.xla_hlo_profile();
119   }
120 
cpu_traceme_enabled()121   bool cpu_traceme_enabled() const {
122     return debug_options_.xla_cpu_enable_xprof_traceme();
123   }
124 
125   // Sets/returns the module seed set during execution.
set_seed(uint64_t seed)126   void set_seed(uint64_t seed) { seed_ = seed; }
seed()127   uint64_t seed() const { return seed_; }
128 
129   // Set the launch id of the program. Launch id identifies a set of programs
130   // that should be launched together.
set_launch_id(uint64_t launch_id)131   void set_launch_id(uint64_t launch_id) { launch_id_ = launch_id; }
132 
launch_id()133   int32_t launch_id() const { return launch_id_; }
134 
set_replica_count(int64_t replica_count)135   void set_replica_count(int64_t replica_count) {
136     replica_count_ = replica_count;
137   }
replica_count()138   int64_t replica_count() const { return replica_count_; }
139 
set_num_partitions(int64_t num_partitions)140   void set_num_partitions(int64_t num_partitions) {
141     num_partitions_ = num_partitions;
142   }
num_partitions()143   int64_t num_partitions() const { return num_partitions_; }
144 
param_requires_broadcast_via_collectives()145   const std::vector<bool> param_requires_broadcast_via_collectives() const {
146     return param_requires_broadcast_via_collectives_;
147   }
set_param_requires_broadcast_via_collectives(const std::vector<bool> require_broadcast)148   void set_param_requires_broadcast_via_collectives(
149       const std::vector<bool> require_broadcast) {
150     param_requires_broadcast_via_collectives_ = std::move(require_broadcast);
151   }
152 
set_use_spmd_partitioning(bool use_spmd_partitioning)153   void set_use_spmd_partitioning(bool use_spmd_partitioning) {
154     use_spmd_partitioning_ = use_spmd_partitioning;
155   }
use_spmd_partitioning()156   bool use_spmd_partitioning() const { return use_spmd_partitioning_; }
157 
set_use_auto_spmd_partitioning(bool use_auto_spmd_partitioning)158   void set_use_auto_spmd_partitioning(bool use_auto_spmd_partitioning) {
159     use_auto_spmd_partitioning_ = use_auto_spmd_partitioning;
160     if (use_auto_spmd_partitioning) {
161       // TODO(yuemmawang) Remove this warning once auto sharding is thoroughly
162       // tested with fleetwide models.
163       LOG(WARNING) << "Warning: Using auto_spmd_partitioning. It is "
164                       "experimental and may "
165                       "contain bugs!";
166       LOG(INFO) << "Overwriting use_spmd_partitioning to true, because "
167                    "use_auto_spmd_partitioning is true.";
168       set_use_spmd_partitioning(true);
169     }
170   }
use_auto_spmd_partitioning()171   bool use_auto_spmd_partitioning() const {
172     return use_auto_spmd_partitioning_;
173   }
174 
set_auto_spmd_partitioning_mesh_shape(std::vector<int64_t> mesh_shape)175   void set_auto_spmd_partitioning_mesh_shape(std::vector<int64_t> mesh_shape) {
176     auto_spmd_partitioning_mesh_shape_ = mesh_shape;
177   }
auto_spmd_partitioning_mesh_shape()178   std::vector<int64_t> auto_spmd_partitioning_mesh_shape() const {
179     return auto_spmd_partitioning_mesh_shape_;
180   }
181 
set_auto_spmd_partitioning_mesh_ids(std::vector<int64_t> mesh_ids)182   void set_auto_spmd_partitioning_mesh_ids(std::vector<int64_t> mesh_ids) {
183     auto_spmd_partitioning_mesh_ids_ = mesh_ids;
184   }
auto_spmd_partitioning_mesh_ids()185   std::vector<int64_t> auto_spmd_partitioning_mesh_ids() const {
186     return auto_spmd_partitioning_mesh_ids_;
187   }
188 
189   // If enabled, deduplicate equivalent hlos into function calls to reduce code
190   // size.
set_deduplicate_hlo(bool deduplicate_hlo)191   void set_deduplicate_hlo(bool deduplicate_hlo) {
192     deduplicate_hlo_ = deduplicate_hlo;
193   }
194 
set_device_type(const std::string & device_type)195   void set_device_type(const std::string& device_type) {
196     device_type_ = device_type;
197   }
198 
deduplicate_hlo()199   bool deduplicate_hlo() const { return deduplicate_hlo_; }
200 
201   // Return a string which unambiguously represents all the fields of this data
202   // structure. Used for generating a cache key for storing the compiled
203   // executable.
204   std::string compilation_cache_key() const;
205 
device_type()206   std::string device_type() const { return device_type_; }
207 
debug_options()208   const DebugOptions& debug_options() const { return debug_options_; }
209 
set_debug_options(const DebugOptions & debug_options)210   void set_debug_options(const DebugOptions& debug_options) {
211     debug_options_ = debug_options;
212   }
213 
214   // Sets/returns the number of intra op threads for this module.
set_intra_op_parallelism_threads(const int intra_op_parallelism_threads)215   void set_intra_op_parallelism_threads(
216       const int intra_op_parallelism_threads) {
217     intra_op_parallelism_threads_ = intra_op_parallelism_threads;
218   }
intra_op_parallelism_threads()219   int64_t intra_op_parallelism_threads() const {
220     return intra_op_parallelism_threads_;
221   }
222 
223   // Checks if this config has a static device assignment.
has_static_device_assignment()224   bool has_static_device_assignment() const {
225     return static_device_assignment_.has_value();
226   }
227 
228   // Getter and setter of the compile-time known device assignment.
static_device_assignment()229   const DeviceAssignment& static_device_assignment() const {
230     CHECK(static_device_assignment_.has_value());
231     return *static_device_assignment_;
232   }
set_static_device_assignment(const DeviceAssignment & device_assignment)233   void set_static_device_assignment(const DeviceAssignment& device_assignment) {
234     static_device_assignment_ = device_assignment;
235   }
236 
shardable_value_update_pairs()237   const std::vector<ShardableValueUpdatePair> shardable_value_update_pairs()
238       const {
239     return shardable_value_update_pairs_;
240   }
set_shardable_value_update_pairs(std::vector<ShardableValueUpdatePair> pairs)241   void set_shardable_value_update_pairs(
242       std::vector<ShardableValueUpdatePair> pairs) {
243     shardable_value_update_pairs_ = std::move(pairs);
244   }
245 
246   // Whether input and output buffers are aliased if the associated parameter is
247   // passed-through XLA modules without being changed.
alias_passthrough_params()248   bool alias_passthrough_params() const { return alias_passthrough_params_; }
set_alias_passthrough_params(bool alias_passthrough_params)249   void set_alias_passthrough_params(bool alias_passthrough_params) {
250     alias_passthrough_params_ = alias_passthrough_params;
251   }
252 
content_aware_computation_sorting()253   bool content_aware_computation_sorting() const {
254     return content_aware_computation_sorting_;
255   }
set_content_aware_computation_sorting(bool content_aware_computation_sorting)256   void set_content_aware_computation_sorting(
257       bool content_aware_computation_sorting) {
258     content_aware_computation_sorting_ = content_aware_computation_sorting;
259   }
260 
fusion_config_collection()261   FusionConfigCollection fusion_config_collection() const {
262     return fusion_config_collection_;
263   }
set_fusion_config_collection(FusionConfigCollection fusion_config_collection)264   void set_fusion_config_collection(
265       FusionConfigCollection fusion_config_collection) {
266     fusion_config_collection_ = fusion_config_collection;
267   }
268 
fusion_config()269   const std::vector<std::vector<bool>>& fusion_config() const {
270     return fusion_config_;
271   }
mutable_fusion_config()272   std::vector<std::vector<bool>>* mutable_fusion_config() {
273     return &fusion_config_;
274   }
275 
dot_config()276   const absl::flat_hash_map<std::string, std::vector<int64_t>>& dot_config()
277       const {
278     return dot_config_;
279   }
280 
mutable_dot_config()281   absl::flat_hash_map<std::string, std::vector<int64_t>>* mutable_dot_config() {
282     return &dot_config_;
283   }
284 
layout_config()285   const std::vector<std::vector<std::vector<int64_t>>>& layout_config() const {
286     return layout_config_;
287   }
288 
mutable_layout_config()289   std::vector<std::vector<std::vector<int64_t>>>* mutable_layout_config() {
290     return &layout_config_;
291   }
292 
phase_ordering_config()293   const std::vector<std::vector<bool>>& phase_ordering_config() const {
294     return phase_ordering_config_;
295   }
296 
mutable_phase_ordering_config()297   std::vector<std::vector<bool>>* mutable_phase_ordering_config() {
298     return &phase_ordering_config_;
299   }
300 
flag_config()301   const absl::flat_hash_map<std::string, std::string>& flag_config() const {
302     return flag_config_;
303   }
304 
mutable_flag_config()305   absl::flat_hash_map<std::string, std::string>* mutable_flag_config() {
306     return &flag_config_;
307   }
308 
phase_index()309   const int phase_index() const { return phase_index_; }
set_phase_index(const int phase_index)310   void set_phase_index(const int phase_index) { phase_index_ = phase_index; }
311 
set_allow_spmd_sharding_propagation_to_output(bool allow_spmd_sharding_propagation_to_output)312   void set_allow_spmd_sharding_propagation_to_output(
313       bool allow_spmd_sharding_propagation_to_output) {
314     allow_spmd_sharding_propagation_to_output_ =
315         allow_spmd_sharding_propagation_to_output;
316   }
allow_spmd_sharding_propagation_to_output()317   bool allow_spmd_sharding_propagation_to_output() const {
318     return allow_spmd_sharding_propagation_to_output_;
319   }
320 
memory_space_assignment_config()321   const std::vector<uint64_t>& memory_space_assignment_config() const {
322     return memory_space_assignment_config_;
323   }
324 
mutable_memory_space_assignment_config()325   std::vector<uint64_t>* mutable_memory_space_assignment_config() {
326     return &memory_space_assignment_config_;
327   }
328 
GetAnalysisAllowance(absl::string_view pass_name)329   int64_t GetAnalysisAllowance(absl::string_view pass_name) const {
330     auto it = analysis_allowance_map_.find(pass_name);
331     if (it == analysis_allowance_map_.end()) {
332       return -1;
333     }
334     return (*it).second;
335   }
336 
SetAnalysisAllowance(absl::string_view pass_name,int64_t allowance)337   void SetAnalysisAllowance(absl::string_view pass_name, int64_t allowance) {
338     analysis_allowance_map_[pass_name] = allowance;
339   }
340 
matrix_unit_operand_precision()341   PrecisionConfig::Precision matrix_unit_operand_precision() const {
342     return matrix_unit_operand_precision_;
343   }
set_matrix_unit_operand_precision(PrecisionConfig::Precision matrix_unit_operand_precision)344   void set_matrix_unit_operand_precision(
345       PrecisionConfig::Precision matrix_unit_operand_precision) {
346     matrix_unit_operand_precision_ = matrix_unit_operand_precision;
347   }
348 
349  private:
350   // If you add new members, be sure to update compilation_cache_key.
351 
352   std::optional<ComputationLayout> entry_computation_layout_;
353 
354   // Module/graph-level seed handle.
355   uint64_t seed_ = 0;
356 
357   // Program id that identifies a set of program to be launched together.
358   int32_t launch_id_ = 0;
359 
360   // The number of replicas (data parallelism) to compile this binary for.
361   int64_t replica_count_ = 1;
362 
363   // The number of partitions (model parallelism) to compile this binary for.
364   int64_t num_partitions_ = 1;
365 
366   // Whether to broadcast args across all replicas. One entry per arg.
367   std::vector<bool> param_requires_broadcast_via_collectives_;
368 
369   // Whether to use SPMD (true) or MPMD (false) when num_partitions_ > 0 and XLA
370   // needs to partition the module.
371   bool use_spmd_partitioning_ = false;
372 
373   // Whether to automatically generate XLA shardings for SPMD partitioner.
374   bool use_auto_spmd_partitioning_ = false;
375 
376   // Mesh shape and mesh ids used by auto spmd partitioning.
377   std::vector<int64_t> auto_spmd_partitioning_mesh_shape_;
378 
379   std::vector<int64_t> auto_spmd_partitioning_mesh_ids_;
380 
381   // If enabled, deduplicate equivalent hlos into function calls to reduce code
382   // size.
383   bool deduplicate_hlo_ = false;
384 
385   // The target maximum parallelism at which to partition HLOs for parallel
386   // execution on the CPU backend.
387   int64_t intra_op_parallelism_threads_ = -1;
388 
389   std::string device_type_;
390 
391   DebugOptions debug_options_;
392 
393   // Compile-time known device assignment.
394   std::optional<DeviceAssignment> static_device_assignment_;
395 
396   std::vector<ShardableValueUpdatePair> shardable_value_update_pairs_;
397 
398   bool alias_passthrough_params_ = false;
399 
400   bool content_aware_computation_sorting_ = true;
401 
402   FusionConfigCollection fusion_config_collection_ =
403       FusionConfigCollection::kOff;
404 
405   // TODO(b/155665133): Consolidate fusion, dot, and layout config into a proto
406   // similar to backend config.
407 
408   // Custom fusion configuration, where fusion_config_[c][v] control if node v
409   // in computation c must be fused to all its consumers (true) or not (false).
410   std::vector<std::vector<bool>> fusion_config_;
411 
412   // Custom dot canonicalization configuration, where dot_config_[v] control
413   // how to convert dot operation named 'v' to convolution.
414   absl::flat_hash_map<std::string, std::vector<int64_t>> dot_config_;
415 
416   // Layout configuration, where layout_config_[v][i] controls the layout
417   // decision i of operation v.
418   std::vector<std::vector<std::vector<int64_t>>> layout_config_;
419 
420   // Memory Space Assignment configuration, where
421   // memory_space_assignment_config_ controls the order of buffer intervals
422   // of this hlo module.
423   std::vector<uint64_t> memory_space_assignment_config_;
424 
425   // Phase ordering configuration, where phase_ordering_config[v][i] controls
426   // whether a specific pass with index i (e.g. 0 = DCE, 1 = CSE, etc.) is
427   // inserted after pass v in pipeline. See tuning::PhaseOrderingConfig for
428   // details on what indices (i) correspond to which passes.
429   std::vector<std::vector<bool>> phase_ordering_config_;
430   // Index (v) corresponding to current passes being added for phase ordering.
431   // This is the variable that stores state to allow us to use the same
432   // config across functions during compilation.
433   int phase_index_ = 0;
434 
435   // Flag configuration to use instead of global flags. This allows multiple
436   // HLO modules to be compiled in parallel with different flag values.
437   absl::flat_hash_map<std::string, std::string> flag_config_;
438 
439   // Allows sharding propagation to propagate to the outputs. This changes the
440   // output shape of the computation (which is undesirable), but it can be used
441   // to allow to run partial compilation to determine what would be the output
442   // sharding of a computation if XLA would be allowed to propagate the sharding
443   // which can be used by higher level framework as a way to query intermediate
444   // sharding of operations when multiple computation would be chained and
445   // merged together.
446   bool allow_spmd_sharding_propagation_to_output_ = false;
447 
448   // Each Hlo analysis is allowed at least a constant number of
449   // abstract cost units, before it is considered for early termination.
450   absl::flat_hash_map<absl::string_view, int64_t> analysis_allowance_map_;
451 
452   PrecisionConfig::Precision matrix_unit_operand_precision_ =
453       PrecisionConfig::DEFAULT;
454 };
455 
456 }  // namespace xla
457 
458 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_
459