xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_replication_analysis.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REPLICATION_ANALYSIS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REPLICATION_ANALYSIS_H_
18 
19 #include <string>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_module.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 
28 namespace xla {
29 
30 // An HLO pass that determines whether each instruction in the module outputs
31 // the same value across replicas or across partitions (depending on the value
32 // `cross_partition_spmd`). It propagates sources of replicated values to
33 // the rest of the module, where sources include cross-replica-sum, annotated
34 // entry parameters, and constants.
35 class HloReplicationAnalysis {
36  public:
37   // Runs the analysis on module and returns the result or an error.
38   static StatusOr<std::unique_ptr<HloReplicationAnalysis>> Run(
39       const HloModule* module, bool cross_partition_spmd);
40 
41   // Same as above, but the caller can provide additional annotations: a set of
42   // while loops that are known to have the same iteration counts across
43   // replicas or partitions.
44   static StatusOr<std::unique_ptr<HloReplicationAnalysis>> Run(
45       const HloModule* module, bool cross_partition_spmd,
46       const absl::flat_hash_set<const HloInstruction*>*
47           loops_known_with_same_iterations);
48 
49   // Same as above but supports finding partially replicated HLOs.
50   static StatusOr<std::unique_ptr<HloReplicationAnalysis>>
51   RunWithPartialReplication(const HloModule* module, bool cross_partition_spmd);
52 
53   // Returns if the HLO instruction outputs the same value (i.e., replicated) at
54   // the given index across all replicas or partitions.
55   bool HloInstructionIsReplicatedAt(const HloInstruction* inst,
56                                     const ShapeIndex& index) const;
57 
58   bool HloInstructionIsReplicatedAt(
59       const HloInstruction* inst, const ShapeIndex& index,
60       absl::Span<const ReplicaGroup> replica_groups) const;
61 
62  private:
63   // A data structure that represents how an HLO is replicated among a set of
64   // devices. Device ID could be either partition ID or replica ID.
65   // We represent partial replication by grouping devices that have the same
66   // value into the same set.
67   class HloReplication {
68    public:
69     static HloReplication ReplicatedOnAllDevices();
70     static HloReplication UniqueOnAllDevices();
71     static HloReplication PartiallyReplicated(
72         absl::Span<const absl::Span<const int64_t>> device_sets);
73     HloReplication();
74     HloReplication(const HloReplication& other) = default;
75     HloReplication(HloReplication&& other) = default;
76     HloReplication& operator=(HloReplication&& other) = default;
77     HloReplication Merge(const HloReplication& other) const;
78     bool Equal(const HloReplication& other) const;
79     bool IsReplicatedOnAllDevices() const;
80     bool IsUniqueOnAllDevices() const;
81     bool IsReplicatedWithinSubgroup(absl::Span<const int64_t> device_ids) const;
82     std::string ToString() const;
83 
84    private:
85     enum class State {
86       kReplicatedOnAllDevices = 0,
87       kUniqueOnAllDevices = 1,
88       kPartiallyReplicated = 2,
89     };
90     explicit HloReplication(State state,
91                             absl::Span<const int64_t> device_set_root);
92     State state_;
93     // Empty if state_ is kReplicatedOnAllDevices or kUniqueOnAllDevices.
94     // Otherwise, its size equals to the number of devices (either partitions
95     // or replications). Maps each device ID to the smallest device ID in the
96     // set.
97     std::vector<int64_t> device_set_root_;
98   };
99 
100   static HloReplication DetermineHloInstructionIsReplicated(
101       const HloInstruction* hlo, const ShapeIndex& index,
102       bool cross_partition_spmd,
103       const absl::flat_hash_map<const HloInstruction*,
104                                 ShapeTree<HloReplication>>& hlo_replication,
105       bool support_partial_replication);
106 
HloReplicationAnalysis(const HloModule * module,bool cross_partition_spmd,const absl::flat_hash_set<const HloInstruction * > * loops_known_with_same_iterations,bool support_partial_replication)107   HloReplicationAnalysis(const HloModule* module, bool cross_partition_spmd,
108                          const absl::flat_hash_set<const HloInstruction*>*
109                              loops_known_with_same_iterations,
110                          bool support_partial_replication)
111       : module_(module),
112         cross_partition_spmd_(cross_partition_spmd),
113         loops_known_with_same_iterations_(*loops_known_with_same_iterations),
114         support_partial_replication_(support_partial_replication) {}
115 
116   // Computes hlo_replication_.
117   void ComputeHloReplication();
118 
119   // A helper function to recursively compute hlo_replication on a computation.
120   // Returns whether hlo_replication_ is changed.
121   bool ComputeHloReplicationOnComputation(const HloComputation* computation,
122                                           bool mark_everything_not_replicated);
123 
124   const HloModule* module_;
125 
126   // If true, run this replication analysis for replicated values across
127   // partitions (not across replicas) on an SPMD partitioned module. This means
128   // that HloInstructionIsReplicatedAt() returns true if the value is identical
129   // across partitions for each replica. The module-level parameter and root
130   // instructions may have HloSharding attributes that indicate whether values
131   // are identical across partitions.
132   //
133   // If false, HloReplicationAnalysis runs across replicas.
134   bool cross_partition_spmd_;
135 
136   // A set of while loops that are known to have the same iteration counts
137   // across replicas or partitions. This is provided by the caller as additional
138   // annotations.
139   const absl::flat_hash_set<const HloInstruction*>&
140       loops_known_with_same_iterations_;
141 
142   const bool support_partial_replication_;
143 
144   // A map from each analyzed HLO instruction to a shape tree that represents
145   // whether the instruction outputs the same value across replicas or
146   // partitions at each shape index.
147   absl::flat_hash_map<const HloInstruction*, ShapeTree<HloReplication>>
148       hlo_replication_;
149 };
150 
151 }  // namespace xla
152 
153 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REPLICATION_ANALYSIS_H_
154