1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/dtensor/mlir/expansions/dtensor_op_spmd_expander.h"
17
18 #include <string>
19
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_join.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
26 #include "mlir/IR/Builders.h" // from @llvm-project
27 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
28 #include "mlir/IR/Value.h" // from @llvm-project
29 #include "mlir/IR/Visitors.h" // from @llvm-project
30 #include "mlir/Support/LogicalResult.h" // from @llvm-project
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
35 #include "tensorflow/core/platform/errors.h"
36 #include "tensorflow/core/platform/statusor.h"
37 #include "tensorflow/dtensor/cc/dstatus.h"
38 #include "tensorflow/dtensor/cc/tensor_layout.h"
39 #include "tensorflow/dtensor/mlir/collectives.h"
40 #include "tensorflow/dtensor/mlir/device_utils.h"
41 #include "tensorflow/dtensor/mlir/dtensor_send_recv.h"
42 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
43 #include "tensorflow/dtensor/mlir/layout_parsing.h"
44 #include "tensorflow/dtensor/mlir/op_utils.h"
45 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
46 #include "tensorflow/dtensor/mlir/value_utils.h"
47
48 namespace tensorflow {
49 namespace dtensor {
50 namespace {
51
52 // Validates send/recv layout and mesh configurations. Among other things, this
53 // checks for below constraints.
54 // 1. Src/target layouts have non empty mesh.
55 // 2. Src/target layouts have the same host.
56 // 3. Src/target layouts are from different mesh.
57 // 4. One of scr/target layout is from host mesh cluster.
58 // 5. CPU host cluster mesh has 1 device.
ValidateSendRecvLayoutConfiguration(mlir::TF::DTensorSend dtensor_send,mlir::TF::DTensorRecv dtensor_recv)59 Status ValidateSendRecvLayoutConfiguration(mlir::TF::DTensorSend dtensor_send,
60 mlir::TF::DTensorRecv dtensor_recv) {
61 // If either one of the send/recv ops has already been lowered, then send/recv
62 // configuration has already been verified.
63 if (!dtensor_send || !dtensor_recv) return OkStatus();
64
65 TF_ASSIGN_OR_RETURN(const absl::optional<Layout> send_layout_or_null,
66 ExtractLayoutFromOperand(dtensor_send.input()));
67
68 if (!send_layout_or_null.has_value())
69 return errors::InvalidArgument(
70 "Input to DTensorSend must have specified layout.");
71
72 const Layout& send_layout = send_layout_or_null.value();
73 const Layout recv_layout = dtensor_recv.layout();
74
75 const Mesh& send_mesh = send_layout.mesh();
76 const Mesh& recv_mesh = recv_layout.mesh();
77
78 // If any one of send/recv mesh are empty, return error.
79 if (send_mesh.IsEmpty() || recv_mesh.IsEmpty())
80 return errors::InvalidArgument(
81 "Found empty mesh when sending/receiving tensor across clusters.");
82
83 // If send host not found in list of receiving hosts, return error.
84 std::vector<std::string> send_hosts = send_layout.ReducedMesh().hosts();
85 std::vector<std::string> recv_hosts = recv_layout.ReducedMesh().hosts();
86 if (send_hosts != recv_hosts)
87 return errors::InvalidArgument("Send and receive hosts don't match");
88
89 // Check shards in sending host match those in the receiving host.
90 const auto send_host_shard_map = send_layout.HostShardMap();
91 const auto recv_host_shard_map = recv_layout.HostShardMap();
92 for (const std::string& host : send_hosts) {
93 const ShardVector& shards_in_send_host =
94 send_host_shard_map.find(host)->second;
95 ShardVector shards_in_recv_host = recv_host_shard_map.find(host)->second;
96 if (shards_in_send_host != shards_in_recv_host)
97 return errors::InvalidArgument(
98 "Send and receive host shard vectors don't match. Send shard_vector:",
99 shards_in_send_host.ToString(),
100 " / Recv host spec : ", shards_in_recv_host.ToString());
101 }
102
103 // Send/Recv mesh must be different.
104 if (recv_mesh == send_mesh)
105 return errors::InvalidArgument(
106 "Found CopyToMesh op sending tensor to same mesh. Only use "
107 "CopyToMesh to transfer data across different mesh cluster. For "
108 "changing layout within the same mesh, use tf.Relayout op.");
109
110 // Either one of send/recv pair must be to/from CPU mesh.
111 // For example, TPU mesh -> GPU mesh or TPU mesh -> another TPU mesh
112 // is disallowed.
113 if (!send_mesh.is_cpu_mesh() && !recv_mesh.is_cpu_mesh())
114 return errors::InvalidArgument(
115 "tf.CopyToMesh op must be used to send data from/to host mesh.");
116
117 return OkStatus();
118 }
119
120 // Returns whether to lower DTensorSend/DTensorRecv op to xla backend ops.
121 // Xla backend ops are used when either sending/receiving device uses XLA
122 // compiler.
SendRecvOpUsesXla(const Mesh & send_mesh,const Mesh & recv_mesh)123 bool SendRecvOpUsesXla(const Mesh& send_mesh, const Mesh& recv_mesh) {
124 assert(!(send_mesh.is_tpu_mesh() && recv_mesh.is_tpu_mesh()));
125 return (send_mesh.is_tpu_mesh() || recv_mesh.is_tpu_mesh());
126 }
127
128 // Takes relayout which may have kMatch dimensions and uses it to mask input.
129 // Here source_layout
MergeLayouts(const absl::flat_hash_set<std::string> & used_mesh_dimensions,const Layout & mask_layout,const Layout & target_layout)130 StatusOr<Layout> MergeLayouts(
131 const absl::flat_hash_set<std::string>& used_mesh_dimensions,
132 const Layout& mask_layout, const Layout& target_layout) {
133 std::vector<std::string> sharding_specs(mask_layout.sharding_spec_strs());
134 for (int i = 0; i < target_layout.rank(); ++i) {
135 if (sharding_specs[i] == Layout::kMatch &&
136 !used_mesh_dimensions.contains(target_layout.sharding_spec(i)))
137 sharding_specs[i] = target_layout.sharding_spec(i);
138 }
139 return Layout::GetLayout(sharding_specs, target_layout.mesh());
140 }
141
142 // Given one side of layouts, compute the other side of the layouts.
143 // Note that this implies that we compute the same layout for the
144 // operand and output.
ComputeRelayoutLayout(mlir::Operation * op,const llvm::DenseMap<int,Layout> & layouts)145 StatusOr<llvm::DenseMap<int, Layout>> ComputeRelayoutLayout(
146 mlir::Operation* op, const llvm::DenseMap<int, Layout>& layouts) {
147 mlir::TF::RelayoutOp relayout = llvm::cast<mlir::TF::RelayoutOp>(op);
148 mlir::StringRef layout_attr = relayout.layout();
149 TF_ASSIGN_OR_RETURN(const Layout mask_layout,
150 Layout::FromString(layout_attr.str()));
151
152 absl::flat_hash_set<std::string> used_dimensions;
153 bool match_present = false;
154 for (const std::string& sharding_spec : mask_layout.sharding_spec_strs()) {
155 if (sharding_spec == Layout::kMatch)
156 match_present = true;
157 else if (Layout::IsShardedDimension(sharding_spec))
158 used_dimensions.insert(sharding_spec);
159 }
160 if (!match_present) {
161 return llvm::DenseMap<int, Layout>({{0, mask_layout}});
162 }
163
164 if (layouts.find(0) != layouts.end()) {
165 TF_ASSIGN_OR_RETURN(
166 Layout new_layout,
167 MergeLayouts(used_dimensions, mask_layout, layouts.lookup(0)));
168 return llvm::DenseMap<int, Layout>({{0, new_layout}});
169 }
170 return llvm::DenseMap<int, Layout>();
171 }
172 } // namespace
173
ExpandOp(mlir::Operation * op)174 StatusOr<mlir::Operation*> RelayoutSPMDExpander::ExpandOp(mlir::Operation* op) {
175 mlir::TF::RelayoutOp relayout = mlir::cast<mlir::TF::RelayoutOp>(op);
176 mlir::StringRef layout_attr = relayout.layout();
177 TF_ASSIGN_OR_RETURN(const Layout target_layout,
178 Layout::FromString(layout_attr.str()));
179 TF_ASSIGN_OR_RETURN(const Layout output_layout,
180 ExtractRequiredSingleLayoutFromOp(op));
181 TF_ASSIGN_OR_RETURN(const Layout input_layout,
182 ExtractRequiredLayoutFromOperand(relayout.input()));
183 bool match_present = false;
184 for (const std::string& sharding_spec : target_layout.sharding_spec_strs())
185 if (sharding_spec == Layout::kMatch) match_present = true;
186
187 if (!match_present && output_layout != target_layout)
188 return errors::Internal(
189 "output layout of Relayout op after layout propagation does not match "
190 "layout specified by Relayout op.");
191
192 if (input_layout == output_layout) {
193 // Input of RelayoutOp must be output value from DTensorLayout operation
194 // as layout propagation adds DTensorLayout op for each tensor values.
195 // Replace with identity op.
196 mlir::OpBuilder builder(relayout);
197 mlir::TF::IdentityOp op = builder.create<mlir::TF::IdentityOp>(
198 relayout.getLoc(), relayout.input().getType(), relayout.input());
199 relayout.output().replaceAllUsesWith(op.output());
200 relayout.erase();
201 return op.getOperation();
202 }
203
204 auto value_or_status =
205 EmitRelayout(relayout.input(), input_layout, output_layout);
206 if (!value_or_status.ok())
207 return errors::InvalidArgument(
208 llvm::formatv("Unsupported layout received for tf.Relayout op. Trying "
209 "to set tensor "
210 "to layout : {0}. Found error {1}",
211 layout_attr.str(),
212 value_or_status.status().error_message())
213 .str());
214 mlir::Value output = value_or_status.ValueOrDie();
215 relayout.output().replaceAllUsesWith(output);
216 relayout.erase();
217 return output.getDefiningOp();
218 }
219
220 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)221 RelayoutSPMDExpander::ComputeLayoutForward(
222 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
223 return ComputeRelayoutLayout(op, input_layouts);
224 }
225
226 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)227 RelayoutSPMDExpander::ComputeLayoutBackward(
228 mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
229 return ComputeRelayoutLayout(op, output_layouts);
230 }
231
232 namespace {
233
234 // Returns whether send/recv layout represents send/recv of tensor between
235 // i-th TPU device and i-th device of the host mesh. Host mesh represents the
236 // CPU devices that are 1-to-1 mapped with the TPU mesh devices, having the same
237 // global and local device IDs.
IsOneToOneHostMeshTransfer(const Layout & send_layout,const Layout & recv_layout)238 bool IsOneToOneHostMeshTransfer(const Layout& send_layout,
239 const Layout& recv_layout) {
240 const Mesh& send_mesh = send_layout.mesh();
241 const Mesh& recv_mesh = recv_layout.mesh();
242
243 // Check tensor is being transferred between CPU <-> TPU.
244 if (!(send_mesh.is_tpu_mesh() && recv_mesh.is_cpu_mesh()) &&
245 !(recv_mesh.is_tpu_mesh() && send_mesh.is_cpu_mesh()))
246 return false;
247
248 // Check tensor transfer is happening between TPU and its host mesh.
249 if (!((send_mesh.is_tpu_mesh() &&
250 send_mesh.tpu_host_mesh() == recv_mesh.ToString()) ||
251 (recv_mesh.is_tpu_mesh() &&
252 recv_mesh.tpu_host_mesh() == send_mesh.ToString())))
253 return false;
254
255 // Check local device IDs are fully matching so that there is no cross-host
256 // transfer.
257 if (send_mesh.local_device_ids() != recv_mesh.local_device_ids())
258 return false;
259
260 return send_layout.GetShardVector() == recv_layout.GetShardVector();
261 }
262
263 } // namespace
264
ExpandOp(mlir::Operation * op)265 StatusOr<mlir::Operation*> DTensorSendSPMDExpander::ExpandOp(
266 mlir::Operation* op) {
267 mlir::ModuleOp module = op->getParentOfType<mlir::ModuleOp>();
268 auto dtensor_send = llvm::cast<mlir::TF::DTensorSend>(op);
269
270 TF_ASSIGN_OR_RETURN(mlir::Operation * recv_op,
271 GetCorrespondingDTensorSendRecvOp<mlir::TF::DTensorSend>(
272 module, dtensor_send));
273 auto dtensor_recv = llvm::dyn_cast<mlir::TF::DTensorRecv>(recv_op);
274
275 TF_RETURN_IF_ERROR(
276 ValidateSendRecvLayoutConfiguration(dtensor_send, dtensor_recv));
277
278 TF_ASSIGN_OR_RETURN(const Layout input_layout,
279 ExtractRequiredLayoutFromOperand(dtensor_send.input()));
280
281 // Is tensor transfer is from TPU mesh to host mesh and send layout and recv
282 // layout is identical, then tensor from each source device is sent to
283 // target device asynchronously.
284 if (IsOneToOneHostMeshTransfer(input_layout, dtensor_send.target_layout())) {
285 return LowerDTensorSendToXlaOp(input_layout, dtensor_send.input(),
286 dtensor_send,
287 /*send_from_device_zero=*/false);
288 }
289
290 // Calculate input tensor layout of data to send and target fully replicated
291 // layout. For now, we ensure that all data transfer happen with fully
292 // replicated tensors.
293 const int rank = ValueRank(dtensor_send.input());
294 const Layout target_layout =
295 Layout::ReplicatedOnMesh(input_layout.mesh(), rank);
296
297 // Convert tensor to send to replicated layout.
298 mlir::OpBuilder builder(dtensor_send);
299 TF_ASSIGN_OR_RETURN(mlir::Value send_input,
300 EmitAllGather(builder, dtensor_send.input(), input_layout,
301 target_layout));
302
303 // Insert control flow such that only device with device ordinal == 0 sends
304 // the tensor data across mesh.
305 auto send_cluster =
306 dtensor_send->getParentOfType<mlir::tf_device::ClusterOp>();
307 TF_ASSIGN_OR_RETURN(absl::optional<Mesh> mesh,
308 ExtractDeviceMeshFromOp(send_cluster));
309 if (!mesh.has_value())
310 return errors::InvalidArgument(
311 "failed to lower DTensor CopyToMesh op as sending side mesh is not "
312 "specified.");
313
314 mlir::Location loc = dtensor_send.getLoc();
315 TF_ASSIGN_OR_RETURN(
316 mlir::Value device_ordinal,
317 GetDeviceOrdinal(*mesh, loc,
318 send_cluster->getParentOfType<mlir::func::FuncOp>(),
319 &builder));
320 mlir::Value predicate = builder.create<mlir::TF::EqualOp>(
321 loc, device_ordinal, CreateIntScalarConst(0, builder, loc),
322 /*incompatible_shape_error=*/builder.getBoolAttr(true));
323
324 auto send_if = builder.create<mlir::TF::IfRegionOp>(
325 loc, llvm::SmallVector<mlir::Type, 4>{}, predicate,
326 /*is_stateless=*/builder.getBoolAttr(true),
327 GetUniqueControlflowFnName("copy_to_mesh_send_if_then", builder),
328 GetUniqueControlflowFnName("copy_to_mesh_send_if_else", builder));
329
330 // Create empty else branch region.
331 auto& else_branch = send_if.else_branch();
332 else_branch.push_back(new mlir::Block);
333 builder.setInsertionPointToEnd(&else_branch.front());
334 builder.create<mlir::TF::YieldOp>(loc,
335 /*operands=*/llvm::ArrayRef<mlir::Value>{});
336
337 // Create then branch region with DTensorSend op.
338 auto& then_branch = send_if.then_branch();
339 then_branch.push_back(new mlir::Block);
340 builder.setInsertionPointToEnd(&then_branch.front());
341 auto yield = builder.create<mlir::TF::YieldOp>(
342 loc, /*operands=*/llvm::ArrayRef<mlir::Value>{});
343 dtensor_send->moveBefore(yield);
344
345 // Lower DTensorSend op to actual TF op.
346 TF_ASSIGN_OR_RETURN(const Mesh recv_mesh,
347 ExtractDeviceMeshEnclosingCluster(recv_op));
348 mlir::Operation* lowered_send;
349 if (SendRecvOpUsesXla(input_layout.mesh(), recv_mesh)) {
350 // Lower DTensorSend op to Xla Send ops.
351 TF_ASSIGN_OR_RETURN(
352 lowered_send,
353 LowerDTensorSendToXlaOp(input_layout, send_input, dtensor_send,
354 /*send_from_device_zero=*/true));
355 } else if (input_layout.mesh().is_cpu_mesh() &&
356 target_layout.mesh().is_cpu_mesh()) {
357 // Lower DTensorSend op to TF Host Send op.
358 TF_ASSIGN_OR_RETURN(
359 lowered_send,
360 LowerDTensorSendFromCPUToTFOp(input_layout, send_input, dtensor_send));
361 } else {
362 // TODO(hongjunchoi): Implement SPMD transformation lowering that lowers
363 // DTensorSend to vanilla TF Send op.
364 return errors::Unimplemented(
365 "CopyToMesh between CPU/GPU not implemented yet.");
366 }
367
368 return lowered_send;
369 }
370
371 // DTensorSend op respects input layout from input operations and does not
372 // set any preferred inputs layouts. During SPMD expansion, however, tensor
373 // values are changed to replicated layout before transferring data across mesh
374 // cluster.
375 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)376 DTensorSendSPMDExpander::ComputeLayoutForward(
377 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
378 return llvm::DenseMap<int, Layout>();
379 }
380
381 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)382 DTensorSendSPMDExpander::ComputeLayoutBackward(
383 mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
384 return llvm::DenseMap<int, Layout>();
385 }
386
ExpandOp(mlir::Operation * op)387 StatusOr<mlir::Operation*> DTensorRecvSPMDExpander::ExpandOp(
388 mlir::Operation* op) {
389 mlir::ModuleOp module = op->getParentOfType<mlir::ModuleOp>();
390 auto dtensor_recv = llvm::cast<mlir::TF::DTensorRecv>(op);
391
392 TF_ASSIGN_OR_RETURN(mlir::Operation * send_op,
393 GetCorrespondingDTensorSendRecvOp<mlir::TF::DTensorRecv>(
394 module, dtensor_recv));
395 auto dtensor_send = llvm::dyn_cast<mlir::TF::DTensorSend>(send_op);
396
397 TF_RETURN_IF_ERROR(
398 ValidateSendRecvLayoutConfiguration(dtensor_send, dtensor_recv));
399
400 TF_ASSIGN_OR_RETURN(const Layout send_layout,
401 ExtractRequiredLayoutFromOperand(send_op->getOperand(0)));
402
403 TF_ASSIGN_OR_RETURN(const Mesh send_mesh,
404 ExtractDeviceMeshEnclosingCluster(send_op));
405
406 TF_ASSIGN_OR_RETURN(const Layout output_layout,
407 ExtractRequiredSingleLayoutFromOp(op));
408
409 mlir::Operation* lowered_recv;
410 const Layout recv_layout = dtensor_recv.layout();
411 const Mesh& recv_mesh = recv_layout.mesh();
412 mlir::OpBuilder builder(dtensor_recv);
413
414 if (SendRecvOpUsesXla(send_mesh, recv_mesh)) {
415 if (recv_mesh.is_cpu_mesh() ||
416 IsOneToOneHostMeshTransfer(send_layout, recv_layout)) {
417 // Recv can be lowered directly for a 1-to-1 transfer between host and
418 // device.
419 TF_ASSIGN_OR_RETURN(mlir::TensorType local_output_type,
420 LocalTypeFromGlobalType(
421 dtensor_recv.layout(),
422 dtensor_recv.getType().cast<mlir::TensorType>()));
423 TF_ASSIGN_OR_RETURN(lowered_recv, LowerDTensorRecvToXlaOp(
424 dtensor_recv, local_output_type));
425 dtensor_recv->replaceAllUsesWith(lowered_recv);
426 dtensor_recv.erase();
427 } else {
428 // For other send/recv layouts, the tensor needs to be replicated.
429 if (!dtensor_recv.layout().IsFullyReplicated()) {
430 return errors::InvalidArgument(
431 "CopyToMesh where target mesh is TPU requires target layout to be "
432 "replicated.");
433 }
434
435 // For Receiving at TPU, only receive for device with device ordinal 0.
436 auto recv_cluster =
437 dtensor_recv->getParentOfType<mlir::tf_device::ClusterOp>();
438 mlir::Location loc = dtensor_recv.getLoc();
439 TF_ASSIGN_OR_RETURN(
440 mlir::Value device_ordinal,
441 GetDeviceOrdinal(recv_mesh, loc,
442 recv_cluster->getParentOfType<mlir::func::FuncOp>(),
443 &builder));
444 mlir::Value predicate = builder.create<mlir::TF::EqualOp>(
445 loc, device_ordinal, CreateIntScalarConst(0, builder, loc),
446 /*incompatible_shape_error=*/builder.getBoolAttr(true));
447
448 auto recv_if = builder.create<mlir::TF::IfRegionOp>(
449 loc, llvm::SmallVector<mlir::Type, 4>{dtensor_recv.getType()},
450 predicate,
451 /*is_stateless=*/builder.getBoolAttr(true),
452 GetUniqueControlflowFnName("copy_to_mesh_recv_if_then", builder),
453 GetUniqueControlflowFnName("copy_to_mesh_recv_if_else", builder));
454
455 // Create empty else branch region that outputs zeros.
456 auto& else_branch = recv_if.else_branch();
457 else_branch.push_back(new mlir::Block);
458 builder.setInsertionPointToEnd(&else_branch.front());
459
460 // Create a zero constant.
461 mlir::Attribute const_attr;
462 if (dtensor_recv.getType().getElementType().isIntOrIndex()) {
463 const_attr = mlir::DenseIntElementsAttr::get(
464 dtensor_recv.getType(), llvm::SmallVector<int32_t>{0});
465 } else {
466 const_attr = mlir::DenseFPElementsAttr::get(
467 dtensor_recv.getType(), llvm::SmallVector<float>{0.0});
468 }
469
470 mlir::Value zeros = builder.create<mlir::TF::ConstOp>(loc, const_attr);
471 builder.create<mlir::TF::YieldOp>(
472 loc, /*operands=*/llvm::ArrayRef<mlir::Value>{zeros});
473
474 // Create then branch region with DTensorRecv op.
475 auto& then_branch = recv_if.then_branch();
476 then_branch.push_back(new mlir::Block);
477 builder.setInsertionPointToEnd(&then_branch.front());
478 dtensor_recv->moveBefore(&then_branch.front(), then_branch.front().end());
479
480 TF_ASSIGN_OR_RETURN(mlir::Operation * xla_recv,
481 LowerDTensorRecvToXlaOp(dtensor_recv));
482 builder.create<mlir::TF::YieldOp>(
483 loc,
484 /*operands=*/llvm::ArrayRef<mlir::Value>{xla_recv->getResult(0)});
485
486 // Broadcast the received output to all TPU cores.
487 mlir::Value if_output = recv_if->getResult(0);
488 builder.setInsertionPointAfterValue(if_output);
489 absl::flat_hash_set<std::string> reduced_dims;
490 for (const auto& mesh_dim : recv_mesh.dims())
491 reduced_dims.insert(mesh_dim.name);
492
493 TF_ASSIGN_OR_RETURN(lowered_recv,
494 EmitAllReduce(builder, recv_layout, reduced_dims,
495 recv_if, kReduceOpAdd));
496
497 // Replaces usages of DTensorRecv op with the broadcasted value.
498 dtensor_recv.output().replaceUsesWithIf(
499 lowered_recv->getResult(0), [&](mlir::OpOperand& operand) {
500 return !recv_if->isProperAncestor(operand.getOwner());
501 });
502 dtensor_recv.erase();
503 }
504 } else if (dtensor_recv.layout().mesh().is_cpu_mesh() &&
505 send_mesh.is_cpu_mesh()) {
506 // Lower DTensorRecv op to TF Host Recv op.
507 TF_ASSIGN_OR_RETURN(lowered_recv,
508 LowerDTensorRecvFromCPUToTFOp(send_mesh, dtensor_recv));
509 } else {
510 // TODO(hongjunchoi): Implement SPMD transformation lowering that lowers
511 // DTensorRecv to vanilla TF Recv op.
512 return errors::Unimplemented(
513 "CopyToMesh between CPU/GPU not implemented yet.");
514 }
515
516 llvm::SmallPtrSet<mlir::Operation*, 4> newly_created_ops;
517 builder.setInsertionPointAfter(lowered_recv);
518 TF_ASSIGN_OR_RETURN(
519 mlir::Value recv_output,
520 EmitAllScatter(builder, lowered_recv->getResult(0), recv_layout,
521 output_layout, &newly_created_ops));
522 lowered_recv->getResult(0).replaceAllUsesExcept(recv_output,
523 newly_created_ops);
524 return recv_output.getDefiningOp();
525 }
526
527 // DTensorRecv always returns tensors with fully replicated layout.
528 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)529 DTensorRecvSPMDExpander::ComputeLayoutForward(
530 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
531 mlir::TF::DTensorRecv dtensor_recv =
532 mlir::dyn_cast<mlir::TF::DTensorRecv>(op);
533 if (!dtensor_recv) {
534 return errors::InvalidArgument(
535 llvm::formatv("Expecting DTensorRecvOp but got {0}", OpName(op)).str());
536 }
537 return llvm::DenseMap<int, Layout>({{0, dtensor_recv.layout()}});
538 }
539
540 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)541 DTensorRecvSPMDExpander::ComputeLayoutBackward(
542 mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
543 return llvm::DenseMap<int, Layout>();
544 }
545
546 } // namespace dtensor
547 } // namespace tensorflow
548