1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering TensorFlow dialect's communication
17 // ops (TF/XLA) to the HLO dialect.
18
19 #include <memory>
20 #include <string>
21
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/None.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/Support/ErrorHandling.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
29 #include "mlir/IR/Attributes.h" // from @llvm-project
30 #include "mlir/IR/Builders.h" // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
33 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
34 #include "mlir/IR/Value.h" // from @llvm-project
35 #include "mlir/IR/Visitors.h" // from @llvm-project
36 #include "mlir/Pass/Pass.h" // from @llvm-project
37 #include "mlir/Support/LLVM.h" // from @llvm-project
38 #include "mlir/Support/LogicalResult.h" // from @llvm-project
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40 #include "tensorflow/compiler/mlir/xla/transforms/tf_xla_passes_detail.h"
41 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
42 #include "tensorflow/compiler/xla/client/sharding_builder.h"
43 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
44 #include "tensorflow/compiler/xla/primitive_util.h"
45 #include "tensorflow/compiler/xla/side_effect_util.h"
46
47 namespace mlir {
48
49 using func::FuncOp;
50
51 namespace mhlo {
52
53 namespace {
54 constexpr char kShardingAttr[] = "mhlo.sharding";
55 constexpr char kFrontendAttributesAttr[] = "mhlo.frontend_attributes";
56 // TPU core that sends to and receives from host.
57 constexpr int64_t kShardingTpuCore = 0;
58
59 // A pass that legalizes TF/XLA communication ops, propagate their respective
60 // tokens (for ordering), and rewrite their respective functions and control
61 // flow ops when necessary.
62 // Note, this currently does not handle nested modules/functions or region based
63 // ops other than certain control flow ops (`mhlo.if`, `mhlo.while`).
64 class LegalizeTFCommunication
65 : public LegalizeTFCommunicationPassBase<LegalizeTFCommunication> {
66 void runOnOperation() override;
67 };
68
69 // Checks if an op is a TF/XLA communication op.
IsCommunicationOp(Operation * op)70 bool IsCommunicationOp(Operation* op) {
71 return isa<TF::_XlaHostComputeMlirOp, TF::XlaSendToHostOp,
72 TF::XlaRecvFromHostOp>(op);
73 }
74
75 // Checks if an op is a supported HLO control flow op.
IsControlFlowOp(Operation * op)76 bool IsControlFlowOp(Operation* op) { return isa<IfOp, WhileOp>(op); }
77
78 // Collects control flow op ancestors of a given op, up until FuncOp. If any
79 // ancestor is not a control flow op or a FuncOp, or of a single block region,
80 // an error will be returned.
GetControlFlowAncestors(Operation * op,llvm::SmallPtrSetImpl<Operation * > & control_flow_ops,llvm::SmallPtrSetImpl<Block * > & control_flow_blocks)81 LogicalResult GetControlFlowAncestors(
82 Operation* op, llvm::SmallPtrSetImpl<Operation*>& control_flow_ops,
83 llvm::SmallPtrSetImpl<Block*>& control_flow_blocks) {
84 Block* block = op->getBlock();
85 Operation* parent = block->getParentOp();
86 while (block && parent && !isa<func::FuncOp>(parent)) {
87 if (!IsControlFlowOp(parent))
88 return op->emitOpError()
89 << "expects ancestor(s) to be of ['" << IfOp::getOperationName()
90 << "', '" << func::FuncOp::getOperationName() << "']";
91
92 if (!llvm::hasSingleElement(block->getParent()->getBlocks()))
93 return op->emitOpError() << "expects single block region ancestor(s)";
94
95 control_flow_ops.insert(parent);
96 control_flow_blocks.insert(block);
97
98 parent = block->getParentOp();
99 block = parent->getBlock();
100 }
101 return success();
102 }
103
104 // Finds communication ops in a function. `control_flow_ops` and
105 // `control_flow_blocks` will be populated with control flow op ancestors for
106 // every communication op.
FindCommunicationOps(func::FuncOp func,llvm::SmallPtrSetImpl<Operation * > & control_flow_ops,llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,bool & has_communication_ops)107 LogicalResult FindCommunicationOps(
108 func::FuncOp func, llvm::SmallPtrSetImpl<Operation*>& control_flow_ops,
109 llvm::SmallPtrSetImpl<Block*>& control_flow_blocks,
110 bool& has_communication_ops) {
111 auto result = func.walk([&](Operation* op) {
112 if (!IsCommunicationOp(op)) return WalkResult::advance();
113 has_communication_ops = true;
114 if (failed(
115 GetControlFlowAncestors(op, control_flow_ops, control_flow_blocks)))
116 return WalkResult::interrupt();
117 return WalkResult::advance();
118 });
119 return failure(result.wasInterrupted());
120 }
121
122 // Helper struct holding a function to be rewritten, it's control flow ops that
123 // lead to a communication op or function call with a communication op
124 // (transitively), and an optional clone of itself. If `clone` is set, function
125 // calls to `original` will be replaced with `clone`.
126 struct FuncToRewrite {
127 func::FuncOp original;
128 llvm::SmallPtrSet<Operation*, 4> control_flow_ops;
129 llvm::SmallPtrSet<Block*, 4> control_flow_blocks;
130 func::FuncOp clone;
131 };
132
133 // Finds all functions that need to be rewritten with communication ops and
134 // and associated tokens.
GetFunctionsToRewrite(ModuleOp module,llvm::SmallDenseMap<StringRef,FuncToRewrite> & funcs_to_rewrite)135 LogicalResult GetFunctionsToRewrite(
136 ModuleOp module,
137 llvm::SmallDenseMap<StringRef, FuncToRewrite>& funcs_to_rewrite) {
138 // Find functions containing communication ops.
139 SmallVector<func::FuncOp, 4> funcs_to_visit;
140 for (func::FuncOp func : module.getOps<func::FuncOp>()) {
141 FuncToRewrite func_to_rewrite{/*original=*/func, /*control_flow_ops=*/{},
142 /*control_flow_blocks=*/{},
143 /*clone=*/nullptr};
144 bool has_communication_ops = false;
145 if (failed(FindCommunicationOps(func, func_to_rewrite.control_flow_ops,
146 func_to_rewrite.control_flow_blocks,
147 has_communication_ops)))
148 return failure();
149
150 if (!has_communication_ops) continue;
151 funcs_to_rewrite.insert({func.getName(), func_to_rewrite});
152 funcs_to_visit.push_back(func);
153 }
154
155 // Find functions that call functions with communication ops, transitively.
156 while (!funcs_to_visit.empty()) {
157 SmallVector<func::FuncOp, 4> new_funcs_to_visit;
158 for (func::FuncOp& func : funcs_to_visit) {
159 auto uses = func.getSymbolUses(module);
160 if (!uses) continue;
161 for (auto& use : *uses) {
162 // Only `mlir::func::CallOp` is supported as this requires knowing how
163 // to rewrite arguments and results to a function.
164 if (!isa<mlir::func::CallOp>(use.getUser())) continue;
165 auto caller_parent_func =
166 use.getUser()->getParentOfType<func::FuncOp>();
167 if (!caller_parent_func) continue;
168
169 FuncToRewrite func_to_rewrite{/*original=*/caller_parent_func,
170 /*control_flow_ops=*/{},
171 /*control_flow_blocks=*/{},
172 /*clone=*/nullptr};
173 if (failed(GetControlFlowAncestors(
174 use.getUser(), func_to_rewrite.control_flow_ops,
175 func_to_rewrite.control_flow_blocks)))
176 return failure();
177
178 auto it = funcs_to_rewrite.insert(
179 {caller_parent_func.getName(), func_to_rewrite});
180 if (it.second) {
181 new_funcs_to_visit.push_back(caller_parent_func);
182 } else {
183 it.first->getSecond().control_flow_ops.insert(
184 func_to_rewrite.control_flow_ops.begin(),
185 func_to_rewrite.control_flow_ops.end());
186 it.first->getSecond().control_flow_blocks.insert(
187 func_to_rewrite.control_flow_blocks.begin(),
188 func_to_rewrite.control_flow_blocks.end());
189 }
190 }
191 }
192
193 funcs_to_visit.swap(new_funcs_to_visit);
194 }
195
196 // Clone public functions that need to be rewritten. Function calls to this
197 // function will be replaced with the cloned function.
198 SymbolTable symbol_table(module);
199 for (auto& func : funcs_to_rewrite) {
200 if (func.getSecond().original.isPublic() &&
201 !func.getSecond().original.symbolKnownUseEmpty(module)) {
202 auto clone = func.getSecond().original.clone();
203 clone.setPrivate();
204 symbol_table.insert(clone);
205 func.getSecond().clone = clone;
206 }
207 }
208
209 return success();
210 }
211
212 // Assigns op sharding to full tensor on `kShardingTpuCore`.
SetOpSharding(Operation * op)213 void SetOpSharding(Operation* op) {
214 std::string sharding_serialized =
215 ::xla::sharding_builder::AssignDevice(kShardingTpuCore)
216 .SerializeAsString();
217 op->setAttr(kShardingAttr,
218 StringAttr::get(op->getContext(), sharding_serialized));
219 }
220
221 // Assigns frontend attributes holding information about data type and
222 // TensorFlow rendezvous channel name. The TensorFlow rendezvous channel name is
223 // handled differently as individual names are used per data send and receive.
SetFrontendAttributes(Operation * op,int32_t index,StringRef key,Type type,bool device_to_host,StringRef host_handler_name)224 void SetFrontendAttributes(Operation* op, int32_t index, StringRef key,
225 Type type, bool device_to_host,
226 StringRef host_handler_name) {
227 MLIRContext* context = op->getContext();
228
229 std::string formatted_key =
230 device_to_host ? llvm::formatv("{0}_dtoh_{1}", key, index).str()
231 : llvm::formatv("{0}_htod_{1}", key, index).str();
232
233 auto rendezvous_name = StringAttr::get(context, formatted_key);
234 auto rendezvous_name_attr = NamedAttribute(
235 StringAttr::get(context, xla::kXlaHostTransferRendezvousNameAttr),
236 rendezvous_name);
237
238 auto element_type = getElementTypeOrSelf(type);
239 auto xla_element_type = ::xla::TypeToPrimitiveType(element_type);
240 const std::string& xla_element_type_str =
241 ::xla::primitive_util::LowercasePrimitiveTypeName(xla_element_type);
242 auto original_type = StringAttr::get(context, xla_element_type_str);
243 auto original_type_attr = NamedAttribute(
244 StringAttr::get(context, xla::kXlaHostTransferOriginalTypeAttr),
245 original_type);
246
247 auto host_handler_name_value =
248 StringAttr::get(context, host_handler_name.str());
249 auto host_handler_name_attr = NamedAttribute(
250 StringAttr::get(context, xla::kXlaHostTransferHandlerNameAttr),
251 host_handler_name_value);
252
253 auto frontend_attributes = DictionaryAttr::get(
254 context,
255 ArrayRef<NamedAttribute>{rendezvous_name_attr, original_type_attr,
256 host_handler_name_attr});
257 op->setAttr(kFrontendAttributesAttr, frontend_attributes);
258 }
259
260 // Creates a `mhlo.send` op for sending value `operand`.
CreateSendOp(OpBuilder & builder,int64_t & channel_id,Location loc,Value operand,StringRef key,size_t index,Value token,StringRef host_handler_name)261 Value CreateSendOp(OpBuilder& builder, int64_t& channel_id, Location loc,
262 Value operand, StringRef key, size_t index, Value token,
263 StringRef host_handler_name) {
264 // type 2 == DEVICE_TO_HOST
265 auto channel_handle = ChannelHandleAttr::get(builder.getContext(),
266 /*handle=*/channel_id++,
267 /*type=*/2);
268 auto send = builder.create<SendOp>(
269 loc, token.getType(), operand, token, channel_handle,
270 /*is_host_transfer=*/builder.getBoolAttr(true));
271
272 SetFrontendAttributes(send, index, key, operand.getType(),
273 /*device_to_host=*/true, host_handler_name);
274
275 SetOpSharding(send);
276
277 return send.getResult();
278 }
279
280 // Creates a `mhlo.recv` op for receiving a value.
CreateRecvOp(OpBuilder & builder,int64_t & channel_id,Location loc,Value result,StringRef key,size_t index,Value token,StringRef host_handler_name)281 Value CreateRecvOp(OpBuilder& builder, int64_t& channel_id, Location loc,
282 Value result, StringRef key, size_t index, Value token,
283 StringRef host_handler_name) {
284 // type 3 == HOST_TO_DEVICE
285 auto channel_handle = ChannelHandleAttr::get(builder.getContext(),
286 /*handle=*/channel_id++,
287 /*type=*/3);
288 auto result_type = result.getType();
289 SmallVector<Type, 2> recv_result_type = {result_type, token.getType()};
290 auto recv =
291 builder.create<RecvOp>(loc, recv_result_type, token, channel_handle,
292 /*is_host_transfer=*/builder.getBoolAttr(true));
293
294 SetFrontendAttributes(recv, index, key, result_type,
295 /*device_to_host=*/false, host_handler_name);
296
297 SetOpSharding(recv);
298
299 result.replaceAllUsesWith(recv.getResult(0));
300
301 return recv.getResult(1);
302 }
303
304 // Creates a new token if necessary, acting as a sink to previous tokens. If
305 // there is only one token in `tokens`, the only token is returned. If `tokens`
306 // is empty, `original_token` is returned instead.
CreateSinkToken(OpBuilder & builder,Location loc,ArrayRef<Value> tokens,Value original_token)307 Value CreateSinkToken(OpBuilder& builder, Location loc, ArrayRef<Value> tokens,
308 Value original_token) {
309 if (tokens.empty()) {
310 return original_token;
311 } else if (llvm::hasSingleElement(tokens)) {
312 return tokens[0];
313 } else {
314 return builder.create<AfterAllOp>(loc, original_token.getType(), tokens)
315 .getResult();
316 }
317 }
318
319 // Replaces `tf._XlaHostComputeMlir` with individual `mhlo.send` and `mhlo.recv`
320 // ops per operand and result. Unique Channel IDs are assigned per transfer.
321 // Sink tokens are created across all `mhlo.send` ops first and then by
322 // all `mhlo.recv` ops.
RewriteHostComputeOp(OpBuilder & builder,int64_t & channel_id,TF::_XlaHostComputeMlirOp host_compute,Value token)323 Value RewriteHostComputeOp(OpBuilder& builder, int64_t& channel_id,
324 TF::_XlaHostComputeMlirOp host_compute,
325 Value token) {
326 builder.setInsertionPoint(host_compute);
327 Location loc = host_compute.getLoc();
328
329 SmallVector<Value, 4> send_tokens;
330 for (auto operand : llvm::enumerate(host_compute.inputs())) {
331 auto send_token = CreateSendOp(
332 builder, channel_id, loc, operand.value(), host_compute.send_key(),
333 operand.index(), token, xla::kXlaHostTransferTfRendezvousHandlerName);
334 send_tokens.push_back(send_token);
335 }
336 token = CreateSinkToken(builder, loc, send_tokens, token);
337
338 SmallVector<Value, 4> recv_tokens;
339 for (auto result : llvm::enumerate(host_compute.outputs())) {
340 auto recv_token = CreateRecvOp(
341 builder, channel_id, loc, result.value(), host_compute.recv_key(),
342 result.index(), token, xla::kXlaHostTransferTfRendezvousHandlerName);
343 recv_tokens.push_back(recv_token);
344 }
345 token = CreateSinkToken(builder, loc, recv_tokens, token);
346
347 host_compute.erase();
348 return token;
349 }
350
351 // Replaces `tf.XlaSendToHost` with a `mhlo.send`.
RewriteSendToHostOp(OpBuilder & builder,int64_t & channel_id,TF::XlaSendToHostOp send_to_host,Value token)352 Value RewriteSendToHostOp(OpBuilder& builder, int64_t& channel_id,
353 TF::XlaSendToHostOp send_to_host, Value token) {
354 builder.setInsertionPoint(send_to_host);
355 token = CreateSendOp(builder, channel_id, send_to_host.getLoc(),
356 send_to_host.input(), send_to_host.key(),
357 /*index=*/0, token,
358 xla::kXlaHostTransferTfRendezvousHandlerName);
359
360 send_to_host.erase();
361 return token;
362 }
363
364 // Replaces `tf.XlaRecvFromHost` with a `mhlo.recv`.
RewriteRecvFromHostOp(OpBuilder & builder,int64_t & channel_id,TF::XlaRecvFromHostOp recv_from_host,Value token)365 Value RewriteRecvFromHostOp(OpBuilder& builder, int64_t& channel_id,
366 TF::XlaRecvFromHostOp recv_from_host, Value token) {
367 builder.setInsertionPoint(recv_from_host);
368 token = CreateRecvOp(builder, channel_id, recv_from_host.getLoc(),
369 recv_from_host.output(), recv_from_host.key(),
370 /*index=*/0, token,
371 xla::kXlaHostTransferTfRendezvousHandlerName);
372
373 recv_from_host.erase();
374 return token;
375 }
376
377 // Replaces a `mlir::func::CallOp` with one that has an extra `!mhlo.token`
378 // operand and `!mhlo.token` result. If `new_symbol` is set, the new call will
379 // be updated to call the `new_symbol` instead.
RewriteCallOp(OpBuilder & builder,func::CallOp call,const Optional<StringRef> & new_symbol,Value token)380 Value RewriteCallOp(OpBuilder& builder, func::CallOp call,
381 const Optional<StringRef>& new_symbol, Value token) {
382 builder.setInsertionPoint(call);
383 auto new_operands = llvm::to_vector(call.getArgOperands());
384 new_operands.push_back(token);
385 auto new_result_types = llvm::to_vector(call.getResultTypes());
386 new_result_types.push_back(token.getType());
387 auto new_call = builder.create<func::CallOp>(
388 call.getLoc(), new_result_types,
389 new_symbol ? *new_symbol : call.getCallee(), new_operands);
390
391 for (auto results : llvm::zip(call.getResults(), new_call.getResults()))
392 std::get<0>(results).replaceAllUsesWith(std::get<1>(results));
393 call.erase();
394 return new_call.getResults().back();
395 }
396
397 // Helper struct holding state of which op to visit to next. If `op` is in a
398 // control flow op region, `region_idx` will be set with the respective region
399 // index. `token` will be current token from the last communication op/control
400 // flow op transitive communication ops.
401 struct OpVisitorState {
402 Optional<unsigned> region_idx;
403 Value token;
404 Operation* op;
405 };
406
407 // Creates a tuple from a sequence of values.
CreateTuple(OpBuilder & builder,Location loc,ArrayRef<Value> operands)408 Value CreateTuple(OpBuilder& builder, Location loc, ArrayRef<Value> operands) {
409 return builder.create<TupleOp>(loc, operands).getResult();
410 }
411
412 // Extends `values` with the value `token` attached. If `flatten_tuple` is
413 // false, `values` will have a single element, say `value`. If `value` is not a
414 // tuple, a new tuple is formed with `token`. If `values` is a tuple, it is
415 // extended instead. New tuple values created are cached.
GetValueWithToken(OpBuilder & builder,ArrayRef<Value> values,Value token,llvm::SmallDenseMap<Value,Value> & rewritten_values,bool flatten_tuple)416 SmallVector<Value> GetValueWithToken(
417 OpBuilder& builder, ArrayRef<Value> values, Value token,
418 llvm::SmallDenseMap<Value, Value>& rewritten_values, bool flatten_tuple) {
419 if (flatten_tuple) {
420 auto operands = llvm::to_vector(values);
421 operands.push_back(token);
422 return operands;
423 }
424
425 auto value = values[0];
426 // If value with token already exists, reuse it.
427 auto it = rewritten_values.find(value);
428 if (it != rewritten_values.end()) return {it->getSecond()};
429
430 auto create_tuple = [&](ArrayRef<Value> operands) {
431 auto new_result = CreateTuple(builder, value.getLoc(), operands);
432 rewritten_values.insert({value, new_result});
433 return new_result;
434 };
435
436 auto tuple_type = value.getType().dyn_cast<TupleType>();
437 // `value` is not a tuple, create a new tuple.
438 if (!tuple_type) return {create_tuple({value, token})};
439
440 // Extend tuple if `value` is a tuple.
441 // If `value` is an op result and the owner is a `mhlo.tuple`, simply unpack
442 // the tuple.
443 if (auto tuple_op = value.getDefiningOp<TupleOp>()) {
444 auto tuple_operands = llvm::to_vector(tuple_op.getOperands());
445 tuple_operands.push_back(token);
446 return {create_tuple(tuple_operands)};
447 }
448
449 // `value` is not created via a `mhlo.tuple` directly, unpack individual
450 // elements directly with `mhlo.get_tuple_element`.
451 SmallVector<Value, 4> tuple_operands;
452 for (auto idx : llvm::seq<int32_t>(0, tuple_type.getTypes().size()))
453 tuple_operands.push_back(
454 builder.create<GetTupleElementOp>(value.getLoc(), value, idx)
455 .getResult());
456
457 tuple_operands.push_back(token);
458 return {create_tuple(tuple_operands)};
459 }
460
461 // Extends the 'types' to include a `mhlo.token` type. If `flatten_tuple` is
462 // false, `types` will have a single element, say `type`. If `type` is not a
463 // tuple type, a new tuple type with `type` and `mhlo.token` type is created
464 // instead.
GetTypeWithToken(OpBuilder & builder,ArrayRef<Type> types,bool flatten_tuple)465 SmallVector<Type> GetTypeWithToken(OpBuilder& builder, ArrayRef<Type> types,
466 bool flatten_tuple) {
467 SmallVector<Type> new_result_types;
468 auto token_type = TokenType::get(builder.getContext());
469
470 if (flatten_tuple) {
471 auto result_types = llvm::to_vector(types);
472 result_types.push_back(token_type);
473 return result_types;
474 }
475
476 auto type = types[0];
477 if (auto tuple_type = type.dyn_cast<TupleType>()) {
478 auto result_types = llvm::to_vector(tuple_type.getTypes());
479 result_types.push_back(token_type);
480 return {builder.getTupleType(result_types)};
481 }
482
483 return {builder.getTupleType({type, token_type})};
484 }
485
486 // Creates a slice of a tuple `value` with `mhlo.get_tuple_element` from index 0
487 // to `end`, exclusive.
CreateSubTuple(OpBuilder & builder,Value value,size_t end)488 Value CreateSubTuple(OpBuilder& builder, Value value, size_t end) {
489 SmallVector<Value, 4> tuple_operands;
490 for (auto idx : llvm::seq<int32_t>(0, end))
491 tuple_operands.push_back(
492 builder.create<GetTupleElementOp>(value.getLoc(), value, idx)
493 .getResult());
494
495 return CreateTuple(builder, value.getLoc(), tuple_operands);
496 }
497
498 // Replaces uses of `values` with `replacements`. If `flatten_tuple` is false,
499 // `values` will have a single element, say `value`. If `value` is not a tuple
500 // type, an explicit `mhlo.get_tuple_element` is created to unpack the tuple and
501 // return the first element. Otherwise, `mhlo.get_tuple_element` users are
502 // simply updated with `replacement`, and all other users are updated with a
503 // slice of `replacement`.
ReplaceWithTupleResult(OpBuilder & builder,ArrayRef<Value> values,ArrayRef<Value> replacements,bool flatten_tuple)504 void ReplaceWithTupleResult(OpBuilder& builder, ArrayRef<Value> values,
505 ArrayRef<Value> replacements, bool flatten_tuple) {
506 if (flatten_tuple) {
507 for (size_t result_index = 0; result_index < values.size(); result_index++)
508 values[result_index].replaceAllUsesWith(replacements[result_index]);
509 return;
510 }
511
512 auto value = values[0];
513 auto replacement = replacements[0];
514 auto tuple_type = value.getType().dyn_cast<TupleType>();
515 if (!tuple_type) {
516 if (!value.use_empty()) {
517 auto new_element = builder.create<GetTupleElementOp>(replacement.getLoc(),
518 replacement, 0);
519 value.replaceAllUsesWith(new_element.getResult());
520 }
521 return;
522 }
523
524 Value sub_tuple;
525 for (auto& use : llvm::make_early_inc_range(value.getUses())) {
526 if (isa<GetTupleElementOp>(use.getOwner())) {
527 use.set(replacement);
528 continue;
529 }
530
531 if (!sub_tuple)
532 sub_tuple = CreateSubTuple(builder, replacement, tuple_type.size());
533
534 use.set(sub_tuple);
535 }
536 }
537
538 // Replaces control flow op block arguments with new block arguments
539 // of types `types`. The last element of the new block argument (token) is
540 // returned.
UpdateControlFlowBlockArgWithToken(OpBuilder & builder,Block & block,ArrayRef<Type> types)541 Value UpdateControlFlowBlockArgWithToken(OpBuilder& builder, Block& block,
542 ArrayRef<Type> types) {
543 builder.setInsertionPointToStart(&block);
544
545 auto old_args_size = block.getNumArguments();
546
547 block.addArguments(
548 types, SmallVector<Location>(types.size(), block.getParent()->getLoc()));
549
550 auto old_args = ArrayRef<Value>(block.getArguments().begin(),
551 block.getArguments().begin() + old_args_size);
552 auto new_args = ArrayRef<Value>(block.getArguments().begin() + old_args_size,
553 block.getArguments().end());
554 assert(!new_args.empty());
555
556 ReplaceWithTupleResult(builder, old_args, new_args, /*flatten_tuple=*/true);
557 auto new_arg = new_args[new_args.size() - 1];
558
559 block.eraseArguments(
560 llvm::to_vector(llvm::seq((unsigned)0, (unsigned)old_args_size)));
561
562 return new_arg;
563 }
564
565 // Updates control flow op terminator with an extra element `token`.
RewriteControlFlowTerminator(OpBuilder & builder,Operation * terminator,Value token,bool flatten_tuple)566 void RewriteControlFlowTerminator(OpBuilder& builder, Operation* terminator,
567 Value token, bool flatten_tuple) {
568 assert(flatten_tuple || terminator->getNumOperands() == 1);
569 assert(flatten_tuple || terminator->getBlock()->getNumArguments() == 1);
570 // `mhlo.while` cond terminator does not need to be rewritten as it always
571 // returns a tensor<i1> predicate value.
572 if (auto while_parent = dyn_cast_or_null<WhileOp>(terminator->getParentOp()))
573 if (terminator->getParentRegion() == &while_parent.cond()) return;
574
575 builder.setInsertionPoint(terminator);
576 llvm::SmallDenseMap<Value, Value> rewritten_operands;
577 auto new_results =
578 GetValueWithToken(builder, llvm::to_vector(terminator->getOperands()),
579 token, rewritten_operands, flatten_tuple);
580 terminator->setOperands(new_results);
581 }
582
583 // Rewrites a `mhlo.if` op to receive and forward a `mhlo.token`. As If op does
584 // not have any operands other than the predicate, hence we implicitly capture
585 // the parent token. Also we use the same implicit token for use in the If op's
586 // regions.
RewriteRegionIfOp(OpBuilder & builder,IfOp region_if,SmallVectorImpl<OpVisitorState> & ops_to_visit,Value token)587 void RewriteRegionIfOp(OpBuilder& builder, IfOp region_if,
588 SmallVectorImpl<OpVisitorState>& ops_to_visit,
589 Value token) {
590 llvm::SmallDenseMap<Value, Value> rewritten_operands;
591
592 auto new_result_types =
593 GetTypeWithToken(builder, llvm::to_vector(region_if.getResultTypes()),
594 /*flatten_tuple=*/true);
595
596 // Create new `mhlo.if` op with extra token operands and result.
597 auto new_if = builder.create<IfOp>(region_if.getLoc(), new_result_types,
598 region_if.pred());
599
600 // Move all regions from the old `mhlo.if` op to its replacement.
601 new_if.true_branch().takeBody(region_if.true_branch());
602 new_if.false_branch().takeBody(region_if.false_branch());
603
604 // Forward result from old `mhlo.if` with replacement.
605 SmallVector<Value> old_if_results = region_if.getResults();
606 SmallVector<Value> new_if_results = new_if.getResults();
607
608 ReplaceWithTupleResult(builder, old_if_results, new_if_results,
609 /*flatten_tuple=*/true);
610
611 // auto new_token = new_if_results[new_if_results.size() - 1];
612
613 region_if.erase();
614
615 // Next op to visit. The replacement is visited but at its first region.
616 // The new region use the same implicit token used by the If op.
617 ops_to_visit.push_back({/*region_idx=*/0, token, new_if});
618 }
619
620 // Rewrites a `mhlo.if`/`mhlo.while` region to receive and forward a
621 // `mhlo.token`. The block argument is updated to have an extra `mhlo.token`
622 // element. If the region block is to be rewritten, the next op to visit is set
623 // to the first op in the block. Otherwise the terminator is updated to forward
624 // `token`.
RewriteControlFlowOpRegion(OpBuilder & builder,Operation * region_op,unsigned region_idx,ArrayRef<Type> block_arg_types,SmallVectorImpl<OpVisitorState> & ops_to_visit,const llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,Value token)625 void RewriteControlFlowOpRegion(
626 OpBuilder& builder, Operation* region_op, unsigned region_idx,
627 ArrayRef<Type> block_arg_types,
628 SmallVectorImpl<OpVisitorState>& ops_to_visit,
629 const llvm::SmallPtrSetImpl<Block*>& control_flow_blocks, Value token) {
630 ops_to_visit.push_back({region_idx + 1, token, region_op});
631
632 Region& region = region_op->getRegion(region_idx);
633 assert(llvm::hasSingleElement(region));
634
635 auto block_token = UpdateControlFlowBlockArgWithToken(builder, region.front(),
636 block_arg_types);
637
638 if (control_flow_blocks.contains(®ion.front())) {
639 ops_to_visit.push_back(
640 {/*region_idx=*/llvm::None, block_token, ®ion.front().front()});
641 return;
642 }
643
644 RewriteControlFlowTerminator(builder, region.front().getTerminator(),
645 block_token, /*flatten_tuple=*/true);
646 }
647
648 // For mlir::IfOp or mlir::CaseOp, replace the use of their region's block
649 // argument (of type token) with 'implicit_operand'.
ReplaceBlockArgumentsWithImplicitOperands(mlir::Operation * op,unsigned region_idx,Value implicit_operand)650 void ReplaceBlockArgumentsWithImplicitOperands(mlir::Operation* op,
651 unsigned region_idx,
652 Value implicit_operand) {
653 assert((mlir::dyn_cast<mlir::mhlo::IfOp>(*op) ||
654 mlir::dyn_cast<mlir::mhlo::CaseOp>(*op)) &&
655 "Unexpected mlir op in "
656 "HloFunctionImporter::ReplaceBlockArgumentsWithImplicitOperands!");
657
658 auto& region = op->getRegion(region_idx);
659 region.getArgument(0).replaceAllUsesWith(implicit_operand);
660 region.front().eraseArguments(
661 llvm::to_vector(llvm::seq<unsigned>(0, region.getNumArguments())));
662 }
663
664 // Rewrites an `mhlo.if` op or its region. If `region_idx` is not set, the op
665 // operands and results are rewritten. If `region_idx` is set, region
666 // `region_idx` is rewritten to take in and return an additional token. Returns
667 // true if the op or its region was rewritten.
ProcessRegionIfOp(OpBuilder & builder,IfOp region_if,Optional<unsigned> region_idx,SmallVectorImpl<OpVisitorState> & ops_to_visit,const llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,Value token)668 bool ProcessRegionIfOp(OpBuilder& builder, IfOp region_if,
669 Optional<unsigned> region_idx,
670 SmallVectorImpl<OpVisitorState>& ops_to_visit,
671 const llvm::SmallPtrSetImpl<Block*>& control_flow_blocks,
672 Value token) {
673 builder.setInsertionPoint(region_if);
674
675 if (!region_idx) {
676 RewriteRegionIfOp(builder, region_if, ops_to_visit, token);
677 return true;
678 }
679
680 if (*region_idx < region_if.getNumRegions()) {
681 // For the region-blocks of If op, we create a dummy token argument. Later
682 // we replace that block-argument's uses with the same (implicitly captured)
683 // token 'token', used for If op, and erase the argument.
684 // Note that 'RewriteControlFlowOpRegion' sets the token, used for the first
685 // operation of region_idx'th region, to the dummy block-argument. As we
686 // erase that argument, we also need to make sure that the token used for
687 // the next operation is set to 'token'.
688 RewriteControlFlowOpRegion(builder, region_if, *region_idx,
689 {token.getType()}, ops_to_visit,
690 control_flow_blocks, token);
691
692 ReplaceBlockArgumentsWithImplicitOperands(region_if.getOperation(),
693 *region_idx, token);
694
695 auto next_visitor_state = ops_to_visit.back();
696 next_visitor_state.token = token;
697 ops_to_visit.pop_back();
698 ops_to_visit.push_back(next_visitor_state);
699 return true;
700 }
701
702 return false;
703 }
704
705 // Rewrites a `mhlo.while` op to receive and forward a `mhlo.token`. Operands to
706 // the op for all of its regions are extended to have an extra operand `token`.
RewriteRegionWhileOp(OpBuilder & builder,WhileOp region_while,SmallVectorImpl<OpVisitorState> & ops_to_visit,Value token)707 void RewriteRegionWhileOp(OpBuilder& builder, WhileOp region_while,
708 SmallVectorImpl<OpVisitorState>& ops_to_visit,
709 Value token) {
710 llvm::SmallDenseMap<Value, Value> rewritten_operands;
711
712 // Rewrite region operand to have an extra operand `token`.
713 auto new_val_operands =
714 GetValueWithToken(builder, llvm::to_vector(region_while.getOperands()),
715 token, rewritten_operands,
716 /*flatten_tuple=*/true);
717
718 auto new_result_types =
719 GetTypeWithToken(builder, llvm::to_vector(region_while.getResultTypes()),
720 /*flatten_tuple*/ true);
721
722 // Create new `mhlo.while` op with extra token operand and result.
723 auto new_while = builder.create<WhileOp>(region_while.getLoc(),
724 new_result_types, new_val_operands);
725
726 // Move all regions from the old `mhlo.while` op to its replacement.
727 new_while.cond().takeBody(region_while.cond());
728 new_while.body().takeBody(region_while.body());
729
730 // Forward result from old `mhlo.while` with replacement.
731 SmallVector<Value> old_while_results = region_while.getResults();
732 SmallVector<Value> new_while_results = new_while.getResults();
733
734 ReplaceWithTupleResult(builder, old_while_results, new_while_results,
735 /*flatten_tuple*/ true);
736
737 auto new_token = new_while_results[new_while_results.size() - 1];
738
739 region_while.erase();
740
741 // Next op to visit. The replacement is visited but at its first region. The
742 // token result of the new region if is propagated.
743 ops_to_visit.push_back({/*region_idx=*/0, new_token, new_while});
744 }
745
746 // Rewrites an `mhlo.while` op or its region. If `region_idx` is not set, the op
747 // operands and results are rewritten. If `region_idx` is set, region
748 // `region_idx` is rewritten to take in and return an additional token. Returns
749 // true if the op or its region was rewritten.
ProcessRegionWhileOp(OpBuilder & builder,WhileOp region_while,Optional<unsigned> region_idx,SmallVectorImpl<OpVisitorState> & ops_to_visit,const llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,Value token)750 bool ProcessRegionWhileOp(
751 OpBuilder& builder, WhileOp region_while, Optional<unsigned> region_idx,
752 SmallVectorImpl<OpVisitorState>& ops_to_visit,
753 const llvm::SmallPtrSetImpl<Block*>& control_flow_blocks, Value token) {
754 builder.setInsertionPoint(region_while);
755
756 if (!region_idx) {
757 RewriteRegionWhileOp(builder, region_while, ops_to_visit, token);
758 return true;
759 }
760
761 if (*region_idx < region_while.getNumRegions()) {
762 SmallVector<Type> operand_types;
763 for (auto operand : region_while.operand())
764 operand_types.push_back(operand.getType());
765 RewriteControlFlowOpRegion(builder, region_while, *region_idx,
766 operand_types, ops_to_visit, control_flow_blocks,
767 token);
768 return true;
769 }
770
771 return false;
772 }
773
774 // Updates function type based on current function body block arguments and
775 // terminator operand types.
UpdateFunctionType(OpBuilder & builder,func::FuncOp func,Block & func_body)776 void UpdateFunctionType(OpBuilder& builder, func::FuncOp func,
777 Block& func_body) {
778 auto new_argument_types = llvm::to_vector(func_body.getArgumentTypes());
779 auto new_result_types =
780 llvm::to_vector(func_body.getTerminator()->getOperandTypes());
781 func.setType(FunctionType::get(builder.getContext(), new_argument_types,
782 new_result_types));
783 }
784
785 // Replaces a function terminator `return` with another `return` that has an
786 // extra `mhlo.token` operand.
RewriteFunctionTerminator(OpBuilder & builder,mlir::func::ReturnOp terminator,Value token)787 void RewriteFunctionTerminator(OpBuilder& builder,
788 mlir::func::ReturnOp terminator, Value token) {
789 auto new_results = llvm::to_vector(terminator.getOperands());
790 new_results.push_back(token);
791 builder.setInsertionPoint(terminator);
792 builder.create<mlir::func::ReturnOp>(terminator.getLoc(), new_results);
793 terminator.erase();
794 }
795
796 // Rewrites a function body and communication ops inside. Region control flow
797 // are updated when necessary, to propagate tokens. The function may either be
798 // rewritten to create a token or take in and return a token, depending on its
799 // visibility and if there are any callers.
RewriteFunction(OpBuilder & builder,int64_t & channel_id,ModuleOp module,FuncOp func,const llvm::SmallDenseMap<StringRef,FuncToRewrite> & funcs,const llvm::SmallPtrSetImpl<Operation * > & control_flow_ops,const llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,bool is_clone)800 LogicalResult RewriteFunction(
801 OpBuilder& builder, int64_t& channel_id, ModuleOp module, FuncOp func,
802 const llvm::SmallDenseMap<StringRef, FuncToRewrite>& funcs,
803 const llvm::SmallPtrSetImpl<Operation*>& control_flow_ops,
804 const llvm::SmallPtrSetImpl<Block*>& control_flow_blocks, bool is_clone) {
805 MLIRContext* context = module.getContext();
806 if (!llvm::hasSingleElement(func.getBody()))
807 return func.emitError()
808 << "'" << FuncOp::getOperationName()
809 << "' ops with more than one block are not supported";
810
811 bool rewrite_block =
812 is_clone || (!func.isPublic() && !func.symbolKnownUseEmpty(module));
813 Block& func_body = func.front();
814
815 builder.setInsertionPointToStart(&func_body);
816 auto token_type = TokenType::get(context);
817 // If a function is public, it's signature should not be modified, and instead
818 // a token will be created. Otherwise a token block argument is inserted.
819 Value init_token =
820 rewrite_block ? func_body.addArgument(token_type, func.getLoc())
821 : builder.create<CreateTokenOp>(func.getLoc(), token_type)
822 .getResult();
823
824 // Stack to keep track of region based control flow op nesting and current
825 // op to visit.
826 SmallVector<OpVisitorState, 4> ops_to_visit{
827 {/*region_idx=*/llvm::None, init_token, &func_body.front()}};
828
829 while (!ops_to_visit.empty()) {
830 OpVisitorState op_to_visit = ops_to_visit.pop_back_val();
831 Operation* curr_op = op_to_visit.op;
832
833 Value token = op_to_visit.token;
834 // Ops may be removed, so the next op is kept track of beforehand.
835 Operation* next_op = curr_op->getNextNode();
836
837 if (auto host_compute = dyn_cast<TF::_XlaHostComputeMlirOp>(curr_op)) {
838 token = RewriteHostComputeOp(builder, channel_id, host_compute, token);
839 } else if (auto send_to_host = dyn_cast<TF::XlaSendToHostOp>(curr_op)) {
840 token = RewriteSendToHostOp(builder, channel_id, send_to_host, token);
841 } else if (auto recv_from_host = dyn_cast<TF::XlaRecvFromHostOp>(curr_op)) {
842 token = RewriteRecvFromHostOp(builder, channel_id, recv_from_host, token);
843 } else if (auto call = dyn_cast<mlir::func::CallOp>(curr_op)) {
844 // Only `mlir::func::CallOp` is supported as this requires knowing how to
845 // rewrite arguments and results to a function.
846 auto it = funcs.find(call.getCallee());
847 if (it != funcs.end()) {
848 func::FuncOp clone = it->getSecond().clone;
849 Optional<StringRef> symbol_name =
850 clone ? Optional<StringRef>(clone.getName()) : llvm::None;
851 // If the function being called is to be cloned, update the call to also
852 // point to the cloned function.
853 token = RewriteCallOp(builder, call, symbol_name, token);
854 }
855 } else if (auto region_if = dyn_cast<IfOp>(curr_op)) {
856 if (op_to_visit.region_idx || control_flow_ops.contains(region_if)) {
857 auto exist_unprocessed_region =
858 ProcessRegionIfOp(builder, region_if, op_to_visit.region_idx,
859 ops_to_visit, control_flow_blocks, token);
860
861 // Once all the IfOp regions are processed (i.e.
862 // 'exist_unprocessed_region' == false), select returned token-value
863 // from IfOp as the token to be used for the following op.
864 if (!exist_unprocessed_region) {
865 token = curr_op->getResult(curr_op->getNumResults() - 1);
866 } else {
867 continue;
868 }
869 }
870 } else if (auto region_while = dyn_cast<WhileOp>(curr_op)) {
871 if (op_to_visit.region_idx || control_flow_ops.contains(region_while))
872 if (ProcessRegionWhileOp(builder, region_while, op_to_visit.region_idx,
873 ops_to_visit, control_flow_blocks, token))
874 continue;
875 } else if (auto region_terminator = dyn_cast<mhlo::ReturnOp>(curr_op)) {
876 bool flatten_tuple = isa<mhlo::WhileOp, mhlo::IfOp, mhlo::CaseOp>(
877 region_terminator->getParentOp());
878 RewriteControlFlowTerminator(builder, region_terminator, token,
879 flatten_tuple);
880 // There is no next op after the control flow op terminator, simply let
881 // stack have one less element.
882 continue;
883 } else if (auto func_terminator = dyn_cast<mlir::func::ReturnOp>(curr_op)) {
884 if (rewrite_block)
885 RewriteFunctionTerminator(builder, func_terminator, token);
886
887 // There is no next op after the function terminator, simply let stack
888 // have one less element/be empty.
889 continue;
890 }
891
892 // Visit next op.
893 ops_to_visit.push_back({/*region_idx=*/llvm::None, token, next_op});
894 }
895
896 if (rewrite_block) UpdateFunctionType(builder, func, func_body);
897
898 return success();
899 }
900
901 // Checks if a function call is pointing to a function with communication ops.
IsFunctionCallWithCommunication(Operation * op,const llvm::SmallDenseMap<StringRef,FuncToRewrite> & funcs_to_rewrite)902 bool IsFunctionCallWithCommunication(
903 Operation* op,
904 const llvm::SmallDenseMap<StringRef, FuncToRewrite>& funcs_to_rewrite) {
905 if (auto call = dyn_cast<mlir::func::CallOp>(op))
906 return funcs_to_rewrite.count(call.getCallee());
907
908 return false;
909 }
910
911 // Collects all control flow op ancestors of communication ops or function calls
912 // with communication ops (transitively).
GetCommunicationControlFlowOps(func::FuncOp func,const llvm::SmallDenseMap<StringRef,FuncToRewrite> & funcs_to_rewrite,llvm::SmallPtrSetImpl<Operation * > & control_flow_ops,llvm::SmallPtrSetImpl<Block * > & control_flow_blocks)913 void GetCommunicationControlFlowOps(
914 func::FuncOp func,
915 const llvm::SmallDenseMap<StringRef, FuncToRewrite>& funcs_to_rewrite,
916 llvm::SmallPtrSetImpl<Operation*>& control_flow_ops,
917 llvm::SmallPtrSetImpl<Block*>& control_flow_blocks) {
918 func.walk([&](Operation* op) {
919 if (IsCommunicationOp(op) ||
920 IsFunctionCallWithCommunication(op, funcs_to_rewrite))
921 if (failed(GetControlFlowAncestors(op, control_flow_ops,
922 control_flow_blocks)))
923 llvm_unreachable(
924 "checking original function for control flow ancestors should have "
925 "errored first");
926 });
927 }
928
runOnOperation()929 void LegalizeTFCommunication::runOnOperation() {
930 auto module = getOperation();
931 llvm::SmallDenseMap<StringRef, FuncToRewrite> funcs_to_rewrite;
932 if (failed(GetFunctionsToRewrite(module, funcs_to_rewrite)))
933 return signalPassFailure();
934
935 // Module level counter to make sure Channel IDs are unique.
936 int64_t channel_id = 1;
937 OpBuilder builder(&getContext());
938 for (const auto& func_and_name : funcs_to_rewrite) {
939 const auto& func_to_rewrite = func_and_name.getSecond();
940 func::FuncOp func = func_to_rewrite.original;
941 if (failed(RewriteFunction(builder, channel_id, module, func,
942 funcs_to_rewrite,
943 func_to_rewrite.control_flow_ops,
944 func_to_rewrite.control_flow_blocks,
945 /*is_clone=*/false)))
946 return signalPassFailure();
947
948 func::FuncOp clone = func_and_name.getSecond().clone;
949 if (!clone) continue;
950 llvm::SmallPtrSet<Operation*, 4> clone_control_flow_ops;
951 llvm::SmallPtrSet<Block*, 4> clone_control_flow_blocks;
952 GetCommunicationControlFlowOps(clone, funcs_to_rewrite,
953 clone_control_flow_ops,
954 clone_control_flow_blocks);
955 if (failed(RewriteFunction(builder, channel_id, module, clone,
956 funcs_to_rewrite, clone_control_flow_ops,
957 clone_control_flow_blocks,
958 /*is_clone=*/true)))
959 llvm_unreachable(
960 "rewriting of original function should have errored first");
961 }
962 }
963
964 } // namespace
965
CreateLegalizeTFCommunicationPass()966 std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeTFCommunicationPass() {
967 return std::make_unique<LegalizeTFCommunication>();
968 }
969
970 } // namespace mhlo
971 } // namespace mlir
972