1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
17
18 #include <algorithm>
19 #include <atomic>
20 #include <iterator>
21 #include <string>
22
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/string_view.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
29 #include "mlir/IR/Builders.h" // from @llvm-project
30 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
33 #include "mlir/IR/Location.h" // from @llvm-project
34 #include "mlir/IR/MLIRContext.h" // from @llvm-project
35 #include "mlir/IR/OperationSupport.h" // from @llvm-project
36 #include "mlir/IR/Value.h" // from @llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
41 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
42 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/utils/convert_op_folder.h"
43 #include "tensorflow/core/platform/errors.h"
44 #include "tensorflow/dtensor/cc/constants.h"
45 #include "tensorflow/dtensor/cc/tensor_layout.h"
46 #include "tensorflow/dtensor/mlir/device_utils.h"
47 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
48 #include "tensorflow/dtensor/mlir/layout_parsing.h"
49 #include "tensorflow/dtensor/mlir/op_utils.h"
50 #include "tensorflow/dtensor/mlir/shape_utils.h"
51 #include "tensorflow/dtensor/mlir/value_utils.h"
52
53 namespace tensorflow {
54 namespace dtensor {
55
LocalTypeFromGlobalType(const Layout & layout,const mlir::TensorType & original_type)56 StatusOr<mlir::TensorType> LocalTypeFromGlobalType(
57 const Layout& layout, const mlir::TensorType& original_type) {
58 if (!original_type.hasRank()) {
59 return original_type;
60 }
61 auto shape = llvm::to_vector<4>(original_type.getShape());
62 auto shard_values = layout.num_shards();
63 for (int output_axis = 0; output_axis < shape.size(); ++output_axis) {
64 if (shape[output_axis] != mlir::ShapedType::kDynamicSize) {
65 if (shape[output_axis] % shard_values[output_axis] != 0) {
66 return errors::InvalidArgument(
67 "The sharding spec for axis ", output_axis, " splits among ",
68 shard_values[output_axis],
69 " values, which does not evenly divide the length of that axis "
70 "(",
71 shape[output_axis], "). The full requested layout is ",
72 layout.ToString(), ".");
73 }
74 shape[output_axis] /= shard_values[output_axis];
75 }
76 }
77 mlir::RankedTensorType new_output_type =
78 mlir::RankedTensorType::get(shape, original_type.getElementType());
79 return new_output_type;
80 }
81
GlobalTypeFromLocalType(const Layout & layout,const mlir::TensorType & original_type)82 StatusOr<mlir::TensorType> GlobalTypeFromLocalType(
83 const Layout& layout, const mlir::TensorType& original_type) {
84 if (!original_type.hasRank()) {
85 return original_type;
86 }
87 auto shape = llvm::to_vector<4>(original_type.getShape());
88 auto shard_values = layout.num_shards();
89 for (int output_axis = 0; output_axis < shape.size(); ++output_axis)
90 if (shape[output_axis] != mlir::ShapedType::kDynamicSize)
91 shape[output_axis] *= shard_values[output_axis];
92 mlir::RankedTensorType new_output_type =
93 mlir::RankedTensorType::get(shape, original_type.getElementType());
94 return new_output_type;
95 }
96
CreateSplitOp(const int num_split,const int split_dimension,const mlir::Location location,mlir::Value src_input,mlir::OpBuilder * builder,mlir::TF::SplitOp * split_op)97 Status CreateSplitOp(const int num_split, const int split_dimension,
98 const mlir::Location location, mlir::Value src_input,
99 mlir::OpBuilder* builder, mlir::TF::SplitOp* split_op) {
100 // Creates a const op to hold split dimension value.
101 auto split_dim_type =
102 mlir::RankedTensorType::get({}, builder->getIntegerType(32));
103 auto split_dimension_attr =
104 mlir::DenseElementsAttr::get(split_dim_type, split_dimension);
105 auto split_dimension_op = builder->create<mlir::TF::ConstOp>(
106 location, split_dim_type, split_dimension_attr);
107
108 // Correctly set output shapes of split op output if input shape is statically
109 // known.
110 mlir::Type output_type;
111 auto input_type = src_input.getType().cast<mlir::TensorType>();
112
113 if (input_type.hasRank()) {
114 if (input_type.getShape()[split_dimension] ==
115 mlir::ShapedType::kDynamicSize) {
116 output_type = input_type;
117 } else {
118 auto shape = llvm::to_vector<4>(input_type.getShape());
119 if (shape[split_dimension] % num_split != 0) {
120 return errors::InvalidArgument(
121 llvm::formatv(
122 "incorrect input sharding configuration received. "
123 "{0}-th dimension of the input must be evenly divisible by {1}",
124 split_dimension, num_split)
125 .str());
126 }
127
128 shape[split_dimension] = shape[split_dimension] / num_split;
129 output_type =
130 mlir::RankedTensorType::get(shape, input_type.getElementType());
131 }
132 } else {
133 output_type = input_type;
134 }
135
136 // Creates a split op that splits |src_input| along |split_dimension|.
137 llvm::SmallVector<mlir::Type, 4> output_types(num_split, output_type);
138 *split_op = builder->create<mlir::TF::SplitOp>(
139 location, output_types, split_dimension_op.output(), src_input);
140 return OkStatus();
141 }
142
143 // Given layouts + shapes, determines if the two are broadcasting compatible.
144 // When broadcasting we effectively line up the shapes and layouts by the end.
145 // The input with lower rank can be thought of as having abs(rank_a-rank_b)
146 // replicated dims of size 1 prepended to it.
147 //
148 // Returns the broadcast layout and the splits in the two inputs needed to run
149 // an elementwise op efficiently.
150 //
151 // Checks that a given mesh dimension is not used in different tensor dimensions
152 // in the two input layouts.
153 // E.g. a layout like (unsharded,x,unsharded) is not compatible with
154 // (unsharded,x) or (x,unsharded,unsharded) but is compatible with
155 // (x,unsharded), (unsharded,unsharded) or (unsharded,x,unsharded).
156 // (Note that due to broadcasting, we compare the dimensions from the end).
157 //
158 // If dims_to_ignore is > 0, then we ignore when a mesh dimension is used in
159 // different tensor dimensions when those dimensions are both in the last
160 // dims_to_ignore tensor dimensions of each input.
161 // E.g. If dims_to_ignore = 2, then (unsharded,x,unsharded) is now compatible
162 // with (unsharded,x) and it not compatible with (x,unsharded,unsharded).
163 //
164 // The output layout will be of rank max(layout_a.rank(), layout_b.rank()) -
165 // dims_to_ignore and will be replicated on a dimension if either one of the
166 // input layouts is replicated on that dimension. Once again recall due to
167 // broadcasting, layouts are aligned by their ends and not their beginnings.
168 // E.g. if dims_to_ignore is zero, the output layout for the inputs
169 // (unsharded,x,unsharded) and (unsharded,y) is (unsharded,x,y).
170 // If dims_to_ignore is two, the output for (y,x,unsharded) and
171 // (unsharded,x) is just (y).
172 //
173 // In the case that one tensor is sharded and the other is not on a given
174 // dimension, element wise operations *may* need to split the unsharded tensor
175 // along the same mesh dimension that the other input is split on. Note that
176 // the split is *not* needed if the unsharded tensor has dimension of size 1,
177 // due to broadcasting.
178 //
179 // To help with the needed splittings, the vectors to_split_* are resized to the
180 // rank of each input and if that dimension of the tensor needs to be split for
181 // and elementwise op, we record the mesh dimension it should be split along in
182 // the vector.
183 // E.g. in the case of input layouts (unsharded,x,unsharded) and
184 // (unsharded,unsharded) with dimensions (10,10,10) and (10,10),
185 // to_split_a = {"unsharded", "unsharded", "unsharded"} and to_split_b =
186 // {"x", "unsharded"}.
187 // If the shapes were (10,10,10) and (1,10), then to_split_a = {"unsharded",
188 // "unsharded", "unsharded"} and to_split_b = {"unsharded", "unsharded"}.
189 //
190 // Note that "unsharded" == Layout::kUnshardedDim.
191 // NOTE: shape_a and shape_b are *global* shapes.
GetBroadcastLayoutForElementWise(const Layout & layout_a,const Layout & layout_b,mlir::ArrayRef<int64_t> shape_a,mlir::ArrayRef<int64_t> shape_b,int64_t dims_to_ignore,std::vector<std::string> & to_split_a,std::vector<std::string> & to_split_b)192 StatusOr<Layout> GetBroadcastLayoutForElementWise(
193 const Layout& layout_a, const Layout& layout_b,
194 mlir::ArrayRef<int64_t> shape_a, mlir::ArrayRef<int64_t> shape_b,
195 int64_t dims_to_ignore, std::vector<std::string>& to_split_a,
196 std::vector<std::string>& to_split_b) {
197 if (layout_a.mesh() != layout_b.mesh())
198 return errors::InvalidArgument(
199 "layout_a and layout_b cannot be broadcast as they are on different "
200 "meshes.");
201
202 const int rank_a = layout_a.rank();
203 const int rank_b = layout_b.rank();
204 const int rank_offset_a = std::max(0, rank_b - rank_a);
205 const int rank_offset_b = std::max(0, rank_a - rank_b);
206 absl::flat_hash_map<std::string, int> mesh_dim_map_a;
207 absl::flat_hash_map<std::string, int> mesh_dim_map_b;
208 std::vector<string> output_layout_specs;
209
210 auto unsharded_specs = [](const int new_size) -> std::vector<std::string> {
211 std::vector<std::string> spec_strs(new_size, Layout::kUnshardedDim);
212 return spec_strs;
213 };
214
215 to_split_a = unsharded_specs(rank_a - dims_to_ignore);
216 to_split_b = unsharded_specs(rank_b - dims_to_ignore);
217
218 // Note that we record ranks over all dimensions even ones we ignore.
219 // We will check that a non-ignored dimension of a tensor does not use a
220 // mesh dimension that is used by an ignored dimension in the other tensor.
221 for (int i = 0; i < rank_a; ++i)
222 if (!Layout::IsUnshardedDimension(layout_a.sharding_spec(i)))
223 mesh_dim_map_a[layout_a.sharding_spec(i)] = i;
224 for (int i = 0; i < rank_b; ++i)
225 if (!Layout::IsUnshardedDimension(layout_b.sharding_spec(i)))
226 mesh_dim_map_b[layout_b.sharding_spec(i)] = i;
227
228 for (int i = 0; i < std::max(rank_a, rank_b) - dims_to_ignore; ++i) {
229 const int dim_a = i - rank_offset_a;
230 const int dim_b = i - rank_offset_b;
231 // When ranks are not equal we treat the first rank_offset_* dims of the
232 // shorter layout as not sharded.
233 const std::string mesh_dim_a =
234 dim_a >= 0 ? layout_a.sharding_spec(dim_a) : Layout::kUnshardedDim;
235 const std::string mesh_dim_b =
236 dim_b >= 0 ? layout_b.sharding_spec(dim_b) : Layout::kUnshardedDim;
237 // When ranks are not equal, we treat the first rank_offset_* dims of the
238 // shorter shape as if they were 1.
239 const int64_t tensor_dim_a = dim_a >= 0 ? shape_a[dim_a] : 1;
240 const int64_t tensor_dim_b = dim_b >= 0 ? shape_b[dim_b] : 1;
241
242 // Check for conflicted dimensions. If occurred, chose unsharded as merged
243 // result, if generate_unsharded_dim_for_conflicts is set by call site.
244 bool have_conflicted_dim = false;
245 if (!Layout::IsUnshardedDimension(mesh_dim_a) &&
246 mesh_dim_map_b.contains(mesh_dim_a) &&
247 mesh_dim_map_b[mesh_dim_a] != dim_b)
248 have_conflicted_dim = true;
249
250 if (!Layout::IsUnshardedDimension(mesh_dim_b) &&
251 mesh_dim_map_a.contains(mesh_dim_b) &&
252 mesh_dim_map_a[mesh_dim_b] != dim_a)
253 have_conflicted_dim = true;
254
255 // If both dimensions are sharded, we have already verified that they are
256 // sharded on the same mesh dim.
257 if (have_conflicted_dim) {
258 output_layout_specs.emplace_back(Layout::kUnshardedDim);
259 } else {
260 output_layout_specs.emplace_back(
261 Layout::IsUnshardedDimension(mesh_dim_a) ? mesh_dim_b : mesh_dim_a);
262 }
263 if (dim_a >= 0 && tensor_dim_a > 1 &&
264 Layout::IsUnshardedDimension(mesh_dim_a) &&
265 !Layout::IsUnshardedDimension(mesh_dim_b)) {
266 to_split_a[dim_a] = mesh_dim_b;
267 }
268 if (dim_b >= 0 && tensor_dim_b > 1 &&
269 Layout::IsUnshardedDimension(mesh_dim_b) &&
270 !Layout::IsUnshardedDimension(mesh_dim_a)) {
271 to_split_b[dim_b] = mesh_dim_a;
272 }
273 }
274 return Layout::GetLayout(output_layout_specs, layout_a.mesh());
275 }
276
GetMergedOperandLayout(const llvm::DenseMap<int,Layout> & operand_layouts,mlir::Operation * op)277 StatusOr<absl::optional<Layout>> GetMergedOperandLayout(
278 const llvm::DenseMap<int, Layout>& operand_layouts, mlir::Operation* op) {
279 // Represents list of Layouts and it's operand index where layout value is
280 // defined (i.e. layout is not absl::nullopt).
281 llvm::SmallVector<std::pair<const Layout&, llvm::ArrayRef<int64_t>>, 4>
282 filtered_preferred_operand_layouts;
283 filtered_preferred_operand_layouts.reserve(op->getNumOperands());
284
285 for (const auto& index_and_layout : operand_layouts) {
286 TF_ASSIGN_OR_RETURN(
287 llvm::ArrayRef<int64_t> shape_to_merge,
288 GetShapeOfValue(op->getOperand(index_and_layout.first)));
289 filtered_preferred_operand_layouts.emplace_back(index_and_layout.second,
290 shape_to_merge);
291 }
292
293 if (filtered_preferred_operand_layouts.empty())
294 return absl::optional<Layout>();
295
296 // Merged all operands and it's layouts to a single broadcasted layout.
297 Layout merged_operand_layout = filtered_preferred_operand_layouts[0].first;
298 llvm::ArrayRef<int64_t> merged_shape =
299 filtered_preferred_operand_layouts[0].second;
300
301 // Statically analyze merged input operands layouts. Broadcasting is allowed
302 // but no cross device communication should be incurred.
303 for (int i = 1; i < filtered_preferred_operand_layouts.size(); ++i) {
304 const auto& operand_index_and_layout_to_merge =
305 filtered_preferred_operand_layouts[i];
306 const Layout& layout_to_merge = operand_index_and_layout_to_merge.first;
307 llvm::ArrayRef<int64_t> shape_to_merge =
308 operand_index_and_layout_to_merge.second;
309
310 std::vector<std::string> left_splits;
311 std::vector<std::string> right_splits;
312 TF_ASSIGN_OR_RETURN(merged_operand_layout,
313 GetBroadcastLayoutForElementWise(
314 merged_operand_layout, layout_to_merge,
315 merged_shape, shape_to_merge,
316 /*dims_to_ignore=*/0, left_splits, right_splits));
317 }
318 return absl::optional<Layout>(merged_operand_layout);
319 }
320
GetForwardedDTensorLayoutInput(mlir::Value value)321 mlir::Value GetForwardedDTensorLayoutInput(mlir::Value value) {
322 auto layout_op =
323 llvm::dyn_cast_or_null<mlir::TF::DTensorLayout>(value.getDefiningOp());
324 if (!layout_op) return value;
325
326 return layout_op.input();
327 }
328
329 // Takes an operand and traces its use across function call and
330 // tf_device.cluster boundaries. Note that this may turn one operand into many.
331 // TODO(bfontain): Assumes that a function is only called once. This is checked
332 // when creating func_to_caller.
TraceUseToNextTFOp(mlir::OpOperand * operand,const llvm::DenseMap<llvm::StringRef,mlir::Operation * > & func_to_caller,llvm::SmallVector<mlir::Value,4> * skipped_values)333 llvm::SmallVector<mlir::OpOperand*, 4> TraceUseToNextTFOp(
334 mlir::OpOperand* operand,
335 const llvm::DenseMap<llvm::StringRef, mlir::Operation*>& func_to_caller,
336 llvm::SmallVector<mlir::Value, 4>* skipped_values) {
337 mlir::Operation* owner = operand->getOwner();
338 llvm::SmallVector<mlir::Value, 4> values;
339 if (mlir::isa<mlir::TF::PartitionedCallOp>(owner) ||
340 mlir::isa<mlir::TF::StatefulPartitionedCallOp>(owner)) {
341 mlir::func::FuncOp func;
342 if (mlir::isa<mlir::TF::PartitionedCallOp>(owner))
343 func = mlir::cast<mlir::TF::PartitionedCallOp>(owner).func();
344 else
345 func = mlir::cast<mlir::TF::StatefulPartitionedCallOp>(owner).func();
346 values.emplace_back(func.getArgument(operand->getOperandNumber()));
347 } else if (mlir::isa<mlir::tf_device::ReturnOp>(owner)) {
348 auto device_return = mlir::cast<mlir::tf_device::ReturnOp>(owner);
349 auto enclosing_cluster =
350 device_return->getParentOfType<mlir::tf_device::ClusterOp>();
351 values.emplace_back(
352 enclosing_cluster.getResult(operand->getOperandNumber()));
353 } else if (mlir::isa<mlir::func::ReturnOp>(owner)) {
354 auto func = mlir::cast<mlir::func::ReturnOp>(owner)
355 ->getParentOfType<mlir::func::FuncOp>();
356 // The one function we don't have a caller for is the main function.
357 // In this case return the empty list as there are no consumers.
358 auto caller = func_to_caller.find(func.getName());
359 if (caller != func_to_caller.end())
360 values.emplace_back(
361 caller->second->getOpResult(operand->getOperandNumber()));
362 } else if (auto yield = mlir::dyn_cast<mlir::TF::YieldOp>(owner)) {
363 if (auto if_op = owner->getParentOfType<mlir::TF::IfRegionOp>()) {
364 values.emplace_back(if_op.getResult(operand->getOperandNumber()));
365 } else if (auto while_op =
366 owner->getParentOfType<mlir::TF::WhileRegionOp>()) {
367 if (while_op && !while_op.cond().isAncestor(yield->getParentRegion()))
368 values.emplace_back(while_op.getResult(operand->getOperandNumber()));
369 } else {
370 LOG(WARNING)
371 << "Found terminator op for unsupported controlflow operations.";
372 }
373 } else if (mlir::isa<mlir::TF::DTensorLayout>(owner)) {
374 auto dtensor_layout = mlir::cast<mlir::TF::DTensorLayout>(owner);
375 values.emplace_back(dtensor_layout.output());
376 } else if (auto while_op = mlir::dyn_cast<mlir::TF::WhileRegionOp>(owner)) {
377 // Handle loop variant inputs of while op.
378 mlir::Region& cond = while_op.cond();
379 mlir::Region& body = while_op.body();
380 const int operand_index = operand->getOperandNumber();
381 values.emplace_back(cond.front().getArgument(operand_index));
382 values.emplace_back(body.front().getArgument(operand_index));
383 } else {
384 return {operand};
385 }
386 llvm::SmallVector<mlir::OpOperand*, 4> ret;
387 for (mlir::Value value : values) {
388 if (skipped_values != nullptr) skipped_values->emplace_back(value);
389 for (mlir::OpOperand& use : value.getUses()) {
390 // TODO(bfontain): Remove recursion here.
391 const auto& traced_operands =
392 TraceUseToNextTFOp(&use, func_to_caller, skipped_values);
393 ret.append(traced_operands.begin(), traced_operands.end());
394 }
395 }
396
397 return ret;
398 }
399
GetFuncToCaller(mlir::ModuleOp module,llvm::DenseMap<llvm::StringRef,mlir::Operation * > & func_to_caller)400 mlir::LogicalResult GetFuncToCaller(
401 mlir::ModuleOp module,
402 llvm::DenseMap<llvm::StringRef, mlir::Operation*>& func_to_caller) {
403 // For now this is a 1:1 mapping and we will error out if a function is called
404 // by more than one op. The layout code assumes there is 1:many relationship
405 // between producers and consumers. If we allow a function to be called
406 // multiple times, then its consumers consume from multiple producers, which
407 // breaks this assumption.
408 // TODO(bfontain): Fix this, possibly by duplicating all functions in order to
409 // make this mapping 1:1 in truth.
410 auto result = module->walk([&](mlir::Operation* op) -> mlir::WalkResult {
411 mlir::StringRef func;
412 if (mlir::TF::PartitionedCallOp call_op =
413 mlir::dyn_cast<mlir::TF::PartitionedCallOp>(op))
414 func = call_op.func().getName();
415 else if (mlir::TF::StatefulPartitionedCallOp call_op =
416 mlir::dyn_cast<mlir::TF::StatefulPartitionedCallOp>(op))
417 func = call_op.func().getName();
418 else
419 return mlir::WalkResult::advance();
420 if (func_to_caller.find(func) != func_to_caller.end())
421 return op->emitOpError()
422 << "multiple calls found to " << func << " found.";
423 func_to_caller[func] = op;
424 return mlir::WalkResult::advance();
425 });
426 return mlir::failure(result.wasInterrupted());
427 }
428
PopulateConsumersFromModule(mlir::ModuleOp * module,mlir::Dialect * tf_dialect,llvm::DenseMap<mlir::Value,std::vector<mlir::OpOperand * >> & consumers)429 mlir::LogicalResult PopulateConsumersFromModule(
430 mlir::ModuleOp* module, mlir::Dialect* tf_dialect,
431 llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>>& consumers) {
432 mlir::func::FuncOp main_func =
433 module->lookupSymbol<mlir::func::FuncOp>("main");
434 llvm::DenseMap<llvm::StringRef, mlir::Operation*> func_to_caller;
435
436 if (mlir::failed(GetFuncToCaller(*module, func_to_caller)))
437 return mlir::failure();
438
439 module->walk([&](mlir::Operation* op) {
440 if (op->getDialect() != tf_dialect) return;
441
442 if (mlir::isa<mlir::TF::PartitionedCallOp>(op) ||
443 mlir::isa<mlir::TF::StatefulPartitionedCallOp>(op) ||
444 mlir::isa<mlir::TF::WhileRegionOp>(op) ||
445 mlir::isa<mlir::TF::IfRegionOp>(op) ||
446 mlir::isa<mlir::TF::DTensorLayout>(op))
447 return;
448
449 for (const auto& value : op->getOpResults()) {
450 // Call clear so that value is in consumers (with an empty vector)even if
451 // there are no 'uses'. This should only happen for ops whose outputs are
452 // directly to main return, e.g. eagerly executed ops.
453 consumers[value].clear();
454 for (auto& operand : value.getUses())
455 for (auto& traced_operand :
456 TraceUseToNextTFOp(&operand, func_to_caller))
457 consumers[value].emplace_back(traced_operand);
458 }
459 });
460
461 // Note that we need to add in the inputs from the main function (otherwise
462 // we won't have any layouts to propagate!).
463 for (auto& value : main_func.getArguments())
464 for (auto& operand : value.getUses())
465 for (auto* traced_operand : TraceUseToNextTFOp(&operand, func_to_caller))
466 consumers[value].emplace_back(traced_operand);
467 return mlir::success();
468 }
469
470 // Compute the mesh coordinates from a device id + the current cluster.
471 //
472 // If the mesh shape is [a, b, c, d], then the mesh coordinates are
473 // [device_id/b/c/d, device_id/c/d%b, device_id/d%c, device_id%d]
474 // for convenience, since device_id < a*b*c*d, we can apply %a on the first
475 // coordinate as well for simplicity's sake.
476 // Thus we can decompose this calculation into the following tf ops:
477 // tf.FloorMod(tf.Div(device_id, [b*c*d, c*d, d, 1]), [a, b, c, d]) where
478 // [a, b, c, d] and [b*c*d, c*d, d, 1] are simply precomputed constants.
479 //
480 // Note that this returns a tensor of shape [1, mesh.rank()], suitable for
481 // using with MatMul.
GetMeshCoordinatesFromCluster(mlir::tf_device::ClusterOp cluster)482 StatusOr<mlir::Value> GetMeshCoordinatesFromCluster(
483 mlir::tf_device::ClusterOp cluster) {
484 // First try to find a FloorMod op with kMeshCoordinatesAttr attribute that
485 // has the given mesh in it. If it exists, simply return that op's value.
486 TF_ASSIGN_OR_RETURN(const auto mesh, ExtractDeviceMeshFromOp(cluster));
487 if (!mesh) return errors::InvalidArgument("missing mesh on cluster");
488 string serialized_mesh = mesh->ToString();
489 mlir::Value ret_val;
490 auto result = cluster.walk([&](mlir::TF::FloorModOp op) -> mlir::WalkResult {
491 if (op->hasAttrOfType<mlir::StringAttr>(kMeshCoordinatesAttr) &&
492 op->getAttrOfType<mlir::StringAttr>(kMeshCoordinatesAttr)
493 .getValue()
494 .str() == serialized_mesh) {
495 ret_val = op.z();
496 return mlir::WalkResult::interrupt();
497 }
498 return mlir::WalkResult::advance();
499 });
500 if (result.wasInterrupted()) return ret_val;
501
502 // We didn't find a FloorModOp for the given mesh, so we must produce the
503 // FloorModOp and add the attr so we can find it on next call.
504 std::vector<int32> mesh_shape(mesh->rank());
505 for (int i = 0; i < mesh->rank(); ++i) mesh_shape[i] = mesh->dim(i).size;
506
507 // This product represents the [b*c*d, c*d, d, 1] from the function
508 // documentation.
509 std::vector<int32> running_product(mesh->rank());
510 running_product[mesh->rank() - 1] = 1;
511 for (int i = mesh->rank() - 1; i > 0; --i)
512 running_product[i - 1] = running_product[i] * mesh_shape[i];
513
514 mlir::OpBuilder builder(cluster.getContext());
515 builder.setInsertionPointToStart(&cluster.GetBody());
516
517 auto mesh_shape_type = mlir::RankedTensorType::get(
518 {1, mesh->rank()}, builder.getIntegerType(32));
519 mlir::Attribute mesh_shape_attr =
520 mlir::DenseIntElementsAttr::get(mesh_shape_type, mesh_shape);
521 auto mesh_shape_value =
522 builder.create<mlir::TF::ConstOp>(cluster.getLoc(), mesh_shape_attr)
523 .getResult();
524
525 auto running_product_value =
526 IntConst(builder, cluster.getLoc(), running_product);
527
528 TF_ASSIGN_OR_RETURN(mlir::Value device_id, DeviceId(cluster));
529
530 auto div_op = builder.create<mlir::TF::DivOp>(cluster.getLoc(), device_id,
531 running_product_value);
532
533 auto mod_op = builder.create<mlir::TF::FloorModOp>(
534 cluster.getLoc(), div_op.z(), mesh_shape_value);
535
536 mod_op->setAttr(kMeshCoordinatesAttr, builder.getStringAttr(serialized_mesh));
537 return mod_op.z();
538 }
539
ValidateMetadataAttributes(mlir::Operation * op)540 mlir::LogicalResult ValidateMetadataAttributes(mlir::Operation* op) {
541 // If cluster function has attributes containing inferred layout of resource
542 // handle arguments, then add the attributes to the newly created
543 // StatefulPartitonedCallOp.
544 auto inferred_resource_handle_indices =
545 op->getAttrOfType<mlir::DenseIntElementsAttr>(kNewResourceLayoutIndices);
546 auto inferred_resource_handle_layouts =
547 op->getAttrOfType<mlir::ArrayAttr>(kNewResourceArgLayouts);
548 if (inferred_resource_handle_indices || inferred_resource_handle_layouts) {
549 if (!inferred_resource_handle_indices ||
550 !inferred_resource_handle_layouts ||
551 inferred_resource_handle_indices.getNumElements() !=
552 inferred_resource_handle_layouts.size())
553 return op->emitOpError(
554 "inferred layout args doesn't match. indices size: ")
555 << (inferred_resource_handle_indices
556 ? inferred_resource_handle_indices.getNumElements()
557 : 0)
558 << ", layouts size : "
559 << (inferred_resource_handle_layouts
560 ? inferred_resource_handle_layouts.size()
561 : 0);
562 }
563
564 auto shape_layouts = op->getAttrOfType<mlir::ArrayAttr>(kShapeOpInputLayout);
565 auto shape_op_indices =
566 op->getAttrOfType<mlir::DenseIntElementsAttr>(kShapeOpInputLayoutIndices);
567 if (shape_op_indices || shape_layouts) {
568 if (!shape_op_indices || !shape_layouts ||
569 shape_op_indices.getNumElements() != shape_layouts.size())
570 return op->emitOpError("shape layout args doesn't match. indices size: ")
571 << (shape_op_indices ? shape_op_indices.getNumElements() : 0)
572 << ", layouts size : "
573 << (shape_layouts ? shape_layouts.size() : 0);
574 }
575 return mlir::success();
576 }
577
RemoveUnusedClusterResults(mlir::tf_device::ClusterOp cluster)578 void RemoveUnusedClusterResults(mlir::tf_device::ClusterOp cluster) {
579 llvm::SmallVector<mlir::OpResult, 4> new_result_values;
580 llvm::SmallVector<mlir::Value, 4> result_producing_values;
581 new_result_values.reserve(cluster->getNumResults());
582 result_producing_values.reserve(cluster->getNumResults());
583 for (mlir::OpResult result : cluster.results()) {
584 if (!result.use_empty()) {
585 new_result_values.emplace_back(result);
586 result_producing_values.emplace_back(
587 cluster.GetBody().getTerminator()->getOperand(
588 result.getResultNumber()));
589 }
590 }
591
592 if (new_result_values.size() == cluster.getNumResults()) return;
593
594 llvm::SmallVector<mlir::Type, 4> new_result_types;
595 llvm::transform(new_result_values, std::back_inserter(new_result_types),
596 [](mlir::Value v) { return v.getType(); });
597
598 mlir::OpBuilder builder(cluster);
599 auto new_cluster = builder.create<mlir::tf_device::ClusterOp>(
600 cluster.getLoc(), new_result_types);
601 new_cluster->setAttr(kMeshAttr,
602 cluster->getAttrOfType<mlir::StringAttr>(kMeshAttr));
603 new_cluster.body().push_back(new mlir::Block);
604
605 auto& cluster_body = cluster.GetBody().getOperations();
606 new_cluster.GetBody().getOperations().splice(
607 new_cluster.GetBody().end(), cluster_body, cluster_body.begin(),
608 std::prev(cluster_body.end()));
609
610 builder.setInsertionPointToEnd(&new_cluster.GetBody());
611 builder.create<mlir::tf_device::ReturnOp>(cluster.getLoc(),
612 result_producing_values);
613
614 assert(new_cluster.getNumResults() == new_result_values.size());
615 for (auto it : llvm::zip(new_result_values, new_cluster.results())) {
616 mlir::Value value_to_replace = std::get<0>(it);
617 mlir::Value new_result = std::get<1>(it);
618 value_to_replace.replaceAllUsesWith(new_result);
619 }
620 cluster.erase();
621 }
622
623 namespace {
624
625 // Keeps track of number of functions added to the global graph for adding
626 // control flows. When converting regional control flow to functional control
627 // flow ops, function names may collide if non-unique branch function names are
628 // used. In order to ensure that all branch functions of TF control flow ops are
629 // unique, we keep track of atomic counter for each control flow functions.
630 // See b/174253694 for more details.
631 std::atomic<int32> dtensor_controlflow_function_counter{0};
632
633 } // namespace
634
GetUniqueControlflowFnName(const std::string & prefix,mlir::OpBuilder & builder)635 mlir::StringAttr GetUniqueControlflowFnName(const std::string& prefix,
636 mlir::OpBuilder& builder) {
637 int32 unique_id = dtensor_controlflow_function_counter++;
638 return builder.getStringAttr(
639 absl::StrCat(prefix, "_dtensor_function_", unique_id));
640 }
641
SetBuilderInsertionAfterValue(mlir::Value value,mlir::OpBuilder & builder)642 Status SetBuilderInsertionAfterValue(mlir::Value value,
643 mlir::OpBuilder& builder) {
644 if (value.isa<mlir::OpResult>()) {
645 builder.setInsertionPointAfterValue(value);
646 return OkStatus();
647 }
648 mlir::tf_device::ClusterOp cluster;
649 for (mlir::Operation* op : value.getUsers()) {
650 mlir::tf_device::ClusterOp new_cluster =
651 op->getParentOfType<mlir::tf_device::ClusterOp>();
652 if (!new_cluster) continue;
653 if (!cluster) cluster = new_cluster;
654 if (cluster != new_cluster)
655 return errors::Internal("value has multiple uses in different clusters");
656 }
657 if (!cluster) return errors::Internal("value not used in any cluster");
658
659 builder.setInsertionPointToStart(cluster.getBody());
660 return OkStatus();
661 }
662
PrintTensor(mlir::Value value,const std::string & format_string="%s")663 Status PrintTensor(mlir::Value value, const std::string& format_string = "%s") {
664 mlir::OpBuilder builder(value.getContext());
665 builder.setInsertionPointAfterValue(value);
666 TF_ASSIGN_OR_RETURN(mlir::Value device_id, DeviceId(value));
667 std::string all_format = absl::StrCat("Core %s: ", format_string);
668 // Scalar string type
669 mlir::RankedTensorType scalar_string =
670 mlir::RankedTensorType::get({}, builder.getType<mlir::TF::StringType>());
671 mlir::TF::StringFormatOp format = builder.create<mlir::TF::StringFormatOp>(
672 value.getLoc(), scalar_string, mlir::ValueRange({device_id, value}));
673 format->setAttr("template", builder.getStringAttr(all_format));
674 builder.create<mlir::TF::PrintV2Op>(value.getLoc(), format.output(),
675 /*output_stream=*/"log(info)",
676 /*end=*/"\n");
677 return OkStatus();
678 }
679
ExtractConstStringVectorFromValue(mlir::Value value,llvm::SmallVectorImpl<std::string> & out_vector)680 Status ExtractConstStringVectorFromValue(
681 mlir::Value value, llvm::SmallVectorImpl<std::string>& out_vector) {
682 value = GetForwardedDTensorLayoutInput(value);
683 if (value.isa<mlir::BlockArgument>())
684 return errors::Internal("Unable get constant value from block argument.");
685 mlir::DenseStringElementsAttr attr;
686 if (!matchPattern(value, m_Constant(&attr))) {
687 return errors::Internal(
688 llvm::formatv("failed to extract constant string vector from : {0}",
689 value)
690 .str());
691 }
692 for (const auto& str : attr.getRawStringData()) {
693 out_vector.push_back(str.str());
694 }
695 return OkStatus();
696 }
697
ExtractConstScalarStringFromValue(mlir::Value value)698 StatusOr<std::string> ExtractConstScalarStringFromValue(mlir::Value value) {
699 value = GetForwardedDTensorLayoutInput(value);
700 if (value.isa<mlir::BlockArgument>())
701 return errors::Internal("Unable get constant value from block argument.");
702 mlir::DenseStringElementsAttr attr;
703 if (!matchPattern(value, m_Constant(&attr))) {
704 return errors::Internal(absl::StrCat("required constant value for ",
705 OpName(value.getDefiningOp())));
706 }
707 if (attr.size() != 1) {
708 return errors::Internal(absl::StrCat("expected 1 element, got ",
709 attr.size(), " for ",
710 OpName(value.getDefiningOp())));
711 }
712 return std::string(*attr.getRawStringData().begin());
713 }
714
TopologicalIterator(mlir::func::FuncOp main_func)715 TopologicalIterator::TopologicalIterator(mlir::func::FuncOp main_func)
716 : ops_to_visit_{&main_func.front().front()} {
717 funcs_visited_.insert(main_func.getName());
718 funcs_visited_in_call_stack_.insert(main_func.getName());
719 }
720
next()721 mlir::Operation* TopologicalIterator::next() {
722 if (!hasNext()) return nullptr;
723
724 auto* op = ops_to_visit_.pop_back_val();
725 auto* next_op = op->getNextNode();
726 if (next_op) ops_to_visit_.push_back(next_op);
727
728 // If this is a function call op, push the first op of the function body so
729 // that the function body is converted before the call site.
730 absl::optional<mlir::func::FuncOp> func = MaybeFindFunction(op);
731 if (func.has_value()) {
732 mlir::StringRef func_name = func->getName();
733
734 if (funcs_visited_.contains(func_name)) return next();
735
736 ops_to_visit_.push_back(&(func->front().front()));
737 funcs_visited_.insert(func_name);
738 }
739
740 // If we have reached the end of a function body, remove the function from
741 // our active set.
742 if (!next_op && !funcs_visited_in_call_stack_.empty())
743 if (auto func = op->getParentOfType<mlir::func::FuncOp>())
744 funcs_visited_in_call_stack_.erase(func.getName());
745
746 if (auto cluster_op = mlir::dyn_cast<mlir::tf_device::ClusterOp>(op))
747 ops_to_visit_.push_back(&cluster_op.GetBody().front());
748
749 if (auto while_op = mlir::dyn_cast<mlir::TF::WhileRegionOp>(op)) {
750 ops_to_visit_.push_back(&while_op.cond().front().front());
751 ops_to_visit_.push_back(&while_op.body().front().front());
752 }
753
754 if (auto if_op = mlir::dyn_cast<mlir::TF::IfRegionOp>(op)) {
755 ops_to_visit_.push_back(&if_op.then_branch().front().front());
756 ops_to_visit_.push_back(&if_op.else_branch().front().front());
757 }
758 return op;
759 }
760
hasNext()761 bool TopologicalIterator::hasNext() { return !ops_to_visit_.empty(); }
762
763 } // namespace dtensor
764 } // namespace tensorflow
765