1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_ 18 19 #include <optional> 20 #include <string> 21 #include <vector> 22 23 #include "absl/container/flat_hash_map.h" 24 #include "absl/strings/string_view.h" 25 #include "tensorflow/compiler/xla/debug_options_flags.h" 26 #include "tensorflow/compiler/xla/service/computation_layout.h" 27 #include "tensorflow/compiler/xla/service/computation_placer.h" 28 #include "tensorflow/compiler/xla/types.h" 29 #include "tensorflow/compiler/xla/xla.pb.h" 30 #include "tensorflow/compiler/xla/xla_data.pb.h" 31 32 namespace xla { 33 34 enum class FusionConfigCollection { 35 kOff, // Do not collect configuration. 36 kPerEdge, // Collect per-edge configuration. 37 kPerNode, // Collect per-node configuration. 38 }; 39 40 // This class gathers all settings and values which affect the compiled 41 // executable outside of the HLO code itself. This include layouts of inputs and 42 // outputs to the module and settings such as HLO profiling. Together the 43 // HloModule and HloModuleConfig unambiguously determine a particular 44 // executable. 45 class HloModuleConfig { 46 public: 47 // Represents a pair of input and output of the entry computation that can be 48 // considered as the original and updated values of a variable maintained by 49 // the caller, and that can be transparently sharded by XLA as an internal 50 // optimization. If sharded, XLA will create separate sharding/unsharding 51 // programs, and the caller is responsible to call the XLA-generated 52 // sharding/unsharding programs before and after the sharded main program. 53 // 54 // If the variable is not updated and there is not a corresponding output, use 55 // {-1} as the output_shape_index. 56 // 57 // The sharding/unsharding programs will include all the input/output pairs in 58 // shardable_value_update_pairs() as a flat tuple in their inputs/outputs, 59 // sorted by (input_parameter_number, parameter_shape_index). 60 // 61 // A typical usage pattern is to shard the variables first, then repeatedly 62 // invoke the main program, and finally invoke the unsharding program before 63 // they are used in full-shape. 64 struct ShardableValueUpdatePair { 65 int64_t input_parameter_number; 66 ShapeIndex parameter_shape_index; 67 ShapeIndex output_shape_index; 68 }; 69 70 // A configuration can be created either with, or without an entry 71 // ComputationLayout. The default ctor creates it without -- in this case 72 // accessing entry_computation_layout will CHECK-fail. The ctor accepting a 73 // ProgramShape creates a computation layout using this shape. 74 // The layouts in the ProgramShape will be reset to default unless 75 // ignore_layouts is set to false. HloModuleConfig()76 HloModuleConfig() { debug_options_ = DefaultDebugOptionsIgnoringFlags(); } 77 78 explicit HloModuleConfig(const ProgramShape& program_shape, 79 bool ignore_layouts = true); 80 81 explicit HloModuleConfig(ComputationLayout entry_computation_layout); 82 83 // Checks if this config has an entry computation layout already. has_entry_computation_layout()84 bool has_entry_computation_layout() const { 85 return entry_computation_layout_.has_value(); 86 } 87 88 // Sets the entry_computation_layout's parameter and result shapes for this 89 // config, according to the given program shape. The parameters and result 90 // are set to default layout. 91 void SetDefaultComputationLayout(const ProgramShape& program_shape); 92 93 // Same as above but if the given program contains layout for parameters or 94 // result, the entry_computation_layout's layout is updated accordingly. 95 void SetComputationLayoutIfExists(const ProgramShape& program_shape); 96 97 // Returns a constant reference to the layout of the entry computation. 98 // Assumes the layout was set. entry_computation_layout()99 const ComputationLayout& entry_computation_layout() const { 100 CHECK(entry_computation_layout_.has_value()); 101 return *entry_computation_layout_; 102 } 103 104 // Returns a mutable pointer to the layout of the entry computation. 105 // Assumes the layout was set. mutable_entry_computation_layout()106 ComputationLayout* mutable_entry_computation_layout() { 107 CHECK(entry_computation_layout_.has_value()); 108 return &(*entry_computation_layout_); 109 } 110 111 // Clears the entry computation layout. clear_entry_computation_layout()112 void clear_entry_computation_layout() { 113 entry_computation_layout_ = std::nullopt; 114 } 115 116 // Returns whether to enable HLO-level profiling. hlo_profiling_enabled()117 bool hlo_profiling_enabled() const { 118 return debug_options_.xla_hlo_profile(); 119 } 120 cpu_traceme_enabled()121 bool cpu_traceme_enabled() const { 122 return debug_options_.xla_cpu_enable_xprof_traceme(); 123 } 124 125 // Sets/returns the module seed set during execution. set_seed(uint64_t seed)126 void set_seed(uint64_t seed) { seed_ = seed; } seed()127 uint64_t seed() const { return seed_; } 128 129 // Set the launch id of the program. Launch id identifies a set of programs 130 // that should be launched together. set_launch_id(uint64_t launch_id)131 void set_launch_id(uint64_t launch_id) { launch_id_ = launch_id; } 132 launch_id()133 int32_t launch_id() const { return launch_id_; } 134 set_replica_count(int64_t replica_count)135 void set_replica_count(int64_t replica_count) { 136 replica_count_ = replica_count; 137 } replica_count()138 int64_t replica_count() const { return replica_count_; } 139 set_num_partitions(int64_t num_partitions)140 void set_num_partitions(int64_t num_partitions) { 141 num_partitions_ = num_partitions; 142 } num_partitions()143 int64_t num_partitions() const { return num_partitions_; } 144 param_requires_broadcast_via_collectives()145 const std::vector<bool> param_requires_broadcast_via_collectives() const { 146 return param_requires_broadcast_via_collectives_; 147 } set_param_requires_broadcast_via_collectives(const std::vector<bool> require_broadcast)148 void set_param_requires_broadcast_via_collectives( 149 const std::vector<bool> require_broadcast) { 150 param_requires_broadcast_via_collectives_ = std::move(require_broadcast); 151 } 152 set_use_spmd_partitioning(bool use_spmd_partitioning)153 void set_use_spmd_partitioning(bool use_spmd_partitioning) { 154 use_spmd_partitioning_ = use_spmd_partitioning; 155 } use_spmd_partitioning()156 bool use_spmd_partitioning() const { return use_spmd_partitioning_; } 157 set_use_auto_spmd_partitioning(bool use_auto_spmd_partitioning)158 void set_use_auto_spmd_partitioning(bool use_auto_spmd_partitioning) { 159 use_auto_spmd_partitioning_ = use_auto_spmd_partitioning; 160 if (use_auto_spmd_partitioning) { 161 // TODO(yuemmawang) Remove this warning once auto sharding is thoroughly 162 // tested with fleetwide models. 163 LOG(WARNING) << "Warning: Using auto_spmd_partitioning. It is " 164 "experimental and may " 165 "contain bugs!"; 166 LOG(INFO) << "Overwriting use_spmd_partitioning to true, because " 167 "use_auto_spmd_partitioning is true."; 168 set_use_spmd_partitioning(true); 169 } 170 } use_auto_spmd_partitioning()171 bool use_auto_spmd_partitioning() const { 172 return use_auto_spmd_partitioning_; 173 } 174 set_auto_spmd_partitioning_mesh_shape(std::vector<int64_t> mesh_shape)175 void set_auto_spmd_partitioning_mesh_shape(std::vector<int64_t> mesh_shape) { 176 auto_spmd_partitioning_mesh_shape_ = mesh_shape; 177 } auto_spmd_partitioning_mesh_shape()178 std::vector<int64_t> auto_spmd_partitioning_mesh_shape() const { 179 return auto_spmd_partitioning_mesh_shape_; 180 } 181 set_auto_spmd_partitioning_mesh_ids(std::vector<int64_t> mesh_ids)182 void set_auto_spmd_partitioning_mesh_ids(std::vector<int64_t> mesh_ids) { 183 auto_spmd_partitioning_mesh_ids_ = mesh_ids; 184 } auto_spmd_partitioning_mesh_ids()185 std::vector<int64_t> auto_spmd_partitioning_mesh_ids() const { 186 return auto_spmd_partitioning_mesh_ids_; 187 } 188 189 // If enabled, deduplicate equivalent hlos into function calls to reduce code 190 // size. set_deduplicate_hlo(bool deduplicate_hlo)191 void set_deduplicate_hlo(bool deduplicate_hlo) { 192 deduplicate_hlo_ = deduplicate_hlo; 193 } 194 set_device_type(const std::string & device_type)195 void set_device_type(const std::string& device_type) { 196 device_type_ = device_type; 197 } 198 deduplicate_hlo()199 bool deduplicate_hlo() const { return deduplicate_hlo_; } 200 201 // Return a string which unambiguously represents all the fields of this data 202 // structure. Used for generating a cache key for storing the compiled 203 // executable. 204 std::string compilation_cache_key() const; 205 device_type()206 std::string device_type() const { return device_type_; } 207 debug_options()208 const DebugOptions& debug_options() const { return debug_options_; } 209 set_debug_options(const DebugOptions & debug_options)210 void set_debug_options(const DebugOptions& debug_options) { 211 debug_options_ = debug_options; 212 } 213 214 // Sets/returns the number of intra op threads for this module. set_intra_op_parallelism_threads(const int intra_op_parallelism_threads)215 void set_intra_op_parallelism_threads( 216 const int intra_op_parallelism_threads) { 217 intra_op_parallelism_threads_ = intra_op_parallelism_threads; 218 } intra_op_parallelism_threads()219 int64_t intra_op_parallelism_threads() const { 220 return intra_op_parallelism_threads_; 221 } 222 223 // Checks if this config has a static device assignment. has_static_device_assignment()224 bool has_static_device_assignment() const { 225 return static_device_assignment_.has_value(); 226 } 227 228 // Getter and setter of the compile-time known device assignment. static_device_assignment()229 const DeviceAssignment& static_device_assignment() const { 230 CHECK(static_device_assignment_.has_value()); 231 return *static_device_assignment_; 232 } set_static_device_assignment(const DeviceAssignment & device_assignment)233 void set_static_device_assignment(const DeviceAssignment& device_assignment) { 234 static_device_assignment_ = device_assignment; 235 } 236 shardable_value_update_pairs()237 const std::vector<ShardableValueUpdatePair> shardable_value_update_pairs() 238 const { 239 return shardable_value_update_pairs_; 240 } set_shardable_value_update_pairs(std::vector<ShardableValueUpdatePair> pairs)241 void set_shardable_value_update_pairs( 242 std::vector<ShardableValueUpdatePair> pairs) { 243 shardable_value_update_pairs_ = std::move(pairs); 244 } 245 246 // Whether input and output buffers are aliased if the associated parameter is 247 // passed-through XLA modules without being changed. alias_passthrough_params()248 bool alias_passthrough_params() const { return alias_passthrough_params_; } set_alias_passthrough_params(bool alias_passthrough_params)249 void set_alias_passthrough_params(bool alias_passthrough_params) { 250 alias_passthrough_params_ = alias_passthrough_params; 251 } 252 content_aware_computation_sorting()253 bool content_aware_computation_sorting() const { 254 return content_aware_computation_sorting_; 255 } set_content_aware_computation_sorting(bool content_aware_computation_sorting)256 void set_content_aware_computation_sorting( 257 bool content_aware_computation_sorting) { 258 content_aware_computation_sorting_ = content_aware_computation_sorting; 259 } 260 fusion_config_collection()261 FusionConfigCollection fusion_config_collection() const { 262 return fusion_config_collection_; 263 } set_fusion_config_collection(FusionConfigCollection fusion_config_collection)264 void set_fusion_config_collection( 265 FusionConfigCollection fusion_config_collection) { 266 fusion_config_collection_ = fusion_config_collection; 267 } 268 fusion_config()269 const std::vector<std::vector<bool>>& fusion_config() const { 270 return fusion_config_; 271 } mutable_fusion_config()272 std::vector<std::vector<bool>>* mutable_fusion_config() { 273 return &fusion_config_; 274 } 275 dot_config()276 const absl::flat_hash_map<std::string, std::vector<int64_t>>& dot_config() 277 const { 278 return dot_config_; 279 } 280 mutable_dot_config()281 absl::flat_hash_map<std::string, std::vector<int64_t>>* mutable_dot_config() { 282 return &dot_config_; 283 } 284 layout_config()285 const std::vector<std::vector<std::vector<int64_t>>>& layout_config() const { 286 return layout_config_; 287 } 288 mutable_layout_config()289 std::vector<std::vector<std::vector<int64_t>>>* mutable_layout_config() { 290 return &layout_config_; 291 } 292 phase_ordering_config()293 const std::vector<std::vector<bool>>& phase_ordering_config() const { 294 return phase_ordering_config_; 295 } 296 mutable_phase_ordering_config()297 std::vector<std::vector<bool>>* mutable_phase_ordering_config() { 298 return &phase_ordering_config_; 299 } 300 flag_config()301 const absl::flat_hash_map<std::string, std::string>& flag_config() const { 302 return flag_config_; 303 } 304 mutable_flag_config()305 absl::flat_hash_map<std::string, std::string>* mutable_flag_config() { 306 return &flag_config_; 307 } 308 phase_index()309 const int phase_index() const { return phase_index_; } set_phase_index(const int phase_index)310 void set_phase_index(const int phase_index) { phase_index_ = phase_index; } 311 set_allow_spmd_sharding_propagation_to_output(bool allow_spmd_sharding_propagation_to_output)312 void set_allow_spmd_sharding_propagation_to_output( 313 bool allow_spmd_sharding_propagation_to_output) { 314 allow_spmd_sharding_propagation_to_output_ = 315 allow_spmd_sharding_propagation_to_output; 316 } allow_spmd_sharding_propagation_to_output()317 bool allow_spmd_sharding_propagation_to_output() const { 318 return allow_spmd_sharding_propagation_to_output_; 319 } 320 memory_space_assignment_config()321 const std::vector<uint64_t>& memory_space_assignment_config() const { 322 return memory_space_assignment_config_; 323 } 324 mutable_memory_space_assignment_config()325 std::vector<uint64_t>* mutable_memory_space_assignment_config() { 326 return &memory_space_assignment_config_; 327 } 328 GetAnalysisAllowance(absl::string_view pass_name)329 int64_t GetAnalysisAllowance(absl::string_view pass_name) const { 330 auto it = analysis_allowance_map_.find(pass_name); 331 if (it == analysis_allowance_map_.end()) { 332 return -1; 333 } 334 return (*it).second; 335 } 336 SetAnalysisAllowance(absl::string_view pass_name,int64_t allowance)337 void SetAnalysisAllowance(absl::string_view pass_name, int64_t allowance) { 338 analysis_allowance_map_[pass_name] = allowance; 339 } 340 matrix_unit_operand_precision()341 PrecisionConfig::Precision matrix_unit_operand_precision() const { 342 return matrix_unit_operand_precision_; 343 } set_matrix_unit_operand_precision(PrecisionConfig::Precision matrix_unit_operand_precision)344 void set_matrix_unit_operand_precision( 345 PrecisionConfig::Precision matrix_unit_operand_precision) { 346 matrix_unit_operand_precision_ = matrix_unit_operand_precision; 347 } 348 349 private: 350 // If you add new members, be sure to update compilation_cache_key. 351 352 std::optional<ComputationLayout> entry_computation_layout_; 353 354 // Module/graph-level seed handle. 355 uint64_t seed_ = 0; 356 357 // Program id that identifies a set of program to be launched together. 358 int32_t launch_id_ = 0; 359 360 // The number of replicas (data parallelism) to compile this binary for. 361 int64_t replica_count_ = 1; 362 363 // The number of partitions (model parallelism) to compile this binary for. 364 int64_t num_partitions_ = 1; 365 366 // Whether to broadcast args across all replicas. One entry per arg. 367 std::vector<bool> param_requires_broadcast_via_collectives_; 368 369 // Whether to use SPMD (true) or MPMD (false) when num_partitions_ > 0 and XLA 370 // needs to partition the module. 371 bool use_spmd_partitioning_ = false; 372 373 // Whether to automatically generate XLA shardings for SPMD partitioner. 374 bool use_auto_spmd_partitioning_ = false; 375 376 // Mesh shape and mesh ids used by auto spmd partitioning. 377 std::vector<int64_t> auto_spmd_partitioning_mesh_shape_; 378 379 std::vector<int64_t> auto_spmd_partitioning_mesh_ids_; 380 381 // If enabled, deduplicate equivalent hlos into function calls to reduce code 382 // size. 383 bool deduplicate_hlo_ = false; 384 385 // The target maximum parallelism at which to partition HLOs for parallel 386 // execution on the CPU backend. 387 int64_t intra_op_parallelism_threads_ = -1; 388 389 std::string device_type_; 390 391 DebugOptions debug_options_; 392 393 // Compile-time known device assignment. 394 std::optional<DeviceAssignment> static_device_assignment_; 395 396 std::vector<ShardableValueUpdatePair> shardable_value_update_pairs_; 397 398 bool alias_passthrough_params_ = false; 399 400 bool content_aware_computation_sorting_ = true; 401 402 FusionConfigCollection fusion_config_collection_ = 403 FusionConfigCollection::kOff; 404 405 // TODO(b/155665133): Consolidate fusion, dot, and layout config into a proto 406 // similar to backend config. 407 408 // Custom fusion configuration, where fusion_config_[c][v] control if node v 409 // in computation c must be fused to all its consumers (true) or not (false). 410 std::vector<std::vector<bool>> fusion_config_; 411 412 // Custom dot canonicalization configuration, where dot_config_[v] control 413 // how to convert dot operation named 'v' to convolution. 414 absl::flat_hash_map<std::string, std::vector<int64_t>> dot_config_; 415 416 // Layout configuration, where layout_config_[v][i] controls the layout 417 // decision i of operation v. 418 std::vector<std::vector<std::vector<int64_t>>> layout_config_; 419 420 // Memory Space Assignment configuration, where 421 // memory_space_assignment_config_ controls the order of buffer intervals 422 // of this hlo module. 423 std::vector<uint64_t> memory_space_assignment_config_; 424 425 // Phase ordering configuration, where phase_ordering_config[v][i] controls 426 // whether a specific pass with index i (e.g. 0 = DCE, 1 = CSE, etc.) is 427 // inserted after pass v in pipeline. See tuning::PhaseOrderingConfig for 428 // details on what indices (i) correspond to which passes. 429 std::vector<std::vector<bool>> phase_ordering_config_; 430 // Index (v) corresponding to current passes being added for phase ordering. 431 // This is the variable that stores state to allow us to use the same 432 // config across functions during compilation. 433 int phase_index_ = 0; 434 435 // Flag configuration to use instead of global flags. This allows multiple 436 // HLO modules to be compiled in parallel with different flag values. 437 absl::flat_hash_map<std::string, std::string> flag_config_; 438 439 // Allows sharding propagation to propagate to the outputs. This changes the 440 // output shape of the computation (which is undesirable), but it can be used 441 // to allow to run partial compilation to determine what would be the output 442 // sharding of a computation if XLA would be allowed to propagate the sharding 443 // which can be used by higher level framework as a way to query intermediate 444 // sharding of operations when multiple computation would be chained and 445 // merged together. 446 bool allow_spmd_sharding_propagation_to_output_ = false; 447 448 // Each Hlo analysis is allowed at least a constant number of 449 // abstract cost units, before it is considered for early termination. 450 absl::flat_hash_map<absl::string_view, int64_t> analysis_allowance_map_; 451 452 PrecisionConfig::Precision matrix_unit_operand_precision_ = 453 PrecisionConfig::DEFAULT; 454 }; 455 456 } // namespace xla 457 458 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_ 459