xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/expansions/meta_spmd_expander.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/dtensor/mlir/expansions/meta_spmd_expander.h"
17 
18 #include <cstdint>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/strings/str_join.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
30 #include "mlir/IR/Builders.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
32 #include "mlir/IR/Matchers.h"  // from @llvm-project
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
35 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
36 #include "tensorflow/compiler/mlir/utils/array_container_utils.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/dtensor/cc/constants.h"
39 #include "tensorflow/dtensor/cc/dstatus.h"
40 #include "tensorflow/dtensor/mlir/collectives.h"
41 #include "tensorflow/dtensor/mlir/dtensor_location.h"
42 #include "tensorflow/dtensor/mlir/layout_parsing.h"
43 #include "tensorflow/dtensor/mlir/spmd_expander.h"
44 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
45 #include "tensorflow/dtensor/mlir/value_utils.h"
46 
47 namespace tensorflow {
48 namespace dtensor {
49 namespace {
50 
51 // Validates `axis` for pack/unpack and resolves negative values.
52 //
53 // Returns a valid positive axis or an error.
CanonicalizeAxis(int axis,int packed_rank)54 StatusOr<int> CanonicalizeAxis(int axis, int packed_rank) {
55   // Axis can be in range [-packed_rank, packed_rank), so we add packed_rank
56   // to wrap it around.
57   if (axis >= -packed_rank && axis < 0) {
58     axis += packed_rank;
59   } else if (axis < -packed_rank || axis >= packed_rank) {
60     return errors::InvalidArgument(
61         "Invalid axis; expected a value in [-packed_rank, packed_rank)");
62   }
63   return axis;
64 }
65 
66 // Implements, for pack or unpack, layout propagation from a suggested layout
67 // for the packed tensor to suggested layouts for the unpacked tensors.
LayoutsFromPackedTensor(int axis,const Layout & packed_layout,size_t num_unpacked_tensors)68 StatusOr<llvm::DenseMap<int, Layout>> LayoutsFromPackedTensor(
69     int axis, const Layout& packed_layout, size_t num_unpacked_tensors) {
70   TF_ASSIGN_OR_RETURN(axis,
71                       CanonicalizeAxis(axis,
72                                        /*packed_rank=*/packed_layout.rank()));
73   const Layout unpacked_layout =
74       packed_layout.GetLayoutWithReducedDims({axis}, false);
75   llvm::DenseMap<int, Layout> layouts(num_unpacked_tensors);
76   for (int i = 0; i < num_unpacked_tensors; ++i) {
77     layouts[i] = unpacked_layout;
78   }
79   return layouts;
80 }
81 
82 // Implements, for pack or unpack, layout propagation from suggested layouts for
83 // the unpacked tensors to a suggested layout for the packed tensor.
LayoutFromUnpackedTensors(int axis,const llvm::DenseMap<int,Layout> & unpacked_layouts)84 StatusOr<llvm::DenseMap<int, Layout>> LayoutFromUnpackedTensors(
85     int axis, const llvm::DenseMap<int, Layout>& unpacked_layouts) {
86   if (unpacked_layouts.empty()) return llvm::DenseMap<int, Layout>();
87 
88   auto it = unpacked_layouts.begin();
89   const Layout& first_layout = it->getSecond();
90   const Mesh& mesh = first_layout.mesh();
91 
92   // Record the mesh and rank of the first input layout that exists.
93   // The rank + mesh for others will be the same.
94   const int unpacked_rank = first_layout.rank();
95   TF_ASSIGN_OR_RETURN(axis,
96                       CanonicalizeAxis(axis,
97                                        /*packed_rank=*/unpacked_rank + 1));
98 
99   std::vector<std::string> inferred_packed_layout_specs;
100   for (int rank_index = 0; rank_index <= unpacked_rank; ++rank_index) {
101     if (rank_index == axis)
102       inferred_packed_layout_specs.push_back(Layout::kUnshardedDim);
103     if (rank_index == unpacked_rank) {
104       break;
105     }
106     // When we have multiple input with conflicting shardings, set that
107     // dimension to replicated (aka unsharded).
108     std::string dimension = Layout::kUnshardedDim;
109     bool found_sharded_dim = false;
110     for (; it != unpacked_layouts.end(); ++it) {
111       const std::string& sharding_spec =
112           it->getSecond().sharding_spec(rank_index);
113       if (!Layout::IsUnshardedDimension(sharding_spec)) {
114         if (!found_sharded_dim) {
115           found_sharded_dim = true;
116           dimension = sharding_spec;
117         } else if (sharding_spec != dimension) {
118           dimension = Layout::kUnshardedDim;
119         }
120       }
121     }
122     inferred_packed_layout_specs.push_back(dimension);
123   }
124   TF_ASSIGN_OR_RETURN(auto inferred_packed_layout,
125                       Layout::GetLayout(inferred_packed_layout_specs, mesh));
126   return llvm::DenseMap<int, Layout>({{0, inferred_packed_layout}});
127 }
128 
129 }  // namespace
130 
ExpandOp(mlir::Operation * op)131 StatusOr<mlir::Operation*> PackSPMDExpander::ExpandOp(mlir::Operation* op) {
132   auto pack = llvm::cast<mlir::TF::PackOp>(op);
133   TF_ASSIGN_OR_RETURN(const absl::optional<Layout> output_layout,
134                       ExtractSingleLayoutFromOp(op));
135 
136   const int output_rank = ValueRank(pack.output());
137   if (output_rank == -1)
138     return errors::Unimplemented("output must have a rank");
139 
140   TF_ASSIGN_OR_RETURN(
141       int axis, CanonicalizeAxis(pack.axis(), /*packed_rank=*/output_rank));
142 
143   // TODO(bfontain): This may not be the best, but for now relayout all inputs
144   // to match the output layout. E.g. if the output layout is not but the input
145   // is, this would force a AllConcat on all inputs, rather than first packing
146   // and emitting one AllConcat.
147   const Layout new_input_layout =
148       output_layout->GetLayoutWithReducedDims({axis}, /*keep_dims=*/false);
149 
150   for (int i = 0; i < op->getNumOperands(); ++i) {
151     TF_ASSIGN_OR_RETURN(const absl::optional<Layout> layout,
152                         ExtractLayoutFromOperand(pack.getOperand(i)));
153     if (!layout) return errors::InvalidArgument("missing layout for input ", i);
154 
155     TF_ASSIGN_OR_RETURN(
156         mlir::Value new_input,
157         EmitRelayout(pack.getOperand(i), *layout, new_input_layout));
158 
159     pack.setOperand(i, new_input);
160   }
161 
162   return InferSPMDExpandedLocalShape(op);
163 }
164 
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)165 StatusOr<llvm::DenseMap<int, Layout>> PackSPMDExpander::ComputeLayoutForward(
166     mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
167   auto pack = llvm::cast<mlir::TF::PackOp>(op);
168   const int axis = pack.axis();
169   return LayoutFromUnpackedTensors(axis, input_layouts);
170 }
171 
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)172 StatusOr<llvm::DenseMap<int, Layout>> PackSPMDExpander::ComputeLayoutBackward(
173     mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
174   if (output_layouts.find(0) == output_layouts.end())
175     return llvm::DenseMap<int, Layout>();
176 
177   auto pack = llvm::cast<mlir::TF::PackOp>(op);
178   const int axis = pack.axis();
179   return LayoutsFromPackedTensor(axis, output_layouts.lookup(0),
180                                  pack->getNumOperands());
181 }
182 
ExpandOp(mlir::Operation * op)183 StatusOr<mlir::Operation*> UnpackSPMDExpander::ExpandOp(mlir::Operation* op) {
184   auto unpack = llvm::cast<mlir::TF::UnpackOp>(op);
185   TF_ASSIGN_OR_RETURN(const absl::optional<Layout> input_layout,
186                       ExtractLayoutFromOperand(unpack.getOperand()));
187   if (!input_layout) {
188     return errors::Unimplemented("input must have a layout");
189   }
190 
191   const int input_rank = ValueRank(unpack.getOperand());
192   if (input_rank == -1) {
193     return errors::Unimplemented("input must have a rank");
194   }
195 
196   TF_ASSIGN_OR_RETURN(
197       int axis, CanonicalizeAxis(unpack.axis(), /*packed_rank=*/input_rank));
198 
199   if (input_layout->num_shards_for_dim(input_layout->dim(axis)) != 1) {
200     // If the axis being unpacked is sharded, relayout to replicated along that
201     // axis since each device needs to split across it.
202     std::vector<ShardingSpec> new_layout_specs(input_rank);
203     for (int input_index = 0; input_index < input_rank; ++input_index) {
204       if (input_index == axis) {
205         new_layout_specs[input_index].set_sharding_spec(Layout::kUnshardedDim);
206       } else {
207         new_layout_specs[input_index] = input_layout->dim(input_index);
208       }
209     }
210     TF_ASSIGN_OR_RETURN(
211         Layout new_input_layout,
212         Layout::GetLayout(std::move(new_layout_specs), input_layout->mesh()));
213     TF_ASSIGN_OR_RETURN(
214         mlir::Value new_input,
215         EmitRelayout(unpack.getOperand(), *input_layout, new_input_layout));
216     unpack.setOperand(new_input);
217   }
218   return InferSPMDExpandedLocalShape(op);
219 }
220 
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)221 StatusOr<llvm::DenseMap<int, Layout>> UnpackSPMDExpander::ComputeLayoutForward(
222     mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
223   if (input_layouts.find(0) == input_layouts.end())
224     return llvm::DenseMap<int, Layout>();
225 
226   auto unpack = llvm::cast<mlir::TF::UnpackOp>(op);
227   const int axis = unpack.axis();
228   return LayoutsFromPackedTensor(axis, input_layouts.lookup(0),
229                                  unpack->getNumResults());
230 }
231 
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)232 StatusOr<llvm::DenseMap<int, Layout>> UnpackSPMDExpander::ComputeLayoutBackward(
233     mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
234   auto unpack = llvm::cast<mlir::TF::UnpackOp>(op);
235   const int axis = unpack.axis();
236   return LayoutFromUnpackedTensors(axis, output_layouts);
237 }
238 
239 namespace {
240 
VerifyPaddedDimensionNotSharded(const Layout & layout,mlir::Value pad_input,mlir::Value pad_output)241 Status VerifyPaddedDimensionNotSharded(const Layout& layout,
242                                        mlir::Value pad_input,
243                                        mlir::Value pad_output) {
244   auto input_type = pad_input.getType().dyn_cast<mlir::RankedTensorType>();
245   auto output_type = pad_output.getType().dyn_cast<mlir::RankedTensorType>();
246   if (!input_type || !output_type)
247     return errors::InvalidArgument(
248         "pad op input/output should have statically known shape for SPMD.");
249 
250   const auto input_shape = input_type.getShape();
251   const auto output_shape = input_type.getShape();
252   for (const auto& dim_shard_and_index :
253        llvm::enumerate(layout.sharding_specs())) {
254     const int index = dim_shard_and_index.index();
255     const auto& tensor_dimension = dim_shard_and_index.value();
256     const int input_shape_for_dim = input_shape[index];
257     const int output_shape_for_dim = output_shape[index];
258     if ((input_shape_for_dim == -1 || output_shape_for_dim == -1 ||
259          output_shape_for_dim != input_shape_for_dim) &&
260         layout.num_shards_for_dim(tensor_dimension) > 1) {
261       return errors::InvalidArgument(
262           "Padding over sharded dimension is not allowed.");
263     }
264   }
265   return OkStatus();
266 }
267 
268 }  // namespace
269 
ExpandOp(mlir::Operation * op)270 StatusOr<mlir::Operation*> PadSPMDExpander::ExpandOp(mlir::Operation* op) {
271   // TODO(b/170666884): Implement sharded SPMD logic for tf.Pad op.
272   TF_ASSIGN_OR_RETURN(auto op_layout, ExtractSingleLayoutFromOp(op));
273   auto pad_input = op->getOperand(0);
274   auto pad_output = op->getResult(0);
275 
276   TF_ASSIGN_OR_RETURN(auto input_layout, ExtractLayoutFromOperand(pad_input));
277   assert(input_layout && op_layout);
278 
279   if (op_layout != input_layout)
280     return errors::Unimplemented(
281         "pad op with input layout different from op output layout is not yet "
282         "supported.");
283 
284   TF_RETURN_IF_ERROR(
285       VerifyPaddedDimensionNotSharded(*op_layout, pad_input, pad_output));
286   return InferSPMDExpandedLocalShape(op);
287 }
288 
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)289 StatusOr<llvm::DenseMap<int, Layout>> PadSPMDExpander::ComputeLayoutForward(
290     mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
291   if (input_layouts.find(0) == input_layouts.end())
292     return llvm::DenseMap<int, Layout>();
293 
294   const Layout input_layout = input_layouts.lookup(0);
295   mlir::Value pad_input;
296   mlir::Value pad_output;
297 
298   if (auto pad_v2 = llvm::dyn_cast<mlir::TF::PadV2Op>(op)) {
299     pad_output = pad_v2.output();
300     pad_input = pad_v2.input();
301   } else {
302     auto pad_op = llvm::cast<mlir::TF::PadOp>(op);
303     pad_output = pad_op.output();
304     pad_input = pad_op.input();
305   }
306 
307   TF_RETURN_IF_ERROR(
308       VerifyPaddedDimensionNotSharded(input_layout, pad_input, pad_output));
309   return llvm::DenseMap<int, Layout>({{0, input_layout}});
310 }
311 
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)312 StatusOr<llvm::DenseMap<int, Layout>> PadSPMDExpander::ComputeLayoutBackward(
313     mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
314   TF_ASSIGN_OR_RETURN(auto mesh, ExtractDeviceMeshEnclosingCluster(op));
315   mlir::Value pad_input;
316   mlir::Value pad_output;
317 
318   llvm::DenseMap<int, Layout> input_layouts(op->getNumOperands());
319   // Pad op `padding` operand always has rank 2 tensor.
320   input_layouts[1] = Layout::ReplicatedOnMesh(mesh, /*rank=*/2);
321 
322   if (auto pad_v2 = llvm::dyn_cast<mlir::TF::PadV2Op>(op)) {
323     pad_output = pad_v2.output();
324     pad_input = pad_v2.input();
325     // `constant_values` operand
326     input_layouts[2] = Layout::ReplicatedOnMesh(mesh, /*rank=*/0);
327   } else {
328     auto pad_op = llvm::cast<mlir::TF::PadOp>(op);
329     pad_output = pad_op.output();
330     pad_input = pad_op.input();
331   }
332 
333   if (output_layouts.find(0) != output_layouts.end()) {
334     const Layout output_layout = output_layouts.lookup(0);
335     TF_RETURN_IF_ERROR(
336         VerifyPaddedDimensionNotSharded(output_layout, pad_input, pad_output));
337     // `input` operand
338     input_layouts[0] = output_layout;
339   }
340   return input_layouts;
341 }
342 
343 namespace {
344 
VerifyTileOperandLayout(const Layout & operand_layout,llvm::ArrayRef<int64_t> static_multiples)345 Status VerifyTileOperandLayout(const Layout& operand_layout,
346                                llvm::ArrayRef<int64_t> static_multiples) {
347   for (const auto& tensor_dim_and_multiple :
348        llvm::zip(operand_layout.sharding_specs(), static_multiples)) {
349     const auto& tensor_dimension = std::get<0>(tensor_dim_and_multiple);
350     const int64_t multiple_factor = std::get<1>(tensor_dim_and_multiple);
351     if (multiple_factor > 1 &&
352         operand_layout.num_shards_for_dim(tensor_dimension) > 1)
353       return errors::InvalidArgument(
354           "tile op with input sharded at dimension where `multiple` > 1 is not "
355           "supported.");
356   }
357   return OkStatus();
358 }
359 
360 }  // namespace
361 
ExpandOp(mlir::Operation * op)362 StatusOr<mlir::Operation*> TileSPMDExpander::ExpandOp(mlir::Operation* op) {
363   auto tile_op = llvm::cast<mlir::TF::TileOp>(op);
364   // After layout propagation, tile op should already have the proper output
365   // layout tagged on itself.
366   TF_ASSIGN_OR_RETURN(absl::optional<Layout> output_layout,
367                       ExtractSingleLayoutFromOp(op));
368   if (!output_layout)
369     return errors::InvalidArgument(
370         "TileOP doesn't have a layout after layout propagation");
371 
372   TF_ASSIGN_OR_RETURN(absl::optional<Layout> operand_layout,
373                       ExtractLayoutFromOperand(tile_op.input()));
374   if (!operand_layout)
375     return errors::InvalidArgument(
376         "Input operand to TileOp doesn't have a layout after layout "
377         "propagation.");
378 
379   if (operand_layout->IsFullyReplicated() &&
380       output_layout->IsFullyReplicated()) {
381     // There's nothing to do; we can avoid some unimplemented cases.
382     return InferSPMDExpandedLocalShape(op);
383   }
384 
385   llvm::SmallVector<int64_t, 4> static_multiples;
386   auto status =
387       ExtractConstVectorFromValue(tile_op.multiples(), &static_multiples);
388   if (!status.ok())
389     return errors::Unimplemented(
390         "Tile with a sharded output is not implemented for dynamic "
391         "`multiples`.");
392 
393   // If `multiples` values can be statically known, verify that all dimensions
394   // with `multiples` > 1 is replicated.
395   TF_RETURN_IF_ERROR(
396       VerifyTileOperandLayout(*operand_layout, static_multiples));
397 
398   llvm::SmallVector<int, 4> local_tile_multiples;
399   std::vector<int32> operand_shards = operand_layout->num_shards();
400   std::vector<int32> output_shards = output_layout->num_shards();
401   if (operand_shards.size() != output_shards.size()) {
402     return errors::InvalidArgument(
403         "Expected inputs and outputs to have the same rank.");
404   }
405 
406   for (int dim_index = 0; dim_index < operand_shards.size(); ++dim_index) {
407     if (static_multiples[dim_index] == 1) {
408       local_tile_multiples.push_back(static_multiples[dim_index]);
409       continue;
410     }
411     if (output_shards[dim_index] > static_multiples[dim_index])
412       // TODO(b/161012891): Split the input to support sharding the output
413       // more than `multiples` ways.
414       return errors::Unimplemented(
415           "Sharding the output of Tile into more than `multiples` shards is "
416           "not currently supported.");
417     if (static_multiples[dim_index] % output_shards[dim_index] != 0)
418       return errors::Unimplemented(
419           "The output sharding of Tile must evenly divide `multiples`.");
420     if (!Layout::IsUnshardedDimension(
421             operand_layout->sharding_spec(dim_index)) &&
422         (Layout::IsUnshardedDimension(
423              output_layout->sharding_spec(dim_index)) ||
424          (operand_layout->sharding_spec(dim_index) !=
425           output_layout->sharding_spec(dim_index))))
426       return errors::Unimplemented(
427           "Input is replicated on tensor dimension ", dim_index,
428           " but the "
429           "output is not replicated or is replicated on a different mesh "
430           "dimension.");
431     local_tile_multiples.push_back(static_multiples[dim_index] /
432                                    output_shards[dim_index]);
433   }
434   mlir::OpBuilder builder(op);
435   auto location = DT_LOC(tile_op.getLoc());
436   auto multiples_const = IntConst(builder, location, local_tile_multiples);
437 
438   auto global_output_type =
439       tile_op.getResult().getType().cast<mlir::TensorType>();
440   TF_ASSIGN_OR_RETURN(
441       auto local_type,
442       LocalTypeFromGlobalType(output_layout.value(), global_output_type));
443 
444   auto new_tile =
445       builder.create<mlir::TF::TileOp>(location, /*output=*/local_type,
446                                        /*input=*/tile_op.input(),
447                                        /*multiples=*/multiples_const);
448   tile_op.getResult().replaceAllUsesWith(new_tile.output());
449   tile_op.erase();
450   return new_tile.getOperation();
451 }
452 
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)453 StatusOr<llvm::DenseMap<int, Layout>> TileSPMDExpander::ComputeLayoutForward(
454     mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
455   if (input_layouts.find(0) == input_layouts.end())
456     return llvm::DenseMap<int, Layout>();
457 
458   TF_ASSIGN_OR_RETURN(auto mesh, ExtractDeviceMeshEnclosingCluster(op));
459   auto tile_op = llvm::cast<mlir::TF::TileOp>(op);
460 
461   auto output_ranked_type =
462       tile_op.output().getType().dyn_cast<mlir::RankedTensorType>();
463   if (!output_ranked_type || !output_ranked_type.hasStaticShape()) {
464     return errors::InvalidArgument(
465         llvm::formatv(
466             "requires output type to have statically known rank, but got : {0}",
467             output_ranked_type)
468             .str());
469   }
470   auto tile_output_shape = output_ranked_type.getShape();
471 
472   llvm::SmallVector<int64_t, 4> static_multiple;
473   auto status =
474       ExtractConstVectorFromValue(tile_op.multiples(), &static_multiple);
475 
476   // If multiple operands cannot be statically known, output is set to
477   // replicated.
478   if (!status.ok()) {
479     return llvm::DenseMap<int, Layout>(
480         {{0, Layout::ReplicatedOnMesh(mesh, tile_output_shape.size())}});
481   }
482 
483   // When suggested input layout exists then forward the input sharding for all
484   // dimensions where `multiple` == 1.
485   const Layout input_layout = input_layouts.lookup(0);
486   std::vector<std::string> output_layout_specs;
487   for (const auto& multiple_and_dim_sharding :
488        llvm::zip(static_multiple, input_layout.sharding_specs())) {
489     const int multiple = std::get<0>(multiple_and_dim_sharding);
490     const auto& tensor_dimension = std::get<1>(multiple_and_dim_sharding);
491     output_layout_specs.push_back(multiple == 1
492                                       ? tensor_dimension.sharding_spec()
493                                       : Layout::kUnshardedDim);
494   }
495 
496   TF_ASSIGN_OR_RETURN(const Layout output_layout,
497                       Layout::GetLayout(output_layout_specs, mesh));
498   return llvm::DenseMap<int, Layout>({{0, output_layout}});
499 }
500 
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)501 StatusOr<llvm::DenseMap<int, Layout>> TileSPMDExpander::ComputeLayoutBackward(
502     mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
503   TF_ASSIGN_OR_RETURN(auto mesh, ExtractDeviceMeshEnclosingCluster(op));
504   auto tile_op = llvm::cast<mlir::TF::TileOp>(op);
505 
506   // Retrieve operand/output shapes of tile op.
507   auto input_ranked_type =
508       tile_op.input().getType().dyn_cast<mlir::RankedTensorType>();
509   if (!input_ranked_type || !input_ranked_type.hasStaticShape()) {
510     return errors::InvalidArgument(
511         llvm::formatv(
512             "requires input type to have statically known rank, but got : {0}",
513             input_ranked_type)
514             .str());
515   }
516   auto tile_input_shape = input_ranked_type.getShape();
517 
518   llvm::DenseMap<int, Layout> input_layouts(op->getNumOperands());
519 
520   // `multiples` operand is always set to have replicated layout.
521   input_layouts[1] = Layout::ReplicatedOnMesh(
522       mesh,
523       tile_op.multiples().getType().cast<mlir::RankedTensorType>().getRank());
524 
525   llvm::SmallVector<int64_t, 4> static_multiple;
526   auto status =
527       ExtractConstVectorFromValue(tile_op.multiples(), &static_multiple);
528 
529   // If multiple operands cannot be statically known they are set to replicated.
530   if (!status.ok()) {
531     input_layouts[0] = Layout::ReplicatedOnMesh(mesh, tile_input_shape.size());
532     return input_layouts;
533   }
534 
535   // When suggested output layout exists, then override operand layout with
536   // consumer suggested output layout if `multiple` of dimension == 1 and
537   // dimension size can be evenly divisible by the sharding.
538   if (output_layouts.find(0) != output_layouts.end()) {
539     const Layout output_layout = output_layouts.lookup(0);
540     std::vector<std::string> input_layout_specs;
541     for (const auto& multiple_and_dim_sharding :
542          llvm::zip(static_multiple, output_layout.sharding_specs())) {
543       const int multiple = std::get<0>(multiple_and_dim_sharding);
544       const auto& tensor_dimension = std::get<1>(multiple_and_dim_sharding);
545       input_layout_specs.push_back(multiple == 1
546                                        ? tensor_dimension.sharding_spec()
547                                        : Layout::kUnshardedDim);
548     }
549     TF_ASSIGN_OR_RETURN(const Layout input_layout,
550                         Layout::GetLayout(input_layout_specs, mesh));
551     input_layouts[0] = input_layout;
552   }
553   return input_layouts;
554 }
555 
556 namespace {
557 
558 // From input shape and output shape, extract a maximal list segments where
559 // the product of the input shape from input_segment_start to input_segment_end
560 // is equal to the product of the output shape from output_segment_start
561 // to output_segment_end and is not equal for the product of any subsequence.
562 // Note that dimensions of shape are skipped over if they would be at the start
563 // of a segment.
564 // Note that shapes with unknown dimension size (represented by -1) are
565 // unsupported.
ComputeReshapeSegments(llvm::ArrayRef<int64_t> input_shape,llvm::ArrayRef<int64_t> output_shape,llvm::SmallVectorImpl<int64_t> & input_segment_start,llvm::SmallVectorImpl<int64_t> & input_segment_end,llvm::SmallVectorImpl<int64_t> & output_segment_start,llvm::SmallVectorImpl<int64_t> & output_segment_end)566 void ComputeReshapeSegments(
567     llvm::ArrayRef<int64_t> input_shape, llvm::ArrayRef<int64_t> output_shape,
568     llvm::SmallVectorImpl<int64_t>& input_segment_start,
569     llvm::SmallVectorImpl<int64_t>& input_segment_end,
570     llvm::SmallVectorImpl<int64_t>& output_segment_start,
571     llvm::SmallVectorImpl<int64_t>& output_segment_end) {
572   int input_offset = 0;
573   int output_offset = 0;
574 
575   while (input_offset < input_shape.size() &&
576          output_offset < output_shape.size()) {
577     while (input_offset < input_shape.size() && input_shape[input_offset] == 1)
578       input_offset++;
579     while (output_offset < output_shape.size() &&
580            output_shape[output_offset] == 1)
581       output_offset++;
582     if (input_offset >= input_shape.size() ||
583         output_offset >= output_shape.size()) {
584       // Since the input and output tensors the same number of entries, we are
585       // guaranteed to reach the end of both shapes at the same time.
586       assert(input_offset >= input_shape.size() &&
587              output_offset >= output_shape.size());
588       return;
589     }
590 
591     input_segment_start.emplace_back(input_offset);
592     output_segment_start.emplace_back(output_offset);
593 
594     int64 input_prod = input_shape[input_offset++];
595     int64 output_prod = output_shape[output_offset++];
596     while (input_prod != output_prod) {
597       if (input_prod < output_prod)
598         input_prod *= input_shape[input_offset++];
599       else
600         output_prod *= output_shape[output_offset++];
601     }
602     input_segment_end.emplace_back(input_offset);
603     output_segment_end.emplace_back(output_offset);
604   }
605 }
606 
607 // For reshape we want to reduce the number of all-to-alls and slices needed.
608 // Note that the forward layout propagation for reshape will be the same
609 // algorithm as backwards propagation.
610 //
611 // Suppose we have input shape a_0,...,a_n and output shape b_0,...,b_k  such
612 // that a_0*...*a_i != b_0*...*b_j except with (i,j)=(n,k).
613 // The forward propagation of an input layout depends only on size of axis 0
614 // of the output shape and the mesh dimension axis 0 of input is sharded on:
615 //
616 // 1. In any case we must all to all on any input axis from 1 to n and the
617 //    output layout from output axis 1 to k will always be replicated.
618 //
619 // 2. If input axis 0 is replicated, we do a local reshape and set the layout
620 //    of the output axis 0 to replicated.
621 //
622 // 3. If input axis 0 is sharded and the number of shards *does not divide*
623 //    b_0, then we must all-to-all on input axis 0 (as well as the axis
624 //    mentioned in 1) do a local reshape and set the layout of output axis 0
625 //    to replicated.
626 //
627 // 4. If input axis 0 is sharded and the number of shards does divide b_0, we
628 //    can do a local reshape and set the layout of output axis 0 to the same
629 //    mesh dimension as the input layout axis 0.
630 //
631 // Finally if for a general input and output shape, if we partition the input
632 // and output shape into such segments, we can apply the above rule on each
633 // segment. The ComputeReshapeSegments function above computes the starting
634 // and ending index of each segment.
MakeLayoutForReshape(const Layout & input_layout,const llvm::ArrayRef<int64_t> output_shape,llvm::SmallVectorImpl<int64_t> & input_segment_start,llvm::SmallVectorImpl<int64_t> & output_segment_start)635 StatusOr<Layout> MakeLayoutForReshape(
636     const Layout& input_layout, const llvm::ArrayRef<int64_t> output_shape,
637     llvm::SmallVectorImpl<int64_t>& input_segment_start,
638     llvm::SmallVectorImpl<int64_t>& output_segment_start) {
639   std::vector<std::string> layout_specs;
640   layout_specs.reserve(output_shape.size());
641   // Initialy set the layout to be all replicated.
642   for (int i = 0; i < output_shape.size(); ++i)
643     layout_specs.push_back(Layout::kUnshardedDim);
644   // Now process each segment, for each segment if the number of shards on the
645   // first entry of the input segment divides the output shape on the first
646   // entry of the output segment, we request a sharded layout on that axis.
647   for (int i = 0; i < input_segment_start.size(); ++i) {
648     const int num_shards = input_layout.num_shards_for_dim(
649         input_layout.dim(input_segment_start[i]));
650     if (output_shape[output_segment_start[i]] % num_shards == 0)
651       layout_specs[output_segment_start[i]] =
652           input_layout.sharding_spec(input_segment_start[i]);
653   }
654   return Layout::GetLayout(layout_specs, input_layout.mesh());
655 }
656 
657 }  // namespace
658 
659 // TODO(b/171335075): Implement the SPMD for generic Reshape.
ExpandOp(mlir::Operation * op)660 StatusOr<mlir::Operation*> ReshapeSPMDExpander::ExpandOp(mlir::Operation* op) {
661   // Update input/output shape based on the sharding information.
662   TF_ASSIGN_OR_RETURN(auto input_layout,
663                       ExtractLayoutFromOperand(op->getOperand(0)));
664   TF_ASSIGN_OR_RETURN(auto output_layout, ExtractSingleLayoutFromOp(op));
665 
666   if (!input_layout || !output_layout)
667     return errors::InvalidArgument(
668         "Input and output layouts of Reshape op must be known before SPMD "
669         "expansion.");
670 
671   if (input_layout->IsFullyReplicated() && output_layout->IsFullyReplicated())
672     return InferSPMDExpandedLocalShape(op);
673 
674   TF_ASSIGN_OR_RETURN(auto global_input_shape,
675                       ExtractGlobalInputShape(op->getOpOperand(0)));
676   TF_ASSIGN_OR_RETURN(auto global_output_shape,
677                       ExtractGlobalOutputShape(op->getOpResult(0)));
678 
679   llvm::SmallVector<int64_t, 4> input_segment_start;
680   llvm::SmallVector<int64_t, 4> input_segment_end;
681   llvm::SmallVector<int64_t, 4> output_segment_start;
682   llvm::SmallVector<int64_t, 4> output_segment_end;
683 
684   llvm::SmallVector<int64_t, 4> local_reshape_const;
685 
686   // Break up input and output shapes into segments which multiply to the same
687   // number. We will treat each segment seaparately when constructing the input
688   // shape from the output shape and vica versa.
689   ComputeReshapeSegments(global_input_shape, global_output_shape,
690                          input_segment_start, input_segment_end,
691                          output_segment_start, output_segment_end);
692 
693   // Compute the shape for the local reshape op. For each input segment,
694   // 1) Check the sharding status of all dimensions in that segment.
695   // 2) Create entries in the output shape and layout for the segment.
696   //
697   // Also insert the necessary 1 dimensions between input and output segments.
698   //
699   // Currently the algorithm supports Reshape with limited cases.
700   // - For example, reshape a [2, 16] shape tensor with layout ['not_sharded',
701   //   'x'],
702   //      - to a [2, 4, 4] shape tensor with layout ['not_sharded', 'x',
703   //        'not_sharded'] does not need cross device data shuffling.
704   //      - to a [2, 4, 4] shape tensor with layout ['not_sharded',
705   //        'not_sharded', 'x'] needs cross device AllToAll on the input and
706   //        a slice on the output.
707   // - For trivial cases, which AllToAll can support, an AllToAll will be
708   //   inserted. For example, reshape a [2, 4, 3] shape tensor with layout
709   //   ['not_sharded', 'x', 'not_sharded'] to [2, 12] shape tensor fully
710   //   replicated can be supported.
711   std::vector<ShardingSpec> tgt_input_layout(input_layout->rank());
712   std::vector<ShardingSpec> tgt_output_layout(output_layout->rank());
713 
714   for (int i = 0; i < input_segment_start.size(); ++i) {
715     const int input_start = input_segment_start[i];
716     const int output_start = output_segment_start[i];
717     const int prev_input_segment_end = (i == 0 ? 0 : input_segment_end[i - 1]);
718     const int prev_output_segment_end =
719         (i == 0 ? 0 : output_segment_end[i - 1]);
720 
721     // Between this segment and the last segment, if there is a gap, insert
722     // dimensions of size 1 and kUnshardedDim as output layout dim.
723     for (int j = prev_input_segment_end; j < input_start; ++j)
724       tgt_input_layout[j].set_sharding_spec(Layout::kUnshardedDim);
725     for (int j = prev_output_segment_end; j < output_start; ++j) {
726       local_reshape_const.emplace_back(1);
727       tgt_output_layout[j].set_sharding_spec(Layout::kUnshardedDim);
728     }
729 
730     const int num_input_shards =
731         input_layout->num_shards_for_dim(input_layout->dim(input_start));
732 
733     // Decide on the sharding of the input for this segment.
734     // If the input is already sharded, we try to keep this sharding (unless
735     // the output size of first output dimension is incompatible).
736     // NOTE: If the input is unsharded in a dimension, and the output is sharded
737     // we could 'preshard' the input on this dimension before the reshape.
738     // This is unlikely to have any major gains in performance.
739     if (global_output_shape[output_start] % num_input_shards != 0) {
740       tgt_input_layout[input_start].set_sharding_spec(Layout::kUnshardedDim);
741       tgt_output_layout[output_start].set_sharding_spec(Layout::kUnshardedDim);
742       local_reshape_const.emplace_back(global_output_shape[output_start]);
743     } else {
744       tgt_input_layout[input_start] = input_layout->dim(input_start);
745       tgt_output_layout[output_start] = input_layout->dim(input_start);
746       local_reshape_const.emplace_back(global_output_shape[output_start] /
747                                        num_input_shards);
748     }
749 
750     for (int j = input_start + 1; j < input_segment_end[i]; ++j)
751       tgt_input_layout[j].set_sharding_spec(Layout::kUnshardedDim);
752     for (int j = output_start + 1; j < output_segment_end[i]; ++j) {
753       local_reshape_const.emplace_back(global_output_shape[j]);
754       tgt_output_layout[j].set_sharding_spec(Layout::kUnshardedDim);
755     }
756   }
757 
758   // Fill any remaining dimensions of size 1 and sharding dim on the end of the
759   // layout.
760   for (int j = input_segment_end.back(); j < tgt_input_layout.size(); ++j)
761     tgt_input_layout[j].set_sharding_spec(Layout::kUnshardedDim);
762   for (int j = output_segment_end.back(); j < tgt_output_layout.size(); ++j) {
763     local_reshape_const.emplace_back(1);
764     tgt_output_layout[j].set_sharding_spec(Layout::kUnshardedDim);
765   }
766 
767   TF_ASSIGN_OR_RETURN(
768       auto desired_input_layout,
769       Layout::GetLayout(tgt_input_layout, input_layout->mesh()));
770   TF_ASSIGN_OR_RETURN(
771       auto desired_output_layout,
772       Layout::GetLayout(tgt_output_layout, input_layout->mesh()));
773 
774   auto reshape_op = mlir::cast<mlir::TF::ReshapeOp>(op);
775   TF_ASSIGN_OR_RETURN(
776       mlir::Value new_input,
777       EmitRelayout(reshape_op.tensor(), *input_layout, desired_input_layout));
778 
779   mlir::OpBuilder builder(op);
780 
781   // Update shape op to use the local shape as input. Importantly, this updates
782   // the shape attr in the Op, which `InferSPMDExpandedLocalShape` does not
783   // help.
784   auto new_shape = mlir::RankedTensorType::get(
785       {static_cast<int64_t>(local_reshape_const.size())}, builder.getI64Type());
786   auto const_attr =
787       mlir::DenseIntElementsAttr::get(new_shape, local_reshape_const);
788   auto new_reshape_const_op =
789       builder.create<mlir::TF::ConstOp>(DT_LOC(op), const_attr);
790   mlir::TF::ReshapeOp new_reshape_op = builder.create<mlir::TF::ReshapeOp>(
791       op->getLoc(), new_input, new_reshape_const_op);
792 
793   TF_ASSIGN_OR_RETURN(auto final_output,
794                       EmitRelayout(new_reshape_op.output(),
795                                    desired_output_layout, *output_layout));
796 
797   op->getResult(0).replaceAllUsesWith(final_output);
798   op->erase();
799   return final_output.getDefiningOp();
800 }
801 
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)802 StatusOr<llvm::DenseMap<int, Layout>> ReshapeSPMDExpander::ComputeLayoutForward(
803     mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
804   if (input_layouts.find(0) == input_layouts.end())
805     return llvm::DenseMap<int, Layout>();
806 
807   auto reshape_op = mlir::cast<mlir::TF::ReshapeOp>(op);
808   TF_ASSIGN_OR_RETURN(
809       auto input_shape,
810       GetShapeOfValue(reshape_op.tensor(), /*fail_on_dynamic=*/true));
811   TF_ASSIGN_OR_RETURN(
812       auto output_shape,
813       GetShapeOfValue(reshape_op.output(), /*fail_on_dynamic=*/true));
814 
815   llvm::SmallVector<int64_t, 4> input_segment_start;
816   llvm::SmallVector<int64_t, 4> input_segment_end;
817   llvm::SmallVector<int64_t, 4> output_segment_start;
818   llvm::SmallVector<int64_t, 4> output_segment_end;
819 
820   // Break up input and output shapes into segments which multiply to the same
821   // number. We will treat each segment seaparately when constructing the input
822   // shape from the output shape and vica versa.
823   ComputeReshapeSegments(input_shape, output_shape, input_segment_start,
824                          input_segment_end, output_segment_start,
825                          output_segment_end);
826 
827   TF_ASSIGN_OR_RETURN(
828       const Layout output_layout,
829       MakeLayoutForReshape(input_layouts.lookup(0), output_shape,
830                            input_segment_start, output_segment_start));
831   return llvm::DenseMap<int, Layout>({{0, output_layout}});
832 }
833 
834 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)835 ReshapeSPMDExpander::ComputeLayoutBackward(
836     mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
837   if (output_layouts.find(0) == output_layouts.end())
838     return llvm::DenseMap<int, Layout>();
839 
840   auto reshape_op = mlir::cast<mlir::TF::ReshapeOp>(op);
841   TF_ASSIGN_OR_RETURN(
842       auto input_shape,
843       GetShapeOfValue(reshape_op.tensor(), /*fail_on_dynamic=*/true));
844   TF_ASSIGN_OR_RETURN(
845       auto output_shape,
846       GetShapeOfValue(reshape_op.output(), /*fail_on_dynamic=*/true));
847 
848   llvm::SmallVector<int64_t, 4> input_segment_start;
849   llvm::SmallVector<int64_t, 4> input_segment_end;
850   llvm::SmallVector<int64_t, 4> output_segment_start;
851   llvm::SmallVector<int64_t, 4> output_segment_end;
852 
853   // Break up input and output shapes into segments which multiply to the same
854   // number. We will treat each segment seaparately when constructing the input
855   // shape from the output shape and vica versa.
856   ComputeReshapeSegments(input_shape, output_shape, input_segment_start,
857                          input_segment_end, output_segment_start,
858                          output_segment_end);
859 
860   TF_ASSIGN_OR_RETURN(
861       const Layout input_layout,
862       MakeLayoutForReshape(output_layouts.lookup(0), input_shape,
863                            output_segment_start, input_segment_start));
864   return llvm::DenseMap<int, Layout>({{0, input_layout}});
865 }
866 
ExpandOp(mlir::Operation * op)867 StatusOr<mlir::Operation*> TransposeSPMDExpander::ExpandOp(
868     mlir::Operation* op) {
869   // Currently we only support transpose without shuffling data. When use cases
870   // come, we can add support as we need to figure the best strategy to keep the
871   // cost as low as possible. Before that, add a check with good error message.
872   {
873     TF_ASSIGN_OR_RETURN(auto output_layout, ExtractSingleLayoutFromOp(op));
874     TF_ASSIGN_OR_RETURN(auto operand_layout,
875                         ExtractLayoutFromOperand(op->getOperand(0)));
876 
877     if (!output_layout)
878       return errors::InvalidArgument(
879           "output layout of TransposeOp must be known before SPMD expansion.");
880     if (!operand_layout)
881       return errors::InvalidArgument(
882           "operand layout of TransposeOp must be known before SPMD expansion.");
883 
884     auto transpose = mlir::cast<mlir::TF::TransposeOp>(op);
885     llvm::SmallVector<int64, 4> perm;
886     TF_RETURN_IF_ERROR(ExtractConstVectorFromValue(transpose.perm(), &perm));
887 
888     for (const auto& p : llvm::enumerate(perm)) {
889       if (operand_layout->dim(p.value()).sharding_spec() !=
890           output_layout->dim(p.index()).sharding_spec()) {
891         return errors::InvalidArgument(
892             "TransposeOp SPMD needs communication is not supported yet. \n "
893             "operand layout: ",
894             operand_layout->ToString(),
895             "\n output layout: ", output_layout->ToString());
896       }
897     }
898   }
899 
900   // Do nothing but infer local shape for now.
901   return InferSPMDExpandedLocalShape(op);
902 }
903 
904 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)905 TransposeSPMDExpander::ComputeLayoutForward(
906     mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
907   if (input_layouts.find(0) == input_layouts.end())
908     return llvm::DenseMap<int, Layout>();
909 
910   auto transpose = mlir::cast<mlir::TF::TransposeOp>(op);
911   llvm::SmallVector<int64, 4> perm;
912   TF_RETURN_IF_ERROR(ExtractConstVectorFromValue(transpose.perm(), &perm));
913 
914   const Layout input_layout = input_layouts.lookup(0);
915   std::vector<std::string> output_layout_specs;
916   for (int64 p : perm)
917     output_layout_specs.push_back(input_layout.sharding_spec(p));
918 
919   TF_ASSIGN_OR_RETURN(
920       const Layout output_layout,
921       Layout::GetLayout(output_layout_specs, input_layout.mesh()));
922   return llvm::DenseMap<int, Layout>({{0, output_layout}});
923 }
924 
925 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)926 TransposeSPMDExpander::ComputeLayoutBackward(
927     mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
928   auto transpose = mlir::cast<mlir::TF::TransposeOp>(op);
929   llvm::SmallVector<int64, 4> perm;
930   TF_RETURN_IF_ERROR(ExtractConstVectorFromValue(transpose.perm(), &perm));
931   TF_ASSIGN_OR_RETURN(const Mesh mesh, ExtractDeviceMeshEnclosingCluster(op));
932 
933   llvm::DenseMap<int, Layout> input_layouts(transpose->getNumOperands());
934   input_layouts[1] = Layout::ReplicatedOnMesh(mesh, /*rank=*/1);
935 
936   if (output_layouts.find(0) != output_layouts.end()) {
937     const Layout output_layout = output_layouts.lookup(0);
938 
939     llvm::SmallVector<int64, 4> inverse_perm(perm.size());
940     for (const auto& p : llvm::enumerate(perm)) {
941       inverse_perm[p.value()] = p.index();
942     }
943 
944     std::vector<std::string> input_layout_specs;
945     // For example, if perm [2, 0, 1], then inverse perm is [1, 2, 0].
946     // So for input_dim[i] it is output[reverse_perm[i]]
947     for (auto dim_in_output : inverse_perm)
948       input_layout_specs.push_back(output_layout.sharding_spec(dim_in_output));
949 
950     TF_ASSIGN_OR_RETURN(const Layout input_layout,
951                         Layout::GetLayout(input_layout_specs, mesh));
952     input_layouts[0] = input_layout;
953   }
954 
955   return input_layouts;
956 }
957 
958 namespace {
959 
RelayoutOneHotInput(const absl::optional<Layout> & input_layout,const absl::optional<Layout> & output_layout,const int axis,mlir::TF::OneHotOp & one_hot)960 Status RelayoutOneHotInput(const absl::optional<Layout>& input_layout,
961                            const absl::optional<Layout>& output_layout,
962                            const int axis, mlir::TF::OneHotOp& one_hot) {
963   if (!input_layout || !output_layout)
964     return errors::InvalidArgument(
965         "layout for tf.OneHot operation inputs and outputs must be known before"
966         " SPMD expansion. Consider adding Relayout() op to specify the "
967         "layout.");
968 
969   std::vector<ShardingSpec> sharding_specs(input_layout->rank());
970   for (int i = 0; i < input_layout->rank(); ++i) {
971     if (i < axis)
972       sharding_specs[i] = output_layout->dim(i);
973     else
974       sharding_specs[i] = output_layout->dim(i + 1);
975   }
976   TF_ASSIGN_OR_RETURN(const Layout new_input_layout,
977                       Layout::GetLayout(sharding_specs, input_layout->mesh()));
978 
979   TF_ASSIGN_OR_RETURN(
980       mlir::Value new_input,
981       EmitRelayout(one_hot.indices(), *input_layout, new_input_layout));
982 
983   one_hot->setOperand(0, new_input);
984 
985   return OkStatus();
986 }
987 
988 }  // namespace
989 
ExpandOp(mlir::Operation * op)990 StatusOr<mlir::Operation*> OneHotSPMDExpander::ExpandOp(mlir::Operation* op) {
991   auto one_hot_op = llvm::cast<mlir::TF::OneHotOp>(op);
992 
993   mlir::OpBuilder builder(op);
994   TF_ASSIGN_OR_RETURN(const auto input_layout,
995                       ExtractLayoutFromOperand(one_hot_op->getOperand(0)));
996 
997   TF_ASSIGN_OR_RETURN(const auto output_layout,
998                       ExtractSingleLayoutFromOp(one_hot_op));
999   int axis = one_hot_op.axisAttr().getInt();
1000   if (axis == -1) axis = output_layout->rank() - 1;
1001 
1002   // For tf.OneHot, relayout input so that it matches the output layout (outside
1003   // of the one hot dimension).
1004   TF_RETURN_IF_ERROR(
1005       RelayoutOneHotInput(input_layout, output_layout, axis, one_hot_op));
1006 
1007   const int num_shards = output_layout->num_shards()[axis];
1008   const auto depth = ExtractConstIntFromValue(one_hot_op.depth());
1009   const bool depth_statically_divisible_by_sharding =
1010       (depth.ok() && (*depth) % num_shards == 0);
1011 
1012   // If axis dimension of tf.OneHot is sharded and number of shards evenly
1013   // divides the `depth` input of the one hot operations, we can mutate the
1014   // `depth` and parameter and `indices` parameter to calculate local tensor
1015   // directly.
1016   const std::string& mesh_dim_name = output_layout->sharding_spec(axis);
1017 
1018   if (mesh_dim_name != Layout::kUnshardedDim) {
1019     if (!depth_statically_divisible_by_sharding)
1020       return errors::InvalidArgument(
1021           "OneHot axis dimension is sharded with incorrect layout. OneHot op "
1022           "depth should be evenly divisible by number of shards.");
1023 
1024     // Recalculate new local depth. Namely: new_depth = depth / num_shards
1025     mlir::Value new_depth = CreateIntScalarConst((*depth) / num_shards, builder,
1026                                                  one_hot_op->getLoc(), false);
1027 
1028     // Calculate shard id at mesh dimension for the sharded axis.
1029     TF_ASSIGN_OR_RETURN(const Mesh mesh,
1030                         ExtractDeviceMeshEnclosingCluster(one_hot_op));
1031     mlir::tf_device::ClusterOp cluster =
1032         one_hot_op->getParentOfType<mlir::tf_device::ClusterOp>();
1033 
1034     // `mesh_coordinates` is tensor of size [1, num mesh dimensions] where each
1035     // element in the tensor refers to shard id for the specified mesh
1036     // dimension.
1037     TF_ASSIGN_OR_RETURN(mlir::Value mesh_coordinates,
1038                         GetMeshCoordinatesFromCluster(cluster));
1039     const int num_mesh_dimensions = output_layout->mesh().dims().size();
1040     llvm::SmallVector<int32_t, 4> multiplier(num_mesh_dimensions);
1041     const int mesh_dim_index =
1042         output_layout->mesh().GetMeshDimIndexWithName(mesh_dim_name);
1043 
1044     mlir::TF::SliceOp selected_sharding_at_dimension = builder.create<
1045         mlir::TF::SliceOp>(
1046         one_hot_op.getLoc(),
1047         mlir::RankedTensorType::get({1, 1}, mesh_coordinates.getType()
1048                                                 .cast<mlir::TensorType>()
1049                                                 .getElementType()),
1050         /*input=*/mesh_coordinates,
1051         /*begin=*/IntConst(builder, one_hot_op.getLoc(), {0, mesh_dim_index}),
1052         /*size=*/IntConst(builder, one_hot_op.getLoc(), {1, 1}));
1053 
1054     // Reshape the sliced shape (1,1) tensor to shape 0 scalar.
1055     auto scalar_size_type =
1056         mlir::RankedTensorType::get({}, builder.getIntegerType(32));
1057     mlir::Value scalar_shape = mlir::TF::collection_ops_util::GetR1Const(
1058         scalar_size_type.getShape(), builder, one_hot_op->getLoc());
1059     mlir::Value selected_sharding_scalar_value =
1060         builder.create<mlir::TF::ReshapeOp>(
1061             one_hot_op.getLoc(), mlir::ArrayRef<mlir::Type>{scalar_size_type},
1062             mlir::ArrayRef<mlir::Value>{selected_sharding_at_dimension.output(),
1063                                         scalar_shape},
1064             mlir::ArrayRef<mlir::NamedAttribute>{});
1065 
1066     // `new_indices` =  `original_indices` - `selected_sharding_scalar_value` *
1067     // (depth/num_shards)
1068     mlir::Value id_offset = builder.create<mlir::TF::MulOp>(
1069         one_hot_op->getLoc(), new_depth, selected_sharding_scalar_value);
1070     mlir::Value original_indices = one_hot_op.indices();
1071     mlir::Value new_indices = builder.create<mlir::TF::SubOp>(
1072         one_hot_op->getLoc(), original_indices, id_offset);
1073 
1074     // Replace onehot operation inputs with mutated `new_depth` and `new_input`
1075     // tensors so that local tensors can be calculated directly without
1076     // calculating intermediate global tensors.
1077     one_hot_op->getOpOperand(0).set(new_indices);
1078     one_hot_op->getOpOperand(1).set(new_depth);
1079   }
1080   return InferSPMDExpandedLocalShape(one_hot_op);
1081 }
1082 
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)1083 StatusOr<llvm::DenseMap<int, Layout>> OneHotSPMDExpander::ComputeLayoutForward(
1084     mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
1085   if (input_layouts.find(0) == input_layouts.end())
1086     return llvm::DenseMap<int, Layout>();
1087 
1088   auto one_hot = mlir::dyn_cast<mlir::TF::OneHotOp>(op);
1089   int axis = one_hot.axis();
1090   if (axis == -1) axis = ValueRank(one_hot.indices());
1091   TF_ASSIGN_OR_RETURN(auto mesh, ExtractDeviceMeshEnclosingCluster(op));
1092 
1093   const Layout indices_layout = input_layouts.lookup(0);
1094   std::vector<std::string> output_layout_specs;
1095   for (int i = 0; i < indices_layout.rank(); ++i) {
1096     // Insert an onehot dimension for expanded axis.
1097     if (i == axis) {
1098       output_layout_specs.push_back(Layout::kUnshardedDim);
1099     }
1100     output_layout_specs.push_back(indices_layout.sharding_spec(i));
1101   }
1102   if (axis == indices_layout.rank() || axis == -1) {
1103     output_layout_specs.push_back(Layout::kUnshardedDim);
1104   }
1105 
1106   TF_ASSIGN_OR_RETURN(auto output_layout,
1107                       Layout::GetLayout(output_layout_specs, mesh));
1108   return llvm::DenseMap<int, Layout>({{0, output_layout}});
1109 }
1110 
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)1111 StatusOr<llvm::DenseMap<int, Layout>> OneHotSPMDExpander::ComputeLayoutBackward(
1112     mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
1113   auto one_hot = mlir::dyn_cast<mlir::TF::OneHotOp>(op);
1114   int axis = one_hot.axis();
1115   if (axis == -1) axis = ValueRank(one_hot.indices());
1116   TF_ASSIGN_OR_RETURN(auto mesh, ExtractDeviceMeshEnclosingCluster(op));
1117 
1118   llvm::DenseMap<int, Layout> input_layouts(one_hot->getNumOperands());
1119   const auto scalar_replicated_layout =
1120       Layout::ReplicatedOnMesh(mesh, /*rank=*/0);
1121   input_layouts[1] = scalar_replicated_layout;  // depth
1122   input_layouts[2] = scalar_replicated_layout;  // on_value
1123   input_layouts[3] = scalar_replicated_layout;  // off_value
1124 
1125   // If output layout is specified, then propagate all dimensions (except axis
1126   // dimension) as operand layout.
1127   if (output_layouts.find(0) != output_layouts.end()) {
1128     const Layout output_layout = output_layouts.lookup(0);
1129 
1130     std::vector<std::string> indices_layout_specs;
1131     for (int i = 0; i < output_layout.rank(); ++i) {
1132       if (i == axis) continue;
1133       indices_layout_specs.push_back(output_layout.sharding_spec(i));
1134     }
1135 
1136     TF_ASSIGN_OR_RETURN(auto input_layout,
1137                         Layout::GetLayout(indices_layout_specs, mesh));
1138     input_layouts[0] = input_layout;
1139   }
1140 
1141   return input_layouts;
1142 }
1143 
ExpandOp(mlir::Operation * op)1144 StatusOr<mlir::Operation*> ShapeSPMDExpander::ExpandOp(mlir::Operation* op) {
1145   TF_ASSIGN_OR_RETURN(auto result_layouts, ExtractLayoutFromOp(op));
1146   for (const auto& layout : result_layouts) {
1147     if (!layout.has_value())
1148       return errors::Internal(
1149           "All op result layouts must be specified for SPMD expansion.");
1150 
1151     if (!layout->IsFullyReplicated()) {
1152       return errors::Internal(
1153           "Shape/Rank ops must output value with replicated layout.");
1154     }
1155   }
1156   InferSPMDExpandedLocalShape(op);
1157 
1158   // DTensors shards are always full rank -- local rank == global rank
1159   if (mlir::isa<mlir::TF::RankOp>(op)) return op;
1160 
1161   // We have Shape/ShapeN op.
1162 
1163   // Find enclosing device_cluster op and update attributes for it if the
1164   // shape op result is returned to the cluster.
1165   auto enclosing_cluster = op->getParentOfType<mlir::tf_device::ClusterOp>();
1166   if (!enclosing_cluster)
1167     return errors::InvalidArgument(
1168         "Error during SPMD expansion of Shape op. Op must be enclosed in a "
1169         "device_cluster.");
1170 
1171   // Record output result index -> input_layout mapping.
1172   llvm::SmallVector<std::string, 4> input_layouts;
1173   std::vector<int> return_indices;
1174 
1175   // For each operand, extract global shape if necessary. If global shape
1176   // transformation is needed, and the transformed shape is returned to
1177   // outside of the device cluster, also attach input layout as additional
1178   // information so that future stack could infer local shape from the result.
1179   llvm::SmallVector<mlir::TF::MulOp, 4> output_ops;
1180   for (int i = 0; i < op->getNumOperands(); ++i) {
1181     // Fetch layout from _input_, not current op.
1182     TF_ASSIGN_OR_RETURN(auto input_layout,
1183                         ExtractLayoutFromOperand(op->getOperand(i)));
1184     if (!input_layout)
1185       return errors::InvalidArgument(
1186           "Input layout to shape op must be known before SPMD expansion.");
1187 
1188     // Fully replicated tensors: local shape = global shape.
1189     if (input_layout->IsFullyReplicated()) {
1190       continue;
1191     }
1192 
1193     // If a DTensor is sharded over a dimension, shards have equal size.
1194     // GlobalShape[Dim] = LocalShape[Dim] * NumShards[Dim]
1195     mlir::OpBuilder builder(op->getBlock(), ++mlir::Block::iterator(op));
1196     auto num_shards =
1197         IntConst(builder, op->getLoc(), input_layout->num_shards());
1198     auto global_shape = builder.create<mlir::TF::MulOp>(
1199         op->getLoc(), op->getResult(i), num_shards);
1200 
1201     op->getResult(i).replaceAllUsesExcept(
1202         global_shape.getResult(),
1203         llvm::SmallPtrSet<mlir::Operation*, 1>{global_shape});
1204 
1205     // Find the returned global shape and attach Metadata information.
1206     for (auto& use : global_shape.getOperation()->getUses()) {
1207       if (use.getOwner() == enclosing_cluster.GetBody().getTerminator()) {
1208         input_layouts.emplace_back(input_layout->ToString());
1209         return_indices.emplace_back(use.getOperandNumber());
1210         break;
1211       }
1212     }
1213     output_ops.emplace_back(global_shape);
1214   }
1215 
1216   // Attach original input for the enclosing device_cluster op.
1217   if (!input_layouts.empty()) {
1218     mlir::OpBuilder builder(op);
1219     enclosing_cluster->setAttr(kShapeOpInputLayoutIndices,
1220                                builder.getI32VectorAttr(return_indices));
1221     enclosing_cluster->setAttr(
1222         kShapeOpInputLayout,
1223         builder.getStrArrayAttr(llvm::SmallVector<llvm::StringRef, 4>(
1224             input_layouts.begin(), input_layouts.end())));
1225   }
1226 
1227   // TODO(hthu): Support multiple returns in ShapeN op.
1228   return output_ops.empty() ? op : output_ops[0].getOperation();
1229 }
1230 
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)1231 StatusOr<llvm::DenseMap<int, Layout>> ShapeSPMDExpander::ComputeLayoutForward(
1232     mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
1233   assert(op->getNumResults() == 1);
1234   TF_ASSIGN_OR_RETURN(auto mesh, ExtractDeviceMeshEnclosingCluster(op));
1235   const int output_rank = ValueRank(op->getResult(0));
1236   return llvm::DenseMap<int, Layout>(
1237       {{0, Layout::ReplicatedOnMesh(mesh, output_rank)}});
1238 }
1239 
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)1240 StatusOr<llvm::DenseMap<int, Layout>> ShapeSPMDExpander::ComputeLayoutBackward(
1241     mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
1242   return llvm::DenseMap<int, Layout>();
1243 }
1244 
1245 }  // namespace dtensor
1246 }  // namespace tensorflow
1247