1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/None.h"
20 #include "llvm/ADT/Optional.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/Support/Casting.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
27 #include "mlir/IR/Attributes.h" // from @llvm-project
28 #include "mlir/IR/Builders.h" // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
31 #include "mlir/IR/Location.h" // from @llvm-project
32 #include "mlir/IR/MLIRContext.h" // from @llvm-project
33 #include "mlir/IR/SymbolTable.h" // from @llvm-project
34 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
35 #include "mlir/IR/Types.h" // from @llvm-project
36 #include "mlir/IR/Value.h" // from @llvm-project
37 #include "mlir/Pass/Pass.h" // from @llvm-project
38 #include "mlir/Support/LLVM.h" // from @llvm-project
39 #include "mlir/Support/LogicalResult.h" // from @llvm-project
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
42 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
43 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
44 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
47 #include "tensorflow/core/framework/tensor.h"
48 #include "tensorflow/core/framework/tensor_shape.pb.h"
49 #include "tensorflow/core/framework/types.pb.h"
50 #include "tensorflow/core/platform/types.h"
51
52 namespace mlir {
53
54 namespace {
55
56 namespace cutil = TF::collection_ops_util;
57
58 struct StackOpsDecompositionPass
59 : public TF::StackOpsDecompositionPassBase<StackOpsDecompositionPass> {
60 void runOnOperation() final;
61 };
62
63 // Returns the type of the local variable for the stack size.
GetSizeVarType(OpBuilder builder)64 Type GetSizeVarType(OpBuilder builder) {
65 auto size_type = cutil::GetSizeType(builder);
66 return RankedTensorType::get(
67 {}, TF::ResourceType::get(ArrayRef<TensorType>{size_type},
68 builder.getContext()));
69 }
70
71 // Returns the aliasing argument number of a fucntion return value if it simply
72 // forwards the argument. Otherwise, returns -1.
FindAliasedInput(func::FuncOp func,int64_t return_index)73 int64_t FindAliasedInput(func::FuncOp func, int64_t return_index) {
74 Value return_val = func.front().getTerminator()->getOperand(return_index);
75 auto maybe_arg = return_val.dyn_cast<BlockArgument>();
76 if (!maybe_arg) return -1;
77 return maybe_arg.getArgNumber();
78 }
79
80 // Changes the function signature that has stacks in the arguments. A stack
81 // argument will be turned into a variable type if arg_to_stack_type returns
82 // such a type, and a new argument will be added to the end of the argument
83 // list for the size variable.
84 //
85 // If stack_var_to_size_var is not nullptr, it will be used to store the
86 // mapping from the stack-variable argument to the size-variable argument.
87 //
88 // If handle_new_size_vars is provided, it will be invoked on the list of new
89 // size variables before finally changing the function type.
ModifyFunctionSignature(func::FuncOp func,llvm::SmallDenseMap<Value,Value> * stack_var_to_size_var,llvm::function_ref<llvm::Optional<Type> (int64_t)> arg_to_stack_type,llvm::function_ref<void (ArrayRef<BlockArgument>)> handle_new_size_vars=nullptr)90 void ModifyFunctionSignature(
91 func::FuncOp func, llvm::SmallDenseMap<Value, Value>* stack_var_to_size_var,
92 llvm::function_ref<llvm::Optional<Type>(int64_t)> arg_to_stack_type,
93 llvm::function_ref<void(ArrayRef<BlockArgument>)> handle_new_size_vars =
94 nullptr) {
95 auto new_input_types = llvm::to_vector<8>(func.getFunctionType().getInputs());
96 auto size_var_type = GetSizeVarType(OpBuilder(func));
97 int64_t original_arg_count = new_input_types.size();
98 for (int64_t i = 0; i < original_arg_count; ++i) {
99 auto stack_type = arg_to_stack_type(i);
100 if (!stack_type.has_value()) continue;
101 func.getArgument(i).setType(*stack_type);
102 new_input_types[i] = *stack_type;
103 auto size_arg = func.front().addArgument(size_var_type, func.getLoc());
104 new_input_types.push_back(size_arg.getType());
105 if (stack_var_to_size_var) {
106 (*stack_var_to_size_var)[func.getArgument(i)] = size_arg;
107 }
108 }
109 if (handle_new_size_vars) {
110 handle_new_size_vars(func.getArguments().drop_front(original_arg_count));
111 }
112 func.setType(
113 FunctionType::get(func.getContext(), new_input_types,
114 func.front().getTerminator()->getOperandTypes()));
115 }
116
117 // Contains cached information for decomposed callee functions for (stateful)
118 // partitioned call ops.
119 struct PartitionedCallStackOpsInfo {
120 bool signature_change;
121 func::FuncOp decomposed_callee;
122 llvm::SmallDenseMap<int64_t, int64_t> stack_var_arg_to_size_arg;
123 };
124
125 LogicalResult DecomposeStackOpsInternal(
126 Block*, ModuleOp, llvm::SmallDenseMap<Value, Value>*,
127 llvm::StringMap<PartitionedCallStackOpsInfo>*);
128
129 // Handles stack usage by a tf.While. It will convert the body and conditional
130 // function signatures, and performs stack ops decomposition on them.
HandleWhileOp(TF::WhileOp while_op,ModuleOp module,const llvm::SmallDenseMap<Value,Value> & data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)131 LogicalResult HandleWhileOp(
132 TF::WhileOp while_op, ModuleOp module,
133 const llvm::SmallDenseMap<Value, Value>& data_var_to_size_var,
134 llvm::StringMap<PartitionedCallStackOpsInfo>*
135 decomposed_partitioned_call_callees) {
136 auto body = while_op.body_function();
137 llvm::SmallDenseMap<Value, Value> body_map;
138 auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional<Type> {
139 auto it = data_var_to_size_var.find(while_op.getOperand(index));
140 if (it == data_var_to_size_var.end()) return llvm::None;
141 return it->getFirst().getType();
142 };
143 auto add_size_vars_to_return = [&](ArrayRef<BlockArgument> new_args) {
144 if (new_args.empty()) return;
145 auto body_ret = body.front().getTerminator();
146 auto new_body_returns = llvm::to_vector<8>(body_ret->getOperands());
147 for (auto arg : new_args) new_body_returns.push_back(arg);
148 OpBuilder(body_ret).create<func::ReturnOp>(body_ret->getLoc(),
149 new_body_returns);
150 body_ret->erase();
151 };
152 // Handle body.
153 ModifyFunctionSignature(body, &body_map, find_arg_stack_type,
154 add_size_vars_to_return);
155 const bool signature_change = !body_map.empty();
156 if (failed(DecomposeStackOpsInternal(&body.front(), module, &body_map,
157 decomposed_partitioned_call_callees))) {
158 return failure();
159 }
160 // Cond should not change stacks in the arguments, so use an empty map.
161 auto cond = while_op.cond_function();
162 ModifyFunctionSignature(cond, nullptr, find_arg_stack_type);
163 llvm::SmallDenseMap<Value, Value> empty_map;
164 if (failed(DecomposeStackOpsInternal(&cond.front(), module, &empty_map,
165 decomposed_partitioned_call_callees))) {
166 return failure();
167 }
168 if (!signature_change) return success();
169 // Create the new while op.
170 auto new_while_operands = llvm::to_vector<8>(while_op.getOperands());
171 OpBuilder builder(while_op);
172 assert(while_op.getNumOperands() == while_op.getNumResults());
173 for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
174 auto it = data_var_to_size_var.find(while_op.getOperand(i));
175 if (it == data_var_to_size_var.end()) continue;
176 new_while_operands.push_back(it->getSecond());
177 }
178 auto new_while = builder.create<TF::WhileOp>(
179 while_op.getLoc(), body.getFunctionType().getInputs(), new_while_operands,
180 while_op->getAttrs());
181 for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
182 if (!getElementTypeOrSelf(while_op.getOperand(i).getType())
183 .isa<TF::ResourceType>()) {
184 continue;
185 }
186 int64_t aliased_input = FindAliasedInput(body, i);
187 if (aliased_input == i) {
188 // Replace aliased stack output uses with input.
189 while_op.getResult(i).replaceAllUsesWith(while_op.getOperand(i));
190 }
191 }
192 while_op.replaceAllUsesWith(
193 new_while.getResults().take_front(while_op.getNumResults()));
194 while_op.erase();
195 return success();
196 }
197
198 // Handles stack usage by a tf.If. It will convert the branch function
199 // signatures, and performs stack ops decomposition on them.
HandleIfOp(TF::IfOp if_op,ModuleOp module,const llvm::SmallDenseMap<Value,Value> & data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)200 LogicalResult HandleIfOp(
201 TF::IfOp if_op, ModuleOp module,
202 const llvm::SmallDenseMap<Value, Value>& data_var_to_size_var,
203 llvm::StringMap<PartitionedCallStackOpsInfo>*
204 decomposed_partitioned_call_callees) {
205 auto then_func = if_op.then_function();
206 auto else_func = if_op.else_function();
207 llvm::SmallDenseMap<Value, Value> then_map;
208 llvm::SmallDenseMap<Value, Value> else_map;
209
210 auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional<Type> {
211 auto it = data_var_to_size_var.find(if_op.getOperand(index + 1));
212 if (it == data_var_to_size_var.end()) return llvm::None;
213 return it->getFirst().getType();
214 };
215 ModifyFunctionSignature(then_func, &then_map, find_arg_stack_type);
216 ModifyFunctionSignature(else_func, &else_map, find_arg_stack_type);
217 const bool signature_change = !then_map.empty() || !else_map.empty();
218 if (failed(DecomposeStackOpsInternal(&then_func.front(), module, &then_map,
219 decomposed_partitioned_call_callees)) ||
220 failed(DecomposeStackOpsInternal(&else_func.front(), module, &else_map,
221 decomposed_partitioned_call_callees))) {
222 return failure();
223 }
224 if (!signature_change) return success();
225 auto new_if_operands = llvm::to_vector<8>(if_op.getOperands());
226 for (auto operand : if_op.getOperands()) {
227 auto it = data_var_to_size_var.find(operand);
228 if (it == data_var_to_size_var.end()) continue;
229 new_if_operands.push_back(it->getSecond());
230 }
231 auto new_if = OpBuilder(if_op).create<TF::IfOp>(
232 if_op.getLoc(), then_func.getFunctionType().getResults(), new_if_operands,
233 if_op->getAttrs());
234 for (auto result : if_op.getResults()) {
235 if (!getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) {
236 continue;
237 }
238 int64_t then_aliased_input =
239 FindAliasedInput(then_func, result.getResultNumber());
240 int64_t else_aliased_input =
241 FindAliasedInput(else_func, result.getResultNumber());
242 if (then_aliased_input >= 0 && then_aliased_input == else_aliased_input) {
243 // Replace aliased stack output uses with input.
244 result.replaceAllUsesWith(if_op.getOperand(then_aliased_input + 1));
245 }
246 }
247 if_op.replaceAllUsesWith(new_if);
248 if_op.erase();
249 return success();
250 }
251
252 // Handles stack usage by a tf.StatefulPartitionedCall or a tf.PartitionedCall.
253 // It will first check if the callee was previously handled, and try to reuse
254 // that result if so. Otherwise, it will clone and convert the callee function,
255 // and performs stack ops decomposition on it.
256 template <typename CallOp>
HandlePartitionedCallOp(CallOp call,func::FuncOp callee,ModuleOp module,const llvm::SmallDenseMap<Value,Value> & data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)257 LogicalResult HandlePartitionedCallOp(
258 CallOp call, func::FuncOp callee, ModuleOp module,
259 const llvm::SmallDenseMap<Value, Value>& data_var_to_size_var,
260 llvm::StringMap<PartitionedCallStackOpsInfo>*
261 decomposed_partitioned_call_callees) {
262 auto emplace_res = decomposed_partitioned_call_callees->try_emplace(
263 callee.getName(), PartitionedCallStackOpsInfo());
264 auto& info = emplace_res.first->second;
265 // Recreate the call op with info.
266 auto recreate_caller = [&] {
267 auto new_operands = llvm::to_vector<8>(call.getOperands());
268 for (int64_t i = 0; i < call.getNumOperands(); ++i) {
269 auto arg_it = info.stack_var_arg_to_size_arg.find(i);
270 if (arg_it == info.stack_var_arg_to_size_arg.end()) continue;
271 auto it = data_var_to_size_var.find(call.getOperand(i));
272 if (it == data_var_to_size_var.end()) {
273 call.emitOpError("unknown stack");
274 return failure();
275 }
276 assert(arg_it->second == new_operands.size());
277 new_operands.push_back(it->getSecond());
278 }
279 OpBuilder builder(call);
280 auto new_call = builder.create<CallOp>(
281 call.getLoc(), info.decomposed_callee.getFunctionType().getResults(),
282 new_operands, call->getAttrs());
283 new_call->setAttr(
284 "f", SymbolRefAttr::get(
285 builder.getContext(),
286 const_cast<func::FuncOp&>(info.decomposed_callee).getName()));
287 for (int64_t i = 0; i < call.getNumResults(); ++i) {
288 auto result = call.getResult(i);
289 if (!getElementTypeOrSelf(result.getType())
290 .template isa<TF::ResourceType>()) {
291 continue;
292 }
293 int64_t aliased_input = FindAliasedInput(info.decomposed_callee, i);
294 if (aliased_input >= 0) {
295 // Replace aliased stack output uses with input.
296 result.replaceAllUsesWith(call.getOperand(aliased_input));
297 }
298 }
299 call.replaceAllUsesWith(new_call);
300 call.erase();
301 return success();
302 };
303 if (!emplace_res.second) {
304 // This callee was handled before.
305 if (!info.signature_change) return success();
306 return recreate_caller();
307 }
308 llvm::SmallDenseMap<Value, Value> callee_map;
309 func::FuncOp lowered_callee = callee;
310 if (!callee.isPrivate()) {
311 // Clone non-private callee in case of signature change.
312 lowered_callee = callee.clone();
313 lowered_callee.setPrivate();
314 }
315 auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional<Type> {
316 auto it = data_var_to_size_var.find(call.getOperand(index));
317 if (it == data_var_to_size_var.end()) return llvm::None;
318 return it->getFirst().getType();
319 };
320 ModifyFunctionSignature(lowered_callee, &callee_map, find_arg_stack_type);
321 info.signature_change = !callee_map.empty();
322 if (!info.signature_change) {
323 // Signature is not modified. We do not need the clone.
324 if (lowered_callee != callee) {
325 lowered_callee.erase();
326 }
327 } else {
328 info.decomposed_callee = lowered_callee;
329 for (auto& entry : callee_map) {
330 info.stack_var_arg_to_size_arg
331 [entry.getFirst().cast<BlockArgument>().getArgNumber()] =
332 entry.getSecond().cast<BlockArgument>().getArgNumber();
333 }
334 if (lowered_callee != callee) {
335 // Add the clone with a new name.
336 lowered_callee.setName(StringAttr::get(
337 callee->getContext(),
338 llvm::formatv("{0}_stack_decomposed", callee.getName()).str()));
339 SymbolTable(module).insert(lowered_callee);
340 callee = lowered_callee;
341 }
342 }
343 if (failed(DecomposeStackOpsInternal(&callee.front(), module, &callee_map,
344 decomposed_partitioned_call_callees))) {
345 return failure();
346 }
347 if (info.signature_change) return recreate_caller();
348 return success();
349 }
350
HandleStackV2Op(TF::StackV2Op stack,ModuleOp module,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var)351 LogicalResult HandleStackV2Op(
352 TF::StackV2Op stack, ModuleOp module,
353 llvm::SmallDenseMap<Value, Value>* data_var_to_size_var) {
354 // Create a buffer variable and a size variable to replace the stack.
355 auto elem_type = cutil::GetElementTypeFromAccess(
356 stack.handle(), module, [](Operation* user) -> llvm::Optional<Type> {
357 auto push = llvm::dyn_cast<TF::StackPushV2Op>(user);
358 if (!push) return llvm::None;
359 return push.elem().getType();
360 });
361 if (!elem_type.has_value()) {
362 return stack.emitOpError("cannot infer element shape of stack");
363 }
364 OpBuilder builder(stack);
365 Value buffer;
366 if (failed(cutil::CreateInitBufferValue(
367 elem_type->getShape(), stack.max_size(), stack,
368 elem_type->getElementType(), builder, &buffer))) {
369 return failure();
370 }
371 auto size_var_type = GetSizeVarType(builder);
372 auto var_type = RankedTensorType::get(
373 {}, TF::ResourceType::get(
374 ArrayRef<TensorType>{buffer.getType().cast<TensorType>()},
375 stack.getContext()));
376 auto local_var = builder.create<TF::MlirLocalVarOp>(
377 stack.getLoc(), ArrayRef<Type>{var_type}, ArrayRef<Value>{});
378 auto local_size_var = builder.create<TF::MlirLocalVarOp>(
379 stack.getLoc(), ArrayRef<Type>{size_var_type}, ArrayRef<Value>{});
380 // Zero-initialize the local vars.
381 cutil::WriteLocalVariable(local_size_var,
382 cutil::GetR1Const({0LL}, builder, stack.getLoc()),
383 builder, stack.getLoc());
384 cutil::WriteLocalVariable(local_var, buffer, builder, stack.getLoc());
385 stack.handle().replaceAllUsesWith(local_var);
386 (*data_var_to_size_var)[local_var] = local_size_var;
387 stack.erase();
388 return success();
389 }
390
HandleStackPushV2Op(TF::StackPushV2Op push,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var)391 LogicalResult HandleStackPushV2Op(
392 TF::StackPushV2Op push,
393 llvm::SmallDenseMap<Value, Value>* data_var_to_size_var) {
394 auto it = data_var_to_size_var->find(push.handle());
395 if (it == data_var_to_size_var->end()) {
396 return push.emitOpError("unknown stack");
397 }
398 // Push output simply forward the input element.
399 push.replaceAllUsesWith(push.elem());
400 OpBuilder builder(push);
401 // Read the current buffer and size.
402 auto stack_val =
403 cutil::ReadLocalVariable(push.handle(), builder, push.getLoc());
404 auto index =
405 cutil::ReadLocalVariable(it->getSecond(), builder, push.getLoc());
406 stack_val =
407 cutil::SetElement(index, stack_val, push.elem(), builder, push.getLoc());
408 // Assign the new buffer and size.
409 cutil::WriteLocalVariable(push.handle(), stack_val, builder, push.getLoc());
410 index = builder.create<TF::AddV2Op>(
411 push.getLoc(), ArrayRef<Type>{index.getType()},
412 ArrayRef<Value>{index, cutil::GetR1Const({1}, builder, push.getLoc())});
413 cutil::WriteLocalVariable(it->getSecond(), index, builder, push.getLoc());
414 push.erase();
415 return success();
416 }
417
HandleStackPopV2Op(TF::StackPopV2Op pop,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var)418 LogicalResult HandleStackPopV2Op(
419 TF::StackPopV2Op pop,
420 llvm::SmallDenseMap<Value, Value>* data_var_to_size_var) {
421 auto it = data_var_to_size_var->find(pop.handle());
422 if (it == data_var_to_size_var->end()) {
423 return pop.emitOpError("unknown stack");
424 }
425 OpBuilder builder(pop);
426 // Read the current buffer and size.
427 auto stack_val =
428 cutil::ReadLocalVariable(pop.handle(), builder, pop.getLoc());
429 auto size = cutil::ReadLocalVariable(it->getSecond(), builder, pop.getLoc());
430 auto new_size = builder.create<TF::SubOp>(
431 pop.getLoc(), ArrayRef<Type>{size.getType()},
432 ArrayRef<Value>{size, cutil::GetR1Const({1}, builder, pop.getLoc())});
433 auto pop_val = cutil::GetElement(new_size, stack_val, builder, pop.getLoc());
434 pop.replaceAllUsesWith(pop_val);
435 // Update the size.
436 cutil::WriteLocalVariable(it->getSecond(), new_size, builder, pop.getLoc());
437 pop.erase();
438 return success();
439 }
440
HandleRegionControlFlowOps(Operation & op,ModuleOp module,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)441 LogicalResult HandleRegionControlFlowOps(
442 Operation& op, ModuleOp module,
443 llvm::SmallDenseMap<Value, Value>* data_var_to_size_var,
444 llvm::StringMap<PartitionedCallStackOpsInfo>*
445 decomposed_partitioned_call_callees) {
446 for (OpOperand& operand : op.getOpOperands()) {
447 if (getElementTypeOrSelf(operand.get().getType()).isa<TF::ResourceType>()) {
448 return op.emitOpError()
449 << "found unexpected type " << operand.get().getType()
450 << " of operand #" << operand.getOperandNumber()
451 << ", resource type operands are expected to have been "
452 "canonicalized away for region based control flow ops";
453 }
454 }
455 for (OpResult result : op.getResults()) {
456 if (getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) {
457 return op.emitOpError()
458 << "found unexpected type " << result.getType() << " of result #"
459 << result.getResultNumber()
460 << ", resource type results are expected to have been "
461 "canonicalized away for region based control flow ops";
462 }
463 }
464 for (Region& region : op.getRegions()) {
465 if (failed(DecomposeStackOpsInternal(®ion.front(), module,
466 data_var_to_size_var,
467 decomposed_partitioned_call_callees)))
468 return failure();
469 }
470 return success();
471 }
472
473 // Decomposes stack ops on a region and recursively decomposes called functions.
474 // data_var_to_size_var: a mapping from stacks' buffer local variables to size
475 // local variables.
476 // decomposed_partitioned_call_callees: cache for partitioned call ops' callee
477 // function handling.
DecomposeStackOpsInternal(Block * block,ModuleOp module,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)478 LogicalResult DecomposeStackOpsInternal(
479 Block* block, ModuleOp module,
480 llvm::SmallDenseMap<Value, Value>* data_var_to_size_var,
481 llvm::StringMap<PartitionedCallStackOpsInfo>*
482 decomposed_partitioned_call_callees) {
483 for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
484 if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
485 // Removes identity nodes in the block. The device computation does not
486 // need such nodes to carry information.
487 op.replaceAllUsesWith(op.getOperands());
488 op.erase();
489 } else if (auto stack = llvm::dyn_cast<TF::StackV2Op>(&op)) {
490 if (failed(HandleStackV2Op(stack, module, data_var_to_size_var))) {
491 return failure();
492 }
493 } else if (auto push = llvm::dyn_cast<TF::StackPushV2Op>(&op)) {
494 if (failed(HandleStackPushV2Op(push, data_var_to_size_var))) {
495 return failure();
496 }
497 } else if (auto pop = llvm::dyn_cast<TF::StackPopV2Op>(&op)) {
498 if (failed(HandleStackPopV2Op(pop, data_var_to_size_var))) {
499 return failure();
500 }
501 } else if (auto close = llvm::dyn_cast<TF::StackCloseV2Op>(&op)) {
502 data_var_to_size_var->erase(close.handle());
503 close.erase();
504 } else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
505 if (failed(HandleWhileOp(while_op, module, *data_var_to_size_var,
506 decomposed_partitioned_call_callees))) {
507 return failure();
508 }
509 } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
510 if (failed(HandleIfOp(if_op, module, *data_var_to_size_var,
511 decomposed_partitioned_call_callees))) {
512 return failure();
513 }
514 } else if (llvm::isa<TF::WhileRegionOp>(op) ||
515 llvm::isa<TF::IfRegionOp>(op) ||
516 llvm::isa<TF::CaseRegionOp>(op)) {
517 if (failed(
518 HandleRegionControlFlowOps(op, module, data_var_to_size_var,
519 decomposed_partitioned_call_callees)))
520 return failure();
521 } else if (auto pcall = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
522 if (!pcall.func()) {
523 return pcall.emitOpError(
524 "stack decomposition does not support call with nested references");
525 }
526 if (failed(HandlePartitionedCallOp(
527 pcall, pcall.func(), module, *data_var_to_size_var,
528 decomposed_partitioned_call_callees))) {
529 return failure();
530 }
531 } else if (auto spcall =
532 llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
533 if (failed(HandlePartitionedCallOp(
534 spcall, spcall.func(), module, *data_var_to_size_var,
535 decomposed_partitioned_call_callees))) {
536 return failure();
537 }
538 }
539 }
540 return success();
541 }
542
DecomposeStackOps(Block * block,ModuleOp module)543 LogicalResult DecomposeStackOps(Block* block, ModuleOp module) {
544 llvm::SmallDenseMap<Value, Value> data_var_to_size_var;
545 llvm::StringMap<PartitionedCallStackOpsInfo>
546 decomposed_partitioned_call_callees;
547 return DecomposeStackOpsInternal(block, module, &data_var_to_size_var,
548 &decomposed_partitioned_call_callees);
549 }
550
runOnOperation()551 void StackOpsDecompositionPass::runOnOperation() {
552 auto module = getOperation();
553 auto main = module.lookupSymbol<func::FuncOp>("main");
554 if (!main) return;
555 if (failed(DecomposeStackOps(&main.front(), module))) {
556 signalPassFailure();
557 }
558 }
559
560 } // namespace
561
562 namespace TF {
CreateStackOpsDecompositionPass()563 std::unique_ptr<OperationPass<ModuleOp>> CreateStackOpsDecompositionPass() {
564 return std::make_unique<StackOpsDecompositionPass>();
565 }
566
567 } // namespace TF
568 } // namespace mlir
569