1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_ 18 19 #include <map> 20 #include <optional> 21 #include <string> 22 #include <vector> 23 24 #include "absl/container/inlined_vector.h" 25 #include "tensorflow/compiler/xla/service/hlo_computation.h" 26 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 27 #include "tensorflow/compiler/xla/service/hlo_instructions.h" 28 #include "tensorflow/compiler/xla/service/hlo_module.h" 29 #include "tensorflow/compiler/xla/service/hlo_sharding.h" 30 31 namespace xla { 32 namespace hlo_sharding_util { 33 34 struct GatherParallelDims { 35 absl::InlinedVector<int64_t, 1> indices_parallel_dims; 36 absl::InlinedVector<int64_t, 1> operand_parallel_dims; 37 std::vector<int64_t> index_parallel_in_dim; 38 }; 39 40 // Returns true if the lhs sharding is preferable over the rhs sharding. 41 // The most specific sharding is tile maximal followed by single device tile 42 // maximal and finally replicated. This order aims to primarily reduce memory 43 // usage and secondly reduce total compute. 44 // Note: This does NOT provide a total ordering as we can have 2 different 45 // sharding with same preference level. 46 bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs); 47 48 // Tries to refine `to_merge` by combining with `old`. Returns if the final 49 // `to_merge` is more specific than `old`. 50 bool MergeSharding(const HloSharding& old, HloSharding* to_merge, 51 bool may_combine_partial_sharding); 52 53 // Merges `to_merge` into `dst` only if they are compatible, and the merged 54 // sharding has >= minimum_tiles tiles. Returns if merging happened. 55 bool MergeShardingIfCompatible(const HloSharding& to_merge, 56 int64_t minimum_tiles, HloSharding* dst); 57 58 // Given a map<device, occurrence_count>, selects the device with higher 59 // occurrence count (if any). If top_count in not nullptr, it will receive the 60 // count of the dominant device returned. 61 std::optional<int64_t> SelectDominantDevice( 62 const std::map<int64_t, int64_t>& device_map, int64_t* top_count); 63 64 // Assigns all the instructions of a computation, to a given device. 65 // This API does not recurse into called computations, and does not assign 66 // instructions which already have sharding. 67 void AssignComputationDevice(HloComputation* computation, int64_t device); 68 69 // Given an instruction container, returns the device which is most commonly 70 // occurring among the instructions. 71 std::optional<int64_t> GetMostOccurringDevice( 72 absl::Span<HloInstruction* const> instructions); 73 74 // Given a set of computations, tries to extract the dominant device. A device 75 // is dominant if the combined occurrence among all the instructions of the 76 // input computations, is greater/equal than/to dominant_factor (real number 77 // from 0 to 1). 78 // This API does not recurse into called computations. 79 // If no device exists that satisfies the condition, the returned optional will 80 // hold no value. 81 std::optional<int64_t> GetDominantDevice( 82 absl::Span<HloComputation* const> computations, double dominant_factor); 83 84 // Returns the HloSharding with the tile dimensions and tile assignment 85 // transposed based on the specified dimension numbers. In case of a tile 86 // maximal sharding returns the original sharding. 87 HloSharding TransposeSharding(const HloSharding& sharding, 88 absl::Span<const int64_t> dimensions); 89 90 // Returns the HloSharding with the tile shape reshaped based on the source and 91 // target shapes and the tile assignment adjusted to correspond to the new tile 92 // shape or std::nullopt if the resulting reshape would create an invalid 93 // sharding (non continuous or non uniformly sized tiles). In case of a tile 94 // maximal sharding returns the original sharding. 95 std::optional<HloSharding> ReshapeSharding(const Shape& source_shape, 96 const Shape& target_shape, 97 const HloSharding& sharding); 98 99 // Returns the HloSharding with the tile dimensions and tile assignment 100 // reversed based on the specified dimension numbers. In case of a tile 101 // maximal sharding returns the original sharding. 102 HloSharding ReverseSharding(const HloSharding& sharding, 103 absl::Span<const int64_t> dimensions); 104 105 // Returns a sharding tiled on unique dimension dim by reshaping the tile 106 // assignment of the sharding argument. Only dimensions in the dims span 107 // argument are considered for reshaping, the others are ignored. 108 // Assumptions: sharding is tile sharded, and dim must be included in dims. 109 HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64_t dim, 110 absl::Span<const int64_t> dims); 111 112 // Returns true if the provided module includes one or more instructions with 113 // a tile sharding. 114 bool ContainsTileSharding(const HloModule& module); 115 116 // Returns the preferred output sharding for a gather op based on the sharding 117 // of the indces. 118 HloSharding GatherOutputSharding(const HloSharding& index_sharding, 119 const HloInstruction* hlo); 120 121 // Returns the preferred index sharding for a gather op based on the sharding 122 // of the output. 123 HloSharding GatherIndexSharding(const HloSharding& output_sharding, 124 const HloInstruction* hlo); 125 126 // Returns a new HloSharding for a gather op so that only non offset dimensions 127 // are sharded. Assume "result" is returned by this function. It is ensured that 128 // "GetIndexSharding(result, hlo)" will have the same number of elements as 129 // "result". 130 HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo); 131 132 // Returns the preferred index sharding for a scatter op based on the sharding 133 // of the data. 134 HloSharding ScatterIndexSharding(const HloSharding& data_sharding, 135 const HloScatterInstruction* scatter); 136 137 // Returns the preferred data sharding for a scatter op based on the sharding 138 // of the index. 139 HloSharding ScatterDataSharding(const HloSharding& index_sharding, 140 const HloScatterInstruction* scatter); 141 142 // Returns a new index sharding for a scatter op so that we only shard on first 143 // "number of scatter_window_dims" dimensions. Assume "result" is returned by 144 // this function. It is ensured that "ScatterDataSharding(result, hlo)" will 145 // have the same number of elements as "result". 146 HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding, 147 const HloScatterInstruction& scatter); 148 149 // Returns a new data sharding for a scatter op so that we only shard on 150 // scatter_window_dims. Assume "result" is returned by this function. It is 151 // ensured that "ScatterIndexSharding(result, hlo)" will have the same number of 152 // elements as "result". 153 HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding, 154 const HloScatterInstruction& scatter); 155 156 // Returns an output sharding of gather by passing through the data operand's 157 // sharding. 158 std::optional<HloSharding> GatherOutputShardingFromDataOperand( 159 const HloSharding& data_operand_sharding, const HloInstruction& hlo, 160 absl::Span<const int64_t> slice_sizes, const Shape& output_shape, 161 const Shape& operand_shape); 162 163 // Returns a data operand sharding of gather by passing through the output's 164 // sharding. 165 std::optional<HloSharding> GatherDataOperandShardingFromOutput( 166 const HloSharding& output_sharding, const HloInstruction& hlo); 167 168 // Returns the slice size for a scatter with given operand and update shapes. 169 std::vector<int64_t> GetScatterSliceSize(const Shape& operand_shape, 170 const Shape& update_shape, 171 const ScatterDimensionNumbers& dnums); 172 173 // Returns an output sharding of scatter by passing through the update operand's 174 // sharding. 175 std::optional<HloSharding> ScatterOutputShardingFromUpdate( 176 const HloSharding& update_sharding, const HloScatterInstruction& scatter); 177 178 // Returns an update operand sharding of scatter by passing through the output's 179 // sharding. 180 std::optional<HloSharding> ScatterUpdateShardingFromOutput( 181 const HloSharding& per_output_sharding, 182 const HloScatterInstruction& scatter); 183 184 // Returns an identity value and an HloOpcode for reduce computation of scatter 185 // instruction. 186 // - If computation is add/or, return 0/false with corresponding op code; 187 // - If computation is multiply/and, return 1/true with corresponding op code. 188 // - If computation is min/max, return max value/min value with corresponding op 189 // code. 190 // - Otherwise, return error status. 191 StatusOr<std::pair<std::unique_ptr<HloInstruction>, HloOpcode>> 192 IdentityValueAndHloOpcodeForScatterReduceComputation( 193 const HloScatterInstruction& scatter); 194 195 // Given a sharding and a list of devices in the topology, return a 196 // list of the devices that `sharding` applies to. 197 std::vector<int64_t> DevicesForSharding( 198 const HloSharding& sharding, absl::Span<const int64_t> available_devices); 199 200 // Returns a sharding that replicates data across devices along the given 201 // dimensions in the original sharding. 202 HloSharding PartiallyReplicateTiledShardingOnDims( 203 const HloSharding& sharding, absl::Span<const int64_t> dims_to_replicate); 204 205 // Returns a sharding that replicates data across devices along all dimensions 206 // but the given ones to keep in the original sharding. 207 HloSharding PartiallyReplicateTiledShardingOnAllDimsExcept( 208 const HloSharding& sharding, absl::Span<const int64_t> dims_to_keep); 209 210 // Returns a sharding that replicates all data dimensions, but keep manual 211 // subgroups. If data_rank is provided >= 0, the result sharding's data rank 212 // will be set to it. 213 HloSharding ReplicateAllDataDims(const HloSharding& sharding, 214 int64_t data_rank = -1); 215 216 // Returns a sharding the removes given tile dimensions. 217 // 218 // Precondition: if not tile maximal, the size of each tile dimension must be 1. 219 HloSharding RemoveShapeDimensions(const HloSharding& sharding, 220 absl::Span<const int64_t> dims_to_remove); 221 222 // Similar to TransposeSharding(), but allows removing/adding non-partitioned 223 // dimensions. In src_to_tgt and tgt_to_src, -1 represents a non-existing 224 // dimension. 225 std::optional<HloSharding> TransposeShardingWithCollapsedDims( 226 const HloSharding& source, absl::Span<int64_t const> src_to_tgt, 227 absl::Span<int64_t const> tgt_to_src); 228 229 // Returns the iota dimension if maybe_iota is an kIota instruction or 230 // equivalent to kIota. 231 std::optional<int64_t> GetDimensionForIota(const HloInstruction* maybe_iota); 232 233 // Returns identified parallel dimensions for Gather. 234 std::optional<GatherParallelDims> GetGatherBatchParallelDims( 235 const HloInstruction& hlo); 236 237 // Returns the parallel dimensions of the output of a gather based on the 238 // parallel dimensions of the input. 239 absl::InlinedVector<int64_t, 1> GatherParallelOutputDims( 240 const HloInstruction& gather, const GatherParallelDims& parallel_dim); 241 242 // Returns the parallel dimensions of the data operand of a gather with the 243 // order of the parallel dimensions matching that of the parallel dimensions 244 // of the output. 245 absl::InlinedVector<int64_t, 1> GatherOutputAlignedOperandParallelDims( 246 const HloInstruction& gather, const GatherParallelDims& parallel_dims); 247 248 // Represents grouping devices in a tiled sharding along certain dimensions. 249 // Elements in group dimensions define different device groups, and the sharding 250 // represents the in-group sharding. 251 struct GroupedSharding { 252 GroupedSharding(std::vector<std::vector<int64_t>> device_groups, 253 std::vector<int64_t> group_dims, 254 std::vector<int64_t> group_dim_sizes, int64_t data_rank, 255 HloSharding grouped_sharding, bool subgroup_manual = false) device_groupsGroupedSharding256 : device_groups(std::move(device_groups)), 257 group_dims(std::move(group_dims)), 258 group_dim_sizes(std::move(group_dim_sizes)), 259 data_rank(data_rank), 260 sharding(std::move(grouped_sharding)), 261 subgroup_manual(subgroup_manual) {} 262 std::string ToString() const; 263 std::vector<std::vector<int64_t>> device_groups; 264 std::vector<int64_t> group_dims; 265 std::vector<int64_t> group_dim_sizes; 266 int64_t data_rank; 267 HloSharding sharding; 268 bool subgroup_manual; 269 }; 270 271 // Creates a GroupedSharding for a tiled sharding with group dim shard sizes. 272 GroupedSharding GroupShardingOnDims(const HloSharding& sharding, 273 absl::Span<const int64_t> group_dims, 274 absl::Span<const int64_t> group_dim_shards, 275 bool subgroup_manual = false); 276 277 // Creates a GroupedSharding for a tiled sharding. 278 GroupedSharding GroupShardingOnDims(const HloSharding& sharding, 279 absl::Span<const int64_t> group_dims, 280 bool subgroup_manual = false); 281 282 // Get group sharding for each manual subgroup. 283 GroupedSharding GetManualSubgroupSharding(const HloSharding& sharding); 284 285 // Reconstructs the ungrouped sharding from a GroupedSharding. 286 HloSharding UngroupSharding(const GroupedSharding& grouped_sharding); 287 288 // Check if the device groups are match for the LHS or RHS group shardings. 289 bool DeviceGroupsAreMatch(GroupedSharding& lhs, GroupedSharding& rhs, 290 bool ignore_group_order = true); 291 292 // Spawns a new dimension by splitting an existing dimension and generating a 293 // new dimension to its right of the passed down size. The original dimension 294 // will be of size "original_dim_size / new_dim_size". The original dimension 295 // size needs to be divisible by new_dim_size. 296 HloSharding SplitShardingDimension(const HloSharding& sharding, 297 int64_t dimension, int64_t new_dim_size); 298 299 // Merges a dimension 300 // to its left. The new dimension will be of size 301 // dimensions[dimension] * dimensions[dimension+1}. 302 HloSharding MergeShardingDimension(const HloSharding& sharding, 303 int64_t dimension); 304 } // namespace hlo_sharding_util 305 } // namespace xla 306 307 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_ 308