xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering TensorFlow dialect's communication
17 // ops (TF/XLA) to the HLO dialect.
18 
19 #include <memory>
20 #include <string>
21 
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/None.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/Support/ErrorHandling.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
29 #include "mlir/IR/Attributes.h"  // from @llvm-project
30 #include "mlir/IR/Builders.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
33 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
34 #include "mlir/IR/Value.h"  // from @llvm-project
35 #include "mlir/IR/Visitors.h"  // from @llvm-project
36 #include "mlir/Pass/Pass.h"  // from @llvm-project
37 #include "mlir/Support/LLVM.h"  // from @llvm-project
38 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40 #include "tensorflow/compiler/mlir/xla/transforms/tf_xla_passes_detail.h"
41 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
42 #include "tensorflow/compiler/xla/client/sharding_builder.h"
43 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
44 #include "tensorflow/compiler/xla/primitive_util.h"
45 #include "tensorflow/compiler/xla/side_effect_util.h"
46 
47 namespace mlir {
48 
49 using func::FuncOp;
50 
51 namespace mhlo {
52 
53 namespace {
54 constexpr char kShardingAttr[] = "mhlo.sharding";
55 constexpr char kFrontendAttributesAttr[] = "mhlo.frontend_attributes";
56 // TPU core that sends to and receives from host.
57 constexpr int64_t kShardingTpuCore = 0;
58 
59 // A pass that legalizes TF/XLA communication ops, propagate their respective
60 // tokens (for ordering), and rewrite their respective functions and control
61 // flow ops when necessary.
62 // Note, this currently does not handle nested modules/functions or region based
63 // ops other than certain control flow ops (`mhlo.if`, `mhlo.while`).
64 class LegalizeTFCommunication
65     : public LegalizeTFCommunicationPassBase<LegalizeTFCommunication> {
66   void runOnOperation() override;
67 };
68 
69 // Checks if an op is a TF/XLA communication op.
IsCommunicationOp(Operation * op)70 bool IsCommunicationOp(Operation* op) {
71   return isa<TF::_XlaHostComputeMlirOp, TF::XlaSendToHostOp,
72              TF::XlaRecvFromHostOp>(op);
73 }
74 
75 // Checks if an op is a supported HLO control flow op.
IsControlFlowOp(Operation * op)76 bool IsControlFlowOp(Operation* op) { return isa<IfOp, WhileOp>(op); }
77 
78 // Collects control flow op ancestors of a given op, up until FuncOp. If any
79 // ancestor is not a control flow op or a FuncOp, or of a single block region,
80 // an error will be returned.
GetControlFlowAncestors(Operation * op,llvm::SmallPtrSetImpl<Operation * > & control_flow_ops,llvm::SmallPtrSetImpl<Block * > & control_flow_blocks)81 LogicalResult GetControlFlowAncestors(
82     Operation* op, llvm::SmallPtrSetImpl<Operation*>& control_flow_ops,
83     llvm::SmallPtrSetImpl<Block*>& control_flow_blocks) {
84   Block* block = op->getBlock();
85   Operation* parent = block->getParentOp();
86   while (block && parent && !isa<func::FuncOp>(parent)) {
87     if (!IsControlFlowOp(parent))
88       return op->emitOpError()
89              << "expects ancestor(s) to be of ['" << IfOp::getOperationName()
90              << "', '" << func::FuncOp::getOperationName() << "']";
91 
92     if (!llvm::hasSingleElement(block->getParent()->getBlocks()))
93       return op->emitOpError() << "expects single block region ancestor(s)";
94 
95     control_flow_ops.insert(parent);
96     control_flow_blocks.insert(block);
97 
98     parent = block->getParentOp();
99     block = parent->getBlock();
100   }
101   return success();
102 }
103 
104 // Finds communication ops in a function. `control_flow_ops` and
105 // `control_flow_blocks` will be populated with control flow op ancestors for
106 // every communication op.
FindCommunicationOps(func::FuncOp func,llvm::SmallPtrSetImpl<Operation * > & control_flow_ops,llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,bool & has_communication_ops)107 LogicalResult FindCommunicationOps(
108     func::FuncOp func, llvm::SmallPtrSetImpl<Operation*>& control_flow_ops,
109     llvm::SmallPtrSetImpl<Block*>& control_flow_blocks,
110     bool& has_communication_ops) {
111   auto result = func.walk([&](Operation* op) {
112     if (!IsCommunicationOp(op)) return WalkResult::advance();
113     has_communication_ops = true;
114     if (failed(
115             GetControlFlowAncestors(op, control_flow_ops, control_flow_blocks)))
116       return WalkResult::interrupt();
117     return WalkResult::advance();
118   });
119   return failure(result.wasInterrupted());
120 }
121 
122 // Helper struct holding a function to be rewritten, it's control flow ops that
123 // lead to a communication op or function call with a communication op
124 // (transitively), and an optional clone of itself. If `clone` is set, function
125 // calls to `original` will be replaced with `clone`.
126 struct FuncToRewrite {
127   func::FuncOp original;
128   llvm::SmallPtrSet<Operation*, 4> control_flow_ops;
129   llvm::SmallPtrSet<Block*, 4> control_flow_blocks;
130   func::FuncOp clone;
131 };
132 
133 // Finds all functions that need to be rewritten with communication ops and
134 // and associated tokens.
GetFunctionsToRewrite(ModuleOp module,llvm::SmallDenseMap<StringRef,FuncToRewrite> & funcs_to_rewrite)135 LogicalResult GetFunctionsToRewrite(
136     ModuleOp module,
137     llvm::SmallDenseMap<StringRef, FuncToRewrite>& funcs_to_rewrite) {
138   // Find functions containing communication ops.
139   SmallVector<func::FuncOp, 4> funcs_to_visit;
140   for (func::FuncOp func : module.getOps<func::FuncOp>()) {
141     FuncToRewrite func_to_rewrite{/*original=*/func, /*control_flow_ops=*/{},
142                                   /*control_flow_blocks=*/{},
143                                   /*clone=*/nullptr};
144     bool has_communication_ops = false;
145     if (failed(FindCommunicationOps(func, func_to_rewrite.control_flow_ops,
146                                     func_to_rewrite.control_flow_blocks,
147                                     has_communication_ops)))
148       return failure();
149 
150     if (!has_communication_ops) continue;
151     funcs_to_rewrite.insert({func.getName(), func_to_rewrite});
152     funcs_to_visit.push_back(func);
153   }
154 
155   // Find functions that call functions with communication ops, transitively.
156   while (!funcs_to_visit.empty()) {
157     SmallVector<func::FuncOp, 4> new_funcs_to_visit;
158     for (func::FuncOp& func : funcs_to_visit) {
159       auto uses = func.getSymbolUses(module);
160       if (!uses) continue;
161       for (auto& use : *uses) {
162         // Only `mlir::func::CallOp` is supported as this requires knowing how
163         // to rewrite arguments and results to a function.
164         if (!isa<mlir::func::CallOp>(use.getUser())) continue;
165         auto caller_parent_func =
166             use.getUser()->getParentOfType<func::FuncOp>();
167         if (!caller_parent_func) continue;
168 
169         FuncToRewrite func_to_rewrite{/*original=*/caller_parent_func,
170                                       /*control_flow_ops=*/{},
171                                       /*control_flow_blocks=*/{},
172                                       /*clone=*/nullptr};
173         if (failed(GetControlFlowAncestors(
174                 use.getUser(), func_to_rewrite.control_flow_ops,
175                 func_to_rewrite.control_flow_blocks)))
176           return failure();
177 
178         auto it = funcs_to_rewrite.insert(
179             {caller_parent_func.getName(), func_to_rewrite});
180         if (it.second) {
181           new_funcs_to_visit.push_back(caller_parent_func);
182         } else {
183           it.first->getSecond().control_flow_ops.insert(
184               func_to_rewrite.control_flow_ops.begin(),
185               func_to_rewrite.control_flow_ops.end());
186           it.first->getSecond().control_flow_blocks.insert(
187               func_to_rewrite.control_flow_blocks.begin(),
188               func_to_rewrite.control_flow_blocks.end());
189         }
190       }
191     }
192 
193     funcs_to_visit.swap(new_funcs_to_visit);
194   }
195 
196   // Clone public functions that need to be rewritten. Function calls to this
197   // function will be replaced with the cloned function.
198   SymbolTable symbol_table(module);
199   for (auto& func : funcs_to_rewrite) {
200     if (func.getSecond().original.isPublic() &&
201         !func.getSecond().original.symbolKnownUseEmpty(module)) {
202       auto clone = func.getSecond().original.clone();
203       clone.setPrivate();
204       symbol_table.insert(clone);
205       func.getSecond().clone = clone;
206     }
207   }
208 
209   return success();
210 }
211 
212 // Assigns op sharding to full tensor on `kShardingTpuCore`.
SetOpSharding(Operation * op)213 void SetOpSharding(Operation* op) {
214   std::string sharding_serialized =
215       ::xla::sharding_builder::AssignDevice(kShardingTpuCore)
216           .SerializeAsString();
217   op->setAttr(kShardingAttr,
218               StringAttr::get(op->getContext(), sharding_serialized));
219 }
220 
221 // Assigns frontend attributes holding information about data type and
222 // TensorFlow rendezvous channel name. The TensorFlow rendezvous channel name is
223 // handled differently as individual names are used per data send and receive.
SetFrontendAttributes(Operation * op,int32_t index,StringRef key,Type type,bool device_to_host,StringRef host_handler_name)224 void SetFrontendAttributes(Operation* op, int32_t index, StringRef key,
225                            Type type, bool device_to_host,
226                            StringRef host_handler_name) {
227   MLIRContext* context = op->getContext();
228 
229   std::string formatted_key =
230       device_to_host ? llvm::formatv("{0}_dtoh_{1}", key, index).str()
231                      : llvm::formatv("{0}_htod_{1}", key, index).str();
232 
233   auto rendezvous_name = StringAttr::get(context, formatted_key);
234   auto rendezvous_name_attr = NamedAttribute(
235       StringAttr::get(context, xla::kXlaHostTransferRendezvousNameAttr),
236       rendezvous_name);
237 
238   auto element_type = getElementTypeOrSelf(type);
239   auto xla_element_type = ::xla::TypeToPrimitiveType(element_type);
240   const std::string& xla_element_type_str =
241       ::xla::primitive_util::LowercasePrimitiveTypeName(xla_element_type);
242   auto original_type = StringAttr::get(context, xla_element_type_str);
243   auto original_type_attr = NamedAttribute(
244       StringAttr::get(context, xla::kXlaHostTransferOriginalTypeAttr),
245       original_type);
246 
247   auto host_handler_name_value =
248       StringAttr::get(context, host_handler_name.str());
249   auto host_handler_name_attr = NamedAttribute(
250       StringAttr::get(context, xla::kXlaHostTransferHandlerNameAttr),
251       host_handler_name_value);
252 
253   auto frontend_attributes = DictionaryAttr::get(
254       context,
255       ArrayRef<NamedAttribute>{rendezvous_name_attr, original_type_attr,
256                                host_handler_name_attr});
257   op->setAttr(kFrontendAttributesAttr, frontend_attributes);
258 }
259 
260 // Creates a `mhlo.send` op for sending value `operand`.
CreateSendOp(OpBuilder & builder,int64_t & channel_id,Location loc,Value operand,StringRef key,size_t index,Value token,StringRef host_handler_name)261 Value CreateSendOp(OpBuilder& builder, int64_t& channel_id, Location loc,
262                    Value operand, StringRef key, size_t index, Value token,
263                    StringRef host_handler_name) {
264   // type 2 == DEVICE_TO_HOST
265   auto channel_handle = ChannelHandleAttr::get(builder.getContext(),
266                                                /*handle=*/channel_id++,
267                                                /*type=*/2);
268   auto send = builder.create<SendOp>(
269       loc, token.getType(), operand, token, channel_handle,
270       /*is_host_transfer=*/builder.getBoolAttr(true));
271 
272   SetFrontendAttributes(send, index, key, operand.getType(),
273                         /*device_to_host=*/true, host_handler_name);
274 
275   SetOpSharding(send);
276 
277   return send.getResult();
278 }
279 
280 // Creates a `mhlo.recv` op for receiving a value.
CreateRecvOp(OpBuilder & builder,int64_t & channel_id,Location loc,Value result,StringRef key,size_t index,Value token,StringRef host_handler_name)281 Value CreateRecvOp(OpBuilder& builder, int64_t& channel_id, Location loc,
282                    Value result, StringRef key, size_t index, Value token,
283                    StringRef host_handler_name) {
284   // type 3 == HOST_TO_DEVICE
285   auto channel_handle = ChannelHandleAttr::get(builder.getContext(),
286                                                /*handle=*/channel_id++,
287                                                /*type=*/3);
288   auto result_type = result.getType();
289   SmallVector<Type, 2> recv_result_type = {result_type, token.getType()};
290   auto recv =
291       builder.create<RecvOp>(loc, recv_result_type, token, channel_handle,
292                              /*is_host_transfer=*/builder.getBoolAttr(true));
293 
294   SetFrontendAttributes(recv, index, key, result_type,
295                         /*device_to_host=*/false, host_handler_name);
296 
297   SetOpSharding(recv);
298 
299   result.replaceAllUsesWith(recv.getResult(0));
300 
301   return recv.getResult(1);
302 }
303 
304 // Creates a new token if necessary, acting as a sink to previous tokens. If
305 // there is only one token in `tokens`, the only token is returned. If `tokens`
306 // is empty, `original_token` is returned instead.
CreateSinkToken(OpBuilder & builder,Location loc,ArrayRef<Value> tokens,Value original_token)307 Value CreateSinkToken(OpBuilder& builder, Location loc, ArrayRef<Value> tokens,
308                       Value original_token) {
309   if (tokens.empty()) {
310     return original_token;
311   } else if (llvm::hasSingleElement(tokens)) {
312     return tokens[0];
313   } else {
314     return builder.create<AfterAllOp>(loc, original_token.getType(), tokens)
315         .getResult();
316   }
317 }
318 
319 // Replaces `tf._XlaHostComputeMlir` with individual `mhlo.send` and `mhlo.recv`
320 // ops per operand and result. Unique Channel IDs are assigned per transfer.
321 // Sink tokens are created across all `mhlo.send` ops first and then by
322 // all `mhlo.recv` ops.
RewriteHostComputeOp(OpBuilder & builder,int64_t & channel_id,TF::_XlaHostComputeMlirOp host_compute,Value token)323 Value RewriteHostComputeOp(OpBuilder& builder, int64_t& channel_id,
324                            TF::_XlaHostComputeMlirOp host_compute,
325                            Value token) {
326   builder.setInsertionPoint(host_compute);
327   Location loc = host_compute.getLoc();
328 
329   SmallVector<Value, 4> send_tokens;
330   for (auto operand : llvm::enumerate(host_compute.inputs())) {
331     auto send_token = CreateSendOp(
332         builder, channel_id, loc, operand.value(), host_compute.send_key(),
333         operand.index(), token, xla::kXlaHostTransferTfRendezvousHandlerName);
334     send_tokens.push_back(send_token);
335   }
336   token = CreateSinkToken(builder, loc, send_tokens, token);
337 
338   SmallVector<Value, 4> recv_tokens;
339   for (auto result : llvm::enumerate(host_compute.outputs())) {
340     auto recv_token = CreateRecvOp(
341         builder, channel_id, loc, result.value(), host_compute.recv_key(),
342         result.index(), token, xla::kXlaHostTransferTfRendezvousHandlerName);
343     recv_tokens.push_back(recv_token);
344   }
345   token = CreateSinkToken(builder, loc, recv_tokens, token);
346 
347   host_compute.erase();
348   return token;
349 }
350 
351 // Replaces `tf.XlaSendToHost` with a `mhlo.send`.
RewriteSendToHostOp(OpBuilder & builder,int64_t & channel_id,TF::XlaSendToHostOp send_to_host,Value token)352 Value RewriteSendToHostOp(OpBuilder& builder, int64_t& channel_id,
353                           TF::XlaSendToHostOp send_to_host, Value token) {
354   builder.setInsertionPoint(send_to_host);
355   token = CreateSendOp(builder, channel_id, send_to_host.getLoc(),
356                        send_to_host.input(), send_to_host.key(),
357                        /*index=*/0, token,
358                        xla::kXlaHostTransferTfRendezvousHandlerName);
359 
360   send_to_host.erase();
361   return token;
362 }
363 
364 // Replaces `tf.XlaRecvFromHost` with a `mhlo.recv`.
RewriteRecvFromHostOp(OpBuilder & builder,int64_t & channel_id,TF::XlaRecvFromHostOp recv_from_host,Value token)365 Value RewriteRecvFromHostOp(OpBuilder& builder, int64_t& channel_id,
366                             TF::XlaRecvFromHostOp recv_from_host, Value token) {
367   builder.setInsertionPoint(recv_from_host);
368   token = CreateRecvOp(builder, channel_id, recv_from_host.getLoc(),
369                        recv_from_host.output(), recv_from_host.key(),
370                        /*index=*/0, token,
371                        xla::kXlaHostTransferTfRendezvousHandlerName);
372 
373   recv_from_host.erase();
374   return token;
375 }
376 
377 // Replaces a `mlir::func::CallOp` with one that has an extra `!mhlo.token`
378 // operand and `!mhlo.token` result. If `new_symbol` is set, the new call will
379 // be updated to call the `new_symbol` instead.
RewriteCallOp(OpBuilder & builder,func::CallOp call,const Optional<StringRef> & new_symbol,Value token)380 Value RewriteCallOp(OpBuilder& builder, func::CallOp call,
381                     const Optional<StringRef>& new_symbol, Value token) {
382   builder.setInsertionPoint(call);
383   auto new_operands = llvm::to_vector(call.getArgOperands());
384   new_operands.push_back(token);
385   auto new_result_types = llvm::to_vector(call.getResultTypes());
386   new_result_types.push_back(token.getType());
387   auto new_call = builder.create<func::CallOp>(
388       call.getLoc(), new_result_types,
389       new_symbol ? *new_symbol : call.getCallee(), new_operands);
390 
391   for (auto results : llvm::zip(call.getResults(), new_call.getResults()))
392     std::get<0>(results).replaceAllUsesWith(std::get<1>(results));
393   call.erase();
394   return new_call.getResults().back();
395 }
396 
397 // Helper struct holding state of which op to visit to next. If `op` is in a
398 // control flow op region, `region_idx` will be set with the respective region
399 // index. `token` will be current token from the last communication op/control
400 // flow op transitive communication ops.
401 struct OpVisitorState {
402   Optional<unsigned> region_idx;
403   Value token;
404   Operation* op;
405 };
406 
407 // Creates a tuple from a sequence of values.
CreateTuple(OpBuilder & builder,Location loc,ArrayRef<Value> operands)408 Value CreateTuple(OpBuilder& builder, Location loc, ArrayRef<Value> operands) {
409   return builder.create<TupleOp>(loc, operands).getResult();
410 }
411 
412 // Extends `values` with the value `token` attached. If `flatten_tuple` is
413 // false, `values` will have a single element, say `value`. If `value` is not a
414 // tuple, a new tuple is formed with `token`. If `values` is a tuple, it is
415 // extended instead. New tuple values created are cached.
GetValueWithToken(OpBuilder & builder,ArrayRef<Value> values,Value token,llvm::SmallDenseMap<Value,Value> & rewritten_values,bool flatten_tuple)416 SmallVector<Value> GetValueWithToken(
417     OpBuilder& builder, ArrayRef<Value> values, Value token,
418     llvm::SmallDenseMap<Value, Value>& rewritten_values, bool flatten_tuple) {
419   if (flatten_tuple) {
420     auto operands = llvm::to_vector(values);
421     operands.push_back(token);
422     return operands;
423   }
424 
425   auto value = values[0];
426   // If value with token already exists, reuse it.
427   auto it = rewritten_values.find(value);
428   if (it != rewritten_values.end()) return {it->getSecond()};
429 
430   auto create_tuple = [&](ArrayRef<Value> operands) {
431     auto new_result = CreateTuple(builder, value.getLoc(), operands);
432     rewritten_values.insert({value, new_result});
433     return new_result;
434   };
435 
436   auto tuple_type = value.getType().dyn_cast<TupleType>();
437   // `value` is not a tuple, create a new tuple.
438   if (!tuple_type) return {create_tuple({value, token})};
439 
440   // Extend tuple if `value` is a tuple.
441   // If `value` is an op result and the owner is a `mhlo.tuple`, simply unpack
442   // the tuple.
443   if (auto tuple_op = value.getDefiningOp<TupleOp>()) {
444     auto tuple_operands = llvm::to_vector(tuple_op.getOperands());
445     tuple_operands.push_back(token);
446     return {create_tuple(tuple_operands)};
447   }
448 
449   // `value` is not created via a `mhlo.tuple` directly, unpack individual
450   // elements directly with `mhlo.get_tuple_element`.
451   SmallVector<Value, 4> tuple_operands;
452   for (auto idx : llvm::seq<int32_t>(0, tuple_type.getTypes().size()))
453     tuple_operands.push_back(
454         builder.create<GetTupleElementOp>(value.getLoc(), value, idx)
455             .getResult());
456 
457   tuple_operands.push_back(token);
458   return {create_tuple(tuple_operands)};
459 }
460 
461 // Extends the 'types' to include a `mhlo.token` type. If `flatten_tuple` is
462 // false, `types` will have a single element, say `type`. If `type` is not a
463 // tuple type, a new tuple type with `type` and `mhlo.token` type is created
464 // instead.
GetTypeWithToken(OpBuilder & builder,ArrayRef<Type> types,bool flatten_tuple)465 SmallVector<Type> GetTypeWithToken(OpBuilder& builder, ArrayRef<Type> types,
466                                    bool flatten_tuple) {
467   SmallVector<Type> new_result_types;
468   auto token_type = TokenType::get(builder.getContext());
469 
470   if (flatten_tuple) {
471     auto result_types = llvm::to_vector(types);
472     result_types.push_back(token_type);
473     return result_types;
474   }
475 
476   auto type = types[0];
477   if (auto tuple_type = type.dyn_cast<TupleType>()) {
478     auto result_types = llvm::to_vector(tuple_type.getTypes());
479     result_types.push_back(token_type);
480     return {builder.getTupleType(result_types)};
481   }
482 
483   return {builder.getTupleType({type, token_type})};
484 }
485 
486 // Creates a slice of a tuple `value` with `mhlo.get_tuple_element` from index 0
487 // to `end`, exclusive.
CreateSubTuple(OpBuilder & builder,Value value,size_t end)488 Value CreateSubTuple(OpBuilder& builder, Value value, size_t end) {
489   SmallVector<Value, 4> tuple_operands;
490   for (auto idx : llvm::seq<int32_t>(0, end))
491     tuple_operands.push_back(
492         builder.create<GetTupleElementOp>(value.getLoc(), value, idx)
493             .getResult());
494 
495   return CreateTuple(builder, value.getLoc(), tuple_operands);
496 }
497 
498 // Replaces uses of `values` with `replacements`. If `flatten_tuple` is false,
499 // `values` will have a single element, say `value`. If `value` is not a tuple
500 // type, an explicit `mhlo.get_tuple_element` is created to unpack the tuple and
501 // return the first element. Otherwise, `mhlo.get_tuple_element` users are
502 // simply updated with `replacement`, and all other users are updated with a
503 // slice of `replacement`.
ReplaceWithTupleResult(OpBuilder & builder,ArrayRef<Value> values,ArrayRef<Value> replacements,bool flatten_tuple)504 void ReplaceWithTupleResult(OpBuilder& builder, ArrayRef<Value> values,
505                             ArrayRef<Value> replacements, bool flatten_tuple) {
506   if (flatten_tuple) {
507     for (size_t result_index = 0; result_index < values.size(); result_index++)
508       values[result_index].replaceAllUsesWith(replacements[result_index]);
509     return;
510   }
511 
512   auto value = values[0];
513   auto replacement = replacements[0];
514   auto tuple_type = value.getType().dyn_cast<TupleType>();
515   if (!tuple_type) {
516     if (!value.use_empty()) {
517       auto new_element = builder.create<GetTupleElementOp>(replacement.getLoc(),
518                                                            replacement, 0);
519       value.replaceAllUsesWith(new_element.getResult());
520     }
521     return;
522   }
523 
524   Value sub_tuple;
525   for (auto& use : llvm::make_early_inc_range(value.getUses())) {
526     if (isa<GetTupleElementOp>(use.getOwner())) {
527       use.set(replacement);
528       continue;
529     }
530 
531     if (!sub_tuple)
532       sub_tuple = CreateSubTuple(builder, replacement, tuple_type.size());
533 
534     use.set(sub_tuple);
535   }
536 }
537 
538 // Replaces control flow op block arguments with new block arguments
539 // of types `types`. The last element of the new block argument (token) is
540 // returned.
UpdateControlFlowBlockArgWithToken(OpBuilder & builder,Block & block,ArrayRef<Type> types)541 Value UpdateControlFlowBlockArgWithToken(OpBuilder& builder, Block& block,
542                                          ArrayRef<Type> types) {
543   builder.setInsertionPointToStart(&block);
544 
545   auto old_args_size = block.getNumArguments();
546 
547   block.addArguments(
548       types, SmallVector<Location>(types.size(), block.getParent()->getLoc()));
549 
550   auto old_args = ArrayRef<Value>(block.getArguments().begin(),
551                                   block.getArguments().begin() + old_args_size);
552   auto new_args = ArrayRef<Value>(block.getArguments().begin() + old_args_size,
553                                   block.getArguments().end());
554   assert(!new_args.empty());
555 
556   ReplaceWithTupleResult(builder, old_args, new_args, /*flatten_tuple=*/true);
557   auto new_arg = new_args[new_args.size() - 1];
558 
559   block.eraseArguments(
560       llvm::to_vector(llvm::seq((unsigned)0, (unsigned)old_args_size)));
561 
562   return new_arg;
563 }
564 
565 // Updates control flow op terminator with an extra element `token`.
RewriteControlFlowTerminator(OpBuilder & builder,Operation * terminator,Value token,bool flatten_tuple)566 void RewriteControlFlowTerminator(OpBuilder& builder, Operation* terminator,
567                                   Value token, bool flatten_tuple) {
568   assert(flatten_tuple || terminator->getNumOperands() == 1);
569   assert(flatten_tuple || terminator->getBlock()->getNumArguments() == 1);
570   // `mhlo.while` cond terminator does not need to be rewritten as it always
571   // returns a tensor<i1> predicate value.
572   if (auto while_parent = dyn_cast_or_null<WhileOp>(terminator->getParentOp()))
573     if (terminator->getParentRegion() == &while_parent.cond()) return;
574 
575   builder.setInsertionPoint(terminator);
576   llvm::SmallDenseMap<Value, Value> rewritten_operands;
577   auto new_results =
578       GetValueWithToken(builder, llvm::to_vector(terminator->getOperands()),
579                         token, rewritten_operands, flatten_tuple);
580   terminator->setOperands(new_results);
581 }
582 
583 // Rewrites a `mhlo.if` op to receive and forward a `mhlo.token`. As If op does
584 // not have any operands other than the predicate, hence we implicitly capture
585 // the parent token. Also we use the same implicit token for use in the If op's
586 // regions.
RewriteRegionIfOp(OpBuilder & builder,IfOp region_if,SmallVectorImpl<OpVisitorState> & ops_to_visit,Value token)587 void RewriteRegionIfOp(OpBuilder& builder, IfOp region_if,
588                        SmallVectorImpl<OpVisitorState>& ops_to_visit,
589                        Value token) {
590   llvm::SmallDenseMap<Value, Value> rewritten_operands;
591 
592   auto new_result_types =
593       GetTypeWithToken(builder, llvm::to_vector(region_if.getResultTypes()),
594                        /*flatten_tuple=*/true);
595 
596   // Create new `mhlo.if` op with extra token operands and result.
597   auto new_if = builder.create<IfOp>(region_if.getLoc(), new_result_types,
598                                      region_if.pred());
599 
600   // Move all regions from the old `mhlo.if` op to its replacement.
601   new_if.true_branch().takeBody(region_if.true_branch());
602   new_if.false_branch().takeBody(region_if.false_branch());
603 
604   // Forward result from old `mhlo.if` with replacement.
605   SmallVector<Value> old_if_results = region_if.getResults();
606   SmallVector<Value> new_if_results = new_if.getResults();
607 
608   ReplaceWithTupleResult(builder, old_if_results, new_if_results,
609                          /*flatten_tuple=*/true);
610 
611   // auto new_token = new_if_results[new_if_results.size() - 1];
612 
613   region_if.erase();
614 
615   // Next op to visit. The replacement is visited but at its first region.
616   // The new region use the same implicit token used by the If op.
617   ops_to_visit.push_back({/*region_idx=*/0, token, new_if});
618 }
619 
620 // Rewrites a `mhlo.if`/`mhlo.while` region to receive and forward a
621 // `mhlo.token`. The block argument is updated to have an extra `mhlo.token`
622 // element. If the region block is to be rewritten, the next op to visit is set
623 // to the first op in the block. Otherwise the terminator is updated to forward
624 // `token`.
RewriteControlFlowOpRegion(OpBuilder & builder,Operation * region_op,unsigned region_idx,ArrayRef<Type> block_arg_types,SmallVectorImpl<OpVisitorState> & ops_to_visit,const llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,Value token)625 void RewriteControlFlowOpRegion(
626     OpBuilder& builder, Operation* region_op, unsigned region_idx,
627     ArrayRef<Type> block_arg_types,
628     SmallVectorImpl<OpVisitorState>& ops_to_visit,
629     const llvm::SmallPtrSetImpl<Block*>& control_flow_blocks, Value token) {
630   ops_to_visit.push_back({region_idx + 1, token, region_op});
631 
632   Region& region = region_op->getRegion(region_idx);
633   assert(llvm::hasSingleElement(region));
634 
635   auto block_token = UpdateControlFlowBlockArgWithToken(builder, region.front(),
636                                                         block_arg_types);
637 
638   if (control_flow_blocks.contains(&region.front())) {
639       ops_to_visit.push_back(
640           {/*region_idx=*/llvm::None, block_token, &region.front().front()});
641     return;
642   }
643 
644   RewriteControlFlowTerminator(builder, region.front().getTerminator(),
645                                block_token, /*flatten_tuple=*/true);
646 }
647 
648 // For mlir::IfOp or mlir::CaseOp, replace the use of their region's block
649 // argument (of type token) with 'implicit_operand'.
ReplaceBlockArgumentsWithImplicitOperands(mlir::Operation * op,unsigned region_idx,Value implicit_operand)650 void ReplaceBlockArgumentsWithImplicitOperands(mlir::Operation* op,
651                                                unsigned region_idx,
652                                                Value implicit_operand) {
653   assert((mlir::dyn_cast<mlir::mhlo::IfOp>(*op) ||
654           mlir::dyn_cast<mlir::mhlo::CaseOp>(*op)) &&
655          "Unexpected mlir op in "
656          "HloFunctionImporter::ReplaceBlockArgumentsWithImplicitOperands!");
657 
658   auto& region = op->getRegion(region_idx);
659   region.getArgument(0).replaceAllUsesWith(implicit_operand);
660   region.front().eraseArguments(
661       llvm::to_vector(llvm::seq<unsigned>(0, region.getNumArguments())));
662 }
663 
664 // Rewrites an `mhlo.if` op or its region. If `region_idx` is not set, the op
665 // operands and results are rewritten. If `region_idx` is set, region
666 // `region_idx` is rewritten to take in and return an additional token. Returns
667 // true if the op or its region was rewritten.
ProcessRegionIfOp(OpBuilder & builder,IfOp region_if,Optional<unsigned> region_idx,SmallVectorImpl<OpVisitorState> & ops_to_visit,const llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,Value token)668 bool ProcessRegionIfOp(OpBuilder& builder, IfOp region_if,
669                        Optional<unsigned> region_idx,
670                        SmallVectorImpl<OpVisitorState>& ops_to_visit,
671                        const llvm::SmallPtrSetImpl<Block*>& control_flow_blocks,
672                        Value token) {
673   builder.setInsertionPoint(region_if);
674 
675   if (!region_idx) {
676     RewriteRegionIfOp(builder, region_if, ops_to_visit, token);
677     return true;
678   }
679 
680   if (*region_idx < region_if.getNumRegions()) {
681     // For the region-blocks of If op, we create a dummy token argument. Later
682     // we replace that block-argument's uses with the same (implicitly captured)
683     // token 'token', used for If op, and erase the argument.
684     // Note that 'RewriteControlFlowOpRegion' sets the token, used for the first
685     // operation of region_idx'th region, to the dummy block-argument. As we
686     // erase that argument, we also need to make sure that the token used for
687     // the next operation is set to 'token'.
688     RewriteControlFlowOpRegion(builder, region_if, *region_idx,
689                                {token.getType()}, ops_to_visit,
690                                control_flow_blocks, token);
691 
692     ReplaceBlockArgumentsWithImplicitOperands(region_if.getOperation(),
693                                               *region_idx, token);
694 
695     auto next_visitor_state = ops_to_visit.back();
696     next_visitor_state.token = token;
697     ops_to_visit.pop_back();
698     ops_to_visit.push_back(next_visitor_state);
699     return true;
700   }
701 
702   return false;
703 }
704 
705 // Rewrites a `mhlo.while` op to receive and forward a `mhlo.token`. Operands to
706 // the op for all of its regions are extended to have an extra operand `token`.
RewriteRegionWhileOp(OpBuilder & builder,WhileOp region_while,SmallVectorImpl<OpVisitorState> & ops_to_visit,Value token)707 void RewriteRegionWhileOp(OpBuilder& builder, WhileOp region_while,
708                           SmallVectorImpl<OpVisitorState>& ops_to_visit,
709                           Value token) {
710   llvm::SmallDenseMap<Value, Value> rewritten_operands;
711 
712   // Rewrite region operand to have an extra operand `token`.
713   auto new_val_operands =
714       GetValueWithToken(builder, llvm::to_vector(region_while.getOperands()),
715                         token, rewritten_operands,
716                         /*flatten_tuple=*/true);
717 
718   auto new_result_types =
719       GetTypeWithToken(builder, llvm::to_vector(region_while.getResultTypes()),
720                        /*flatten_tuple*/ true);
721 
722   // Create new `mhlo.while` op with extra token operand and result.
723   auto new_while = builder.create<WhileOp>(region_while.getLoc(),
724                                            new_result_types, new_val_operands);
725 
726   // Move all regions from the old `mhlo.while` op to its replacement.
727   new_while.cond().takeBody(region_while.cond());
728   new_while.body().takeBody(region_while.body());
729 
730   // Forward result from old `mhlo.while` with replacement.
731   SmallVector<Value> old_while_results = region_while.getResults();
732   SmallVector<Value> new_while_results = new_while.getResults();
733 
734   ReplaceWithTupleResult(builder, old_while_results, new_while_results,
735                          /*flatten_tuple*/ true);
736 
737   auto new_token = new_while_results[new_while_results.size() - 1];
738 
739   region_while.erase();
740 
741   // Next op to visit. The replacement is visited but at its first region. The
742   // token result of the new region if is propagated.
743   ops_to_visit.push_back({/*region_idx=*/0, new_token, new_while});
744 }
745 
746 // Rewrites an `mhlo.while` op or its region. If `region_idx` is not set, the op
747 // operands and results are rewritten. If `region_idx` is set, region
748 // `region_idx` is rewritten to take in and return an additional token. Returns
749 // true if the op or its region was rewritten.
ProcessRegionWhileOp(OpBuilder & builder,WhileOp region_while,Optional<unsigned> region_idx,SmallVectorImpl<OpVisitorState> & ops_to_visit,const llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,Value token)750 bool ProcessRegionWhileOp(
751     OpBuilder& builder, WhileOp region_while, Optional<unsigned> region_idx,
752     SmallVectorImpl<OpVisitorState>& ops_to_visit,
753     const llvm::SmallPtrSetImpl<Block*>& control_flow_blocks, Value token) {
754   builder.setInsertionPoint(region_while);
755 
756   if (!region_idx) {
757     RewriteRegionWhileOp(builder, region_while, ops_to_visit, token);
758     return true;
759   }
760 
761   if (*region_idx < region_while.getNumRegions()) {
762     SmallVector<Type> operand_types;
763     for (auto operand : region_while.operand())
764       operand_types.push_back(operand.getType());
765     RewriteControlFlowOpRegion(builder, region_while, *region_idx,
766                                operand_types, ops_to_visit, control_flow_blocks,
767                                token);
768     return true;
769   }
770 
771   return false;
772 }
773 
774 // Updates function type based on current function body block arguments and
775 // terminator operand types.
UpdateFunctionType(OpBuilder & builder,func::FuncOp func,Block & func_body)776 void UpdateFunctionType(OpBuilder& builder, func::FuncOp func,
777                         Block& func_body) {
778   auto new_argument_types = llvm::to_vector(func_body.getArgumentTypes());
779   auto new_result_types =
780       llvm::to_vector(func_body.getTerminator()->getOperandTypes());
781   func.setType(FunctionType::get(builder.getContext(), new_argument_types,
782                                  new_result_types));
783 }
784 
785 // Replaces a function terminator `return` with another `return` that has an
786 // extra `mhlo.token` operand.
RewriteFunctionTerminator(OpBuilder & builder,mlir::func::ReturnOp terminator,Value token)787 void RewriteFunctionTerminator(OpBuilder& builder,
788                                mlir::func::ReturnOp terminator, Value token) {
789   auto new_results = llvm::to_vector(terminator.getOperands());
790   new_results.push_back(token);
791   builder.setInsertionPoint(terminator);
792   builder.create<mlir::func::ReturnOp>(terminator.getLoc(), new_results);
793   terminator.erase();
794 }
795 
796 // Rewrites a function body and communication ops inside. Region control flow
797 // are updated when necessary, to propagate tokens. The function may either be
798 // rewritten to create a token or take in and return a token, depending on its
799 // visibility and if there are any callers.
RewriteFunction(OpBuilder & builder,int64_t & channel_id,ModuleOp module,FuncOp func,const llvm::SmallDenseMap<StringRef,FuncToRewrite> & funcs,const llvm::SmallPtrSetImpl<Operation * > & control_flow_ops,const llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,bool is_clone)800 LogicalResult RewriteFunction(
801     OpBuilder& builder, int64_t& channel_id, ModuleOp module, FuncOp func,
802     const llvm::SmallDenseMap<StringRef, FuncToRewrite>& funcs,
803     const llvm::SmallPtrSetImpl<Operation*>& control_flow_ops,
804     const llvm::SmallPtrSetImpl<Block*>& control_flow_blocks, bool is_clone) {
805   MLIRContext* context = module.getContext();
806   if (!llvm::hasSingleElement(func.getBody()))
807     return func.emitError()
808            << "'" << FuncOp::getOperationName()
809            << "' ops with more than one block are not supported";
810 
811   bool rewrite_block =
812       is_clone || (!func.isPublic() && !func.symbolKnownUseEmpty(module));
813   Block& func_body = func.front();
814 
815   builder.setInsertionPointToStart(&func_body);
816   auto token_type = TokenType::get(context);
817   // If a function is public, it's signature should not be modified, and instead
818   // a token will be created. Otherwise a token block argument is inserted.
819   Value init_token =
820       rewrite_block ? func_body.addArgument(token_type, func.getLoc())
821                     : builder.create<CreateTokenOp>(func.getLoc(), token_type)
822                           .getResult();
823 
824   // Stack to keep track of region based control flow op nesting and current
825   // op to visit.
826   SmallVector<OpVisitorState, 4> ops_to_visit{
827       {/*region_idx=*/llvm::None, init_token, &func_body.front()}};
828 
829   while (!ops_to_visit.empty()) {
830     OpVisitorState op_to_visit = ops_to_visit.pop_back_val();
831     Operation* curr_op = op_to_visit.op;
832 
833     Value token = op_to_visit.token;
834     // Ops may be removed, so the next op is kept track of beforehand.
835     Operation* next_op = curr_op->getNextNode();
836 
837     if (auto host_compute = dyn_cast<TF::_XlaHostComputeMlirOp>(curr_op)) {
838       token = RewriteHostComputeOp(builder, channel_id, host_compute, token);
839     } else if (auto send_to_host = dyn_cast<TF::XlaSendToHostOp>(curr_op)) {
840       token = RewriteSendToHostOp(builder, channel_id, send_to_host, token);
841     } else if (auto recv_from_host = dyn_cast<TF::XlaRecvFromHostOp>(curr_op)) {
842       token = RewriteRecvFromHostOp(builder, channel_id, recv_from_host, token);
843     } else if (auto call = dyn_cast<mlir::func::CallOp>(curr_op)) {
844       // Only `mlir::func::CallOp` is supported as this requires knowing how to
845       // rewrite arguments and results to a function.
846       auto it = funcs.find(call.getCallee());
847       if (it != funcs.end()) {
848         func::FuncOp clone = it->getSecond().clone;
849         Optional<StringRef> symbol_name =
850             clone ? Optional<StringRef>(clone.getName()) : llvm::None;
851         // If the function being called is to be cloned, update the call to also
852         // point to the cloned function.
853         token = RewriteCallOp(builder, call, symbol_name, token);
854       }
855     } else if (auto region_if = dyn_cast<IfOp>(curr_op)) {
856       if (op_to_visit.region_idx || control_flow_ops.contains(region_if)) {
857         auto exist_unprocessed_region =
858             ProcessRegionIfOp(builder, region_if, op_to_visit.region_idx,
859                               ops_to_visit, control_flow_blocks, token);
860 
861         // Once all the IfOp regions are processed (i.e.
862         // 'exist_unprocessed_region' == false), select returned token-value
863         // from IfOp as the token to be used for the following op.
864         if (!exist_unprocessed_region) {
865           token = curr_op->getResult(curr_op->getNumResults() - 1);
866         } else {
867           continue;
868         }
869       }
870     } else if (auto region_while = dyn_cast<WhileOp>(curr_op)) {
871       if (op_to_visit.region_idx || control_flow_ops.contains(region_while))
872         if (ProcessRegionWhileOp(builder, region_while, op_to_visit.region_idx,
873                                  ops_to_visit, control_flow_blocks, token))
874           continue;
875     } else if (auto region_terminator = dyn_cast<mhlo::ReturnOp>(curr_op)) {
876       bool flatten_tuple = isa<mhlo::WhileOp, mhlo::IfOp, mhlo::CaseOp>(
877           region_terminator->getParentOp());
878       RewriteControlFlowTerminator(builder, region_terminator, token,
879                                    flatten_tuple);
880       // There is no next op after the control flow op terminator, simply let
881       // stack have one less element.
882       continue;
883     } else if (auto func_terminator = dyn_cast<mlir::func::ReturnOp>(curr_op)) {
884       if (rewrite_block)
885         RewriteFunctionTerminator(builder, func_terminator, token);
886 
887       // There is no next op after the function terminator, simply let stack
888       // have one less element/be empty.
889       continue;
890     }
891 
892     // Visit next op.
893     ops_to_visit.push_back({/*region_idx=*/llvm::None, token, next_op});
894   }
895 
896   if (rewrite_block) UpdateFunctionType(builder, func, func_body);
897 
898   return success();
899 }
900 
901 // Checks if a function call is pointing to a function with communication ops.
IsFunctionCallWithCommunication(Operation * op,const llvm::SmallDenseMap<StringRef,FuncToRewrite> & funcs_to_rewrite)902 bool IsFunctionCallWithCommunication(
903     Operation* op,
904     const llvm::SmallDenseMap<StringRef, FuncToRewrite>& funcs_to_rewrite) {
905   if (auto call = dyn_cast<mlir::func::CallOp>(op))
906     return funcs_to_rewrite.count(call.getCallee());
907 
908   return false;
909 }
910 
911 // Collects all control flow op ancestors of communication ops or function calls
912 // with communication ops (transitively).
GetCommunicationControlFlowOps(func::FuncOp func,const llvm::SmallDenseMap<StringRef,FuncToRewrite> & funcs_to_rewrite,llvm::SmallPtrSetImpl<Operation * > & control_flow_ops,llvm::SmallPtrSetImpl<Block * > & control_flow_blocks)913 void GetCommunicationControlFlowOps(
914     func::FuncOp func,
915     const llvm::SmallDenseMap<StringRef, FuncToRewrite>& funcs_to_rewrite,
916     llvm::SmallPtrSetImpl<Operation*>& control_flow_ops,
917     llvm::SmallPtrSetImpl<Block*>& control_flow_blocks) {
918   func.walk([&](Operation* op) {
919     if (IsCommunicationOp(op) ||
920         IsFunctionCallWithCommunication(op, funcs_to_rewrite))
921       if (failed(GetControlFlowAncestors(op, control_flow_ops,
922                                          control_flow_blocks)))
923         llvm_unreachable(
924             "checking original function for control flow ancestors should have "
925             "errored first");
926   });
927 }
928 
runOnOperation()929 void LegalizeTFCommunication::runOnOperation() {
930   auto module = getOperation();
931   llvm::SmallDenseMap<StringRef, FuncToRewrite> funcs_to_rewrite;
932   if (failed(GetFunctionsToRewrite(module, funcs_to_rewrite)))
933     return signalPassFailure();
934 
935   // Module level counter to make sure Channel IDs are unique.
936   int64_t channel_id = 1;
937   OpBuilder builder(&getContext());
938   for (const auto& func_and_name : funcs_to_rewrite) {
939     const auto& func_to_rewrite = func_and_name.getSecond();
940     func::FuncOp func = func_to_rewrite.original;
941     if (failed(RewriteFunction(builder, channel_id, module, func,
942                                funcs_to_rewrite,
943                                func_to_rewrite.control_flow_ops,
944                                func_to_rewrite.control_flow_blocks,
945                                /*is_clone=*/false)))
946       return signalPassFailure();
947 
948     func::FuncOp clone = func_and_name.getSecond().clone;
949     if (!clone) continue;
950     llvm::SmallPtrSet<Operation*, 4> clone_control_flow_ops;
951     llvm::SmallPtrSet<Block*, 4> clone_control_flow_blocks;
952     GetCommunicationControlFlowOps(clone, funcs_to_rewrite,
953                                    clone_control_flow_ops,
954                                    clone_control_flow_blocks);
955     if (failed(RewriteFunction(builder, channel_id, module, clone,
956                                funcs_to_rewrite, clone_control_flow_ops,
957                                clone_control_flow_blocks,
958                                /*is_clone=*/true)))
959       llvm_unreachable(
960           "rewriting of original function should have errored first");
961   }
962 }
963 
964 }  // namespace
965 
CreateLegalizeTFCommunicationPass()966 std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeTFCommunicationPass() {
967   return std::make_unique<LegalizeTFCommunication>();
968 }
969 
970 }  // namespace mhlo
971 }  // namespace mlir
972