1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REPLICATION_ANALYSIS_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REPLICATION_ANALYSIS_H_ 18 19 #include <string> 20 21 #include "absl/container/flat_hash_map.h" 22 #include "absl/container/flat_hash_set.h" 23 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 24 #include "tensorflow/compiler/xla/service/hlo_module.h" 25 #include "tensorflow/compiler/xla/statusor.h" 26 #include "tensorflow/compiler/xla/xla_data.pb.h" 27 28 namespace xla { 29 30 // An HLO pass that determines whether each instruction in the module outputs 31 // the same value across replicas or across partitions (depending on the value 32 // `cross_partition_spmd`). It propagates sources of replicated values to 33 // the rest of the module, where sources include cross-replica-sum, annotated 34 // entry parameters, and constants. 35 class HloReplicationAnalysis { 36 public: 37 // Runs the analysis on module and returns the result or an error. 38 static StatusOr<std::unique_ptr<HloReplicationAnalysis>> Run( 39 const HloModule* module, bool cross_partition_spmd); 40 41 // Same as above, but the caller can provide additional annotations: a set of 42 // while loops that are known to have the same iteration counts across 43 // replicas or partitions. 44 static StatusOr<std::unique_ptr<HloReplicationAnalysis>> Run( 45 const HloModule* module, bool cross_partition_spmd, 46 const absl::flat_hash_set<const HloInstruction*>* 47 loops_known_with_same_iterations); 48 49 // Same as above but supports finding partially replicated HLOs. 50 static StatusOr<std::unique_ptr<HloReplicationAnalysis>> 51 RunWithPartialReplication(const HloModule* module, bool cross_partition_spmd); 52 53 // Returns if the HLO instruction outputs the same value (i.e., replicated) at 54 // the given index across all replicas or partitions. 55 bool HloInstructionIsReplicatedAt(const HloInstruction* inst, 56 const ShapeIndex& index) const; 57 58 bool HloInstructionIsReplicatedAt( 59 const HloInstruction* inst, const ShapeIndex& index, 60 absl::Span<const ReplicaGroup> replica_groups) const; 61 62 private: 63 // A data structure that represents how an HLO is replicated among a set of 64 // devices. Device ID could be either partition ID or replica ID. 65 // We represent partial replication by grouping devices that have the same 66 // value into the same set. 67 class HloReplication { 68 public: 69 static HloReplication ReplicatedOnAllDevices(); 70 static HloReplication UniqueOnAllDevices(); 71 static HloReplication PartiallyReplicated( 72 absl::Span<const absl::Span<const int64_t>> device_sets); 73 HloReplication(); 74 HloReplication(const HloReplication& other) = default; 75 HloReplication(HloReplication&& other) = default; 76 HloReplication& operator=(HloReplication&& other) = default; 77 HloReplication Merge(const HloReplication& other) const; 78 bool Equal(const HloReplication& other) const; 79 bool IsReplicatedOnAllDevices() const; 80 bool IsUniqueOnAllDevices() const; 81 bool IsReplicatedWithinSubgroup(absl::Span<const int64_t> device_ids) const; 82 std::string ToString() const; 83 84 private: 85 enum class State { 86 kReplicatedOnAllDevices = 0, 87 kUniqueOnAllDevices = 1, 88 kPartiallyReplicated = 2, 89 }; 90 explicit HloReplication(State state, 91 absl::Span<const int64_t> device_set_root); 92 State state_; 93 // Empty if state_ is kReplicatedOnAllDevices or kUniqueOnAllDevices. 94 // Otherwise, its size equals to the number of devices (either partitions 95 // or replications). Maps each device ID to the smallest device ID in the 96 // set. 97 std::vector<int64_t> device_set_root_; 98 }; 99 100 static HloReplication DetermineHloInstructionIsReplicated( 101 const HloInstruction* hlo, const ShapeIndex& index, 102 bool cross_partition_spmd, 103 const absl::flat_hash_map<const HloInstruction*, 104 ShapeTree<HloReplication>>& hlo_replication, 105 bool support_partial_replication); 106 HloReplicationAnalysis(const HloModule * module,bool cross_partition_spmd,const absl::flat_hash_set<const HloInstruction * > * loops_known_with_same_iterations,bool support_partial_replication)107 HloReplicationAnalysis(const HloModule* module, bool cross_partition_spmd, 108 const absl::flat_hash_set<const HloInstruction*>* 109 loops_known_with_same_iterations, 110 bool support_partial_replication) 111 : module_(module), 112 cross_partition_spmd_(cross_partition_spmd), 113 loops_known_with_same_iterations_(*loops_known_with_same_iterations), 114 support_partial_replication_(support_partial_replication) {} 115 116 // Computes hlo_replication_. 117 void ComputeHloReplication(); 118 119 // A helper function to recursively compute hlo_replication on a computation. 120 // Returns whether hlo_replication_ is changed. 121 bool ComputeHloReplicationOnComputation(const HloComputation* computation, 122 bool mark_everything_not_replicated); 123 124 const HloModule* module_; 125 126 // If true, run this replication analysis for replicated values across 127 // partitions (not across replicas) on an SPMD partitioned module. This means 128 // that HloInstructionIsReplicatedAt() returns true if the value is identical 129 // across partitions for each replica. The module-level parameter and root 130 // instructions may have HloSharding attributes that indicate whether values 131 // are identical across partitions. 132 // 133 // If false, HloReplicationAnalysis runs across replicas. 134 bool cross_partition_spmd_; 135 136 // A set of while loops that are known to have the same iteration counts 137 // across replicas or partitions. This is provided by the caller as additional 138 // annotations. 139 const absl::flat_hash_set<const HloInstruction*>& 140 loops_known_with_same_iterations_; 141 142 const bool support_partial_replication_; 143 144 // A map from each analyzed HLO instruction to a shape tree that represents 145 // whether the instruction outputs the same value across replicas or 146 // partitions at each shape index. 147 absl::flat_hash_map<const HloInstruction*, ShapeTree<HloReplication>> 148 hlo_replication_; 149 }; 150 151 } // namespace xla 152 153 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REPLICATION_ANALYSIS_H_ 154