1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/dtensor/mlir/expansions/meta_spmd_expander.h"
17
18 #include <cstdint>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/strings/str_join.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
30 #include "mlir/IR/Builders.h" // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
32 #include "mlir/IR/Matchers.h" // from @llvm-project
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
35 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
36 #include "tensorflow/compiler/mlir/utils/array_container_utils.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/dtensor/cc/constants.h"
39 #include "tensorflow/dtensor/cc/dstatus.h"
40 #include "tensorflow/dtensor/mlir/collectives.h"
41 #include "tensorflow/dtensor/mlir/dtensor_location.h"
42 #include "tensorflow/dtensor/mlir/layout_parsing.h"
43 #include "tensorflow/dtensor/mlir/spmd_expander.h"
44 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
45 #include "tensorflow/dtensor/mlir/value_utils.h"
46
47 namespace tensorflow {
48 namespace dtensor {
49 namespace {
50
51 // Validates `axis` for pack/unpack and resolves negative values.
52 //
53 // Returns a valid positive axis or an error.
CanonicalizeAxis(int axis,int packed_rank)54 StatusOr<int> CanonicalizeAxis(int axis, int packed_rank) {
55 // Axis can be in range [-packed_rank, packed_rank), so we add packed_rank
56 // to wrap it around.
57 if (axis >= -packed_rank && axis < 0) {
58 axis += packed_rank;
59 } else if (axis < -packed_rank || axis >= packed_rank) {
60 return errors::InvalidArgument(
61 "Invalid axis; expected a value in [-packed_rank, packed_rank)");
62 }
63 return axis;
64 }
65
66 // Implements, for pack or unpack, layout propagation from a suggested layout
67 // for the packed tensor to suggested layouts for the unpacked tensors.
LayoutsFromPackedTensor(int axis,const Layout & packed_layout,size_t num_unpacked_tensors)68 StatusOr<llvm::DenseMap<int, Layout>> LayoutsFromPackedTensor(
69 int axis, const Layout& packed_layout, size_t num_unpacked_tensors) {
70 TF_ASSIGN_OR_RETURN(axis,
71 CanonicalizeAxis(axis,
72 /*packed_rank=*/packed_layout.rank()));
73 const Layout unpacked_layout =
74 packed_layout.GetLayoutWithReducedDims({axis}, false);
75 llvm::DenseMap<int, Layout> layouts(num_unpacked_tensors);
76 for (int i = 0; i < num_unpacked_tensors; ++i) {
77 layouts[i] = unpacked_layout;
78 }
79 return layouts;
80 }
81
82 // Implements, for pack or unpack, layout propagation from suggested layouts for
83 // the unpacked tensors to a suggested layout for the packed tensor.
LayoutFromUnpackedTensors(int axis,const llvm::DenseMap<int,Layout> & unpacked_layouts)84 StatusOr<llvm::DenseMap<int, Layout>> LayoutFromUnpackedTensors(
85 int axis, const llvm::DenseMap<int, Layout>& unpacked_layouts) {
86 if (unpacked_layouts.empty()) return llvm::DenseMap<int, Layout>();
87
88 auto it = unpacked_layouts.begin();
89 const Layout& first_layout = it->getSecond();
90 const Mesh& mesh = first_layout.mesh();
91
92 // Record the mesh and rank of the first input layout that exists.
93 // The rank + mesh for others will be the same.
94 const int unpacked_rank = first_layout.rank();
95 TF_ASSIGN_OR_RETURN(axis,
96 CanonicalizeAxis(axis,
97 /*packed_rank=*/unpacked_rank + 1));
98
99 std::vector<std::string> inferred_packed_layout_specs;
100 for (int rank_index = 0; rank_index <= unpacked_rank; ++rank_index) {
101 if (rank_index == axis)
102 inferred_packed_layout_specs.push_back(Layout::kUnshardedDim);
103 if (rank_index == unpacked_rank) {
104 break;
105 }
106 // When we have multiple input with conflicting shardings, set that
107 // dimension to replicated (aka unsharded).
108 std::string dimension = Layout::kUnshardedDim;
109 bool found_sharded_dim = false;
110 for (; it != unpacked_layouts.end(); ++it) {
111 const std::string& sharding_spec =
112 it->getSecond().sharding_spec(rank_index);
113 if (!Layout::IsUnshardedDimension(sharding_spec)) {
114 if (!found_sharded_dim) {
115 found_sharded_dim = true;
116 dimension = sharding_spec;
117 } else if (sharding_spec != dimension) {
118 dimension = Layout::kUnshardedDim;
119 }
120 }
121 }
122 inferred_packed_layout_specs.push_back(dimension);
123 }
124 TF_ASSIGN_OR_RETURN(auto inferred_packed_layout,
125 Layout::GetLayout(inferred_packed_layout_specs, mesh));
126 return llvm::DenseMap<int, Layout>({{0, inferred_packed_layout}});
127 }
128
129 } // namespace
130
ExpandOp(mlir::Operation * op)131 StatusOr<mlir::Operation*> PackSPMDExpander::ExpandOp(mlir::Operation* op) {
132 auto pack = llvm::cast<mlir::TF::PackOp>(op);
133 TF_ASSIGN_OR_RETURN(const absl::optional<Layout> output_layout,
134 ExtractSingleLayoutFromOp(op));
135
136 const int output_rank = ValueRank(pack.output());
137 if (output_rank == -1)
138 return errors::Unimplemented("output must have a rank");
139
140 TF_ASSIGN_OR_RETURN(
141 int axis, CanonicalizeAxis(pack.axis(), /*packed_rank=*/output_rank));
142
143 // TODO(bfontain): This may not be the best, but for now relayout all inputs
144 // to match the output layout. E.g. if the output layout is not but the input
145 // is, this would force a AllConcat on all inputs, rather than first packing
146 // and emitting one AllConcat.
147 const Layout new_input_layout =
148 output_layout->GetLayoutWithReducedDims({axis}, /*keep_dims=*/false);
149
150 for (int i = 0; i < op->getNumOperands(); ++i) {
151 TF_ASSIGN_OR_RETURN(const absl::optional<Layout> layout,
152 ExtractLayoutFromOperand(pack.getOperand(i)));
153 if (!layout) return errors::InvalidArgument("missing layout for input ", i);
154
155 TF_ASSIGN_OR_RETURN(
156 mlir::Value new_input,
157 EmitRelayout(pack.getOperand(i), *layout, new_input_layout));
158
159 pack.setOperand(i, new_input);
160 }
161
162 return InferSPMDExpandedLocalShape(op);
163 }
164
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)165 StatusOr<llvm::DenseMap<int, Layout>> PackSPMDExpander::ComputeLayoutForward(
166 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
167 auto pack = llvm::cast<mlir::TF::PackOp>(op);
168 const int axis = pack.axis();
169 return LayoutFromUnpackedTensors(axis, input_layouts);
170 }
171
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)172 StatusOr<llvm::DenseMap<int, Layout>> PackSPMDExpander::ComputeLayoutBackward(
173 mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
174 if (output_layouts.find(0) == output_layouts.end())
175 return llvm::DenseMap<int, Layout>();
176
177 auto pack = llvm::cast<mlir::TF::PackOp>(op);
178 const int axis = pack.axis();
179 return LayoutsFromPackedTensor(axis, output_layouts.lookup(0),
180 pack->getNumOperands());
181 }
182
ExpandOp(mlir::Operation * op)183 StatusOr<mlir::Operation*> UnpackSPMDExpander::ExpandOp(mlir::Operation* op) {
184 auto unpack = llvm::cast<mlir::TF::UnpackOp>(op);
185 TF_ASSIGN_OR_RETURN(const absl::optional<Layout> input_layout,
186 ExtractLayoutFromOperand(unpack.getOperand()));
187 if (!input_layout) {
188 return errors::Unimplemented("input must have a layout");
189 }
190
191 const int input_rank = ValueRank(unpack.getOperand());
192 if (input_rank == -1) {
193 return errors::Unimplemented("input must have a rank");
194 }
195
196 TF_ASSIGN_OR_RETURN(
197 int axis, CanonicalizeAxis(unpack.axis(), /*packed_rank=*/input_rank));
198
199 if (input_layout->num_shards_for_dim(input_layout->dim(axis)) != 1) {
200 // If the axis being unpacked is sharded, relayout to replicated along that
201 // axis since each device needs to split across it.
202 std::vector<ShardingSpec> new_layout_specs(input_rank);
203 for (int input_index = 0; input_index < input_rank; ++input_index) {
204 if (input_index == axis) {
205 new_layout_specs[input_index].set_sharding_spec(Layout::kUnshardedDim);
206 } else {
207 new_layout_specs[input_index] = input_layout->dim(input_index);
208 }
209 }
210 TF_ASSIGN_OR_RETURN(
211 Layout new_input_layout,
212 Layout::GetLayout(std::move(new_layout_specs), input_layout->mesh()));
213 TF_ASSIGN_OR_RETURN(
214 mlir::Value new_input,
215 EmitRelayout(unpack.getOperand(), *input_layout, new_input_layout));
216 unpack.setOperand(new_input);
217 }
218 return InferSPMDExpandedLocalShape(op);
219 }
220
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)221 StatusOr<llvm::DenseMap<int, Layout>> UnpackSPMDExpander::ComputeLayoutForward(
222 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
223 if (input_layouts.find(0) == input_layouts.end())
224 return llvm::DenseMap<int, Layout>();
225
226 auto unpack = llvm::cast<mlir::TF::UnpackOp>(op);
227 const int axis = unpack.axis();
228 return LayoutsFromPackedTensor(axis, input_layouts.lookup(0),
229 unpack->getNumResults());
230 }
231
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)232 StatusOr<llvm::DenseMap<int, Layout>> UnpackSPMDExpander::ComputeLayoutBackward(
233 mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
234 auto unpack = llvm::cast<mlir::TF::UnpackOp>(op);
235 const int axis = unpack.axis();
236 return LayoutFromUnpackedTensors(axis, output_layouts);
237 }
238
239 namespace {
240
VerifyPaddedDimensionNotSharded(const Layout & layout,mlir::Value pad_input,mlir::Value pad_output)241 Status VerifyPaddedDimensionNotSharded(const Layout& layout,
242 mlir::Value pad_input,
243 mlir::Value pad_output) {
244 auto input_type = pad_input.getType().dyn_cast<mlir::RankedTensorType>();
245 auto output_type = pad_output.getType().dyn_cast<mlir::RankedTensorType>();
246 if (!input_type || !output_type)
247 return errors::InvalidArgument(
248 "pad op input/output should have statically known shape for SPMD.");
249
250 const auto input_shape = input_type.getShape();
251 const auto output_shape = input_type.getShape();
252 for (const auto& dim_shard_and_index :
253 llvm::enumerate(layout.sharding_specs())) {
254 const int index = dim_shard_and_index.index();
255 const auto& tensor_dimension = dim_shard_and_index.value();
256 const int input_shape_for_dim = input_shape[index];
257 const int output_shape_for_dim = output_shape[index];
258 if ((input_shape_for_dim == -1 || output_shape_for_dim == -1 ||
259 output_shape_for_dim != input_shape_for_dim) &&
260 layout.num_shards_for_dim(tensor_dimension) > 1) {
261 return errors::InvalidArgument(
262 "Padding over sharded dimension is not allowed.");
263 }
264 }
265 return OkStatus();
266 }
267
268 } // namespace
269
ExpandOp(mlir::Operation * op)270 StatusOr<mlir::Operation*> PadSPMDExpander::ExpandOp(mlir::Operation* op) {
271 // TODO(b/170666884): Implement sharded SPMD logic for tf.Pad op.
272 TF_ASSIGN_OR_RETURN(auto op_layout, ExtractSingleLayoutFromOp(op));
273 auto pad_input = op->getOperand(0);
274 auto pad_output = op->getResult(0);
275
276 TF_ASSIGN_OR_RETURN(auto input_layout, ExtractLayoutFromOperand(pad_input));
277 assert(input_layout && op_layout);
278
279 if (op_layout != input_layout)
280 return errors::Unimplemented(
281 "pad op with input layout different from op output layout is not yet "
282 "supported.");
283
284 TF_RETURN_IF_ERROR(
285 VerifyPaddedDimensionNotSharded(*op_layout, pad_input, pad_output));
286 return InferSPMDExpandedLocalShape(op);
287 }
288
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)289 StatusOr<llvm::DenseMap<int, Layout>> PadSPMDExpander::ComputeLayoutForward(
290 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
291 if (input_layouts.find(0) == input_layouts.end())
292 return llvm::DenseMap<int, Layout>();
293
294 const Layout input_layout = input_layouts.lookup(0);
295 mlir::Value pad_input;
296 mlir::Value pad_output;
297
298 if (auto pad_v2 = llvm::dyn_cast<mlir::TF::PadV2Op>(op)) {
299 pad_output = pad_v2.output();
300 pad_input = pad_v2.input();
301 } else {
302 auto pad_op = llvm::cast<mlir::TF::PadOp>(op);
303 pad_output = pad_op.output();
304 pad_input = pad_op.input();
305 }
306
307 TF_RETURN_IF_ERROR(
308 VerifyPaddedDimensionNotSharded(input_layout, pad_input, pad_output));
309 return llvm::DenseMap<int, Layout>({{0, input_layout}});
310 }
311
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)312 StatusOr<llvm::DenseMap<int, Layout>> PadSPMDExpander::ComputeLayoutBackward(
313 mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
314 TF_ASSIGN_OR_RETURN(auto mesh, ExtractDeviceMeshEnclosingCluster(op));
315 mlir::Value pad_input;
316 mlir::Value pad_output;
317
318 llvm::DenseMap<int, Layout> input_layouts(op->getNumOperands());
319 // Pad op `padding` operand always has rank 2 tensor.
320 input_layouts[1] = Layout::ReplicatedOnMesh(mesh, /*rank=*/2);
321
322 if (auto pad_v2 = llvm::dyn_cast<mlir::TF::PadV2Op>(op)) {
323 pad_output = pad_v2.output();
324 pad_input = pad_v2.input();
325 // `constant_values` operand
326 input_layouts[2] = Layout::ReplicatedOnMesh(mesh, /*rank=*/0);
327 } else {
328 auto pad_op = llvm::cast<mlir::TF::PadOp>(op);
329 pad_output = pad_op.output();
330 pad_input = pad_op.input();
331 }
332
333 if (output_layouts.find(0) != output_layouts.end()) {
334 const Layout output_layout = output_layouts.lookup(0);
335 TF_RETURN_IF_ERROR(
336 VerifyPaddedDimensionNotSharded(output_layout, pad_input, pad_output));
337 // `input` operand
338 input_layouts[0] = output_layout;
339 }
340 return input_layouts;
341 }
342
343 namespace {
344
VerifyTileOperandLayout(const Layout & operand_layout,llvm::ArrayRef<int64_t> static_multiples)345 Status VerifyTileOperandLayout(const Layout& operand_layout,
346 llvm::ArrayRef<int64_t> static_multiples) {
347 for (const auto& tensor_dim_and_multiple :
348 llvm::zip(operand_layout.sharding_specs(), static_multiples)) {
349 const auto& tensor_dimension = std::get<0>(tensor_dim_and_multiple);
350 const int64_t multiple_factor = std::get<1>(tensor_dim_and_multiple);
351 if (multiple_factor > 1 &&
352 operand_layout.num_shards_for_dim(tensor_dimension) > 1)
353 return errors::InvalidArgument(
354 "tile op with input sharded at dimension where `multiple` > 1 is not "
355 "supported.");
356 }
357 return OkStatus();
358 }
359
360 } // namespace
361
ExpandOp(mlir::Operation * op)362 StatusOr<mlir::Operation*> TileSPMDExpander::ExpandOp(mlir::Operation* op) {
363 auto tile_op = llvm::cast<mlir::TF::TileOp>(op);
364 // After layout propagation, tile op should already have the proper output
365 // layout tagged on itself.
366 TF_ASSIGN_OR_RETURN(absl::optional<Layout> output_layout,
367 ExtractSingleLayoutFromOp(op));
368 if (!output_layout)
369 return errors::InvalidArgument(
370 "TileOP doesn't have a layout after layout propagation");
371
372 TF_ASSIGN_OR_RETURN(absl::optional<Layout> operand_layout,
373 ExtractLayoutFromOperand(tile_op.input()));
374 if (!operand_layout)
375 return errors::InvalidArgument(
376 "Input operand to TileOp doesn't have a layout after layout "
377 "propagation.");
378
379 if (operand_layout->IsFullyReplicated() &&
380 output_layout->IsFullyReplicated()) {
381 // There's nothing to do; we can avoid some unimplemented cases.
382 return InferSPMDExpandedLocalShape(op);
383 }
384
385 llvm::SmallVector<int64_t, 4> static_multiples;
386 auto status =
387 ExtractConstVectorFromValue(tile_op.multiples(), &static_multiples);
388 if (!status.ok())
389 return errors::Unimplemented(
390 "Tile with a sharded output is not implemented for dynamic "
391 "`multiples`.");
392
393 // If `multiples` values can be statically known, verify that all dimensions
394 // with `multiples` > 1 is replicated.
395 TF_RETURN_IF_ERROR(
396 VerifyTileOperandLayout(*operand_layout, static_multiples));
397
398 llvm::SmallVector<int, 4> local_tile_multiples;
399 std::vector<int32> operand_shards = operand_layout->num_shards();
400 std::vector<int32> output_shards = output_layout->num_shards();
401 if (operand_shards.size() != output_shards.size()) {
402 return errors::InvalidArgument(
403 "Expected inputs and outputs to have the same rank.");
404 }
405
406 for (int dim_index = 0; dim_index < operand_shards.size(); ++dim_index) {
407 if (static_multiples[dim_index] == 1) {
408 local_tile_multiples.push_back(static_multiples[dim_index]);
409 continue;
410 }
411 if (output_shards[dim_index] > static_multiples[dim_index])
412 // TODO(b/161012891): Split the input to support sharding the output
413 // more than `multiples` ways.
414 return errors::Unimplemented(
415 "Sharding the output of Tile into more than `multiples` shards is "
416 "not currently supported.");
417 if (static_multiples[dim_index] % output_shards[dim_index] != 0)
418 return errors::Unimplemented(
419 "The output sharding of Tile must evenly divide `multiples`.");
420 if (!Layout::IsUnshardedDimension(
421 operand_layout->sharding_spec(dim_index)) &&
422 (Layout::IsUnshardedDimension(
423 output_layout->sharding_spec(dim_index)) ||
424 (operand_layout->sharding_spec(dim_index) !=
425 output_layout->sharding_spec(dim_index))))
426 return errors::Unimplemented(
427 "Input is replicated on tensor dimension ", dim_index,
428 " but the "
429 "output is not replicated or is replicated on a different mesh "
430 "dimension.");
431 local_tile_multiples.push_back(static_multiples[dim_index] /
432 output_shards[dim_index]);
433 }
434 mlir::OpBuilder builder(op);
435 auto location = DT_LOC(tile_op.getLoc());
436 auto multiples_const = IntConst(builder, location, local_tile_multiples);
437
438 auto global_output_type =
439 tile_op.getResult().getType().cast<mlir::TensorType>();
440 TF_ASSIGN_OR_RETURN(
441 auto local_type,
442 LocalTypeFromGlobalType(output_layout.value(), global_output_type));
443
444 auto new_tile =
445 builder.create<mlir::TF::TileOp>(location, /*output=*/local_type,
446 /*input=*/tile_op.input(),
447 /*multiples=*/multiples_const);
448 tile_op.getResult().replaceAllUsesWith(new_tile.output());
449 tile_op.erase();
450 return new_tile.getOperation();
451 }
452
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)453 StatusOr<llvm::DenseMap<int, Layout>> TileSPMDExpander::ComputeLayoutForward(
454 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
455 if (input_layouts.find(0) == input_layouts.end())
456 return llvm::DenseMap<int, Layout>();
457
458 TF_ASSIGN_OR_RETURN(auto mesh, ExtractDeviceMeshEnclosingCluster(op));
459 auto tile_op = llvm::cast<mlir::TF::TileOp>(op);
460
461 auto output_ranked_type =
462 tile_op.output().getType().dyn_cast<mlir::RankedTensorType>();
463 if (!output_ranked_type || !output_ranked_type.hasStaticShape()) {
464 return errors::InvalidArgument(
465 llvm::formatv(
466 "requires output type to have statically known rank, but got : {0}",
467 output_ranked_type)
468 .str());
469 }
470 auto tile_output_shape = output_ranked_type.getShape();
471
472 llvm::SmallVector<int64_t, 4> static_multiple;
473 auto status =
474 ExtractConstVectorFromValue(tile_op.multiples(), &static_multiple);
475
476 // If multiple operands cannot be statically known, output is set to
477 // replicated.
478 if (!status.ok()) {
479 return llvm::DenseMap<int, Layout>(
480 {{0, Layout::ReplicatedOnMesh(mesh, tile_output_shape.size())}});
481 }
482
483 // When suggested input layout exists then forward the input sharding for all
484 // dimensions where `multiple` == 1.
485 const Layout input_layout = input_layouts.lookup(0);
486 std::vector<std::string> output_layout_specs;
487 for (const auto& multiple_and_dim_sharding :
488 llvm::zip(static_multiple, input_layout.sharding_specs())) {
489 const int multiple = std::get<0>(multiple_and_dim_sharding);
490 const auto& tensor_dimension = std::get<1>(multiple_and_dim_sharding);
491 output_layout_specs.push_back(multiple == 1
492 ? tensor_dimension.sharding_spec()
493 : Layout::kUnshardedDim);
494 }
495
496 TF_ASSIGN_OR_RETURN(const Layout output_layout,
497 Layout::GetLayout(output_layout_specs, mesh));
498 return llvm::DenseMap<int, Layout>({{0, output_layout}});
499 }
500
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)501 StatusOr<llvm::DenseMap<int, Layout>> TileSPMDExpander::ComputeLayoutBackward(
502 mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
503 TF_ASSIGN_OR_RETURN(auto mesh, ExtractDeviceMeshEnclosingCluster(op));
504 auto tile_op = llvm::cast<mlir::TF::TileOp>(op);
505
506 // Retrieve operand/output shapes of tile op.
507 auto input_ranked_type =
508 tile_op.input().getType().dyn_cast<mlir::RankedTensorType>();
509 if (!input_ranked_type || !input_ranked_type.hasStaticShape()) {
510 return errors::InvalidArgument(
511 llvm::formatv(
512 "requires input type to have statically known rank, but got : {0}",
513 input_ranked_type)
514 .str());
515 }
516 auto tile_input_shape = input_ranked_type.getShape();
517
518 llvm::DenseMap<int, Layout> input_layouts(op->getNumOperands());
519
520 // `multiples` operand is always set to have replicated layout.
521 input_layouts[1] = Layout::ReplicatedOnMesh(
522 mesh,
523 tile_op.multiples().getType().cast<mlir::RankedTensorType>().getRank());
524
525 llvm::SmallVector<int64_t, 4> static_multiple;
526 auto status =
527 ExtractConstVectorFromValue(tile_op.multiples(), &static_multiple);
528
529 // If multiple operands cannot be statically known they are set to replicated.
530 if (!status.ok()) {
531 input_layouts[0] = Layout::ReplicatedOnMesh(mesh, tile_input_shape.size());
532 return input_layouts;
533 }
534
535 // When suggested output layout exists, then override operand layout with
536 // consumer suggested output layout if `multiple` of dimension == 1 and
537 // dimension size can be evenly divisible by the sharding.
538 if (output_layouts.find(0) != output_layouts.end()) {
539 const Layout output_layout = output_layouts.lookup(0);
540 std::vector<std::string> input_layout_specs;
541 for (const auto& multiple_and_dim_sharding :
542 llvm::zip(static_multiple, output_layout.sharding_specs())) {
543 const int multiple = std::get<0>(multiple_and_dim_sharding);
544 const auto& tensor_dimension = std::get<1>(multiple_and_dim_sharding);
545 input_layout_specs.push_back(multiple == 1
546 ? tensor_dimension.sharding_spec()
547 : Layout::kUnshardedDim);
548 }
549 TF_ASSIGN_OR_RETURN(const Layout input_layout,
550 Layout::GetLayout(input_layout_specs, mesh));
551 input_layouts[0] = input_layout;
552 }
553 return input_layouts;
554 }
555
556 namespace {
557
558 // From input shape and output shape, extract a maximal list segments where
559 // the product of the input shape from input_segment_start to input_segment_end
560 // is equal to the product of the output shape from output_segment_start
561 // to output_segment_end and is not equal for the product of any subsequence.
562 // Note that dimensions of shape are skipped over if they would be at the start
563 // of a segment.
564 // Note that shapes with unknown dimension size (represented by -1) are
565 // unsupported.
ComputeReshapeSegments(llvm::ArrayRef<int64_t> input_shape,llvm::ArrayRef<int64_t> output_shape,llvm::SmallVectorImpl<int64_t> & input_segment_start,llvm::SmallVectorImpl<int64_t> & input_segment_end,llvm::SmallVectorImpl<int64_t> & output_segment_start,llvm::SmallVectorImpl<int64_t> & output_segment_end)566 void ComputeReshapeSegments(
567 llvm::ArrayRef<int64_t> input_shape, llvm::ArrayRef<int64_t> output_shape,
568 llvm::SmallVectorImpl<int64_t>& input_segment_start,
569 llvm::SmallVectorImpl<int64_t>& input_segment_end,
570 llvm::SmallVectorImpl<int64_t>& output_segment_start,
571 llvm::SmallVectorImpl<int64_t>& output_segment_end) {
572 int input_offset = 0;
573 int output_offset = 0;
574
575 while (input_offset < input_shape.size() &&
576 output_offset < output_shape.size()) {
577 while (input_offset < input_shape.size() && input_shape[input_offset] == 1)
578 input_offset++;
579 while (output_offset < output_shape.size() &&
580 output_shape[output_offset] == 1)
581 output_offset++;
582 if (input_offset >= input_shape.size() ||
583 output_offset >= output_shape.size()) {
584 // Since the input and output tensors the same number of entries, we are
585 // guaranteed to reach the end of both shapes at the same time.
586 assert(input_offset >= input_shape.size() &&
587 output_offset >= output_shape.size());
588 return;
589 }
590
591 input_segment_start.emplace_back(input_offset);
592 output_segment_start.emplace_back(output_offset);
593
594 int64 input_prod = input_shape[input_offset++];
595 int64 output_prod = output_shape[output_offset++];
596 while (input_prod != output_prod) {
597 if (input_prod < output_prod)
598 input_prod *= input_shape[input_offset++];
599 else
600 output_prod *= output_shape[output_offset++];
601 }
602 input_segment_end.emplace_back(input_offset);
603 output_segment_end.emplace_back(output_offset);
604 }
605 }
606
607 // For reshape we want to reduce the number of all-to-alls and slices needed.
608 // Note that the forward layout propagation for reshape will be the same
609 // algorithm as backwards propagation.
610 //
611 // Suppose we have input shape a_0,...,a_n and output shape b_0,...,b_k such
612 // that a_0*...*a_i != b_0*...*b_j except with (i,j)=(n,k).
613 // The forward propagation of an input layout depends only on size of axis 0
614 // of the output shape and the mesh dimension axis 0 of input is sharded on:
615 //
616 // 1. In any case we must all to all on any input axis from 1 to n and the
617 // output layout from output axis 1 to k will always be replicated.
618 //
619 // 2. If input axis 0 is replicated, we do a local reshape and set the layout
620 // of the output axis 0 to replicated.
621 //
622 // 3. If input axis 0 is sharded and the number of shards *does not divide*
623 // b_0, then we must all-to-all on input axis 0 (as well as the axis
624 // mentioned in 1) do a local reshape and set the layout of output axis 0
625 // to replicated.
626 //
627 // 4. If input axis 0 is sharded and the number of shards does divide b_0, we
628 // can do a local reshape and set the layout of output axis 0 to the same
629 // mesh dimension as the input layout axis 0.
630 //
631 // Finally if for a general input and output shape, if we partition the input
632 // and output shape into such segments, we can apply the above rule on each
633 // segment. The ComputeReshapeSegments function above computes the starting
634 // and ending index of each segment.
MakeLayoutForReshape(const Layout & input_layout,const llvm::ArrayRef<int64_t> output_shape,llvm::SmallVectorImpl<int64_t> & input_segment_start,llvm::SmallVectorImpl<int64_t> & output_segment_start)635 StatusOr<Layout> MakeLayoutForReshape(
636 const Layout& input_layout, const llvm::ArrayRef<int64_t> output_shape,
637 llvm::SmallVectorImpl<int64_t>& input_segment_start,
638 llvm::SmallVectorImpl<int64_t>& output_segment_start) {
639 std::vector<std::string> layout_specs;
640 layout_specs.reserve(output_shape.size());
641 // Initialy set the layout to be all replicated.
642 for (int i = 0; i < output_shape.size(); ++i)
643 layout_specs.push_back(Layout::kUnshardedDim);
644 // Now process each segment, for each segment if the number of shards on the
645 // first entry of the input segment divides the output shape on the first
646 // entry of the output segment, we request a sharded layout on that axis.
647 for (int i = 0; i < input_segment_start.size(); ++i) {
648 const int num_shards = input_layout.num_shards_for_dim(
649 input_layout.dim(input_segment_start[i]));
650 if (output_shape[output_segment_start[i]] % num_shards == 0)
651 layout_specs[output_segment_start[i]] =
652 input_layout.sharding_spec(input_segment_start[i]);
653 }
654 return Layout::GetLayout(layout_specs, input_layout.mesh());
655 }
656
657 } // namespace
658
659 // TODO(b/171335075): Implement the SPMD for generic Reshape.
ExpandOp(mlir::Operation * op)660 StatusOr<mlir::Operation*> ReshapeSPMDExpander::ExpandOp(mlir::Operation* op) {
661 // Update input/output shape based on the sharding information.
662 TF_ASSIGN_OR_RETURN(auto input_layout,
663 ExtractLayoutFromOperand(op->getOperand(0)));
664 TF_ASSIGN_OR_RETURN(auto output_layout, ExtractSingleLayoutFromOp(op));
665
666 if (!input_layout || !output_layout)
667 return errors::InvalidArgument(
668 "Input and output layouts of Reshape op must be known before SPMD "
669 "expansion.");
670
671 if (input_layout->IsFullyReplicated() && output_layout->IsFullyReplicated())
672 return InferSPMDExpandedLocalShape(op);
673
674 TF_ASSIGN_OR_RETURN(auto global_input_shape,
675 ExtractGlobalInputShape(op->getOpOperand(0)));
676 TF_ASSIGN_OR_RETURN(auto global_output_shape,
677 ExtractGlobalOutputShape(op->getOpResult(0)));
678
679 llvm::SmallVector<int64_t, 4> input_segment_start;
680 llvm::SmallVector<int64_t, 4> input_segment_end;
681 llvm::SmallVector<int64_t, 4> output_segment_start;
682 llvm::SmallVector<int64_t, 4> output_segment_end;
683
684 llvm::SmallVector<int64_t, 4> local_reshape_const;
685
686 // Break up input and output shapes into segments which multiply to the same
687 // number. We will treat each segment seaparately when constructing the input
688 // shape from the output shape and vica versa.
689 ComputeReshapeSegments(global_input_shape, global_output_shape,
690 input_segment_start, input_segment_end,
691 output_segment_start, output_segment_end);
692
693 // Compute the shape for the local reshape op. For each input segment,
694 // 1) Check the sharding status of all dimensions in that segment.
695 // 2) Create entries in the output shape and layout for the segment.
696 //
697 // Also insert the necessary 1 dimensions between input and output segments.
698 //
699 // Currently the algorithm supports Reshape with limited cases.
700 // - For example, reshape a [2, 16] shape tensor with layout ['not_sharded',
701 // 'x'],
702 // - to a [2, 4, 4] shape tensor with layout ['not_sharded', 'x',
703 // 'not_sharded'] does not need cross device data shuffling.
704 // - to a [2, 4, 4] shape tensor with layout ['not_sharded',
705 // 'not_sharded', 'x'] needs cross device AllToAll on the input and
706 // a slice on the output.
707 // - For trivial cases, which AllToAll can support, an AllToAll will be
708 // inserted. For example, reshape a [2, 4, 3] shape tensor with layout
709 // ['not_sharded', 'x', 'not_sharded'] to [2, 12] shape tensor fully
710 // replicated can be supported.
711 std::vector<ShardingSpec> tgt_input_layout(input_layout->rank());
712 std::vector<ShardingSpec> tgt_output_layout(output_layout->rank());
713
714 for (int i = 0; i < input_segment_start.size(); ++i) {
715 const int input_start = input_segment_start[i];
716 const int output_start = output_segment_start[i];
717 const int prev_input_segment_end = (i == 0 ? 0 : input_segment_end[i - 1]);
718 const int prev_output_segment_end =
719 (i == 0 ? 0 : output_segment_end[i - 1]);
720
721 // Between this segment and the last segment, if there is a gap, insert
722 // dimensions of size 1 and kUnshardedDim as output layout dim.
723 for (int j = prev_input_segment_end; j < input_start; ++j)
724 tgt_input_layout[j].set_sharding_spec(Layout::kUnshardedDim);
725 for (int j = prev_output_segment_end; j < output_start; ++j) {
726 local_reshape_const.emplace_back(1);
727 tgt_output_layout[j].set_sharding_spec(Layout::kUnshardedDim);
728 }
729
730 const int num_input_shards =
731 input_layout->num_shards_for_dim(input_layout->dim(input_start));
732
733 // Decide on the sharding of the input for this segment.
734 // If the input is already sharded, we try to keep this sharding (unless
735 // the output size of first output dimension is incompatible).
736 // NOTE: If the input is unsharded in a dimension, and the output is sharded
737 // we could 'preshard' the input on this dimension before the reshape.
738 // This is unlikely to have any major gains in performance.
739 if (global_output_shape[output_start] % num_input_shards != 0) {
740 tgt_input_layout[input_start].set_sharding_spec(Layout::kUnshardedDim);
741 tgt_output_layout[output_start].set_sharding_spec(Layout::kUnshardedDim);
742 local_reshape_const.emplace_back(global_output_shape[output_start]);
743 } else {
744 tgt_input_layout[input_start] = input_layout->dim(input_start);
745 tgt_output_layout[output_start] = input_layout->dim(input_start);
746 local_reshape_const.emplace_back(global_output_shape[output_start] /
747 num_input_shards);
748 }
749
750 for (int j = input_start + 1; j < input_segment_end[i]; ++j)
751 tgt_input_layout[j].set_sharding_spec(Layout::kUnshardedDim);
752 for (int j = output_start + 1; j < output_segment_end[i]; ++j) {
753 local_reshape_const.emplace_back(global_output_shape[j]);
754 tgt_output_layout[j].set_sharding_spec(Layout::kUnshardedDim);
755 }
756 }
757
758 // Fill any remaining dimensions of size 1 and sharding dim on the end of the
759 // layout.
760 for (int j = input_segment_end.back(); j < tgt_input_layout.size(); ++j)
761 tgt_input_layout[j].set_sharding_spec(Layout::kUnshardedDim);
762 for (int j = output_segment_end.back(); j < tgt_output_layout.size(); ++j) {
763 local_reshape_const.emplace_back(1);
764 tgt_output_layout[j].set_sharding_spec(Layout::kUnshardedDim);
765 }
766
767 TF_ASSIGN_OR_RETURN(
768 auto desired_input_layout,
769 Layout::GetLayout(tgt_input_layout, input_layout->mesh()));
770 TF_ASSIGN_OR_RETURN(
771 auto desired_output_layout,
772 Layout::GetLayout(tgt_output_layout, input_layout->mesh()));
773
774 auto reshape_op = mlir::cast<mlir::TF::ReshapeOp>(op);
775 TF_ASSIGN_OR_RETURN(
776 mlir::Value new_input,
777 EmitRelayout(reshape_op.tensor(), *input_layout, desired_input_layout));
778
779 mlir::OpBuilder builder(op);
780
781 // Update shape op to use the local shape as input. Importantly, this updates
782 // the shape attr in the Op, which `InferSPMDExpandedLocalShape` does not
783 // help.
784 auto new_shape = mlir::RankedTensorType::get(
785 {static_cast<int64_t>(local_reshape_const.size())}, builder.getI64Type());
786 auto const_attr =
787 mlir::DenseIntElementsAttr::get(new_shape, local_reshape_const);
788 auto new_reshape_const_op =
789 builder.create<mlir::TF::ConstOp>(DT_LOC(op), const_attr);
790 mlir::TF::ReshapeOp new_reshape_op = builder.create<mlir::TF::ReshapeOp>(
791 op->getLoc(), new_input, new_reshape_const_op);
792
793 TF_ASSIGN_OR_RETURN(auto final_output,
794 EmitRelayout(new_reshape_op.output(),
795 desired_output_layout, *output_layout));
796
797 op->getResult(0).replaceAllUsesWith(final_output);
798 op->erase();
799 return final_output.getDefiningOp();
800 }
801
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)802 StatusOr<llvm::DenseMap<int, Layout>> ReshapeSPMDExpander::ComputeLayoutForward(
803 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
804 if (input_layouts.find(0) == input_layouts.end())
805 return llvm::DenseMap<int, Layout>();
806
807 auto reshape_op = mlir::cast<mlir::TF::ReshapeOp>(op);
808 TF_ASSIGN_OR_RETURN(
809 auto input_shape,
810 GetShapeOfValue(reshape_op.tensor(), /*fail_on_dynamic=*/true));
811 TF_ASSIGN_OR_RETURN(
812 auto output_shape,
813 GetShapeOfValue(reshape_op.output(), /*fail_on_dynamic=*/true));
814
815 llvm::SmallVector<int64_t, 4> input_segment_start;
816 llvm::SmallVector<int64_t, 4> input_segment_end;
817 llvm::SmallVector<int64_t, 4> output_segment_start;
818 llvm::SmallVector<int64_t, 4> output_segment_end;
819
820 // Break up input and output shapes into segments which multiply to the same
821 // number. We will treat each segment seaparately when constructing the input
822 // shape from the output shape and vica versa.
823 ComputeReshapeSegments(input_shape, output_shape, input_segment_start,
824 input_segment_end, output_segment_start,
825 output_segment_end);
826
827 TF_ASSIGN_OR_RETURN(
828 const Layout output_layout,
829 MakeLayoutForReshape(input_layouts.lookup(0), output_shape,
830 input_segment_start, output_segment_start));
831 return llvm::DenseMap<int, Layout>({{0, output_layout}});
832 }
833
834 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)835 ReshapeSPMDExpander::ComputeLayoutBackward(
836 mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
837 if (output_layouts.find(0) == output_layouts.end())
838 return llvm::DenseMap<int, Layout>();
839
840 auto reshape_op = mlir::cast<mlir::TF::ReshapeOp>(op);
841 TF_ASSIGN_OR_RETURN(
842 auto input_shape,
843 GetShapeOfValue(reshape_op.tensor(), /*fail_on_dynamic=*/true));
844 TF_ASSIGN_OR_RETURN(
845 auto output_shape,
846 GetShapeOfValue(reshape_op.output(), /*fail_on_dynamic=*/true));
847
848 llvm::SmallVector<int64_t, 4> input_segment_start;
849 llvm::SmallVector<int64_t, 4> input_segment_end;
850 llvm::SmallVector<int64_t, 4> output_segment_start;
851 llvm::SmallVector<int64_t, 4> output_segment_end;
852
853 // Break up input and output shapes into segments which multiply to the same
854 // number. We will treat each segment seaparately when constructing the input
855 // shape from the output shape and vica versa.
856 ComputeReshapeSegments(input_shape, output_shape, input_segment_start,
857 input_segment_end, output_segment_start,
858 output_segment_end);
859
860 TF_ASSIGN_OR_RETURN(
861 const Layout input_layout,
862 MakeLayoutForReshape(output_layouts.lookup(0), input_shape,
863 output_segment_start, input_segment_start));
864 return llvm::DenseMap<int, Layout>({{0, input_layout}});
865 }
866
ExpandOp(mlir::Operation * op)867 StatusOr<mlir::Operation*> TransposeSPMDExpander::ExpandOp(
868 mlir::Operation* op) {
869 // Currently we only support transpose without shuffling data. When use cases
870 // come, we can add support as we need to figure the best strategy to keep the
871 // cost as low as possible. Before that, add a check with good error message.
872 {
873 TF_ASSIGN_OR_RETURN(auto output_layout, ExtractSingleLayoutFromOp(op));
874 TF_ASSIGN_OR_RETURN(auto operand_layout,
875 ExtractLayoutFromOperand(op->getOperand(0)));
876
877 if (!output_layout)
878 return errors::InvalidArgument(
879 "output layout of TransposeOp must be known before SPMD expansion.");
880 if (!operand_layout)
881 return errors::InvalidArgument(
882 "operand layout of TransposeOp must be known before SPMD expansion.");
883
884 auto transpose = mlir::cast<mlir::TF::TransposeOp>(op);
885 llvm::SmallVector<int64, 4> perm;
886 TF_RETURN_IF_ERROR(ExtractConstVectorFromValue(transpose.perm(), &perm));
887
888 for (const auto& p : llvm::enumerate(perm)) {
889 if (operand_layout->dim(p.value()).sharding_spec() !=
890 output_layout->dim(p.index()).sharding_spec()) {
891 return errors::InvalidArgument(
892 "TransposeOp SPMD needs communication is not supported yet. \n "
893 "operand layout: ",
894 operand_layout->ToString(),
895 "\n output layout: ", output_layout->ToString());
896 }
897 }
898 }
899
900 // Do nothing but infer local shape for now.
901 return InferSPMDExpandedLocalShape(op);
902 }
903
904 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)905 TransposeSPMDExpander::ComputeLayoutForward(
906 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
907 if (input_layouts.find(0) == input_layouts.end())
908 return llvm::DenseMap<int, Layout>();
909
910 auto transpose = mlir::cast<mlir::TF::TransposeOp>(op);
911 llvm::SmallVector<int64, 4> perm;
912 TF_RETURN_IF_ERROR(ExtractConstVectorFromValue(transpose.perm(), &perm));
913
914 const Layout input_layout = input_layouts.lookup(0);
915 std::vector<std::string> output_layout_specs;
916 for (int64 p : perm)
917 output_layout_specs.push_back(input_layout.sharding_spec(p));
918
919 TF_ASSIGN_OR_RETURN(
920 const Layout output_layout,
921 Layout::GetLayout(output_layout_specs, input_layout.mesh()));
922 return llvm::DenseMap<int, Layout>({{0, output_layout}});
923 }
924
925 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)926 TransposeSPMDExpander::ComputeLayoutBackward(
927 mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
928 auto transpose = mlir::cast<mlir::TF::TransposeOp>(op);
929 llvm::SmallVector<int64, 4> perm;
930 TF_RETURN_IF_ERROR(ExtractConstVectorFromValue(transpose.perm(), &perm));
931 TF_ASSIGN_OR_RETURN(const Mesh mesh, ExtractDeviceMeshEnclosingCluster(op));
932
933 llvm::DenseMap<int, Layout> input_layouts(transpose->getNumOperands());
934 input_layouts[1] = Layout::ReplicatedOnMesh(mesh, /*rank=*/1);
935
936 if (output_layouts.find(0) != output_layouts.end()) {
937 const Layout output_layout = output_layouts.lookup(0);
938
939 llvm::SmallVector<int64, 4> inverse_perm(perm.size());
940 for (const auto& p : llvm::enumerate(perm)) {
941 inverse_perm[p.value()] = p.index();
942 }
943
944 std::vector<std::string> input_layout_specs;
945 // For example, if perm [2, 0, 1], then inverse perm is [1, 2, 0].
946 // So for input_dim[i] it is output[reverse_perm[i]]
947 for (auto dim_in_output : inverse_perm)
948 input_layout_specs.push_back(output_layout.sharding_spec(dim_in_output));
949
950 TF_ASSIGN_OR_RETURN(const Layout input_layout,
951 Layout::GetLayout(input_layout_specs, mesh));
952 input_layouts[0] = input_layout;
953 }
954
955 return input_layouts;
956 }
957
958 namespace {
959
RelayoutOneHotInput(const absl::optional<Layout> & input_layout,const absl::optional<Layout> & output_layout,const int axis,mlir::TF::OneHotOp & one_hot)960 Status RelayoutOneHotInput(const absl::optional<Layout>& input_layout,
961 const absl::optional<Layout>& output_layout,
962 const int axis, mlir::TF::OneHotOp& one_hot) {
963 if (!input_layout || !output_layout)
964 return errors::InvalidArgument(
965 "layout for tf.OneHot operation inputs and outputs must be known before"
966 " SPMD expansion. Consider adding Relayout() op to specify the "
967 "layout.");
968
969 std::vector<ShardingSpec> sharding_specs(input_layout->rank());
970 for (int i = 0; i < input_layout->rank(); ++i) {
971 if (i < axis)
972 sharding_specs[i] = output_layout->dim(i);
973 else
974 sharding_specs[i] = output_layout->dim(i + 1);
975 }
976 TF_ASSIGN_OR_RETURN(const Layout new_input_layout,
977 Layout::GetLayout(sharding_specs, input_layout->mesh()));
978
979 TF_ASSIGN_OR_RETURN(
980 mlir::Value new_input,
981 EmitRelayout(one_hot.indices(), *input_layout, new_input_layout));
982
983 one_hot->setOperand(0, new_input);
984
985 return OkStatus();
986 }
987
988 } // namespace
989
ExpandOp(mlir::Operation * op)990 StatusOr<mlir::Operation*> OneHotSPMDExpander::ExpandOp(mlir::Operation* op) {
991 auto one_hot_op = llvm::cast<mlir::TF::OneHotOp>(op);
992
993 mlir::OpBuilder builder(op);
994 TF_ASSIGN_OR_RETURN(const auto input_layout,
995 ExtractLayoutFromOperand(one_hot_op->getOperand(0)));
996
997 TF_ASSIGN_OR_RETURN(const auto output_layout,
998 ExtractSingleLayoutFromOp(one_hot_op));
999 int axis = one_hot_op.axisAttr().getInt();
1000 if (axis == -1) axis = output_layout->rank() - 1;
1001
1002 // For tf.OneHot, relayout input so that it matches the output layout (outside
1003 // of the one hot dimension).
1004 TF_RETURN_IF_ERROR(
1005 RelayoutOneHotInput(input_layout, output_layout, axis, one_hot_op));
1006
1007 const int num_shards = output_layout->num_shards()[axis];
1008 const auto depth = ExtractConstIntFromValue(one_hot_op.depth());
1009 const bool depth_statically_divisible_by_sharding =
1010 (depth.ok() && (*depth) % num_shards == 0);
1011
1012 // If axis dimension of tf.OneHot is sharded and number of shards evenly
1013 // divides the `depth` input of the one hot operations, we can mutate the
1014 // `depth` and parameter and `indices` parameter to calculate local tensor
1015 // directly.
1016 const std::string& mesh_dim_name = output_layout->sharding_spec(axis);
1017
1018 if (mesh_dim_name != Layout::kUnshardedDim) {
1019 if (!depth_statically_divisible_by_sharding)
1020 return errors::InvalidArgument(
1021 "OneHot axis dimension is sharded with incorrect layout. OneHot op "
1022 "depth should be evenly divisible by number of shards.");
1023
1024 // Recalculate new local depth. Namely: new_depth = depth / num_shards
1025 mlir::Value new_depth = CreateIntScalarConst((*depth) / num_shards, builder,
1026 one_hot_op->getLoc(), false);
1027
1028 // Calculate shard id at mesh dimension for the sharded axis.
1029 TF_ASSIGN_OR_RETURN(const Mesh mesh,
1030 ExtractDeviceMeshEnclosingCluster(one_hot_op));
1031 mlir::tf_device::ClusterOp cluster =
1032 one_hot_op->getParentOfType<mlir::tf_device::ClusterOp>();
1033
1034 // `mesh_coordinates` is tensor of size [1, num mesh dimensions] where each
1035 // element in the tensor refers to shard id for the specified mesh
1036 // dimension.
1037 TF_ASSIGN_OR_RETURN(mlir::Value mesh_coordinates,
1038 GetMeshCoordinatesFromCluster(cluster));
1039 const int num_mesh_dimensions = output_layout->mesh().dims().size();
1040 llvm::SmallVector<int32_t, 4> multiplier(num_mesh_dimensions);
1041 const int mesh_dim_index =
1042 output_layout->mesh().GetMeshDimIndexWithName(mesh_dim_name);
1043
1044 mlir::TF::SliceOp selected_sharding_at_dimension = builder.create<
1045 mlir::TF::SliceOp>(
1046 one_hot_op.getLoc(),
1047 mlir::RankedTensorType::get({1, 1}, mesh_coordinates.getType()
1048 .cast<mlir::TensorType>()
1049 .getElementType()),
1050 /*input=*/mesh_coordinates,
1051 /*begin=*/IntConst(builder, one_hot_op.getLoc(), {0, mesh_dim_index}),
1052 /*size=*/IntConst(builder, one_hot_op.getLoc(), {1, 1}));
1053
1054 // Reshape the sliced shape (1,1) tensor to shape 0 scalar.
1055 auto scalar_size_type =
1056 mlir::RankedTensorType::get({}, builder.getIntegerType(32));
1057 mlir::Value scalar_shape = mlir::TF::collection_ops_util::GetR1Const(
1058 scalar_size_type.getShape(), builder, one_hot_op->getLoc());
1059 mlir::Value selected_sharding_scalar_value =
1060 builder.create<mlir::TF::ReshapeOp>(
1061 one_hot_op.getLoc(), mlir::ArrayRef<mlir::Type>{scalar_size_type},
1062 mlir::ArrayRef<mlir::Value>{selected_sharding_at_dimension.output(),
1063 scalar_shape},
1064 mlir::ArrayRef<mlir::NamedAttribute>{});
1065
1066 // `new_indices` = `original_indices` - `selected_sharding_scalar_value` *
1067 // (depth/num_shards)
1068 mlir::Value id_offset = builder.create<mlir::TF::MulOp>(
1069 one_hot_op->getLoc(), new_depth, selected_sharding_scalar_value);
1070 mlir::Value original_indices = one_hot_op.indices();
1071 mlir::Value new_indices = builder.create<mlir::TF::SubOp>(
1072 one_hot_op->getLoc(), original_indices, id_offset);
1073
1074 // Replace onehot operation inputs with mutated `new_depth` and `new_input`
1075 // tensors so that local tensors can be calculated directly without
1076 // calculating intermediate global tensors.
1077 one_hot_op->getOpOperand(0).set(new_indices);
1078 one_hot_op->getOpOperand(1).set(new_depth);
1079 }
1080 return InferSPMDExpandedLocalShape(one_hot_op);
1081 }
1082
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)1083 StatusOr<llvm::DenseMap<int, Layout>> OneHotSPMDExpander::ComputeLayoutForward(
1084 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
1085 if (input_layouts.find(0) == input_layouts.end())
1086 return llvm::DenseMap<int, Layout>();
1087
1088 auto one_hot = mlir::dyn_cast<mlir::TF::OneHotOp>(op);
1089 int axis = one_hot.axis();
1090 if (axis == -1) axis = ValueRank(one_hot.indices());
1091 TF_ASSIGN_OR_RETURN(auto mesh, ExtractDeviceMeshEnclosingCluster(op));
1092
1093 const Layout indices_layout = input_layouts.lookup(0);
1094 std::vector<std::string> output_layout_specs;
1095 for (int i = 0; i < indices_layout.rank(); ++i) {
1096 // Insert an onehot dimension for expanded axis.
1097 if (i == axis) {
1098 output_layout_specs.push_back(Layout::kUnshardedDim);
1099 }
1100 output_layout_specs.push_back(indices_layout.sharding_spec(i));
1101 }
1102 if (axis == indices_layout.rank() || axis == -1) {
1103 output_layout_specs.push_back(Layout::kUnshardedDim);
1104 }
1105
1106 TF_ASSIGN_OR_RETURN(auto output_layout,
1107 Layout::GetLayout(output_layout_specs, mesh));
1108 return llvm::DenseMap<int, Layout>({{0, output_layout}});
1109 }
1110
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)1111 StatusOr<llvm::DenseMap<int, Layout>> OneHotSPMDExpander::ComputeLayoutBackward(
1112 mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
1113 auto one_hot = mlir::dyn_cast<mlir::TF::OneHotOp>(op);
1114 int axis = one_hot.axis();
1115 if (axis == -1) axis = ValueRank(one_hot.indices());
1116 TF_ASSIGN_OR_RETURN(auto mesh, ExtractDeviceMeshEnclosingCluster(op));
1117
1118 llvm::DenseMap<int, Layout> input_layouts(one_hot->getNumOperands());
1119 const auto scalar_replicated_layout =
1120 Layout::ReplicatedOnMesh(mesh, /*rank=*/0);
1121 input_layouts[1] = scalar_replicated_layout; // depth
1122 input_layouts[2] = scalar_replicated_layout; // on_value
1123 input_layouts[3] = scalar_replicated_layout; // off_value
1124
1125 // If output layout is specified, then propagate all dimensions (except axis
1126 // dimension) as operand layout.
1127 if (output_layouts.find(0) != output_layouts.end()) {
1128 const Layout output_layout = output_layouts.lookup(0);
1129
1130 std::vector<std::string> indices_layout_specs;
1131 for (int i = 0; i < output_layout.rank(); ++i) {
1132 if (i == axis) continue;
1133 indices_layout_specs.push_back(output_layout.sharding_spec(i));
1134 }
1135
1136 TF_ASSIGN_OR_RETURN(auto input_layout,
1137 Layout::GetLayout(indices_layout_specs, mesh));
1138 input_layouts[0] = input_layout;
1139 }
1140
1141 return input_layouts;
1142 }
1143
ExpandOp(mlir::Operation * op)1144 StatusOr<mlir::Operation*> ShapeSPMDExpander::ExpandOp(mlir::Operation* op) {
1145 TF_ASSIGN_OR_RETURN(auto result_layouts, ExtractLayoutFromOp(op));
1146 for (const auto& layout : result_layouts) {
1147 if (!layout.has_value())
1148 return errors::Internal(
1149 "All op result layouts must be specified for SPMD expansion.");
1150
1151 if (!layout->IsFullyReplicated()) {
1152 return errors::Internal(
1153 "Shape/Rank ops must output value with replicated layout.");
1154 }
1155 }
1156 InferSPMDExpandedLocalShape(op);
1157
1158 // DTensors shards are always full rank -- local rank == global rank
1159 if (mlir::isa<mlir::TF::RankOp>(op)) return op;
1160
1161 // We have Shape/ShapeN op.
1162
1163 // Find enclosing device_cluster op and update attributes for it if the
1164 // shape op result is returned to the cluster.
1165 auto enclosing_cluster = op->getParentOfType<mlir::tf_device::ClusterOp>();
1166 if (!enclosing_cluster)
1167 return errors::InvalidArgument(
1168 "Error during SPMD expansion of Shape op. Op must be enclosed in a "
1169 "device_cluster.");
1170
1171 // Record output result index -> input_layout mapping.
1172 llvm::SmallVector<std::string, 4> input_layouts;
1173 std::vector<int> return_indices;
1174
1175 // For each operand, extract global shape if necessary. If global shape
1176 // transformation is needed, and the transformed shape is returned to
1177 // outside of the device cluster, also attach input layout as additional
1178 // information so that future stack could infer local shape from the result.
1179 llvm::SmallVector<mlir::TF::MulOp, 4> output_ops;
1180 for (int i = 0; i < op->getNumOperands(); ++i) {
1181 // Fetch layout from _input_, not current op.
1182 TF_ASSIGN_OR_RETURN(auto input_layout,
1183 ExtractLayoutFromOperand(op->getOperand(i)));
1184 if (!input_layout)
1185 return errors::InvalidArgument(
1186 "Input layout to shape op must be known before SPMD expansion.");
1187
1188 // Fully replicated tensors: local shape = global shape.
1189 if (input_layout->IsFullyReplicated()) {
1190 continue;
1191 }
1192
1193 // If a DTensor is sharded over a dimension, shards have equal size.
1194 // GlobalShape[Dim] = LocalShape[Dim] * NumShards[Dim]
1195 mlir::OpBuilder builder(op->getBlock(), ++mlir::Block::iterator(op));
1196 auto num_shards =
1197 IntConst(builder, op->getLoc(), input_layout->num_shards());
1198 auto global_shape = builder.create<mlir::TF::MulOp>(
1199 op->getLoc(), op->getResult(i), num_shards);
1200
1201 op->getResult(i).replaceAllUsesExcept(
1202 global_shape.getResult(),
1203 llvm::SmallPtrSet<mlir::Operation*, 1>{global_shape});
1204
1205 // Find the returned global shape and attach Metadata information.
1206 for (auto& use : global_shape.getOperation()->getUses()) {
1207 if (use.getOwner() == enclosing_cluster.GetBody().getTerminator()) {
1208 input_layouts.emplace_back(input_layout->ToString());
1209 return_indices.emplace_back(use.getOperandNumber());
1210 break;
1211 }
1212 }
1213 output_ops.emplace_back(global_shape);
1214 }
1215
1216 // Attach original input for the enclosing device_cluster op.
1217 if (!input_layouts.empty()) {
1218 mlir::OpBuilder builder(op);
1219 enclosing_cluster->setAttr(kShapeOpInputLayoutIndices,
1220 builder.getI32VectorAttr(return_indices));
1221 enclosing_cluster->setAttr(
1222 kShapeOpInputLayout,
1223 builder.getStrArrayAttr(llvm::SmallVector<llvm::StringRef, 4>(
1224 input_layouts.begin(), input_layouts.end())));
1225 }
1226
1227 // TODO(hthu): Support multiple returns in ShapeN op.
1228 return output_ops.empty() ? op : output_ops[0].getOperation();
1229 }
1230
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)1231 StatusOr<llvm::DenseMap<int, Layout>> ShapeSPMDExpander::ComputeLayoutForward(
1232 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
1233 assert(op->getNumResults() == 1);
1234 TF_ASSIGN_OR_RETURN(auto mesh, ExtractDeviceMeshEnclosingCluster(op));
1235 const int output_rank = ValueRank(op->getResult(0));
1236 return llvm::DenseMap<int, Layout>(
1237 {{0, Layout::ReplicatedOnMesh(mesh, output_rank)}});
1238 }
1239
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)1240 StatusOr<llvm::DenseMap<int, Layout>> ShapeSPMDExpander::ComputeLayoutBackward(
1241 mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
1242 return llvm::DenseMap<int, Layout>();
1243 }
1244
1245 } // namespace dtensor
1246 } // namespace tensorflow
1247