xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/spmd_expander_common.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
17 
18 #include <algorithm>
19 #include <atomic>
20 #include <iterator>
21 #include <string>
22 
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/string_view.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
29 #include "mlir/IR/Builders.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
33 #include "mlir/IR/Location.h"  // from @llvm-project
34 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
35 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
36 #include "mlir/IR/Value.h"  // from @llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
41 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
42 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/utils/convert_op_folder.h"
43 #include "tensorflow/core/platform/errors.h"
44 #include "tensorflow/dtensor/cc/constants.h"
45 #include "tensorflow/dtensor/cc/tensor_layout.h"
46 #include "tensorflow/dtensor/mlir/device_utils.h"
47 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
48 #include "tensorflow/dtensor/mlir/layout_parsing.h"
49 #include "tensorflow/dtensor/mlir/op_utils.h"
50 #include "tensorflow/dtensor/mlir/shape_utils.h"
51 #include "tensorflow/dtensor/mlir/value_utils.h"
52 
53 namespace tensorflow {
54 namespace dtensor {
55 
LocalTypeFromGlobalType(const Layout & layout,const mlir::TensorType & original_type)56 StatusOr<mlir::TensorType> LocalTypeFromGlobalType(
57     const Layout& layout, const mlir::TensorType& original_type) {
58   if (!original_type.hasRank()) {
59     return original_type;
60   }
61   auto shape = llvm::to_vector<4>(original_type.getShape());
62   auto shard_values = layout.num_shards();
63   for (int output_axis = 0; output_axis < shape.size(); ++output_axis) {
64     if (shape[output_axis] != mlir::ShapedType::kDynamicSize) {
65       if (shape[output_axis] % shard_values[output_axis] != 0) {
66         return errors::InvalidArgument(
67             "The sharding spec for axis ", output_axis, " splits among ",
68             shard_values[output_axis],
69             " values, which does not evenly divide the length of that axis "
70             "(",
71             shape[output_axis], "). The full requested layout is ",
72             layout.ToString(), ".");
73       }
74       shape[output_axis] /= shard_values[output_axis];
75     }
76   }
77   mlir::RankedTensorType new_output_type =
78       mlir::RankedTensorType::get(shape, original_type.getElementType());
79   return new_output_type;
80 }
81 
GlobalTypeFromLocalType(const Layout & layout,const mlir::TensorType & original_type)82 StatusOr<mlir::TensorType> GlobalTypeFromLocalType(
83     const Layout& layout, const mlir::TensorType& original_type) {
84   if (!original_type.hasRank()) {
85     return original_type;
86   }
87   auto shape = llvm::to_vector<4>(original_type.getShape());
88   auto shard_values = layout.num_shards();
89   for (int output_axis = 0; output_axis < shape.size(); ++output_axis)
90     if (shape[output_axis] != mlir::ShapedType::kDynamicSize)
91       shape[output_axis] *= shard_values[output_axis];
92   mlir::RankedTensorType new_output_type =
93       mlir::RankedTensorType::get(shape, original_type.getElementType());
94   return new_output_type;
95 }
96 
CreateSplitOp(const int num_split,const int split_dimension,const mlir::Location location,mlir::Value src_input,mlir::OpBuilder * builder,mlir::TF::SplitOp * split_op)97 Status CreateSplitOp(const int num_split, const int split_dimension,
98                      const mlir::Location location, mlir::Value src_input,
99                      mlir::OpBuilder* builder, mlir::TF::SplitOp* split_op) {
100   // Creates a const op to hold split dimension value.
101   auto split_dim_type =
102       mlir::RankedTensorType::get({}, builder->getIntegerType(32));
103   auto split_dimension_attr =
104       mlir::DenseElementsAttr::get(split_dim_type, split_dimension);
105   auto split_dimension_op = builder->create<mlir::TF::ConstOp>(
106       location, split_dim_type, split_dimension_attr);
107 
108   // Correctly set output shapes of split op output if input shape is statically
109   // known.
110   mlir::Type output_type;
111   auto input_type = src_input.getType().cast<mlir::TensorType>();
112 
113   if (input_type.hasRank()) {
114     if (input_type.getShape()[split_dimension] ==
115         mlir::ShapedType::kDynamicSize) {
116       output_type = input_type;
117     } else {
118       auto shape = llvm::to_vector<4>(input_type.getShape());
119       if (shape[split_dimension] % num_split != 0) {
120         return errors::InvalidArgument(
121             llvm::formatv(
122                 "incorrect input sharding configuration received. "
123                 "{0}-th dimension of the input must be evenly divisible by {1}",
124                 split_dimension, num_split)
125                 .str());
126       }
127 
128       shape[split_dimension] = shape[split_dimension] / num_split;
129       output_type =
130           mlir::RankedTensorType::get(shape, input_type.getElementType());
131     }
132   } else {
133     output_type = input_type;
134   }
135 
136   // Creates a split op that splits |src_input| along |split_dimension|.
137   llvm::SmallVector<mlir::Type, 4> output_types(num_split, output_type);
138   *split_op = builder->create<mlir::TF::SplitOp>(
139       location, output_types, split_dimension_op.output(), src_input);
140   return OkStatus();
141 }
142 
143 // Given layouts + shapes, determines if the two are broadcasting compatible.
144 // When broadcasting we effectively line up the shapes and layouts by the end.
145 // The input with lower rank can be thought of as having abs(rank_a-rank_b)
146 // replicated dims of size 1 prepended to it.
147 //
148 // Returns the broadcast layout and the splits in the two inputs needed to run
149 // an elementwise op efficiently.
150 //
151 // Checks that a given mesh dimension is not used in different tensor dimensions
152 // in the two input layouts.
153 // E.g. a layout like (unsharded,x,unsharded) is not compatible with
154 // (unsharded,x) or (x,unsharded,unsharded) but is compatible with
155 // (x,unsharded), (unsharded,unsharded) or (unsharded,x,unsharded).
156 // (Note that due to broadcasting, we compare the dimensions from the end).
157 //
158 // If dims_to_ignore is > 0, then we ignore when a mesh dimension is used in
159 // different tensor dimensions when those dimensions are both in the last
160 // dims_to_ignore tensor dimensions of each input.
161 // E.g. If dims_to_ignore = 2, then (unsharded,x,unsharded) is now compatible
162 // with (unsharded,x) and it not compatible with (x,unsharded,unsharded).
163 //
164 // The output layout will be of rank max(layout_a.rank(), layout_b.rank()) -
165 // dims_to_ignore and will be replicated on a dimension if either one of the
166 // input layouts is replicated on that dimension. Once again recall due to
167 // broadcasting, layouts are aligned by their ends and not their beginnings.
168 // E.g. if dims_to_ignore is zero, the output layout for the inputs
169 // (unsharded,x,unsharded) and (unsharded,y) is (unsharded,x,y).
170 // If dims_to_ignore is two, the output for (y,x,unsharded) and
171 // (unsharded,x) is just (y).
172 //
173 // In the case that one tensor is sharded and the other is not on a given
174 // dimension, element wise operations *may* need to split the unsharded tensor
175 // along the same mesh dimension that the other input is split on. Note that
176 // the split is *not* needed if the unsharded tensor has dimension of size 1,
177 // due to broadcasting.
178 //
179 // To help with the needed splittings, the vectors to_split_* are resized to the
180 // rank of each input and if that dimension of the tensor needs to be split for
181 // and elementwise op, we record the mesh dimension it should be split along in
182 // the vector.
183 // E.g. in the case of input layouts (unsharded,x,unsharded) and
184 // (unsharded,unsharded) with dimensions (10,10,10) and (10,10),
185 // to_split_a = {"unsharded", "unsharded", "unsharded"} and to_split_b =
186 // {"x", "unsharded"}.
187 // If the shapes were (10,10,10) and (1,10), then to_split_a = {"unsharded",
188 // "unsharded", "unsharded"} and to_split_b = {"unsharded", "unsharded"}.
189 //
190 // Note that "unsharded" == Layout::kUnshardedDim.
191 // NOTE: shape_a and shape_b are *global* shapes.
GetBroadcastLayoutForElementWise(const Layout & layout_a,const Layout & layout_b,mlir::ArrayRef<int64_t> shape_a,mlir::ArrayRef<int64_t> shape_b,int64_t dims_to_ignore,std::vector<std::string> & to_split_a,std::vector<std::string> & to_split_b)192 StatusOr<Layout> GetBroadcastLayoutForElementWise(
193     const Layout& layout_a, const Layout& layout_b,
194     mlir::ArrayRef<int64_t> shape_a, mlir::ArrayRef<int64_t> shape_b,
195     int64_t dims_to_ignore, std::vector<std::string>& to_split_a,
196     std::vector<std::string>& to_split_b) {
197   if (layout_a.mesh() != layout_b.mesh())
198     return errors::InvalidArgument(
199         "layout_a and layout_b cannot be broadcast as they are on different "
200         "meshes.");
201 
202   const int rank_a = layout_a.rank();
203   const int rank_b = layout_b.rank();
204   const int rank_offset_a = std::max(0, rank_b - rank_a);
205   const int rank_offset_b = std::max(0, rank_a - rank_b);
206   absl::flat_hash_map<std::string, int> mesh_dim_map_a;
207   absl::flat_hash_map<std::string, int> mesh_dim_map_b;
208   std::vector<string> output_layout_specs;
209 
210   auto unsharded_specs = [](const int new_size) -> std::vector<std::string> {
211     std::vector<std::string> spec_strs(new_size, Layout::kUnshardedDim);
212     return spec_strs;
213   };
214 
215   to_split_a = unsharded_specs(rank_a - dims_to_ignore);
216   to_split_b = unsharded_specs(rank_b - dims_to_ignore);
217 
218   // Note that we record ranks over all dimensions even ones we ignore.
219   // We will check that a non-ignored dimension of a tensor does not use a
220   // mesh dimension that is used by an ignored dimension in the other tensor.
221   for (int i = 0; i < rank_a; ++i)
222     if (!Layout::IsUnshardedDimension(layout_a.sharding_spec(i)))
223       mesh_dim_map_a[layout_a.sharding_spec(i)] = i;
224   for (int i = 0; i < rank_b; ++i)
225     if (!Layout::IsUnshardedDimension(layout_b.sharding_spec(i)))
226       mesh_dim_map_b[layout_b.sharding_spec(i)] = i;
227 
228   for (int i = 0; i < std::max(rank_a, rank_b) - dims_to_ignore; ++i) {
229     const int dim_a = i - rank_offset_a;
230     const int dim_b = i - rank_offset_b;
231     // When ranks are not equal we treat the first rank_offset_* dims of the
232     // shorter layout as not sharded.
233     const std::string mesh_dim_a =
234         dim_a >= 0 ? layout_a.sharding_spec(dim_a) : Layout::kUnshardedDim;
235     const std::string mesh_dim_b =
236         dim_b >= 0 ? layout_b.sharding_spec(dim_b) : Layout::kUnshardedDim;
237     // When ranks are not equal, we treat the first rank_offset_* dims of the
238     // shorter shape as if they were 1.
239     const int64_t tensor_dim_a = dim_a >= 0 ? shape_a[dim_a] : 1;
240     const int64_t tensor_dim_b = dim_b >= 0 ? shape_b[dim_b] : 1;
241 
242     // Check for conflicted dimensions. If occurred, chose unsharded as merged
243     // result, if generate_unsharded_dim_for_conflicts is set by call site.
244     bool have_conflicted_dim = false;
245     if (!Layout::IsUnshardedDimension(mesh_dim_a) &&
246         mesh_dim_map_b.contains(mesh_dim_a) &&
247         mesh_dim_map_b[mesh_dim_a] != dim_b)
248       have_conflicted_dim = true;
249 
250     if (!Layout::IsUnshardedDimension(mesh_dim_b) &&
251         mesh_dim_map_a.contains(mesh_dim_b) &&
252         mesh_dim_map_a[mesh_dim_b] != dim_a)
253       have_conflicted_dim = true;
254 
255     // If both dimensions are sharded, we have already verified that they are
256     // sharded on the same mesh dim.
257     if (have_conflicted_dim) {
258       output_layout_specs.emplace_back(Layout::kUnshardedDim);
259     } else {
260       output_layout_specs.emplace_back(
261           Layout::IsUnshardedDimension(mesh_dim_a) ? mesh_dim_b : mesh_dim_a);
262     }
263     if (dim_a >= 0 && tensor_dim_a > 1 &&
264         Layout::IsUnshardedDimension(mesh_dim_a) &&
265         !Layout::IsUnshardedDimension(mesh_dim_b)) {
266       to_split_a[dim_a] = mesh_dim_b;
267     }
268     if (dim_b >= 0 && tensor_dim_b > 1 &&
269         Layout::IsUnshardedDimension(mesh_dim_b) &&
270         !Layout::IsUnshardedDimension(mesh_dim_a)) {
271       to_split_b[dim_b] = mesh_dim_a;
272     }
273   }
274   return Layout::GetLayout(output_layout_specs, layout_a.mesh());
275 }
276 
GetMergedOperandLayout(const llvm::DenseMap<int,Layout> & operand_layouts,mlir::Operation * op)277 StatusOr<absl::optional<Layout>> GetMergedOperandLayout(
278     const llvm::DenseMap<int, Layout>& operand_layouts, mlir::Operation* op) {
279   // Represents list of Layouts and it's operand index where layout value is
280   // defined (i.e. layout is not absl::nullopt).
281   llvm::SmallVector<std::pair<const Layout&, llvm::ArrayRef<int64_t>>, 4>
282       filtered_preferred_operand_layouts;
283   filtered_preferred_operand_layouts.reserve(op->getNumOperands());
284 
285   for (const auto& index_and_layout : operand_layouts) {
286     TF_ASSIGN_OR_RETURN(
287         llvm::ArrayRef<int64_t> shape_to_merge,
288         GetShapeOfValue(op->getOperand(index_and_layout.first)));
289     filtered_preferred_operand_layouts.emplace_back(index_and_layout.second,
290                                                     shape_to_merge);
291   }
292 
293   if (filtered_preferred_operand_layouts.empty())
294     return absl::optional<Layout>();
295 
296   // Merged all operands and it's layouts to a single broadcasted layout.
297   Layout merged_operand_layout = filtered_preferred_operand_layouts[0].first;
298   llvm::ArrayRef<int64_t> merged_shape =
299       filtered_preferred_operand_layouts[0].second;
300 
301   // Statically analyze merged input operands layouts. Broadcasting is allowed
302   // but no cross device communication should be incurred.
303   for (int i = 1; i < filtered_preferred_operand_layouts.size(); ++i) {
304     const auto& operand_index_and_layout_to_merge =
305         filtered_preferred_operand_layouts[i];
306     const Layout& layout_to_merge = operand_index_and_layout_to_merge.first;
307     llvm::ArrayRef<int64_t> shape_to_merge =
308         operand_index_and_layout_to_merge.second;
309 
310     std::vector<std::string> left_splits;
311     std::vector<std::string> right_splits;
312     TF_ASSIGN_OR_RETURN(merged_operand_layout,
313                         GetBroadcastLayoutForElementWise(
314                             merged_operand_layout, layout_to_merge,
315                             merged_shape, shape_to_merge,
316                             /*dims_to_ignore=*/0, left_splits, right_splits));
317   }
318   return absl::optional<Layout>(merged_operand_layout);
319 }
320 
GetForwardedDTensorLayoutInput(mlir::Value value)321 mlir::Value GetForwardedDTensorLayoutInput(mlir::Value value) {
322   auto layout_op =
323       llvm::dyn_cast_or_null<mlir::TF::DTensorLayout>(value.getDefiningOp());
324   if (!layout_op) return value;
325 
326   return layout_op.input();
327 }
328 
329 // Takes an operand and traces its use across function call and
330 // tf_device.cluster boundaries. Note that this may turn one operand into many.
331 // TODO(bfontain): Assumes that a function is only called once. This is checked
332 // when creating func_to_caller.
TraceUseToNextTFOp(mlir::OpOperand * operand,const llvm::DenseMap<llvm::StringRef,mlir::Operation * > & func_to_caller,llvm::SmallVector<mlir::Value,4> * skipped_values)333 llvm::SmallVector<mlir::OpOperand*, 4> TraceUseToNextTFOp(
334     mlir::OpOperand* operand,
335     const llvm::DenseMap<llvm::StringRef, mlir::Operation*>& func_to_caller,
336     llvm::SmallVector<mlir::Value, 4>* skipped_values) {
337   mlir::Operation* owner = operand->getOwner();
338   llvm::SmallVector<mlir::Value, 4> values;
339   if (mlir::isa<mlir::TF::PartitionedCallOp>(owner) ||
340       mlir::isa<mlir::TF::StatefulPartitionedCallOp>(owner)) {
341     mlir::func::FuncOp func;
342     if (mlir::isa<mlir::TF::PartitionedCallOp>(owner))
343       func = mlir::cast<mlir::TF::PartitionedCallOp>(owner).func();
344     else
345       func = mlir::cast<mlir::TF::StatefulPartitionedCallOp>(owner).func();
346     values.emplace_back(func.getArgument(operand->getOperandNumber()));
347   } else if (mlir::isa<mlir::tf_device::ReturnOp>(owner)) {
348     auto device_return = mlir::cast<mlir::tf_device::ReturnOp>(owner);
349     auto enclosing_cluster =
350         device_return->getParentOfType<mlir::tf_device::ClusterOp>();
351     values.emplace_back(
352         enclosing_cluster.getResult(operand->getOperandNumber()));
353   } else if (mlir::isa<mlir::func::ReturnOp>(owner)) {
354     auto func = mlir::cast<mlir::func::ReturnOp>(owner)
355                     ->getParentOfType<mlir::func::FuncOp>();
356     // The one function we don't have a caller for is the main function.
357     // In this case return the empty list as there are no consumers.
358     auto caller = func_to_caller.find(func.getName());
359     if (caller != func_to_caller.end())
360       values.emplace_back(
361           caller->second->getOpResult(operand->getOperandNumber()));
362   } else if (auto yield = mlir::dyn_cast<mlir::TF::YieldOp>(owner)) {
363     if (auto if_op = owner->getParentOfType<mlir::TF::IfRegionOp>()) {
364       values.emplace_back(if_op.getResult(operand->getOperandNumber()));
365     } else if (auto while_op =
366                    owner->getParentOfType<mlir::TF::WhileRegionOp>()) {
367       if (while_op && !while_op.cond().isAncestor(yield->getParentRegion()))
368         values.emplace_back(while_op.getResult(operand->getOperandNumber()));
369     } else {
370       LOG(WARNING)
371           << "Found terminator op for unsupported controlflow operations.";
372     }
373   } else if (mlir::isa<mlir::TF::DTensorLayout>(owner)) {
374     auto dtensor_layout = mlir::cast<mlir::TF::DTensorLayout>(owner);
375     values.emplace_back(dtensor_layout.output());
376   } else if (auto while_op = mlir::dyn_cast<mlir::TF::WhileRegionOp>(owner)) {
377     // Handle loop variant inputs of while op.
378     mlir::Region& cond = while_op.cond();
379     mlir::Region& body = while_op.body();
380     const int operand_index = operand->getOperandNumber();
381     values.emplace_back(cond.front().getArgument(operand_index));
382     values.emplace_back(body.front().getArgument(operand_index));
383   } else {
384     return {operand};
385   }
386   llvm::SmallVector<mlir::OpOperand*, 4> ret;
387   for (mlir::Value value : values) {
388     if (skipped_values != nullptr) skipped_values->emplace_back(value);
389     for (mlir::OpOperand& use : value.getUses()) {
390       // TODO(bfontain): Remove recursion here.
391       const auto& traced_operands =
392           TraceUseToNextTFOp(&use, func_to_caller, skipped_values);
393       ret.append(traced_operands.begin(), traced_operands.end());
394     }
395   }
396 
397   return ret;
398 }
399 
GetFuncToCaller(mlir::ModuleOp module,llvm::DenseMap<llvm::StringRef,mlir::Operation * > & func_to_caller)400 mlir::LogicalResult GetFuncToCaller(
401     mlir::ModuleOp module,
402     llvm::DenseMap<llvm::StringRef, mlir::Operation*>& func_to_caller) {
403   // For now this is a 1:1 mapping and we will error out if a function is called
404   // by more than one op. The layout code assumes there is 1:many relationship
405   // between producers and consumers. If we allow a function to be called
406   // multiple times, then its consumers consume from multiple producers, which
407   // breaks this assumption.
408   // TODO(bfontain): Fix this, possibly by duplicating all functions in order to
409   // make this mapping 1:1 in truth.
410   auto result = module->walk([&](mlir::Operation* op) -> mlir::WalkResult {
411     mlir::StringRef func;
412     if (mlir::TF::PartitionedCallOp call_op =
413             mlir::dyn_cast<mlir::TF::PartitionedCallOp>(op))
414       func = call_op.func().getName();
415     else if (mlir::TF::StatefulPartitionedCallOp call_op =
416                  mlir::dyn_cast<mlir::TF::StatefulPartitionedCallOp>(op))
417       func = call_op.func().getName();
418     else
419       return mlir::WalkResult::advance();
420     if (func_to_caller.find(func) != func_to_caller.end())
421       return op->emitOpError()
422              << "multiple calls found to " << func << " found.";
423     func_to_caller[func] = op;
424     return mlir::WalkResult::advance();
425   });
426   return mlir::failure(result.wasInterrupted());
427 }
428 
PopulateConsumersFromModule(mlir::ModuleOp * module,mlir::Dialect * tf_dialect,llvm::DenseMap<mlir::Value,std::vector<mlir::OpOperand * >> & consumers)429 mlir::LogicalResult PopulateConsumersFromModule(
430     mlir::ModuleOp* module, mlir::Dialect* tf_dialect,
431     llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>>& consumers) {
432   mlir::func::FuncOp main_func =
433       module->lookupSymbol<mlir::func::FuncOp>("main");
434   llvm::DenseMap<llvm::StringRef, mlir::Operation*> func_to_caller;
435 
436   if (mlir::failed(GetFuncToCaller(*module, func_to_caller)))
437     return mlir::failure();
438 
439   module->walk([&](mlir::Operation* op) {
440     if (op->getDialect() != tf_dialect) return;
441 
442     if (mlir::isa<mlir::TF::PartitionedCallOp>(op) ||
443         mlir::isa<mlir::TF::StatefulPartitionedCallOp>(op) ||
444         mlir::isa<mlir::TF::WhileRegionOp>(op) ||
445         mlir::isa<mlir::TF::IfRegionOp>(op) ||
446         mlir::isa<mlir::TF::DTensorLayout>(op))
447       return;
448 
449     for (const auto& value : op->getOpResults()) {
450       // Call clear so that value is in consumers (with an empty vector)even if
451       // there are no 'uses'. This should only happen for ops whose outputs are
452       // directly to main return, e.g. eagerly executed ops.
453       consumers[value].clear();
454       for (auto& operand : value.getUses())
455         for (auto& traced_operand :
456              TraceUseToNextTFOp(&operand, func_to_caller))
457           consumers[value].emplace_back(traced_operand);
458     }
459   });
460 
461   // Note that we need to add in the inputs from the main function (otherwise
462   // we won't have any layouts to propagate!).
463   for (auto& value : main_func.getArguments())
464     for (auto& operand : value.getUses())
465       for (auto* traced_operand : TraceUseToNextTFOp(&operand, func_to_caller))
466         consumers[value].emplace_back(traced_operand);
467   return mlir::success();
468 }
469 
470 // Compute the mesh coordinates from a device id + the current cluster.
471 //
472 // If the mesh shape is [a, b, c, d], then the mesh coordinates are
473 // [device_id/b/c/d, device_id/c/d%b, device_id/d%c, device_id%d]
474 // for convenience, since device_id < a*b*c*d, we can apply %a on the first
475 // coordinate as well for simplicity's sake.
476 // Thus we can decompose this calculation into the following tf ops:
477 // tf.FloorMod(tf.Div(device_id, [b*c*d, c*d, d, 1]), [a, b, c, d]) where
478 // [a, b, c, d] and [b*c*d, c*d, d, 1] are simply precomputed constants.
479 //
480 // Note that this returns a tensor of shape [1, mesh.rank()], suitable for
481 // using with MatMul.
GetMeshCoordinatesFromCluster(mlir::tf_device::ClusterOp cluster)482 StatusOr<mlir::Value> GetMeshCoordinatesFromCluster(
483     mlir::tf_device::ClusterOp cluster) {
484   // First try to find a FloorMod op with kMeshCoordinatesAttr attribute that
485   // has the given mesh in it. If it exists, simply return that op's value.
486   TF_ASSIGN_OR_RETURN(const auto mesh, ExtractDeviceMeshFromOp(cluster));
487   if (!mesh) return errors::InvalidArgument("missing mesh on cluster");
488   string serialized_mesh = mesh->ToString();
489   mlir::Value ret_val;
490   auto result = cluster.walk([&](mlir::TF::FloorModOp op) -> mlir::WalkResult {
491     if (op->hasAttrOfType<mlir::StringAttr>(kMeshCoordinatesAttr) &&
492         op->getAttrOfType<mlir::StringAttr>(kMeshCoordinatesAttr)
493                 .getValue()
494                 .str() == serialized_mesh) {
495       ret_val = op.z();
496       return mlir::WalkResult::interrupt();
497     }
498     return mlir::WalkResult::advance();
499   });
500   if (result.wasInterrupted()) return ret_val;
501 
502   // We didn't find a FloorModOp for the given mesh, so we must produce the
503   // FloorModOp and add the attr so we can find it on next call.
504   std::vector<int32> mesh_shape(mesh->rank());
505   for (int i = 0; i < mesh->rank(); ++i) mesh_shape[i] = mesh->dim(i).size;
506 
507   // This product represents the [b*c*d, c*d, d, 1] from the function
508   // documentation.
509   std::vector<int32> running_product(mesh->rank());
510   running_product[mesh->rank() - 1] = 1;
511   for (int i = mesh->rank() - 1; i > 0; --i)
512     running_product[i - 1] = running_product[i] * mesh_shape[i];
513 
514   mlir::OpBuilder builder(cluster.getContext());
515   builder.setInsertionPointToStart(&cluster.GetBody());
516 
517   auto mesh_shape_type = mlir::RankedTensorType::get(
518       {1, mesh->rank()}, builder.getIntegerType(32));
519   mlir::Attribute mesh_shape_attr =
520       mlir::DenseIntElementsAttr::get(mesh_shape_type, mesh_shape);
521   auto mesh_shape_value =
522       builder.create<mlir::TF::ConstOp>(cluster.getLoc(), mesh_shape_attr)
523           .getResult();
524 
525   auto running_product_value =
526       IntConst(builder, cluster.getLoc(), running_product);
527 
528   TF_ASSIGN_OR_RETURN(mlir::Value device_id, DeviceId(cluster));
529 
530   auto div_op = builder.create<mlir::TF::DivOp>(cluster.getLoc(), device_id,
531                                                 running_product_value);
532 
533   auto mod_op = builder.create<mlir::TF::FloorModOp>(
534       cluster.getLoc(), div_op.z(), mesh_shape_value);
535 
536   mod_op->setAttr(kMeshCoordinatesAttr, builder.getStringAttr(serialized_mesh));
537   return mod_op.z();
538 }
539 
ValidateMetadataAttributes(mlir::Operation * op)540 mlir::LogicalResult ValidateMetadataAttributes(mlir::Operation* op) {
541   // If cluster function has attributes containing inferred layout of resource
542   // handle arguments, then add the attributes to the newly created
543   // StatefulPartitonedCallOp.
544   auto inferred_resource_handle_indices =
545       op->getAttrOfType<mlir::DenseIntElementsAttr>(kNewResourceLayoutIndices);
546   auto inferred_resource_handle_layouts =
547       op->getAttrOfType<mlir::ArrayAttr>(kNewResourceArgLayouts);
548   if (inferred_resource_handle_indices || inferred_resource_handle_layouts) {
549     if (!inferred_resource_handle_indices ||
550         !inferred_resource_handle_layouts ||
551         inferred_resource_handle_indices.getNumElements() !=
552             inferred_resource_handle_layouts.size())
553       return op->emitOpError(
554                  "inferred layout args doesn't match. indices size: ")
555              << (inferred_resource_handle_indices
556                      ? inferred_resource_handle_indices.getNumElements()
557                      : 0)
558              << ", layouts size : "
559              << (inferred_resource_handle_layouts
560                      ? inferred_resource_handle_layouts.size()
561                      : 0);
562   }
563 
564   auto shape_layouts = op->getAttrOfType<mlir::ArrayAttr>(kShapeOpInputLayout);
565   auto shape_op_indices =
566       op->getAttrOfType<mlir::DenseIntElementsAttr>(kShapeOpInputLayoutIndices);
567   if (shape_op_indices || shape_layouts) {
568     if (!shape_op_indices || !shape_layouts ||
569         shape_op_indices.getNumElements() != shape_layouts.size())
570       return op->emitOpError("shape layout args doesn't match. indices size: ")
571              << (shape_op_indices ? shape_op_indices.getNumElements() : 0)
572              << ", layouts size : "
573              << (shape_layouts ? shape_layouts.size() : 0);
574   }
575   return mlir::success();
576 }
577 
RemoveUnusedClusterResults(mlir::tf_device::ClusterOp cluster)578 void RemoveUnusedClusterResults(mlir::tf_device::ClusterOp cluster) {
579   llvm::SmallVector<mlir::OpResult, 4> new_result_values;
580   llvm::SmallVector<mlir::Value, 4> result_producing_values;
581   new_result_values.reserve(cluster->getNumResults());
582   result_producing_values.reserve(cluster->getNumResults());
583   for (mlir::OpResult result : cluster.results()) {
584     if (!result.use_empty()) {
585       new_result_values.emplace_back(result);
586       result_producing_values.emplace_back(
587           cluster.GetBody().getTerminator()->getOperand(
588               result.getResultNumber()));
589     }
590   }
591 
592   if (new_result_values.size() == cluster.getNumResults()) return;
593 
594   llvm::SmallVector<mlir::Type, 4> new_result_types;
595   llvm::transform(new_result_values, std::back_inserter(new_result_types),
596                   [](mlir::Value v) { return v.getType(); });
597 
598   mlir::OpBuilder builder(cluster);
599   auto new_cluster = builder.create<mlir::tf_device::ClusterOp>(
600       cluster.getLoc(), new_result_types);
601   new_cluster->setAttr(kMeshAttr,
602                        cluster->getAttrOfType<mlir::StringAttr>(kMeshAttr));
603   new_cluster.body().push_back(new mlir::Block);
604 
605   auto& cluster_body = cluster.GetBody().getOperations();
606   new_cluster.GetBody().getOperations().splice(
607       new_cluster.GetBody().end(), cluster_body, cluster_body.begin(),
608       std::prev(cluster_body.end()));
609 
610   builder.setInsertionPointToEnd(&new_cluster.GetBody());
611   builder.create<mlir::tf_device::ReturnOp>(cluster.getLoc(),
612                                             result_producing_values);
613 
614   assert(new_cluster.getNumResults() == new_result_values.size());
615   for (auto it : llvm::zip(new_result_values, new_cluster.results())) {
616     mlir::Value value_to_replace = std::get<0>(it);
617     mlir::Value new_result = std::get<1>(it);
618     value_to_replace.replaceAllUsesWith(new_result);
619   }
620   cluster.erase();
621 }
622 
623 namespace {
624 
625 // Keeps track of number of functions added to the global graph for adding
626 // control flows. When converting regional control flow to functional control
627 // flow ops, function names may collide if non-unique branch function names are
628 // used. In order to ensure that all branch functions of TF control flow ops are
629 // unique, we keep track of atomic counter for each control flow functions.
630 // See b/174253694 for more details.
631 std::atomic<int32> dtensor_controlflow_function_counter{0};
632 
633 }  // namespace
634 
GetUniqueControlflowFnName(const std::string & prefix,mlir::OpBuilder & builder)635 mlir::StringAttr GetUniqueControlflowFnName(const std::string& prefix,
636                                             mlir::OpBuilder& builder) {
637   int32 unique_id = dtensor_controlflow_function_counter++;
638   return builder.getStringAttr(
639       absl::StrCat(prefix, "_dtensor_function_", unique_id));
640 }
641 
SetBuilderInsertionAfterValue(mlir::Value value,mlir::OpBuilder & builder)642 Status SetBuilderInsertionAfterValue(mlir::Value value,
643                                      mlir::OpBuilder& builder) {
644   if (value.isa<mlir::OpResult>()) {
645     builder.setInsertionPointAfterValue(value);
646     return OkStatus();
647   }
648   mlir::tf_device::ClusterOp cluster;
649   for (mlir::Operation* op : value.getUsers()) {
650     mlir::tf_device::ClusterOp new_cluster =
651         op->getParentOfType<mlir::tf_device::ClusterOp>();
652     if (!new_cluster) continue;
653     if (!cluster) cluster = new_cluster;
654     if (cluster != new_cluster)
655       return errors::Internal("value has multiple uses in different clusters");
656   }
657   if (!cluster) return errors::Internal("value not used in any cluster");
658 
659   builder.setInsertionPointToStart(cluster.getBody());
660   return OkStatus();
661 }
662 
PrintTensor(mlir::Value value,const std::string & format_string="%s")663 Status PrintTensor(mlir::Value value, const std::string& format_string = "%s") {
664   mlir::OpBuilder builder(value.getContext());
665   builder.setInsertionPointAfterValue(value);
666   TF_ASSIGN_OR_RETURN(mlir::Value device_id, DeviceId(value));
667   std::string all_format = absl::StrCat("Core %s: ", format_string);
668   // Scalar string type
669   mlir::RankedTensorType scalar_string =
670       mlir::RankedTensorType::get({}, builder.getType<mlir::TF::StringType>());
671   mlir::TF::StringFormatOp format = builder.create<mlir::TF::StringFormatOp>(
672       value.getLoc(), scalar_string, mlir::ValueRange({device_id, value}));
673   format->setAttr("template", builder.getStringAttr(all_format));
674   builder.create<mlir::TF::PrintV2Op>(value.getLoc(), format.output(),
675                                       /*output_stream=*/"log(info)",
676                                       /*end=*/"\n");
677   return OkStatus();
678 }
679 
ExtractConstStringVectorFromValue(mlir::Value value,llvm::SmallVectorImpl<std::string> & out_vector)680 Status ExtractConstStringVectorFromValue(
681     mlir::Value value, llvm::SmallVectorImpl<std::string>& out_vector) {
682   value = GetForwardedDTensorLayoutInput(value);
683   if (value.isa<mlir::BlockArgument>())
684     return errors::Internal("Unable get constant value from block argument.");
685   mlir::DenseStringElementsAttr attr;
686   if (!matchPattern(value, m_Constant(&attr))) {
687     return errors::Internal(
688         llvm::formatv("failed to extract constant string vector from : {0}",
689                       value)
690             .str());
691   }
692   for (const auto& str : attr.getRawStringData()) {
693     out_vector.push_back(str.str());
694   }
695   return OkStatus();
696 }
697 
ExtractConstScalarStringFromValue(mlir::Value value)698 StatusOr<std::string> ExtractConstScalarStringFromValue(mlir::Value value) {
699   value = GetForwardedDTensorLayoutInput(value);
700   if (value.isa<mlir::BlockArgument>())
701     return errors::Internal("Unable get constant value from block argument.");
702   mlir::DenseStringElementsAttr attr;
703   if (!matchPattern(value, m_Constant(&attr))) {
704     return errors::Internal(absl::StrCat("required constant value for ",
705                                          OpName(value.getDefiningOp())));
706   }
707   if (attr.size() != 1) {
708     return errors::Internal(absl::StrCat("expected 1 element, got ",
709                                          attr.size(), " for ",
710                                          OpName(value.getDefiningOp())));
711   }
712   return std::string(*attr.getRawStringData().begin());
713 }
714 
TopologicalIterator(mlir::func::FuncOp main_func)715 TopologicalIterator::TopologicalIterator(mlir::func::FuncOp main_func)
716     : ops_to_visit_{&main_func.front().front()} {
717   funcs_visited_.insert(main_func.getName());
718   funcs_visited_in_call_stack_.insert(main_func.getName());
719 }
720 
next()721 mlir::Operation* TopologicalIterator::next() {
722   if (!hasNext()) return nullptr;
723 
724   auto* op = ops_to_visit_.pop_back_val();
725   auto* next_op = op->getNextNode();
726   if (next_op) ops_to_visit_.push_back(next_op);
727 
728   // If this is a function call op, push the first op of the function body so
729   // that the function body is converted before the call site.
730   absl::optional<mlir::func::FuncOp> func = MaybeFindFunction(op);
731   if (func.has_value()) {
732     mlir::StringRef func_name = func->getName();
733 
734     if (funcs_visited_.contains(func_name)) return next();
735 
736     ops_to_visit_.push_back(&(func->front().front()));
737     funcs_visited_.insert(func_name);
738   }
739 
740   // If we have reached the end of a function body, remove the function from
741   // our active set.
742   if (!next_op && !funcs_visited_in_call_stack_.empty())
743     if (auto func = op->getParentOfType<mlir::func::FuncOp>())
744       funcs_visited_in_call_stack_.erase(func.getName());
745 
746   if (auto cluster_op = mlir::dyn_cast<mlir::tf_device::ClusterOp>(op))
747     ops_to_visit_.push_back(&cluster_op.GetBody().front());
748 
749   if (auto while_op = mlir::dyn_cast<mlir::TF::WhileRegionOp>(op)) {
750     ops_to_visit_.push_back(&while_op.cond().front().front());
751     ops_to_visit_.push_back(&while_op.body().front().front());
752   }
753 
754   if (auto if_op = mlir::dyn_cast<mlir::TF::IfRegionOp>(op)) {
755     ops_to_visit_.push_back(&if_op.then_branch().front().front());
756     ops_to_visit_.push_back(&if_op.else_branch().front().front());
757   }
758   return op;
759 }
760 
hasNext()761 bool TopologicalIterator::hasNext() { return !ops_to_visit_.empty(); }
762 
763 }  // namespace dtensor
764 }  // namespace tensorflow
765