xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_sharding_util.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_
18 
19 #include <map>
20 #include <optional>
21 #include <string>
22 #include <vector>
23 
24 #include "absl/container/inlined_vector.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
30 
31 namespace xla {
32 namespace hlo_sharding_util {
33 
34 struct GatherParallelDims {
35   absl::InlinedVector<int64_t, 1> indices_parallel_dims;
36   absl::InlinedVector<int64_t, 1> operand_parallel_dims;
37   std::vector<int64_t> index_parallel_in_dim;
38 };
39 
40 // Returns true if the lhs sharding is preferable over the rhs sharding.
41 // The most specific sharding is tile maximal followed by single device tile
42 // maximal and finally replicated. This order aims to primarily reduce memory
43 // usage and secondly reduce total compute.
44 // Note: This does NOT provide a total ordering as we can have 2 different
45 // sharding with same preference level.
46 bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs);
47 
48 // Tries to refine `to_merge` by combining with `old`. Returns if the final
49 // `to_merge` is more specific than `old`.
50 bool MergeSharding(const HloSharding& old, HloSharding* to_merge,
51                    bool may_combine_partial_sharding);
52 
53 // Merges `to_merge` into `dst` only if they are compatible, and the merged
54 // sharding has >= minimum_tiles tiles. Returns if merging happened.
55 bool MergeShardingIfCompatible(const HloSharding& to_merge,
56                                int64_t minimum_tiles, HloSharding* dst);
57 
58 // Given a map<device, occurrence_count>, selects the device with higher
59 // occurrence count (if any). If top_count in not nullptr, it will receive the
60 // count of the dominant device returned.
61 std::optional<int64_t> SelectDominantDevice(
62     const std::map<int64_t, int64_t>& device_map, int64_t* top_count);
63 
64 // Assigns all the instructions of a computation, to a given device.
65 // This API does not recurse into called computations, and does not assign
66 // instructions which already have sharding.
67 void AssignComputationDevice(HloComputation* computation, int64_t device);
68 
69 // Given an instruction container, returns the device which is most commonly
70 // occurring among the instructions.
71 std::optional<int64_t> GetMostOccurringDevice(
72     absl::Span<HloInstruction* const> instructions);
73 
74 // Given a set of computations, tries to extract the dominant device. A device
75 // is dominant if the combined occurrence among all the instructions of the
76 // input computations, is greater/equal than/to dominant_factor (real number
77 // from 0 to 1).
78 // This API does not recurse into called computations.
79 // If no device exists that satisfies the condition, the returned optional will
80 // hold no value.
81 std::optional<int64_t> GetDominantDevice(
82     absl::Span<HloComputation* const> computations, double dominant_factor);
83 
84 // Returns the HloSharding with the tile dimensions and tile assignment
85 // transposed based on the specified dimension numbers. In case of a tile
86 // maximal sharding returns the original sharding.
87 HloSharding TransposeSharding(const HloSharding& sharding,
88                               absl::Span<const int64_t> dimensions);
89 
90 // Returns the HloSharding with the tile shape reshaped based on the source and
91 // target shapes and the tile assignment adjusted to correspond to the new tile
92 // shape or std::nullopt if the resulting reshape would create an invalid
93 // sharding (non continuous or non uniformly sized tiles). In case of a tile
94 // maximal sharding returns the original sharding.
95 std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
96                                            const Shape& target_shape,
97                                            const HloSharding& sharding);
98 
99 // Returns the HloSharding with the tile dimensions and tile assignment
100 // reversed based on the specified dimension numbers. In case of a tile
101 // maximal sharding returns the original sharding.
102 HloSharding ReverseSharding(const HloSharding& sharding,
103                             absl::Span<const int64_t> dimensions);
104 
105 // Returns a sharding tiled on unique dimension dim by reshaping the tile
106 // assignment of the sharding argument. Only dimensions in the dims span
107 // argument are considered for reshaping, the others are ignored.
108 // Assumptions: sharding is tile sharded, and dim must be included in dims.
109 HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64_t dim,
110                                    absl::Span<const int64_t> dims);
111 
112 // Returns true if the provided module includes one or more instructions with
113 // a tile sharding.
114 bool ContainsTileSharding(const HloModule& module);
115 
116 // Returns the preferred output sharding for a gather op based on the sharding
117 // of the indces.
118 HloSharding GatherOutputSharding(const HloSharding& index_sharding,
119                                  const HloInstruction* hlo);
120 
121 // Returns the preferred index sharding for a gather op based on the sharding
122 // of the output.
123 HloSharding GatherIndexSharding(const HloSharding& output_sharding,
124                                 const HloInstruction* hlo);
125 
126 // Returns a new HloSharding for a gather op so that only non offset dimensions
127 // are sharded. Assume "result" is returned by this function. It is ensured that
128 // "GetIndexSharding(result, hlo)" will have the same number of elements as
129 // "result".
130 HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo);
131 
132 // Returns the preferred index sharding for a scatter op based on the sharding
133 // of the data.
134 HloSharding ScatterIndexSharding(const HloSharding& data_sharding,
135                                  const HloScatterInstruction* scatter);
136 
137 // Returns the preferred data sharding for a scatter op based on the sharding
138 // of the index.
139 HloSharding ScatterDataSharding(const HloSharding& index_sharding,
140                                 const HloScatterInstruction* scatter);
141 
142 // Returns a new index sharding for a scatter op so that we only shard on first
143 // "number of scatter_window_dims" dimensions. Assume "result" is returned by
144 // this function. It is ensured that "ScatterDataSharding(result, hlo)" will
145 // have the same number of elements as "result".
146 HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding,
147                                           const HloScatterInstruction& scatter);
148 
149 // Returns a new data sharding for a scatter op so that we only shard on
150 // scatter_window_dims. Assume "result" is returned by this function. It is
151 // ensured that "ScatterIndexSharding(result, hlo)" will have the same number of
152 // elements as "result".
153 HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding,
154                                          const HloScatterInstruction& scatter);
155 
156 // Returns an output sharding of gather by passing through the data operand's
157 // sharding.
158 std::optional<HloSharding> GatherOutputShardingFromDataOperand(
159     const HloSharding& data_operand_sharding, const HloInstruction& hlo,
160     absl::Span<const int64_t> slice_sizes, const Shape& output_shape,
161     const Shape& operand_shape);
162 
163 // Returns a data operand sharding of gather by passing through the output's
164 // sharding.
165 std::optional<HloSharding> GatherDataOperandShardingFromOutput(
166     const HloSharding& output_sharding, const HloInstruction& hlo);
167 
168 // Returns the slice size for a scatter with given operand and update shapes.
169 std::vector<int64_t> GetScatterSliceSize(const Shape& operand_shape,
170                                          const Shape& update_shape,
171                                          const ScatterDimensionNumbers& dnums);
172 
173 // Returns an output sharding of scatter by passing through the update operand's
174 // sharding.
175 std::optional<HloSharding> ScatterOutputShardingFromUpdate(
176     const HloSharding& update_sharding, const HloScatterInstruction& scatter);
177 
178 // Returns an update operand sharding of scatter by passing through the output's
179 // sharding.
180 std::optional<HloSharding> ScatterUpdateShardingFromOutput(
181     const HloSharding& per_output_sharding,
182     const HloScatterInstruction& scatter);
183 
184 // Returns an identity value and an HloOpcode for reduce computation of scatter
185 // instruction.
186 // - If computation is add/or, return 0/false with corresponding op code;
187 // - If computation is multiply/and, return 1/true with corresponding op code.
188 // - If computation is min/max, return max value/min value with corresponding op
189 //   code.
190 // - Otherwise, return error status.
191 StatusOr<std::pair<std::unique_ptr<HloInstruction>, HloOpcode>>
192 IdentityValueAndHloOpcodeForScatterReduceComputation(
193     const HloScatterInstruction& scatter);
194 
195 // Given a sharding and a list of devices in the topology, return a
196 // list of the devices that `sharding` applies to.
197 std::vector<int64_t> DevicesForSharding(
198     const HloSharding& sharding, absl::Span<const int64_t> available_devices);
199 
200 // Returns a sharding that replicates data across devices along the given
201 // dimensions in the original sharding.
202 HloSharding PartiallyReplicateTiledShardingOnDims(
203     const HloSharding& sharding, absl::Span<const int64_t> dims_to_replicate);
204 
205 // Returns a sharding that replicates data across devices along all dimensions
206 // but the given ones to keep in the original sharding.
207 HloSharding PartiallyReplicateTiledShardingOnAllDimsExcept(
208     const HloSharding& sharding, absl::Span<const int64_t> dims_to_keep);
209 
210 // Returns a sharding that replicates all data dimensions, but keep manual
211 // subgroups. If data_rank is provided >= 0, the result sharding's data rank
212 // will be set to it.
213 HloSharding ReplicateAllDataDims(const HloSharding& sharding,
214                                  int64_t data_rank = -1);
215 
216 // Returns a sharding the removes given tile dimensions.
217 //
218 // Precondition: if not tile maximal, the size of each tile dimension must be 1.
219 HloSharding RemoveShapeDimensions(const HloSharding& sharding,
220                                   absl::Span<const int64_t> dims_to_remove);
221 
222 // Similar to TransposeSharding(), but allows removing/adding non-partitioned
223 // dimensions. In src_to_tgt and tgt_to_src, -1 represents a non-existing
224 // dimension.
225 std::optional<HloSharding> TransposeShardingWithCollapsedDims(
226     const HloSharding& source, absl::Span<int64_t const> src_to_tgt,
227     absl::Span<int64_t const> tgt_to_src);
228 
229 // Returns the iota dimension if maybe_iota is an kIota instruction or
230 // equivalent to kIota.
231 std::optional<int64_t> GetDimensionForIota(const HloInstruction* maybe_iota);
232 
233 // Returns identified parallel dimensions for Gather.
234 std::optional<GatherParallelDims> GetGatherBatchParallelDims(
235     const HloInstruction& hlo);
236 
237 // Returns the parallel dimensions of the output of a gather based on the
238 // parallel dimensions of the input.
239 absl::InlinedVector<int64_t, 1> GatherParallelOutputDims(
240     const HloInstruction& gather, const GatherParallelDims& parallel_dim);
241 
242 // Returns the parallel dimensions of the data operand of a gather with the
243 // order of the parallel dimensions matching that of the parallel dimensions
244 // of the output.
245 absl::InlinedVector<int64_t, 1> GatherOutputAlignedOperandParallelDims(
246     const HloInstruction& gather, const GatherParallelDims& parallel_dims);
247 
248 // Represents grouping devices in a tiled sharding along certain dimensions.
249 // Elements in group dimensions define different device groups, and the sharding
250 // represents the in-group sharding.
251 struct GroupedSharding {
252   GroupedSharding(std::vector<std::vector<int64_t>> device_groups,
253                   std::vector<int64_t> group_dims,
254                   std::vector<int64_t> group_dim_sizes, int64_t data_rank,
255                   HloSharding grouped_sharding, bool subgroup_manual = false)
device_groupsGroupedSharding256       : device_groups(std::move(device_groups)),
257         group_dims(std::move(group_dims)),
258         group_dim_sizes(std::move(group_dim_sizes)),
259         data_rank(data_rank),
260         sharding(std::move(grouped_sharding)),
261         subgroup_manual(subgroup_manual) {}
262   std::string ToString() const;
263   std::vector<std::vector<int64_t>> device_groups;
264   std::vector<int64_t> group_dims;
265   std::vector<int64_t> group_dim_sizes;
266   int64_t data_rank;
267   HloSharding sharding;
268   bool subgroup_manual;
269 };
270 
271 // Creates a GroupedSharding for a tiled sharding with group dim shard sizes.
272 GroupedSharding GroupShardingOnDims(const HloSharding& sharding,
273                                     absl::Span<const int64_t> group_dims,
274                                     absl::Span<const int64_t> group_dim_shards,
275                                     bool subgroup_manual = false);
276 
277 // Creates a GroupedSharding for a tiled sharding.
278 GroupedSharding GroupShardingOnDims(const HloSharding& sharding,
279                                     absl::Span<const int64_t> group_dims,
280                                     bool subgroup_manual = false);
281 
282 // Get group sharding for each manual subgroup.
283 GroupedSharding GetManualSubgroupSharding(const HloSharding& sharding);
284 
285 // Reconstructs the ungrouped sharding from a GroupedSharding.
286 HloSharding UngroupSharding(const GroupedSharding& grouped_sharding);
287 
288 // Check if the device groups are match for the LHS or RHS group shardings.
289 bool DeviceGroupsAreMatch(GroupedSharding& lhs, GroupedSharding& rhs,
290                           bool ignore_group_order = true);
291 
292 // Spawns a new dimension by splitting an existing dimension and generating a
293 // new dimension to its right of the passed down size. The original dimension
294 // will be of size "original_dim_size / new_dim_size". The original dimension
295 // size needs to be divisible by new_dim_size.
296 HloSharding SplitShardingDimension(const HloSharding& sharding,
297                                    int64_t dimension, int64_t new_dim_size);
298 
299 // Merges a dimension
300 // to its left. The new dimension will be of size
301 // dimensions[dimension] * dimensions[dimension+1}.
302 HloSharding MergeShardingDimension(const HloSharding& sharding,
303                                    int64_t dimension);
304 }  // namespace hlo_sharding_util
305 }  // namespace xla
306 
307 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_
308