xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/expansions/dtensor_op_spmd_expander.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/dtensor/mlir/expansions/dtensor_op_spmd_expander.h"
17 
18 #include <string>
19 
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_join.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
26 #include "mlir/IR/Builders.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
28 #include "mlir/IR/Value.h"  // from @llvm-project
29 #include "mlir/IR/Visitors.h"  // from @llvm-project
30 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
35 #include "tensorflow/core/platform/errors.h"
36 #include "tensorflow/core/platform/statusor.h"
37 #include "tensorflow/dtensor/cc/dstatus.h"
38 #include "tensorflow/dtensor/cc/tensor_layout.h"
39 #include "tensorflow/dtensor/mlir/collectives.h"
40 #include "tensorflow/dtensor/mlir/device_utils.h"
41 #include "tensorflow/dtensor/mlir/dtensor_send_recv.h"
42 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
43 #include "tensorflow/dtensor/mlir/layout_parsing.h"
44 #include "tensorflow/dtensor/mlir/op_utils.h"
45 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
46 #include "tensorflow/dtensor/mlir/value_utils.h"
47 
48 namespace tensorflow {
49 namespace dtensor {
50 namespace {
51 
52 // Validates send/recv layout and mesh configurations. Among other things, this
53 // checks for below constraints.
54 // 1. Src/target layouts have non empty mesh.
55 // 2. Src/target layouts have the same host.
56 // 3. Src/target layouts are from different mesh.
57 // 4. One of scr/target layout is from host mesh cluster.
58 // 5. CPU host cluster mesh has 1 device.
ValidateSendRecvLayoutConfiguration(mlir::TF::DTensorSend dtensor_send,mlir::TF::DTensorRecv dtensor_recv)59 Status ValidateSendRecvLayoutConfiguration(mlir::TF::DTensorSend dtensor_send,
60                                            mlir::TF::DTensorRecv dtensor_recv) {
61   // If either one of the send/recv ops has already been lowered, then send/recv
62   // configuration has already been verified.
63   if (!dtensor_send || !dtensor_recv) return OkStatus();
64 
65   TF_ASSIGN_OR_RETURN(const absl::optional<Layout> send_layout_or_null,
66                       ExtractLayoutFromOperand(dtensor_send.input()));
67 
68   if (!send_layout_or_null.has_value())
69     return errors::InvalidArgument(
70         "Input to DTensorSend must have specified layout.");
71 
72   const Layout& send_layout = send_layout_or_null.value();
73   const Layout recv_layout = dtensor_recv.layout();
74 
75   const Mesh& send_mesh = send_layout.mesh();
76   const Mesh& recv_mesh = recv_layout.mesh();
77 
78   // If any one of send/recv mesh are empty, return error.
79   if (send_mesh.IsEmpty() || recv_mesh.IsEmpty())
80     return errors::InvalidArgument(
81         "Found empty mesh when sending/receiving tensor across clusters.");
82 
83   // If send host not found in list of receiving hosts, return error.
84   std::vector<std::string> send_hosts = send_layout.ReducedMesh().hosts();
85   std::vector<std::string> recv_hosts = recv_layout.ReducedMesh().hosts();
86   if (send_hosts != recv_hosts)
87     return errors::InvalidArgument("Send and receive hosts don't match");
88 
89   // Check shards in sending host match those in the receiving host.
90   const auto send_host_shard_map = send_layout.HostShardMap();
91   const auto recv_host_shard_map = recv_layout.HostShardMap();
92   for (const std::string& host : send_hosts) {
93     const ShardVector& shards_in_send_host =
94         send_host_shard_map.find(host)->second;
95     ShardVector shards_in_recv_host = recv_host_shard_map.find(host)->second;
96     if (shards_in_send_host != shards_in_recv_host)
97       return errors::InvalidArgument(
98           "Send and receive host shard vectors don't match. Send shard_vector:",
99           shards_in_send_host.ToString(),
100           " / Recv host spec : ", shards_in_recv_host.ToString());
101   }
102 
103   // Send/Recv mesh must be different.
104   if (recv_mesh == send_mesh)
105     return errors::InvalidArgument(
106         "Found CopyToMesh op sending tensor to same mesh. Only use "
107         "CopyToMesh to transfer data across different mesh cluster. For "
108         "changing layout within the same mesh, use tf.Relayout op.");
109 
110   // Either one of send/recv pair must be to/from CPU mesh.
111   // For example, TPU mesh -> GPU mesh or TPU mesh -> another TPU mesh
112   // is disallowed.
113   if (!send_mesh.is_cpu_mesh() && !recv_mesh.is_cpu_mesh())
114     return errors::InvalidArgument(
115         "tf.CopyToMesh op must be used to send data from/to host mesh.");
116 
117   return OkStatus();
118 }
119 
120 // Returns whether to lower DTensorSend/DTensorRecv op to xla backend ops.
121 // Xla backend ops are used when either sending/receiving device uses XLA
122 // compiler.
SendRecvOpUsesXla(const Mesh & send_mesh,const Mesh & recv_mesh)123 bool SendRecvOpUsesXla(const Mesh& send_mesh, const Mesh& recv_mesh) {
124   assert(!(send_mesh.is_tpu_mesh() && recv_mesh.is_tpu_mesh()));
125   return (send_mesh.is_tpu_mesh() || recv_mesh.is_tpu_mesh());
126 }
127 
128 // Takes relayout which may have kMatch dimensions and uses it to mask input.
129 // Here source_layout
MergeLayouts(const absl::flat_hash_set<std::string> & used_mesh_dimensions,const Layout & mask_layout,const Layout & target_layout)130 StatusOr<Layout> MergeLayouts(
131     const absl::flat_hash_set<std::string>& used_mesh_dimensions,
132     const Layout& mask_layout, const Layout& target_layout) {
133   std::vector<std::string> sharding_specs(mask_layout.sharding_spec_strs());
134   for (int i = 0; i < target_layout.rank(); ++i) {
135     if (sharding_specs[i] == Layout::kMatch &&
136         !used_mesh_dimensions.contains(target_layout.sharding_spec(i)))
137       sharding_specs[i] = target_layout.sharding_spec(i);
138   }
139   return Layout::GetLayout(sharding_specs, target_layout.mesh());
140 }
141 
142 // Given one side of layouts, compute the other side of the layouts.
143 // Note that this implies that we compute the same layout for the
144 // operand and output.
ComputeRelayoutLayout(mlir::Operation * op,const llvm::DenseMap<int,Layout> & layouts)145 StatusOr<llvm::DenseMap<int, Layout>> ComputeRelayoutLayout(
146     mlir::Operation* op, const llvm::DenseMap<int, Layout>& layouts) {
147   mlir::TF::RelayoutOp relayout = llvm::cast<mlir::TF::RelayoutOp>(op);
148   mlir::StringRef layout_attr = relayout.layout();
149   TF_ASSIGN_OR_RETURN(const Layout mask_layout,
150                       Layout::FromString(layout_attr.str()));
151 
152   absl::flat_hash_set<std::string> used_dimensions;
153   bool match_present = false;
154   for (const std::string& sharding_spec : mask_layout.sharding_spec_strs()) {
155     if (sharding_spec == Layout::kMatch)
156       match_present = true;
157     else if (Layout::IsShardedDimension(sharding_spec))
158       used_dimensions.insert(sharding_spec);
159   }
160   if (!match_present) {
161     return llvm::DenseMap<int, Layout>({{0, mask_layout}});
162   }
163 
164   if (layouts.find(0) != layouts.end()) {
165     TF_ASSIGN_OR_RETURN(
166         Layout new_layout,
167         MergeLayouts(used_dimensions, mask_layout, layouts.lookup(0)));
168     return llvm::DenseMap<int, Layout>({{0, new_layout}});
169   }
170   return llvm::DenseMap<int, Layout>();
171 }
172 }  // namespace
173 
ExpandOp(mlir::Operation * op)174 StatusOr<mlir::Operation*> RelayoutSPMDExpander::ExpandOp(mlir::Operation* op) {
175   mlir::TF::RelayoutOp relayout = mlir::cast<mlir::TF::RelayoutOp>(op);
176   mlir::StringRef layout_attr = relayout.layout();
177   TF_ASSIGN_OR_RETURN(const Layout target_layout,
178                       Layout::FromString(layout_attr.str()));
179   TF_ASSIGN_OR_RETURN(const Layout output_layout,
180                       ExtractRequiredSingleLayoutFromOp(op));
181   TF_ASSIGN_OR_RETURN(const Layout input_layout,
182                       ExtractRequiredLayoutFromOperand(relayout.input()));
183   bool match_present = false;
184   for (const std::string& sharding_spec : target_layout.sharding_spec_strs())
185     if (sharding_spec == Layout::kMatch) match_present = true;
186 
187   if (!match_present && output_layout != target_layout)
188     return errors::Internal(
189         "output layout of Relayout op after layout propagation does not match "
190         "layout specified by Relayout op.");
191 
192   if (input_layout == output_layout) {
193     // Input of RelayoutOp must be output value from DTensorLayout operation
194     // as layout propagation adds DTensorLayout op for each tensor values.
195     // Replace with identity op.
196     mlir::OpBuilder builder(relayout);
197     mlir::TF::IdentityOp op = builder.create<mlir::TF::IdentityOp>(
198         relayout.getLoc(), relayout.input().getType(), relayout.input());
199     relayout.output().replaceAllUsesWith(op.output());
200     relayout.erase();
201     return op.getOperation();
202   }
203 
204   auto value_or_status =
205       EmitRelayout(relayout.input(), input_layout, output_layout);
206   if (!value_or_status.ok())
207     return errors::InvalidArgument(
208         llvm::formatv("Unsupported layout received for tf.Relayout op. Trying "
209                       "to set tensor "
210                       "to layout : {0}. Found error {1}",
211                       layout_attr.str(),
212                       value_or_status.status().error_message())
213             .str());
214   mlir::Value output = value_or_status.ValueOrDie();
215   relayout.output().replaceAllUsesWith(output);
216   relayout.erase();
217   return output.getDefiningOp();
218 }
219 
220 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)221 RelayoutSPMDExpander::ComputeLayoutForward(
222     mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
223   return ComputeRelayoutLayout(op, input_layouts);
224 }
225 
226 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)227 RelayoutSPMDExpander::ComputeLayoutBackward(
228     mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
229   return ComputeRelayoutLayout(op, output_layouts);
230 }
231 
232 namespace {
233 
234 // Returns whether send/recv layout represents send/recv of tensor between
235 // i-th TPU device and i-th device of the host mesh. Host mesh represents the
236 // CPU devices that are 1-to-1 mapped with the TPU mesh devices, having the same
237 // global and local device IDs.
IsOneToOneHostMeshTransfer(const Layout & send_layout,const Layout & recv_layout)238 bool IsOneToOneHostMeshTransfer(const Layout& send_layout,
239                                 const Layout& recv_layout) {
240   const Mesh& send_mesh = send_layout.mesh();
241   const Mesh& recv_mesh = recv_layout.mesh();
242 
243   // Check tensor is being transferred between CPU <-> TPU.
244   if (!(send_mesh.is_tpu_mesh() && recv_mesh.is_cpu_mesh()) &&
245       !(recv_mesh.is_tpu_mesh() && send_mesh.is_cpu_mesh()))
246     return false;
247 
248   // Check tensor transfer is happening between TPU and its host mesh.
249   if (!((send_mesh.is_tpu_mesh() &&
250          send_mesh.tpu_host_mesh() == recv_mesh.ToString()) ||
251         (recv_mesh.is_tpu_mesh() &&
252          recv_mesh.tpu_host_mesh() == send_mesh.ToString())))
253     return false;
254 
255   // Check local device IDs are fully matching so that there is no cross-host
256   // transfer.
257   if (send_mesh.local_device_ids() != recv_mesh.local_device_ids())
258     return false;
259 
260   return send_layout.GetShardVector() == recv_layout.GetShardVector();
261 }
262 
263 }  // namespace
264 
ExpandOp(mlir::Operation * op)265 StatusOr<mlir::Operation*> DTensorSendSPMDExpander::ExpandOp(
266     mlir::Operation* op) {
267   mlir::ModuleOp module = op->getParentOfType<mlir::ModuleOp>();
268   auto dtensor_send = llvm::cast<mlir::TF::DTensorSend>(op);
269 
270   TF_ASSIGN_OR_RETURN(mlir::Operation * recv_op,
271                       GetCorrespondingDTensorSendRecvOp<mlir::TF::DTensorSend>(
272                           module, dtensor_send));
273   auto dtensor_recv = llvm::dyn_cast<mlir::TF::DTensorRecv>(recv_op);
274 
275   TF_RETURN_IF_ERROR(
276       ValidateSendRecvLayoutConfiguration(dtensor_send, dtensor_recv));
277 
278   TF_ASSIGN_OR_RETURN(const Layout input_layout,
279                       ExtractRequiredLayoutFromOperand(dtensor_send.input()));
280 
281   // Is tensor transfer is from TPU mesh to host mesh and send layout and recv
282   // layout is identical, then tensor from each source device is sent to
283   // target device asynchronously.
284   if (IsOneToOneHostMeshTransfer(input_layout, dtensor_send.target_layout())) {
285     return LowerDTensorSendToXlaOp(input_layout, dtensor_send.input(),
286                                    dtensor_send,
287                                    /*send_from_device_zero=*/false);
288   }
289 
290   // Calculate input tensor layout of data to send and target fully replicated
291   // layout. For now, we ensure that all data transfer happen with fully
292   // replicated tensors.
293   const int rank = ValueRank(dtensor_send.input());
294   const Layout target_layout =
295       Layout::ReplicatedOnMesh(input_layout.mesh(), rank);
296 
297   // Convert tensor to send to replicated layout.
298   mlir::OpBuilder builder(dtensor_send);
299   TF_ASSIGN_OR_RETURN(mlir::Value send_input,
300                       EmitAllGather(builder, dtensor_send.input(), input_layout,
301                                     target_layout));
302 
303   // Insert control flow such that only device with device ordinal == 0 sends
304   // the tensor data across mesh.
305   auto send_cluster =
306       dtensor_send->getParentOfType<mlir::tf_device::ClusterOp>();
307   TF_ASSIGN_OR_RETURN(absl::optional<Mesh> mesh,
308                       ExtractDeviceMeshFromOp(send_cluster));
309   if (!mesh.has_value())
310     return errors::InvalidArgument(
311         "failed to lower DTensor CopyToMesh op as sending side mesh is not "
312         "specified.");
313 
314   mlir::Location loc = dtensor_send.getLoc();
315   TF_ASSIGN_OR_RETURN(
316       mlir::Value device_ordinal,
317       GetDeviceOrdinal(*mesh, loc,
318                        send_cluster->getParentOfType<mlir::func::FuncOp>(),
319                        &builder));
320   mlir::Value predicate = builder.create<mlir::TF::EqualOp>(
321       loc, device_ordinal, CreateIntScalarConst(0, builder, loc),
322       /*incompatible_shape_error=*/builder.getBoolAttr(true));
323 
324   auto send_if = builder.create<mlir::TF::IfRegionOp>(
325       loc, llvm::SmallVector<mlir::Type, 4>{}, predicate,
326       /*is_stateless=*/builder.getBoolAttr(true),
327       GetUniqueControlflowFnName("copy_to_mesh_send_if_then", builder),
328       GetUniqueControlflowFnName("copy_to_mesh_send_if_else", builder));
329 
330   // Create empty else branch region.
331   auto& else_branch = send_if.else_branch();
332   else_branch.push_back(new mlir::Block);
333   builder.setInsertionPointToEnd(&else_branch.front());
334   builder.create<mlir::TF::YieldOp>(loc,
335                                     /*operands=*/llvm::ArrayRef<mlir::Value>{});
336 
337   // Create then branch region with DTensorSend op.
338   auto& then_branch = send_if.then_branch();
339   then_branch.push_back(new mlir::Block);
340   builder.setInsertionPointToEnd(&then_branch.front());
341   auto yield = builder.create<mlir::TF::YieldOp>(
342       loc, /*operands=*/llvm::ArrayRef<mlir::Value>{});
343   dtensor_send->moveBefore(yield);
344 
345   // Lower DTensorSend op to actual TF op.
346   TF_ASSIGN_OR_RETURN(const Mesh recv_mesh,
347                       ExtractDeviceMeshEnclosingCluster(recv_op));
348   mlir::Operation* lowered_send;
349   if (SendRecvOpUsesXla(input_layout.mesh(), recv_mesh)) {
350     // Lower DTensorSend op to Xla Send ops.
351     TF_ASSIGN_OR_RETURN(
352         lowered_send,
353         LowerDTensorSendToXlaOp(input_layout, send_input, dtensor_send,
354                                 /*send_from_device_zero=*/true));
355   } else if (input_layout.mesh().is_cpu_mesh() &&
356              target_layout.mesh().is_cpu_mesh()) {
357     // Lower DTensorSend op to TF Host Send op.
358     TF_ASSIGN_OR_RETURN(
359         lowered_send,
360         LowerDTensorSendFromCPUToTFOp(input_layout, send_input, dtensor_send));
361   } else {
362     // TODO(hongjunchoi): Implement SPMD transformation lowering that lowers
363     // DTensorSend to vanilla TF Send op.
364     return errors::Unimplemented(
365         "CopyToMesh between CPU/GPU not implemented yet.");
366   }
367 
368   return lowered_send;
369 }
370 
371 // DTensorSend op respects input layout from input operations and does not
372 // set any preferred inputs layouts. During SPMD expansion, however, tensor
373 // values are changed to replicated layout before transferring data across mesh
374 // cluster.
375 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)376 DTensorSendSPMDExpander::ComputeLayoutForward(
377     mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
378   return llvm::DenseMap<int, Layout>();
379 }
380 
381 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)382 DTensorSendSPMDExpander::ComputeLayoutBackward(
383     mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
384   return llvm::DenseMap<int, Layout>();
385 }
386 
ExpandOp(mlir::Operation * op)387 StatusOr<mlir::Operation*> DTensorRecvSPMDExpander::ExpandOp(
388     mlir::Operation* op) {
389   mlir::ModuleOp module = op->getParentOfType<mlir::ModuleOp>();
390   auto dtensor_recv = llvm::cast<mlir::TF::DTensorRecv>(op);
391 
392   TF_ASSIGN_OR_RETURN(mlir::Operation * send_op,
393                       GetCorrespondingDTensorSendRecvOp<mlir::TF::DTensorRecv>(
394                           module, dtensor_recv));
395   auto dtensor_send = llvm::dyn_cast<mlir::TF::DTensorSend>(send_op);
396 
397   TF_RETURN_IF_ERROR(
398       ValidateSendRecvLayoutConfiguration(dtensor_send, dtensor_recv));
399 
400   TF_ASSIGN_OR_RETURN(const Layout send_layout,
401                       ExtractRequiredLayoutFromOperand(send_op->getOperand(0)));
402 
403   TF_ASSIGN_OR_RETURN(const Mesh send_mesh,
404                       ExtractDeviceMeshEnclosingCluster(send_op));
405 
406   TF_ASSIGN_OR_RETURN(const Layout output_layout,
407                       ExtractRequiredSingleLayoutFromOp(op));
408 
409   mlir::Operation* lowered_recv;
410   const Layout recv_layout = dtensor_recv.layout();
411   const Mesh& recv_mesh = recv_layout.mesh();
412   mlir::OpBuilder builder(dtensor_recv);
413 
414   if (SendRecvOpUsesXla(send_mesh, recv_mesh)) {
415     if (recv_mesh.is_cpu_mesh() ||
416         IsOneToOneHostMeshTransfer(send_layout, recv_layout)) {
417       // Recv can be lowered directly for a 1-to-1 transfer between host and
418       // device.
419       TF_ASSIGN_OR_RETURN(mlir::TensorType local_output_type,
420                           LocalTypeFromGlobalType(
421                               dtensor_recv.layout(),
422                               dtensor_recv.getType().cast<mlir::TensorType>()));
423       TF_ASSIGN_OR_RETURN(lowered_recv, LowerDTensorRecvToXlaOp(
424                                             dtensor_recv, local_output_type));
425       dtensor_recv->replaceAllUsesWith(lowered_recv);
426       dtensor_recv.erase();
427     } else {
428       // For other send/recv layouts, the tensor needs to be replicated.
429       if (!dtensor_recv.layout().IsFullyReplicated()) {
430         return errors::InvalidArgument(
431             "CopyToMesh where target mesh is TPU requires target layout to be "
432             "replicated.");
433       }
434 
435       // For Receiving at TPU, only receive for device with device ordinal 0.
436       auto recv_cluster =
437           dtensor_recv->getParentOfType<mlir::tf_device::ClusterOp>();
438       mlir::Location loc = dtensor_recv.getLoc();
439       TF_ASSIGN_OR_RETURN(
440           mlir::Value device_ordinal,
441           GetDeviceOrdinal(recv_mesh, loc,
442                            recv_cluster->getParentOfType<mlir::func::FuncOp>(),
443                            &builder));
444       mlir::Value predicate = builder.create<mlir::TF::EqualOp>(
445           loc, device_ordinal, CreateIntScalarConst(0, builder, loc),
446           /*incompatible_shape_error=*/builder.getBoolAttr(true));
447 
448       auto recv_if = builder.create<mlir::TF::IfRegionOp>(
449           loc, llvm::SmallVector<mlir::Type, 4>{dtensor_recv.getType()},
450           predicate,
451           /*is_stateless=*/builder.getBoolAttr(true),
452           GetUniqueControlflowFnName("copy_to_mesh_recv_if_then", builder),
453           GetUniqueControlflowFnName("copy_to_mesh_recv_if_else", builder));
454 
455       // Create empty else branch region that outputs zeros.
456       auto& else_branch = recv_if.else_branch();
457       else_branch.push_back(new mlir::Block);
458       builder.setInsertionPointToEnd(&else_branch.front());
459 
460       // Create a zero constant.
461       mlir::Attribute const_attr;
462       if (dtensor_recv.getType().getElementType().isIntOrIndex()) {
463         const_attr = mlir::DenseIntElementsAttr::get(
464             dtensor_recv.getType(), llvm::SmallVector<int32_t>{0});
465       } else {
466         const_attr = mlir::DenseFPElementsAttr::get(
467             dtensor_recv.getType(), llvm::SmallVector<float>{0.0});
468       }
469 
470       mlir::Value zeros = builder.create<mlir::TF::ConstOp>(loc, const_attr);
471       builder.create<mlir::TF::YieldOp>(
472           loc, /*operands=*/llvm::ArrayRef<mlir::Value>{zeros});
473 
474       // Create then branch region with DTensorRecv op.
475       auto& then_branch = recv_if.then_branch();
476       then_branch.push_back(new mlir::Block);
477       builder.setInsertionPointToEnd(&then_branch.front());
478       dtensor_recv->moveBefore(&then_branch.front(), then_branch.front().end());
479 
480       TF_ASSIGN_OR_RETURN(mlir::Operation * xla_recv,
481                           LowerDTensorRecvToXlaOp(dtensor_recv));
482       builder.create<mlir::TF::YieldOp>(
483           loc,
484           /*operands=*/llvm::ArrayRef<mlir::Value>{xla_recv->getResult(0)});
485 
486       // Broadcast the received output to all TPU cores.
487       mlir::Value if_output = recv_if->getResult(0);
488       builder.setInsertionPointAfterValue(if_output);
489       absl::flat_hash_set<std::string> reduced_dims;
490       for (const auto& mesh_dim : recv_mesh.dims())
491         reduced_dims.insert(mesh_dim.name);
492 
493       TF_ASSIGN_OR_RETURN(lowered_recv,
494                           EmitAllReduce(builder, recv_layout, reduced_dims,
495                                         recv_if, kReduceOpAdd));
496 
497       // Replaces usages of DTensorRecv op with the broadcasted value.
498       dtensor_recv.output().replaceUsesWithIf(
499           lowered_recv->getResult(0), [&](mlir::OpOperand& operand) {
500             return !recv_if->isProperAncestor(operand.getOwner());
501           });
502       dtensor_recv.erase();
503     }
504   } else if (dtensor_recv.layout().mesh().is_cpu_mesh() &&
505              send_mesh.is_cpu_mesh()) {
506     // Lower DTensorRecv op to TF Host Recv op.
507     TF_ASSIGN_OR_RETURN(lowered_recv,
508                         LowerDTensorRecvFromCPUToTFOp(send_mesh, dtensor_recv));
509   } else {
510     // TODO(hongjunchoi): Implement SPMD transformation lowering that lowers
511     // DTensorRecv to vanilla TF Recv op.
512     return errors::Unimplemented(
513         "CopyToMesh between CPU/GPU not implemented yet.");
514   }
515 
516   llvm::SmallPtrSet<mlir::Operation*, 4> newly_created_ops;
517   builder.setInsertionPointAfter(lowered_recv);
518   TF_ASSIGN_OR_RETURN(
519       mlir::Value recv_output,
520       EmitAllScatter(builder, lowered_recv->getResult(0), recv_layout,
521                      output_layout, &newly_created_ops));
522   lowered_recv->getResult(0).replaceAllUsesExcept(recv_output,
523                                                   newly_created_ops);
524   return recv_output.getDefiningOp();
525 }
526 
527 // DTensorRecv always returns tensors with fully replicated layout.
528 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)529 DTensorRecvSPMDExpander::ComputeLayoutForward(
530     mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
531   mlir::TF::DTensorRecv dtensor_recv =
532       mlir::dyn_cast<mlir::TF::DTensorRecv>(op);
533   if (!dtensor_recv) {
534     return errors::InvalidArgument(
535         llvm::formatv("Expecting DTensorRecvOp but got {0}", OpName(op)).str());
536   }
537   return llvm::DenseMap<int, Layout>({{0, dtensor_recv.layout()}});
538 }
539 
540 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)541 DTensorRecvSPMDExpander::ComputeLayoutBackward(
542     mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
543   return llvm::DenseMap<int, Layout>();
544 }
545 
546 }  // namespace dtensor
547 }  // namespace tensorflow
548