1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/dtensor/mlir/expansions/conv_spmd_expander.h"
17
18 #include <string>
19
20 #include "llvm/Support/Casting.h"
21 #include "llvm/Support/FormatVariadic.h"
22 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
23 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow/dtensor/cc/dstatus.h"
27 #include "tensorflow/dtensor/cc/tensor_layout.h"
28 #include "tensorflow/dtensor/mlir/collectives.h"
29 #include "tensorflow/dtensor/mlir/layout_parsing.h"
30 #include "tensorflow/dtensor/mlir/op_utils.h"
31 #include "tensorflow/dtensor/mlir/shape_utils.h"
32 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
33 #include "tensorflow/dtensor/mlir/value_utils.h"
34
35 namespace tensorflow {
36 namespace dtensor {
37
38 namespace {
39
40 template <typename ConvOp>
VerifyConvLayout(const Layout & input_layout,const Layout & filter_layout,ConvOp conv_op)41 Status VerifyConvLayout(const Layout& input_layout, const Layout& filter_layout,
42 ConvOp conv_op) {
43 if (!filter_layout.IsFullyReplicated())
44 return errors::InvalidArgument(
45 "Filter for convolution must have fully replicated layout.");
46
47 // Data format "NCHW" or "NCDHW".
48 int channel_dim = 1;
49 if (conv_op.data_format() == "NHWC")
50 channel_dim = 3;
51 else if (conv_op.data_format() == "NDHWC")
52 channel_dim = 4;
53
54 if (input_layout.sharding_spec(channel_dim) != Layout::kUnshardedDim)
55 return errors::InvalidArgument(
56 "Conv input's channel dimension must be replicated.");
57
58 if (input_layout.IsBatchParallel())
59 // No further checks needed for replicated case.
60 return OkStatus();
61
62 if (conv_op.padding() == "EXPLICIT")
63 return errors::InvalidArgument(
64 "Explicit padding not supported for convolution with spatial "
65 "partitions.");
66
67 const int num_non_default_dilations =
68 llvm::count_if(conv_op.dilations(), [](mlir::Attribute dilation) {
69 return dilation.cast<mlir::IntegerAttr>().getInt() != 1;
70 });
71 if (num_non_default_dilations > 0)
72 return errors::InvalidArgument(
73 "Only dilation rate 1 is supported for convolution with spatial "
74 "partitions.");
75
76 // TODO(b/208700444): support convolution with strides greater than 1.
77 const int num_non_default_strides =
78 llvm::count_if(conv_op.strides(), [](mlir::Attribute stride) {
79 return stride.cast<mlir::IntegerAttr>().getInt() != 1;
80 });
81 if (num_non_default_strides > 0)
82 return errors::InvalidArgument(
83 "Only stride 1 is supported for convolution with spatial partitions.");
84
85 mlir::Value input = conv_op.input();
86 auto input_type = input.getType().dyn_cast<mlir::RankedTensorType>();
87 if (!input_type || !input_type.hasStaticShape())
88 return errors::InvalidArgument(
89 "Input must have static shapes for convolution with spatial "
90 "partitions.");
91
92 mlir::Value filter = conv_op.filter();
93 auto filter_type = filter.getType().dyn_cast<mlir::RankedTensorType>();
94 if (!filter_type || !filter_type.hasStaticShape())
95 return errors::InvalidArgument(
96 "Filter must have static shapes for convolution with spatial "
97 "partitions.");
98
99 llvm::ArrayRef<int64_t> filter_shape = filter_type.getShape();
100 for (auto it = filter_shape.begin(); it != filter_shape.end() - 2; ++it) {
101 if (*it % 2 != 1)
102 return errors::InvalidArgument(
103 "Filter dimensions must be odd numbers for convolution with "
104 "spatial partitions.");
105 }
106
107 return OkStatus();
108 }
109
PadInputOnUnshardedDim(mlir::OpBuilder & builder,mlir::Location location,mlir::Value input_tensor,int curr_input_dim,int64_t curr_filter_dim_size)110 mlir::Value PadInputOnUnshardedDim(mlir::OpBuilder& builder,
111 mlir::Location location,
112 mlir::Value input_tensor, int curr_input_dim,
113 int64_t curr_filter_dim_size) {
114 auto input_tensor_type =
115 input_tensor.getType().dyn_cast<mlir::RankedTensorType>();
116 auto input_tensor_shape = input_tensor_type.getShape();
117
118 const size_t paddings_flat_length = input_tensor_type.getRank() * 2;
119 llvm::SmallVector<int64_t, 4> paddings_flat_vec(paddings_flat_length, 0);
120 int64_t padding_size = curr_filter_dim_size - 1;
121 paddings_flat_vec[2 * curr_input_dim] = padding_size / 2;
122 paddings_flat_vec[2 * curr_input_dim + 1] = padding_size / 2;
123
124 llvm::SmallVector<int64_t, 4> paddings_shape(input_tensor_shape.begin(),
125 input_tensor_shape.end());
126 paddings_shape[curr_input_dim] += padding_size;
127
128 mlir::Value paddings_flat = Int64Const(builder, location, paddings_flat_vec);
129 mlir::RankedTensorType paddings_type = mlir::RankedTensorType::get(
130 paddings_shape, input_tensor_type.getElementType());
131 mlir::Value paddings = builder.create<mlir::TF::ReshapeOp>(
132 location, paddings_flat,
133 Int64Const(builder, location, {input_tensor_type.getRank(), 2}));
134 return builder.create<mlir::TF::PadOp>(location, paddings_type, input_tensor,
135 paddings);
136 }
137
138 template <typename ConvOp>
HandleConv(ConvOp conv_op)139 StatusOr<mlir::Operation*> HandleConv(ConvOp conv_op) {
140 mlir::OpBuilder builder(conv_op);
141 TF_ASSIGN_OR_RETURN(const Layout input_layout,
142 ExtractRequiredLayoutFromOperand(conv_op.input()));
143 TF_ASSIGN_OR_RETURN(const Layout filter_layout,
144 ExtractRequiredLayoutFromOperand(conv_op.filter()));
145 TF_ASSIGN_OR_RETURN(const Layout output_layout,
146 ExtractRequiredSingleLayoutFromOp(conv_op));
147
148 TF_RETURN_IF_ERROR(VerifyConvLayout(input_layout, filter_layout, conv_op));
149
150 if (input_layout.IsBatchParallel())
151 // No special handling needed for replicated case.
152 return InferSPMDExpandedLocalShape(conv_op);
153
154 mlir::tf_device::ClusterOp cluster =
155 conv_op->template getParentOfType<mlir::tf_device::ClusterOp>();
156 TF_ASSIGN_OR_RETURN(mlir::Value mesh_coordinates,
157 GetMeshCoordinatesFromCluster(cluster));
158 const Mesh& mesh = input_layout.mesh();
159 mlir::Location location = conv_op->getLoc();
160
161 const std::vector<std::string> input_sharding_spec =
162 input_layout.sharding_spec_strs();
163 const std::vector<std::string> output_sharding_spec =
164 output_layout.sharding_spec_strs();
165 llvm::StringRef format = conv_op.data_format();
166 llvm::StringRef padding = conv_op.padding();
167
168 const auto input_num_shards = input_layout.num_shards();
169 const auto output_num_shards = output_layout.num_shards();
170
171 auto filter_type =
172 conv_op.filter().getType().template dyn_cast<mlir::RankedTensorType>();
173 auto filter_shape = filter_type.getShape();
174
175 int begin_input_dim = -1, end_input_dim = -1;
176 if (format == "NCHW") {
177 begin_input_dim = 2;
178 end_input_dim = 3;
179 } else if (format == "NHWC") {
180 begin_input_dim = 1;
181 end_input_dim = 2;
182 } else if (format == "NCDHW") {
183 begin_input_dim = 2;
184 end_input_dim = 4;
185 } else if (format == "NDHWC") {
186 begin_input_dim = 1;
187 end_input_dim = 3;
188 }
189
190 // For non-batch, non-channel dimension sharding, conduct halo exchange.
191 for (int curr_input_dim = begin_input_dim; curr_input_dim <= end_input_dim;
192 ++curr_input_dim) {
193 int curr_filter_dim = curr_input_dim - begin_input_dim;
194
195 auto input_type =
196 conv_op.input().getType().template dyn_cast<mlir::RankedTensorType>();
197 auto input_shape = input_type.getShape();
198
199 if (input_sharding_spec[curr_input_dim] == Layout::kUnshardedDim) {
200 if (padding == "SAME") {
201 // Since we always emit a Conv op with "VALID" padding, we need to
202 // manually pad the input tensor.
203 conv_op->setOperand(
204 0, PadInputOnUnshardedDim(builder, location, conv_op.input(),
205 curr_input_dim,
206 filter_shape[curr_filter_dim]));
207 }
208 // No halo exchange is needed for unsharded dims.
209 continue;
210 }
211
212 TF_ASSIGN_OR_RETURN(const int mesh_dim_index,
213 mesh.idx_for_dim(input_sharding_spec[curr_input_dim]));
214 TF_ASSIGN_OR_RETURN(mlir::Value scalar_mesh_coordinate,
215 SelectScalarValueFromArray(builder, mesh_dim_index,
216 location, mesh_coordinates));
217
218 int halo_size;
219 if (padding == "SAME") {
220 halo_size = std::floor(filter_shape[curr_filter_dim] / 2);
221 } else if (padding == "VALID") {
222 int input_local_size = input_shape[curr_input_dim];
223 int input_size = input_local_size * input_num_shards[curr_input_dim];
224 int output_size = input_size - (filter_shape[curr_filter_dim] - 1);
225 int output_local_size = output_size / output_num_shards[curr_input_dim];
226 halo_size = output_local_size + (filter_shape[curr_filter_dim] - 1) -
227 input_local_size;
228 } else {
229 return errors::Unimplemented(
230 "Spatially partitioned convolution with padding \"", padding.str(),
231 "\" is not supported.");
232 }
233
234 if (halo_size == 0)
235 // No exchange is needed for empty halos.
236 continue;
237
238 builder.setInsertionPoint(conv_op);
239 TF_ASSIGN_OR_RETURN(
240 mlir::Value halo_exchanged_input,
241 EmitHaloExchange(builder, halo_size,
242 input_sharding_spec[curr_input_dim], input_layout,
243 mesh_coordinates, cluster, location, conv_op.input()));
244
245 if (padding == "SAME") {
246 conv_op->setOperand(0, halo_exchanged_input);
247 } else if (padding == "VALID") {
248 // Slice the halo exchanged tensor to the desired size based on the index
249 // of the shard on the current dimension.
250
251 llvm::SmallVector<int32_t, 4> halo_sizes(input_layout.rank(), 0);
252 halo_sizes[curr_input_dim] = halo_size;
253 mlir::Value halo_sizes_const = IntConst(builder, location, halo_sizes);
254
255 llvm::SmallVector<int32_t, 4> halo_increments(input_layout.rank(), 0);
256 halo_increments[curr_input_dim] =
257 halo_size / (input_num_shards[curr_input_dim] - 1);
258 mlir::Value halo_increments_const =
259 IntConst(builder, location, halo_increments);
260
261 mlir::Value offset = builder.create<mlir::TF::MulOp>(
262 location, halo_increments_const.getType(), scalar_mesh_coordinate,
263 halo_increments_const);
264 mlir::Value slice_begin =
265 builder.create<mlir::TF::SubOp>(location, halo_sizes_const, offset);
266
267 llvm::SmallVector<int64_t, 4> slice_size(input_shape.begin(),
268 input_shape.end());
269 slice_size[curr_input_dim] += halo_size;
270 mlir::Value slice_size_const = Int64Const(builder, location, slice_size);
271
272 mlir::RankedTensorType sliced_input_type =
273 mlir::RankedTensorType::get(slice_size, input_type.getElementType());
274 mlir::Value sliced_input = builder.create<mlir::TF::SliceOp>(
275 location, sliced_input_type, /*input=*/halo_exchanged_input,
276 /*begin=*/slice_begin, /*size=*/slice_size_const);
277 conv_op->setOperand(0, sliced_input);
278 }
279
280 // Spatially partitioned convolution always uses VALID padding after halo
281 // exchange.
282 conv_op.paddingAttr(builder.getStringAttr("VALID"));
283 }
284
285 return InferSPMDExpandedLocalShape(conv_op);
286 }
287
288 template <typename ConvBackpropInputOp>
HandleConvBackpropInput(const Layout & output_layout,ConvBackpropInputOp conv_op)289 StatusOr<mlir::Operation*> HandleConvBackpropInput(
290 const Layout& output_layout, ConvBackpropInputOp conv_op) {
291 llvm::SmallVector<int64_t, 4> global_shape;
292 Status extract_status =
293 ExtractConstVectorFromValue(conv_op.input_sizes(), &global_shape);
294
295 // Recover local shape in SPMD expansion.
296 if (extract_status.ok()) {
297 auto local_shape = output_layout.LocalShapeFromGlobalShape(global_shape);
298 mlir::OpBuilder builder(conv_op->getBlock(), conv_op->getBlock()->begin());
299 auto new_const = IntConst(
300 builder, conv_op->getLoc(),
301 llvm::SmallVector<int32_t, 4>(local_shape.begin(), local_shape.end()));
302 conv_op.input_sizesMutable().assign(new_const);
303 }
304
305 return InferSPMDExpandedLocalShape(conv_op);
306 }
307
308 template <typename ConvBackpropFilterOp>
HandleConvBackpropFilter(const Layout & output_layout,ConvBackpropFilterOp conv_op)309 StatusOr<mlir::Operation*> HandleConvBackpropFilter(
310 const Layout& output_layout, ConvBackpropFilterOp conv_op) {
311 TF_ASSIGN_OR_RETURN(Layout input_layout,
312 ExtractRequiredLayoutFromOperand(conv_op.input()));
313
314 TF_ASSIGN_OR_RETURN(
315 Layout out_backprop_layout,
316 ExtractRequiredLayoutFromOperand((conv_op.out_backprop())));
317 // Perform a split on batch dimension so that the each local device performs
318 // local operation.
319 // TODO(hthu): Make this work on input with rank higher than 4.
320 if (input_layout.IsBatchParallel()) {
321 mlir::OpBuilder builder(conv_op);
322 if (out_backprop_layout.IsFullyReplicated()) {
323 TF_ASSIGN_OR_RETURN(const mlir::Value batch_sharded,
324 EmitAllScatter(builder, conv_op.out_backprop(),
325 out_backprop_layout, input_layout));
326 conv_op.out_backpropMutable().assign(batch_sharded);
327 }
328
329 // Perform all reduce over batch dim.
330 builder.setInsertionPointAfter(conv_op);
331 return DT_CTX(EmitAllReduce(builder, output_layout,
332 {input_layout.sharding_spec(0)}, conv_op,
333 kReduceOpAdd));
334 } else {
335 return errors::InvalidArgument(
336 "Convolution backprop for spatially partitioned input not supported.");
337 }
338 return InferSPMDExpandedLocalShape(conv_op);
339 }
340
HandleMaxPoolGradOp(const Layout & output_layout,mlir::TF::MaxPoolGradOp max_pool_grad_op)341 StatusOr<mlir::Operation*> HandleMaxPoolGradOp(
342 const Layout& output_layout, mlir::TF::MaxPoolGradOp max_pool_grad_op) {
343 // MaxPoolGrad has 3 inputs: Original Input to MaxPool, Output of MaxPool and
344 // Gradients.
345 assert(max_pool_grad_op->getOpOperands().size() == 3);
346
347 // Relayout gradient input to match layout of output of maxpool.
348 mlir::OpOperand& max_pool_output = max_pool_grad_op->getOpOperand(1);
349 TF_ASSIGN_OR_RETURN(Layout max_pool_output_layout,
350 ExtractRequiredLayoutFromOperand(max_pool_output.get()));
351
352 mlir::OpOperand& grad_input = max_pool_grad_op->getOpOperand(2);
353 TF_ASSIGN_OR_RETURN(Layout grad_input_layout,
354 ExtractRequiredLayoutFromOperand(grad_input.get()));
355 TF_ASSIGN_OR_RETURN(mlir::Value new_grad_input,
356 EmitRelayout(grad_input.get(), grad_input_layout,
357 max_pool_output_layout));
358 grad_input.set(new_grad_input);
359
360 return InferSPMDExpandedLocalShape(max_pool_grad_op);
361 }
362
363 } // namespace
364
ExpandOp(mlir::Operation * op)365 StatusOr<mlir::Operation*> ConvSPMDExpander::ExpandOp(mlir::Operation* op) {
366 // The first argument to Conv2DBackpropInputOp is the shape of the input we
367 // are generating. Since this is almost always the output of a call to
368 // `shape`, we lose the ability to infer the original input layout. (c.f if
369 // Conv2DBackpropInput accepted the input _tensor_ instead of the shape).
370 // Since in eager execution, we cannot look ahead at consumer operations, we
371 // instead attach the original input layout as a secondary attribute on the
372 // output of the shape operation, and use this to infer the desired layout for
373 // this op.
374
375 TF_ASSIGN_OR_RETURN(const auto output_layout, ExtractSingleLayoutFromOp(op));
376
377 // Forward prop ops.
378 if (llvm::isa<mlir::TF::Conv2DOp>(op))
379 return HandleConv<>(llvm::cast<mlir::TF::Conv2DOp>(op));
380 if (llvm::isa<mlir::TF::Conv3DOp>(op))
381 return HandleConv<>(llvm::cast<mlir::TF::Conv3DOp>(op));
382
383 // Backward prop input ops.
384 if (llvm::isa<mlir::TF::Conv2DBackpropInputOp>(op))
385 return HandleConvBackpropInput<>(
386 *output_layout, llvm::cast<mlir::TF::Conv2DBackpropInputOp>(op));
387 if (llvm::isa<mlir::TF::Conv3DBackpropInputV2Op>(op))
388 return HandleConvBackpropInput<>(
389 *output_layout, llvm::cast<mlir::TF::Conv3DBackpropInputV2Op>(op));
390
391 // Backward prop filter ops.
392 if (llvm::isa<mlir::TF::Conv2DBackpropFilterOp>(op))
393 return HandleConvBackpropFilter<>(
394 *output_layout, llvm::cast<mlir::TF::Conv2DBackpropFilterOp>(op));
395 if (llvm::isa<mlir::TF::Conv3DBackpropFilterV2Op>(op))
396 return HandleConvBackpropFilter<>(
397 *output_layout, llvm::cast<mlir::TF::Conv3DBackpropFilterV2Op>(op));
398
399 // For all other ops, only batch sharded or fully replicated sharding is
400 // supported for now.
401 if (!output_layout->IsFullyReplicated() && !output_layout->IsBatchParallel())
402 return errors::Unimplemented(
403 llvm::formatv(
404 "Only replicated or batch parallel layout is supported in "
405 "expansion of {0}, but got output layout: {1}",
406 op->getName().getStringRef().str(), output_layout->ToString())
407 .str());
408
409 if (auto max_pool_grad = mlir::dyn_cast<mlir::TF::MaxPoolGradOp>(op))
410 return HandleMaxPoolGradOp(*output_layout, max_pool_grad);
411
412 // Local expansion only for all other ops.
413 return InferSPMDExpandedLocalShape(op);
414 }
415
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)416 StatusOr<llvm::DenseMap<int, Layout>> ConvSPMDExpander::ComputeLayoutForward(
417 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
418 TF_ASSIGN_OR_RETURN(auto mesh, ExtractDeviceMeshEnclosingCluster(op));
419 llvm::DenseMap<int, Layout> output_layouts(op->getNumResults());
420
421 if (llvm::isa<mlir::TF::Conv2DOp, mlir::TF::Conv3DOp, mlir::TF::MaxPoolOp,
422 mlir::TF::MaxPoolGradOp>(op)) {
423 // Conv2d/Conv3d and MaxPool ops are grouped together as they all try to
424 // propagate layout from input image (operand 0).
425
426 // If requested 'input' layout exist, try to request same layout for output.
427 if (input_layouts.find(0) != input_layouts.end()) {
428 output_layouts[0] = input_layouts.lookup(0);
429 } else {
430 // For MaxPoolGrad, request same layout as 'orig_output' or 'grad'
431 // whatever is present.
432 if (llvm::isa<mlir::TF::MaxPoolGradOp>(op)) {
433 if (input_layouts.find(1) != input_layouts.end())
434 output_layouts[0] = input_layouts.lookup(1);
435 else if (input_layouts.find(2) != input_layouts.end())
436 output_layouts[0] = input_layouts.lookup(2);
437 }
438 }
439 } else if (llvm::isa<mlir::TF::Conv2DBackpropInputOp,
440 mlir::TF::Conv3DBackpropInputV2Op,
441 mlir::TF::Conv2DBackpropFilterOp,
442 mlir::TF::Conv3DBackpropFilterV2Op>(op)) {
443 // Conv BackProp ops should usually take layout from gradient for both
444 // inputs and filters.
445
446 // 'grad' layout
447 if (input_layouts.find(2) != input_layouts.end()) {
448 if (llvm::isa<mlir::TF::Conv2DBackpropInputOp,
449 mlir::TF::Conv3DBackpropInputV2Op>(op)) {
450 // BackProp ops try to respect layout from gradients for inputs.
451 output_layouts[0] = input_layouts.lookup(2);
452 }
453
454 // For filters, we currently only try to request a replicated output
455 // layout.
456 if (llvm::isa<mlir::TF::Conv2DBackpropFilterOp,
457 mlir::TF::Conv3DBackpropFilterV2Op>(op)) {
458 output_layouts[0] =
459 Layout::ReplicatedOnMesh(mesh, ValueRank(op->getOpResult(0)));
460 }
461 }
462 } else {
463 return errors::InvalidArgument(
464 llvm::formatv(
465 "Layout propagation for unrecognized convolution op {0} not "
466 "supported.",
467 OpName(op))
468 .str());
469 }
470
471 return output_layouts;
472 }
473
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)474 StatusOr<llvm::DenseMap<int, Layout>> ConvSPMDExpander::ComputeLayoutBackward(
475 mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
476 TF_ASSIGN_OR_RETURN(auto mesh, ExtractDeviceMeshEnclosingCluster(op));
477 llvm::DenseMap<int, Layout> input_layouts(op->getNumOperands());
478
479 if (llvm::isa<mlir::TF::Conv2DOp, mlir::TF::Conv3DOp, mlir::TF::MaxPoolOp,
480 mlir::TF::MaxPoolGradOp>(op)) {
481 // If suggested output layout exists, try to request input image to have the
482 // same layout so that all computation would be local.
483 if (output_layouts.find(0) != output_layouts.end()) {
484 const Layout output_layout = output_layouts.lookup(0);
485
486 input_layouts[0] = output_layout;
487
488 // Request replicated for filter input if Conv2D/Conv3D.
489 if (llvm::isa<mlir::TF::Conv2DOp, mlir::TF::Conv3DOp>(op)) {
490 input_layouts[1] =
491 Layout::ReplicatedOnMesh(mesh, ValueRank(op->getOperand(1)));
492 }
493 if (llvm::isa<mlir::TF::MaxPoolGradOp>(op)) {
494 input_layouts[1] = output_layout; // 'orig_output'
495 input_layouts[2] = output_layout; // 'grad'
496 }
497 }
498 } else if (llvm::isa<mlir::TF::Conv2DBackpropInputOp,
499 mlir::TF::Conv3DBackpropInputV2Op,
500 mlir::TF::Conv2DBackpropFilterOp,
501 mlir::TF::Conv3DBackpropFilterV2Op>(op)) {
502 // If suggested output layout exists, try to request grad to have output
503 // layout.
504 if (output_layouts.find(0) != output_layouts.end()) {
505 input_layouts[2] = output_layouts.lookup(0);
506 // Request inputs and filter_sizes to be replicated.
507 input_layouts[0] =
508 Layout::ReplicatedOnMesh(mesh, ValueRank(op->getOperand(0)));
509 input_layouts[1] =
510 Layout::ReplicatedOnMesh(mesh, ValueRank(op->getOperand(1)));
511 }
512 } else {
513 return errors::InvalidArgument(
514 llvm::formatv(
515 "Layout propagation for unrecognized convolution op {0} not "
516 "supported.",
517 OpName(op))
518 .str());
519 }
520
521 return input_layouts;
522 }
523
524 } // namespace dtensor
525 } // namespace tensorflow
526