xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/expansions/conv_spmd_expander.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/dtensor/mlir/expansions/conv_spmd_expander.h"
17 
18 #include <string>
19 
20 #include "llvm/Support/Casting.h"
21 #include "llvm/Support/FormatVariadic.h"
22 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow/dtensor/cc/dstatus.h"
27 #include "tensorflow/dtensor/cc/tensor_layout.h"
28 #include "tensorflow/dtensor/mlir/collectives.h"
29 #include "tensorflow/dtensor/mlir/layout_parsing.h"
30 #include "tensorflow/dtensor/mlir/op_utils.h"
31 #include "tensorflow/dtensor/mlir/shape_utils.h"
32 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
33 #include "tensorflow/dtensor/mlir/value_utils.h"
34 
35 namespace tensorflow {
36 namespace dtensor {
37 
38 namespace {
39 
40 template <typename ConvOp>
VerifyConvLayout(const Layout & input_layout,const Layout & filter_layout,ConvOp conv_op)41 Status VerifyConvLayout(const Layout& input_layout, const Layout& filter_layout,
42                         ConvOp conv_op) {
43   if (!filter_layout.IsFullyReplicated())
44     return errors::InvalidArgument(
45         "Filter for convolution must have fully replicated layout.");
46 
47   // Data format "NCHW" or "NCDHW".
48   int channel_dim = 1;
49   if (conv_op.data_format() == "NHWC")
50     channel_dim = 3;
51   else if (conv_op.data_format() == "NDHWC")
52     channel_dim = 4;
53 
54   if (input_layout.sharding_spec(channel_dim) != Layout::kUnshardedDim)
55     return errors::InvalidArgument(
56         "Conv input's channel dimension must be replicated.");
57 
58   if (input_layout.IsBatchParallel())
59     // No further checks needed for replicated case.
60     return OkStatus();
61 
62   if (conv_op.padding() == "EXPLICIT")
63     return errors::InvalidArgument(
64         "Explicit padding not supported for convolution with spatial "
65         "partitions.");
66 
67   const int num_non_default_dilations =
68       llvm::count_if(conv_op.dilations(), [](mlir::Attribute dilation) {
69         return dilation.cast<mlir::IntegerAttr>().getInt() != 1;
70       });
71   if (num_non_default_dilations > 0)
72     return errors::InvalidArgument(
73         "Only dilation rate 1 is supported for convolution with spatial "
74         "partitions.");
75 
76   // TODO(b/208700444): support convolution with strides greater than 1.
77   const int num_non_default_strides =
78       llvm::count_if(conv_op.strides(), [](mlir::Attribute stride) {
79         return stride.cast<mlir::IntegerAttr>().getInt() != 1;
80       });
81   if (num_non_default_strides > 0)
82     return errors::InvalidArgument(
83         "Only stride 1 is supported for convolution with spatial partitions.");
84 
85   mlir::Value input = conv_op.input();
86   auto input_type = input.getType().dyn_cast<mlir::RankedTensorType>();
87   if (!input_type || !input_type.hasStaticShape())
88     return errors::InvalidArgument(
89         "Input must have static shapes for convolution with spatial "
90         "partitions.");
91 
92   mlir::Value filter = conv_op.filter();
93   auto filter_type = filter.getType().dyn_cast<mlir::RankedTensorType>();
94   if (!filter_type || !filter_type.hasStaticShape())
95     return errors::InvalidArgument(
96         "Filter must have static shapes for convolution with spatial "
97         "partitions.");
98 
99   llvm::ArrayRef<int64_t> filter_shape = filter_type.getShape();
100   for (auto it = filter_shape.begin(); it != filter_shape.end() - 2; ++it) {
101     if (*it % 2 != 1)
102       return errors::InvalidArgument(
103           "Filter dimensions must be odd numbers for convolution with "
104           "spatial partitions.");
105   }
106 
107   return OkStatus();
108 }
109 
PadInputOnUnshardedDim(mlir::OpBuilder & builder,mlir::Location location,mlir::Value input_tensor,int curr_input_dim,int64_t curr_filter_dim_size)110 mlir::Value PadInputOnUnshardedDim(mlir::OpBuilder& builder,
111                                    mlir::Location location,
112                                    mlir::Value input_tensor, int curr_input_dim,
113                                    int64_t curr_filter_dim_size) {
114   auto input_tensor_type =
115       input_tensor.getType().dyn_cast<mlir::RankedTensorType>();
116   auto input_tensor_shape = input_tensor_type.getShape();
117 
118   const size_t paddings_flat_length = input_tensor_type.getRank() * 2;
119   llvm::SmallVector<int64_t, 4> paddings_flat_vec(paddings_flat_length, 0);
120   int64_t padding_size = curr_filter_dim_size - 1;
121   paddings_flat_vec[2 * curr_input_dim] = padding_size / 2;
122   paddings_flat_vec[2 * curr_input_dim + 1] = padding_size / 2;
123 
124   llvm::SmallVector<int64_t, 4> paddings_shape(input_tensor_shape.begin(),
125                                                input_tensor_shape.end());
126   paddings_shape[curr_input_dim] += padding_size;
127 
128   mlir::Value paddings_flat = Int64Const(builder, location, paddings_flat_vec);
129   mlir::RankedTensorType paddings_type = mlir::RankedTensorType::get(
130       paddings_shape, input_tensor_type.getElementType());
131   mlir::Value paddings = builder.create<mlir::TF::ReshapeOp>(
132       location, paddings_flat,
133       Int64Const(builder, location, {input_tensor_type.getRank(), 2}));
134   return builder.create<mlir::TF::PadOp>(location, paddings_type, input_tensor,
135                                          paddings);
136 }
137 
138 template <typename ConvOp>
HandleConv(ConvOp conv_op)139 StatusOr<mlir::Operation*> HandleConv(ConvOp conv_op) {
140   mlir::OpBuilder builder(conv_op);
141   TF_ASSIGN_OR_RETURN(const Layout input_layout,
142                       ExtractRequiredLayoutFromOperand(conv_op.input()));
143   TF_ASSIGN_OR_RETURN(const Layout filter_layout,
144                       ExtractRequiredLayoutFromOperand(conv_op.filter()));
145   TF_ASSIGN_OR_RETURN(const Layout output_layout,
146                       ExtractRequiredSingleLayoutFromOp(conv_op));
147 
148   TF_RETURN_IF_ERROR(VerifyConvLayout(input_layout, filter_layout, conv_op));
149 
150   if (input_layout.IsBatchParallel())
151     // No special handling needed for replicated case.
152     return InferSPMDExpandedLocalShape(conv_op);
153 
154   mlir::tf_device::ClusterOp cluster =
155       conv_op->template getParentOfType<mlir::tf_device::ClusterOp>();
156   TF_ASSIGN_OR_RETURN(mlir::Value mesh_coordinates,
157                       GetMeshCoordinatesFromCluster(cluster));
158   const Mesh& mesh = input_layout.mesh();
159   mlir::Location location = conv_op->getLoc();
160 
161   const std::vector<std::string> input_sharding_spec =
162       input_layout.sharding_spec_strs();
163   const std::vector<std::string> output_sharding_spec =
164       output_layout.sharding_spec_strs();
165   llvm::StringRef format = conv_op.data_format();
166   llvm::StringRef padding = conv_op.padding();
167 
168   const auto input_num_shards = input_layout.num_shards();
169   const auto output_num_shards = output_layout.num_shards();
170 
171   auto filter_type =
172       conv_op.filter().getType().template dyn_cast<mlir::RankedTensorType>();
173   auto filter_shape = filter_type.getShape();
174 
175   int begin_input_dim = -1, end_input_dim = -1;
176   if (format == "NCHW") {
177     begin_input_dim = 2;
178     end_input_dim = 3;
179   } else if (format == "NHWC") {
180     begin_input_dim = 1;
181     end_input_dim = 2;
182   } else if (format == "NCDHW") {
183     begin_input_dim = 2;
184     end_input_dim = 4;
185   } else if (format == "NDHWC") {
186     begin_input_dim = 1;
187     end_input_dim = 3;
188   }
189 
190   // For non-batch, non-channel dimension sharding, conduct halo exchange.
191   for (int curr_input_dim = begin_input_dim; curr_input_dim <= end_input_dim;
192        ++curr_input_dim) {
193     int curr_filter_dim = curr_input_dim - begin_input_dim;
194 
195     auto input_type =
196         conv_op.input().getType().template dyn_cast<mlir::RankedTensorType>();
197     auto input_shape = input_type.getShape();
198 
199     if (input_sharding_spec[curr_input_dim] == Layout::kUnshardedDim) {
200       if (padding == "SAME") {
201         // Since we always emit a Conv op with "VALID" padding, we need to
202         // manually pad the input tensor.
203         conv_op->setOperand(
204             0, PadInputOnUnshardedDim(builder, location, conv_op.input(),
205                                       curr_input_dim,
206                                       filter_shape[curr_filter_dim]));
207       }
208       // No halo exchange is needed for unsharded dims.
209       continue;
210     }
211 
212     TF_ASSIGN_OR_RETURN(const int mesh_dim_index,
213                         mesh.idx_for_dim(input_sharding_spec[curr_input_dim]));
214     TF_ASSIGN_OR_RETURN(mlir::Value scalar_mesh_coordinate,
215                         SelectScalarValueFromArray(builder, mesh_dim_index,
216                                                    location, mesh_coordinates));
217 
218     int halo_size;
219     if (padding == "SAME") {
220       halo_size = std::floor(filter_shape[curr_filter_dim] / 2);
221     } else if (padding == "VALID") {
222       int input_local_size = input_shape[curr_input_dim];
223       int input_size = input_local_size * input_num_shards[curr_input_dim];
224       int output_size = input_size - (filter_shape[curr_filter_dim] - 1);
225       int output_local_size = output_size / output_num_shards[curr_input_dim];
226       halo_size = output_local_size + (filter_shape[curr_filter_dim] - 1) -
227                   input_local_size;
228     } else {
229       return errors::Unimplemented(
230           "Spatially partitioned convolution with padding \"", padding.str(),
231           "\" is not supported.");
232     }
233 
234     if (halo_size == 0)
235       // No exchange is needed for empty halos.
236       continue;
237 
238     builder.setInsertionPoint(conv_op);
239     TF_ASSIGN_OR_RETURN(
240         mlir::Value halo_exchanged_input,
241         EmitHaloExchange(builder, halo_size,
242                          input_sharding_spec[curr_input_dim], input_layout,
243                          mesh_coordinates, cluster, location, conv_op.input()));
244 
245     if (padding == "SAME") {
246       conv_op->setOperand(0, halo_exchanged_input);
247     } else if (padding == "VALID") {
248       // Slice the halo exchanged tensor to the desired size based on the index
249       // of the shard on the current dimension.
250 
251       llvm::SmallVector<int32_t, 4> halo_sizes(input_layout.rank(), 0);
252       halo_sizes[curr_input_dim] = halo_size;
253       mlir::Value halo_sizes_const = IntConst(builder, location, halo_sizes);
254 
255       llvm::SmallVector<int32_t, 4> halo_increments(input_layout.rank(), 0);
256       halo_increments[curr_input_dim] =
257           halo_size / (input_num_shards[curr_input_dim] - 1);
258       mlir::Value halo_increments_const =
259           IntConst(builder, location, halo_increments);
260 
261       mlir::Value offset = builder.create<mlir::TF::MulOp>(
262           location, halo_increments_const.getType(), scalar_mesh_coordinate,
263           halo_increments_const);
264       mlir::Value slice_begin =
265           builder.create<mlir::TF::SubOp>(location, halo_sizes_const, offset);
266 
267       llvm::SmallVector<int64_t, 4> slice_size(input_shape.begin(),
268                                                input_shape.end());
269       slice_size[curr_input_dim] += halo_size;
270       mlir::Value slice_size_const = Int64Const(builder, location, slice_size);
271 
272       mlir::RankedTensorType sliced_input_type =
273           mlir::RankedTensorType::get(slice_size, input_type.getElementType());
274       mlir::Value sliced_input = builder.create<mlir::TF::SliceOp>(
275           location, sliced_input_type, /*input=*/halo_exchanged_input,
276           /*begin=*/slice_begin, /*size=*/slice_size_const);
277       conv_op->setOperand(0, sliced_input);
278     }
279 
280     // Spatially partitioned convolution always uses VALID padding after halo
281     // exchange.
282     conv_op.paddingAttr(builder.getStringAttr("VALID"));
283   }
284 
285   return InferSPMDExpandedLocalShape(conv_op);
286 }
287 
288 template <typename ConvBackpropInputOp>
HandleConvBackpropInput(const Layout & output_layout,ConvBackpropInputOp conv_op)289 StatusOr<mlir::Operation*> HandleConvBackpropInput(
290     const Layout& output_layout, ConvBackpropInputOp conv_op) {
291   llvm::SmallVector<int64_t, 4> global_shape;
292   Status extract_status =
293       ExtractConstVectorFromValue(conv_op.input_sizes(), &global_shape);
294 
295   // Recover local shape in SPMD expansion.
296   if (extract_status.ok()) {
297     auto local_shape = output_layout.LocalShapeFromGlobalShape(global_shape);
298     mlir::OpBuilder builder(conv_op->getBlock(), conv_op->getBlock()->begin());
299     auto new_const = IntConst(
300         builder, conv_op->getLoc(),
301         llvm::SmallVector<int32_t, 4>(local_shape.begin(), local_shape.end()));
302     conv_op.input_sizesMutable().assign(new_const);
303   }
304 
305   return InferSPMDExpandedLocalShape(conv_op);
306 }
307 
308 template <typename ConvBackpropFilterOp>
HandleConvBackpropFilter(const Layout & output_layout,ConvBackpropFilterOp conv_op)309 StatusOr<mlir::Operation*> HandleConvBackpropFilter(
310     const Layout& output_layout, ConvBackpropFilterOp conv_op) {
311   TF_ASSIGN_OR_RETURN(Layout input_layout,
312                       ExtractRequiredLayoutFromOperand(conv_op.input()));
313 
314   TF_ASSIGN_OR_RETURN(
315       Layout out_backprop_layout,
316       ExtractRequiredLayoutFromOperand((conv_op.out_backprop())));
317   // Perform a split on batch dimension so that the each local device performs
318   // local operation.
319   // TODO(hthu): Make this work on input with rank higher than 4.
320   if (input_layout.IsBatchParallel()) {
321     mlir::OpBuilder builder(conv_op);
322     if (out_backprop_layout.IsFullyReplicated()) {
323       TF_ASSIGN_OR_RETURN(const mlir::Value batch_sharded,
324                           EmitAllScatter(builder, conv_op.out_backprop(),
325                                          out_backprop_layout, input_layout));
326       conv_op.out_backpropMutable().assign(batch_sharded);
327     }
328 
329     // Perform all reduce over batch dim.
330     builder.setInsertionPointAfter(conv_op);
331     return DT_CTX(EmitAllReduce(builder, output_layout,
332                                 {input_layout.sharding_spec(0)}, conv_op,
333                                 kReduceOpAdd));
334   } else {
335     return errors::InvalidArgument(
336         "Convolution backprop for spatially partitioned input not supported.");
337   }
338   return InferSPMDExpandedLocalShape(conv_op);
339 }
340 
HandleMaxPoolGradOp(const Layout & output_layout,mlir::TF::MaxPoolGradOp max_pool_grad_op)341 StatusOr<mlir::Operation*> HandleMaxPoolGradOp(
342     const Layout& output_layout, mlir::TF::MaxPoolGradOp max_pool_grad_op) {
343   // MaxPoolGrad has 3 inputs: Original Input to MaxPool, Output of MaxPool and
344   // Gradients.
345   assert(max_pool_grad_op->getOpOperands().size() == 3);
346 
347   // Relayout gradient input to match layout of output of maxpool.
348   mlir::OpOperand& max_pool_output = max_pool_grad_op->getOpOperand(1);
349   TF_ASSIGN_OR_RETURN(Layout max_pool_output_layout,
350                       ExtractRequiredLayoutFromOperand(max_pool_output.get()));
351 
352   mlir::OpOperand& grad_input = max_pool_grad_op->getOpOperand(2);
353   TF_ASSIGN_OR_RETURN(Layout grad_input_layout,
354                       ExtractRequiredLayoutFromOperand(grad_input.get()));
355   TF_ASSIGN_OR_RETURN(mlir::Value new_grad_input,
356                       EmitRelayout(grad_input.get(), grad_input_layout,
357                                    max_pool_output_layout));
358   grad_input.set(new_grad_input);
359 
360   return InferSPMDExpandedLocalShape(max_pool_grad_op);
361 }
362 
363 }  // namespace
364 
ExpandOp(mlir::Operation * op)365 StatusOr<mlir::Operation*> ConvSPMDExpander::ExpandOp(mlir::Operation* op) {
366   // The first argument to Conv2DBackpropInputOp is the shape of the input we
367   // are generating. Since this is almost always the output of a call to
368   // `shape`, we lose the ability to infer the original input layout. (c.f if
369   // Conv2DBackpropInput accepted the input _tensor_ instead of the shape).
370   // Since in eager execution, we cannot look ahead at consumer operations, we
371   // instead attach the original input layout as a secondary attribute on the
372   // output of the shape operation, and use this to infer the desired layout for
373   // this op.
374 
375   TF_ASSIGN_OR_RETURN(const auto output_layout, ExtractSingleLayoutFromOp(op));
376 
377   // Forward prop ops.
378   if (llvm::isa<mlir::TF::Conv2DOp>(op))
379     return HandleConv<>(llvm::cast<mlir::TF::Conv2DOp>(op));
380   if (llvm::isa<mlir::TF::Conv3DOp>(op))
381     return HandleConv<>(llvm::cast<mlir::TF::Conv3DOp>(op));
382 
383   // Backward prop input ops.
384   if (llvm::isa<mlir::TF::Conv2DBackpropInputOp>(op))
385     return HandleConvBackpropInput<>(
386         *output_layout, llvm::cast<mlir::TF::Conv2DBackpropInputOp>(op));
387   if (llvm::isa<mlir::TF::Conv3DBackpropInputV2Op>(op))
388     return HandleConvBackpropInput<>(
389         *output_layout, llvm::cast<mlir::TF::Conv3DBackpropInputV2Op>(op));
390 
391   // Backward prop filter ops.
392   if (llvm::isa<mlir::TF::Conv2DBackpropFilterOp>(op))
393     return HandleConvBackpropFilter<>(
394         *output_layout, llvm::cast<mlir::TF::Conv2DBackpropFilterOp>(op));
395   if (llvm::isa<mlir::TF::Conv3DBackpropFilterV2Op>(op))
396     return HandleConvBackpropFilter<>(
397         *output_layout, llvm::cast<mlir::TF::Conv3DBackpropFilterV2Op>(op));
398 
399   // For all other ops, only batch sharded or fully replicated sharding is
400   // supported for now.
401   if (!output_layout->IsFullyReplicated() && !output_layout->IsBatchParallel())
402     return errors::Unimplemented(
403         llvm::formatv(
404             "Only replicated or batch parallel layout is supported in "
405             "expansion of {0}, but got output layout: {1}",
406             op->getName().getStringRef().str(), output_layout->ToString())
407             .str());
408 
409   if (auto max_pool_grad = mlir::dyn_cast<mlir::TF::MaxPoolGradOp>(op))
410     return HandleMaxPoolGradOp(*output_layout, max_pool_grad);
411 
412   // Local expansion only for all other ops.
413   return InferSPMDExpandedLocalShape(op);
414 }
415 
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)416 StatusOr<llvm::DenseMap<int, Layout>> ConvSPMDExpander::ComputeLayoutForward(
417     mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
418   TF_ASSIGN_OR_RETURN(auto mesh, ExtractDeviceMeshEnclosingCluster(op));
419   llvm::DenseMap<int, Layout> output_layouts(op->getNumResults());
420 
421   if (llvm::isa<mlir::TF::Conv2DOp, mlir::TF::Conv3DOp, mlir::TF::MaxPoolOp,
422                 mlir::TF::MaxPoolGradOp>(op)) {
423     // Conv2d/Conv3d and MaxPool ops are grouped together as they all try to
424     // propagate layout from input image (operand 0).
425 
426     // If requested 'input' layout exist, try to request same layout for output.
427     if (input_layouts.find(0) != input_layouts.end()) {
428       output_layouts[0] = input_layouts.lookup(0);
429     } else {
430       // For MaxPoolGrad, request same layout as 'orig_output' or 'grad'
431       // whatever is present.
432       if (llvm::isa<mlir::TF::MaxPoolGradOp>(op)) {
433         if (input_layouts.find(1) != input_layouts.end())
434           output_layouts[0] = input_layouts.lookup(1);
435         else if (input_layouts.find(2) != input_layouts.end())
436           output_layouts[0] = input_layouts.lookup(2);
437       }
438     }
439   } else if (llvm::isa<mlir::TF::Conv2DBackpropInputOp,
440                        mlir::TF::Conv3DBackpropInputV2Op,
441                        mlir::TF::Conv2DBackpropFilterOp,
442                        mlir::TF::Conv3DBackpropFilterV2Op>(op)) {
443     // Conv BackProp ops should usually take layout from gradient for both
444     // inputs and filters.
445 
446     // 'grad' layout
447     if (input_layouts.find(2) != input_layouts.end()) {
448       if (llvm::isa<mlir::TF::Conv2DBackpropInputOp,
449                     mlir::TF::Conv3DBackpropInputV2Op>(op)) {
450         // BackProp ops try to respect layout from gradients for inputs.
451         output_layouts[0] = input_layouts.lookup(2);
452       }
453 
454       // For filters, we currently only try to request a replicated output
455       // layout.
456       if (llvm::isa<mlir::TF::Conv2DBackpropFilterOp,
457                     mlir::TF::Conv3DBackpropFilterV2Op>(op)) {
458         output_layouts[0] =
459             Layout::ReplicatedOnMesh(mesh, ValueRank(op->getOpResult(0)));
460       }
461     }
462   } else {
463     return errors::InvalidArgument(
464         llvm::formatv(
465             "Layout propagation for unrecognized convolution op {0} not "
466             "supported.",
467             OpName(op))
468             .str());
469   }
470 
471   return output_layouts;
472 }
473 
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)474 StatusOr<llvm::DenseMap<int, Layout>> ConvSPMDExpander::ComputeLayoutBackward(
475     mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
476   TF_ASSIGN_OR_RETURN(auto mesh, ExtractDeviceMeshEnclosingCluster(op));
477   llvm::DenseMap<int, Layout> input_layouts(op->getNumOperands());
478 
479   if (llvm::isa<mlir::TF::Conv2DOp, mlir::TF::Conv3DOp, mlir::TF::MaxPoolOp,
480                 mlir::TF::MaxPoolGradOp>(op)) {
481     // If suggested output layout exists, try to request input image to have the
482     // same layout so that all computation would be local.
483     if (output_layouts.find(0) != output_layouts.end()) {
484       const Layout output_layout = output_layouts.lookup(0);
485 
486       input_layouts[0] = output_layout;
487 
488       // Request replicated for filter input if Conv2D/Conv3D.
489       if (llvm::isa<mlir::TF::Conv2DOp, mlir::TF::Conv3DOp>(op)) {
490         input_layouts[1] =
491             Layout::ReplicatedOnMesh(mesh, ValueRank(op->getOperand(1)));
492       }
493       if (llvm::isa<mlir::TF::MaxPoolGradOp>(op)) {
494         input_layouts[1] = output_layout;  // 'orig_output'
495         input_layouts[2] = output_layout;  // 'grad'
496       }
497     }
498   } else if (llvm::isa<mlir::TF::Conv2DBackpropInputOp,
499                        mlir::TF::Conv3DBackpropInputV2Op,
500                        mlir::TF::Conv2DBackpropFilterOp,
501                        mlir::TF::Conv3DBackpropFilterV2Op>(op)) {
502     // If suggested output layout exists, try to request grad to have output
503     // layout.
504     if (output_layouts.find(0) != output_layouts.end()) {
505       input_layouts[2] = output_layouts.lookup(0);
506       // Request inputs and filter_sizes to be replicated.
507       input_layouts[0] =
508           Layout::ReplicatedOnMesh(mesh, ValueRank(op->getOperand(0)));
509       input_layouts[1] =
510           Layout::ReplicatedOnMesh(mesh, ValueRank(op->getOperand(1)));
511     }
512   } else {
513     return errors::InvalidArgument(
514         llvm::formatv(
515             "Layout propagation for unrecognized convolution op {0} not "
516             "supported.",
517             OpName(op))
518             .str());
519   }
520 
521   return input_layouts;
522 }
523 
524 }  // namespace dtensor
525 }  // namespace tensorflow
526