1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "absl/container/flat_hash_set.h"
17 #include "llvm/ADT/StringRef.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
19 #include "mlir/IR/Attributes.h" // from @llvm-project
20 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
21 #include "mlir/IR/Diagnostics.h" // from @llvm-project
22 #include "mlir/IR/MLIRContext.h" // from @llvm-project
23 #include "mlir/IR/PatternMatch.h" // from @llvm-project
24 #include "mlir/IR/Region.h" // from @llvm-project
25 #include "mlir/IR/Value.h" // from @llvm-project
26 #include "mlir/IR/Visitors.h" // from @llvm-project
27 #include "mlir/Pass/Pass.h" // from @llvm-project
28 #include "mlir/Support/LogicalResult.h" // from @llvm-project
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
30 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
33 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
34 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
35 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
36
37 namespace mlir {
38 namespace TF {
39 namespace {
40
41 // Returns true if the given op is TF/XLA communication op in the old bridge.
IsCommunicationOp(Operation * op)42 bool IsCommunicationOp(Operation* op) {
43 return isa<TF::XlaHostComputeOp, TF::XlaSendToHostOp, TF::XlaRecvFromHostOp>(
44 op);
45 }
46
47 // Returns true if the given op is one of ops supported to have communication
48 // subcomputation in the TF/XLA bridge.
SupportsCommunicationComputation(Operation * op)49 bool SupportsCommunicationComputation(Operation* op) {
50 return isa<TF::IfRegionOp, TF::WhileRegionOp, TF::CaseRegionOp,
51 TF::StatefulPartitionedCallOp, TF::PartitionedCallOp,
52 TF::LegacyCallOp>(op);
53 }
54
55 class PrepareTpuComputationForTfExportPass
56 : public PrepareTpuComputationForTfExportPassBase<
57 PrepareTpuComputationForTfExportPass> {
58 void runOnOperation() override;
59 };
60
61 class RewriteXlaHostComputeMlir
62 : public OpRewritePattern<TF::_XlaHostComputeMlirOp> {
63 public:
64 using OpRewritePattern<TF::_XlaHostComputeMlirOp>::OpRewritePattern;
65
matchAndRewrite(TF::_XlaHostComputeMlirOp op,PatternRewriter & rewriter) const66 LogicalResult matchAndRewrite(TF::_XlaHostComputeMlirOp op,
67 PatternRewriter& rewriter) const override {
68 llvm::SmallVector<Attribute> shape_attrs;
69 shape_attrs.reserve(op.getNumResults());
70 for (Type ty : op.getResultTypes()) {
71 shape_attrs.push_back(
72 TF::ShapeAttr::get(rewriter.getContext(), ty.cast<ShapedType>()));
73 }
74
75 // Clone the `host_func` in the `host_mlir_module` attribute if it exists
76 // and use it for `shape_inference_graph` attribute on XlaHostCompute.
77 func::FuncOp cloned_func;
78 SymbolTable manager(op->getParentOfType<ModuleOp>());
79 StringRef host_module = op.host_mlir_module();
80 if (!host_module.empty()) {
81 mlir::OwningOpRef<mlir::ModuleOp> module_for_func;
82
83 func::FuncOp func = op.GetHostFunc(&module_for_func);
84
85 OpBuilder::InsertionGuard guard(rewriter);
86 rewriter.setInsertionPointAfter(op->getParentOfType<func::FuncOp>());
87 cloned_func = llvm::dyn_cast_or_null<func::FuncOp>(
88 rewriter.clone(*func.getOperation()));
89 manager.insert(cloned_func);
90 rewriter.setInsertionPointToStart(&cloned_func.getBody().front());
91 auto result_type =
92 RankedTensorType::get({3}, rewriter.getType<TF::StringType>());
93 auto dynamic_key =
94 rewriter.create<TF::_TPUCompileMlirPlaceholderProgramKeyOp>(
95 func.getLoc(), /*program=*/result_type, llvm::ArrayRef<Value>{});
96
97 auto recv_at_host = rewriter.create<TF::_XlaRecvAtHostOp>(
98 func.getLoc(), op.getOperandTypes(), /*dynamic_key=*/dynamic_key,
99 op.send_keyAttr(),
100 /*device_ordinal=*/rewriter.getI64IntegerAttr(0));
101 for (auto result :
102 llvm::zip(cloned_func.getArguments(), recv_at_host->getResults())) {
103 std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
104 }
105
106 rewriter.setInsertionPoint(cloned_func.getBody().front().getTerminator());
107 rewriter.create<TF::_XlaSendFromHostOp>(
108 func.getLoc(),
109 cloned_func.getBody().front().getTerminator()->getOperands(),
110 /*dynamic_key=*/dynamic_key, op.recv_keyAttr(),
111 /*device_ordinal=*/rewriter.getI64IntegerAttr(0));
112 }
113
114 constexpr int64_t kDefaultCostEstimate = 1000000;
115 rewriter.replaceOpWithNewOp<TF::XlaHostComputeOp>(
116 op, op.getResultTypes(), op.inputs(),
117 /*ancestors=*/rewriter.getArrayAttr({}),
118 rewriter.getArrayAttr(shape_attrs),
119 /*shape_inference_graph=*/
120 cloned_func ? SymbolRefAttr::get(cloned_func) : SymbolRefAttr(),
121 /*key=*/rewriter.getStringAttr(""), op.send_keyAttr(),
122 op.recv_keyAttr(),
123 /*cost_estimate_ns=*/rewriter.getI64IntegerAttr(kDefaultCostEstimate),
124 /*tpu_core=*/rewriter.getI64IntegerAttr(0));
125 return success();
126 }
127 };
128
UpdateArgAttributes(mlir::func::FuncOp func)129 void UpdateArgAttributes(mlir::func::FuncOp func) {
130 OpBuilder builder(func.getBody());
131 for (int i = 0; i < func.getNumArguments(); ++i) {
132 constexpr char kShardingAttr[] = "mhlo.sharding";
133 if (auto sharding =
134 func.getArgAttrOfType<mlir::StringAttr>(i, kShardingAttr)) {
135 if (!sharding.getValue().empty()) {
136 BlockArgument arg = func.getArgument(i);
137 // TODO(hinsu): Instead of setting both 'sharding' and '_XlaSharding'
138 // attributes, only set the 'sharding' attribute. Both attributes are
139 // currently required as the XlaSharding xla op kernel doesn't use the
140 // 'sharding' attribute.
141 auto updated_arg = builder.create<TF::XlaShardingOp>(
142 func.getLoc(), arg.getType(), arg, sharding, sharding);
143 func.getArgument(i).replaceAllUsesExcept(
144 updated_arg, llvm::SmallPtrSet<Operation*, 1>({updated_arg}));
145 }
146
147 func.removeArgAttr(i, builder.getStringAttr(kShardingAttr));
148 }
149 }
150 }
151
RewriteCommunicationOps(ModuleOp module)152 LogicalResult RewriteCommunicationOps(ModuleOp module) {
153 MLIRContext* ctx = module.getContext();
154 mlir::RewritePatternSet patterns(ctx);
155 patterns.add<RewriteXlaHostComputeMlir>(ctx);
156 if (failed(mlir::applyPatternsAndFoldGreedily(module, std::move(patterns)))) {
157 return module.emitError("failed to apply tf export preparation patterns");
158 }
159
160 // TODO(hinsu): Investigate if the semantics of keys for these communication
161 // ops between the old bridge and new bridge can be reconciled.
162 module.walk([&](Operation* op) {
163 if (isa<TF::XlaSendToHostOp>(op)) {
164 StringRef old_key = op->getAttrOfType<StringAttr>("key").getValue();
165 auto new_key = StringAttr::get(ctx, old_key.str() + "_dtoh_0");
166 op->setAttr("key", new_key);
167 } else if (isa<TF::XlaRecvFromHostOp>(op)) {
168 StringRef old_key = op->getAttrOfType<StringAttr>("key").getValue();
169 auto new_key = StringAttr::get(ctx, old_key.str() + "_htod_0");
170 op->setAttr("key", new_key);
171 }
172 });
173 return success();
174 }
175
176 // Sets token input node names attribute and their corresponding original node
177 // names for tf/xla communication related ops. These attributes are used to
178 // order operations on device. First op in the region should have a special
179 // argument token and then remaining operations should have node name of the
180 // previous communication ops.
SetTokenInputAttrs(ModuleOp module)181 LogicalResult SetTokenInputAttrs(ModuleOp module) {
182 // Collect all the ops that needs to have token input names attributes. These
183 // ops are communication ops and all their parent ops via nesting or function
184 // calls. For example, IfRegion op and PartitionedCall op.
185 std::vector<Operation*> worklist;
186 absl::flat_hash_set<Operation*> ops_with_tokens;
187 module.walk([&](Operation* op) {
188 if (IsCommunicationOp(op)) {
189 ops_with_tokens.insert(op);
190 worklist.push_back(op);
191 }
192 });
193
194 SymbolTableCollection table;
195 SymbolUserMap symbol_map(table, module);
196
197 // Regions that contains ops requiring token input attributes.
198 absl::flat_hash_set<Region*> regions_with_token;
199 while (!worklist.empty()) {
200 Operation* op = worklist.back();
201 worklist.pop_back();
202
203 Region* region = op->getParentRegion();
204 regions_with_token.insert(region);
205
206 // If the parent is not a FuncOp, then add the parent op containing a region
207 // to worklist.
208 Operation* parent = region->getParentOp();
209 if (!isa<func::FuncOp>(parent)) {
210 if (ops_with_tokens.insert(parent).second) {
211 worklist.push_back(parent);
212 }
213 continue;
214 }
215
216 // For functions, get all the users and add them to the worklist.
217 for (auto& user : symbol_map.getUsers(parent)) {
218 if (ops_with_tokens.insert(user).second) {
219 worklist.push_back(user);
220 }
221 }
222 }
223
224 // Use name mapper to uniquely name all ops in the module as export to
225 // TensorFlow graph may change node names. These op names here doesn't need to
226 // match the actual names in the graph as this sets original node name
227 // attribute for all the relevant nodes.
228 tensorflow::OpOrArgLocNameMapper name_mapper;
229 MLIRContext* ctx = module.getContext();
230 for (Region* region : regions_with_token) {
231 // Initialize the token with the special argument token. This gets mapped to
232 // input token in the parent op or a new token for the entry computation.
233 auto token = StringAttr::get(ctx, tensorflow::kXlaTokenArgNodeName);
234 for (Operation& op : region->getOps()) {
235 // Only communication related ops that needs to have token should have the
236 // extra attribute.
237 if (!ops_with_tokens.contains(&op)) continue;
238
239 if (!IsCommunicationOp(&op) && !SupportsCommunicationComputation(&op)) {
240 return op.emitOpError(
241 "does not support subcomputations with tf/xla communication ops");
242 }
243
244 op.setAttr(tensorflow::kXlaTokenInputNodesAttrName,
245 ArrayAttr::get(ctx, {token}));
246
247 auto node_name = StringAttr::get(ctx, name_mapper.GetUniqueName(&op));
248 op.setAttr(tensorflow::kXlaOriginalOutsideCompilationNodeName, node_name);
249 token = node_name;
250 }
251 }
252 return success();
253 }
254
runOnOperation()255 void PrepareTpuComputationForTfExportPass::runOnOperation() {
256 ModuleOp module = getOperation();
257
258 for (func::FuncOp func : module.getOps<func::FuncOp>()) {
259 UpdateArgAttributes(func);
260 }
261
262 // First rewrite communication ops used in the new bridge to match old bridge
263 // semantics and then set token input node names attributes on the supported
264 // ops.
265 if (failed(RewriteCommunicationOps(module)) ||
266 failed(SetTokenInputAttrs(module))) {
267 signalPassFailure();
268 return;
269 }
270 }
271
272 } // namespace
273
274 std::unique_ptr<OperationPass<ModuleOp>>
CreatePrepareTpuComputationForTfExportPass()275 CreatePrepareTpuComputationForTfExportPass() {
276 return std::make_unique<PrepareTpuComputationForTfExportPass>();
277 }
278
279 } // namespace TF
280 } // namespace mlir
281