1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/None.h"
20 #include "llvm/ADT/Optional.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/Support/Casting.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
27 #include "mlir/IR/Attributes.h"  // from @llvm-project
28 #include "mlir/IR/Builders.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
31 #include "mlir/IR/Location.h"  // from @llvm-project
32 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
33 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
34 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
35 #include "mlir/IR/Types.h"  // from @llvm-project
36 #include "mlir/IR/Value.h"  // from @llvm-project
37 #include "mlir/Pass/Pass.h"  // from @llvm-project
38 #include "mlir/Support/LLVM.h"  // from @llvm-project
39 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
42 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
43 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
44 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
47 #include "tensorflow/core/framework/tensor.h"
48 #include "tensorflow/core/framework/tensor_shape.pb.h"
49 #include "tensorflow/core/framework/types.pb.h"
50 #include "tensorflow/core/platform/types.h"
51 
52 namespace mlir {
53 
54 namespace {
55 
56 namespace cutil = TF::collection_ops_util;
57 
58 struct StackOpsDecompositionPass
59     : public TF::StackOpsDecompositionPassBase<StackOpsDecompositionPass> {
60   void runOnOperation() final;
61 };
62 
63 // Returns the type of the local variable for the stack size.
GetSizeVarType(OpBuilder builder)64 Type GetSizeVarType(OpBuilder builder) {
65   auto size_type = cutil::GetSizeType(builder);
66   return RankedTensorType::get(
67       {}, TF::ResourceType::get(ArrayRef<TensorType>{size_type},
68                                 builder.getContext()));
69 }
70 
71 // Returns the aliasing argument number of a fucntion return value if it simply
72 // forwards the argument. Otherwise, returns -1.
FindAliasedInput(func::FuncOp func,int64_t return_index)73 int64_t FindAliasedInput(func::FuncOp func, int64_t return_index) {
74   Value return_val = func.front().getTerminator()->getOperand(return_index);
75   auto maybe_arg = return_val.dyn_cast<BlockArgument>();
76   if (!maybe_arg) return -1;
77   return maybe_arg.getArgNumber();
78 }
79 
80 // Changes the function signature that has stacks in the arguments. A stack
81 // argument will be turned into a variable type if arg_to_stack_type returns
82 // such a type, and a new argument will be added to the end of the argument
83 // list for the size variable.
84 //
85 // If stack_var_to_size_var is not nullptr, it will  be used to store the
86 // mapping from the stack-variable argument to the size-variable argument.
87 //
88 // If handle_new_size_vars is provided, it will be invoked on the list of new
89 // size variables before finally changing the function type.
ModifyFunctionSignature(func::FuncOp func,llvm::SmallDenseMap<Value,Value> * stack_var_to_size_var,llvm::function_ref<llvm::Optional<Type> (int64_t)> arg_to_stack_type,llvm::function_ref<void (ArrayRef<BlockArgument>)> handle_new_size_vars=nullptr)90 void ModifyFunctionSignature(
91     func::FuncOp func, llvm::SmallDenseMap<Value, Value>* stack_var_to_size_var,
92     llvm::function_ref<llvm::Optional<Type>(int64_t)> arg_to_stack_type,
93     llvm::function_ref<void(ArrayRef<BlockArgument>)> handle_new_size_vars =
94         nullptr) {
95   auto new_input_types = llvm::to_vector<8>(func.getFunctionType().getInputs());
96   auto size_var_type = GetSizeVarType(OpBuilder(func));
97   int64_t original_arg_count = new_input_types.size();
98   for (int64_t i = 0; i < original_arg_count; ++i) {
99     auto stack_type = arg_to_stack_type(i);
100     if (!stack_type.has_value()) continue;
101     func.getArgument(i).setType(*stack_type);
102     new_input_types[i] = *stack_type;
103     auto size_arg = func.front().addArgument(size_var_type, func.getLoc());
104     new_input_types.push_back(size_arg.getType());
105     if (stack_var_to_size_var) {
106       (*stack_var_to_size_var)[func.getArgument(i)] = size_arg;
107     }
108   }
109   if (handle_new_size_vars) {
110     handle_new_size_vars(func.getArguments().drop_front(original_arg_count));
111   }
112   func.setType(
113       FunctionType::get(func.getContext(), new_input_types,
114                         func.front().getTerminator()->getOperandTypes()));
115 }
116 
117 // Contains cached information for decomposed callee functions for (stateful)
118 // partitioned call ops.
119 struct PartitionedCallStackOpsInfo {
120   bool signature_change;
121   func::FuncOp decomposed_callee;
122   llvm::SmallDenseMap<int64_t, int64_t> stack_var_arg_to_size_arg;
123 };
124 
125 LogicalResult DecomposeStackOpsInternal(
126     Block*, ModuleOp, llvm::SmallDenseMap<Value, Value>*,
127     llvm::StringMap<PartitionedCallStackOpsInfo>*);
128 
129 // Handles stack usage by a tf.While. It will convert the body and conditional
130 // function signatures, and performs stack ops decomposition on them.
HandleWhileOp(TF::WhileOp while_op,ModuleOp module,const llvm::SmallDenseMap<Value,Value> & data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)131 LogicalResult HandleWhileOp(
132     TF::WhileOp while_op, ModuleOp module,
133     const llvm::SmallDenseMap<Value, Value>& data_var_to_size_var,
134     llvm::StringMap<PartitionedCallStackOpsInfo>*
135         decomposed_partitioned_call_callees) {
136   auto body = while_op.body_function();
137   llvm::SmallDenseMap<Value, Value> body_map;
138   auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional<Type> {
139     auto it = data_var_to_size_var.find(while_op.getOperand(index));
140     if (it == data_var_to_size_var.end()) return llvm::None;
141     return it->getFirst().getType();
142   };
143   auto add_size_vars_to_return = [&](ArrayRef<BlockArgument> new_args) {
144     if (new_args.empty()) return;
145     auto body_ret = body.front().getTerminator();
146     auto new_body_returns = llvm::to_vector<8>(body_ret->getOperands());
147     for (auto arg : new_args) new_body_returns.push_back(arg);
148     OpBuilder(body_ret).create<func::ReturnOp>(body_ret->getLoc(),
149                                                new_body_returns);
150     body_ret->erase();
151   };
152   // Handle body.
153   ModifyFunctionSignature(body, &body_map, find_arg_stack_type,
154                           add_size_vars_to_return);
155   const bool signature_change = !body_map.empty();
156   if (failed(DecomposeStackOpsInternal(&body.front(), module, &body_map,
157                                        decomposed_partitioned_call_callees))) {
158     return failure();
159   }
160   // Cond should not change stacks in the arguments, so use an empty map.
161   auto cond = while_op.cond_function();
162   ModifyFunctionSignature(cond, nullptr, find_arg_stack_type);
163   llvm::SmallDenseMap<Value, Value> empty_map;
164   if (failed(DecomposeStackOpsInternal(&cond.front(), module, &empty_map,
165                                        decomposed_partitioned_call_callees))) {
166     return failure();
167   }
168   if (!signature_change) return success();
169   // Create the new while op.
170   auto new_while_operands = llvm::to_vector<8>(while_op.getOperands());
171   OpBuilder builder(while_op);
172   assert(while_op.getNumOperands() == while_op.getNumResults());
173   for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
174     auto it = data_var_to_size_var.find(while_op.getOperand(i));
175     if (it == data_var_to_size_var.end()) continue;
176     new_while_operands.push_back(it->getSecond());
177   }
178   auto new_while = builder.create<TF::WhileOp>(
179       while_op.getLoc(), body.getFunctionType().getInputs(), new_while_operands,
180       while_op->getAttrs());
181   for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
182     if (!getElementTypeOrSelf(while_op.getOperand(i).getType())
183              .isa<TF::ResourceType>()) {
184       continue;
185     }
186     int64_t aliased_input = FindAliasedInput(body, i);
187     if (aliased_input == i) {
188       // Replace aliased stack output uses with input.
189       while_op.getResult(i).replaceAllUsesWith(while_op.getOperand(i));
190     }
191   }
192   while_op.replaceAllUsesWith(
193       new_while.getResults().take_front(while_op.getNumResults()));
194   while_op.erase();
195   return success();
196 }
197 
198 // Handles stack usage by a tf.If. It will convert the branch function
199 // signatures, and performs stack ops decomposition on them.
HandleIfOp(TF::IfOp if_op,ModuleOp module,const llvm::SmallDenseMap<Value,Value> & data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)200 LogicalResult HandleIfOp(
201     TF::IfOp if_op, ModuleOp module,
202     const llvm::SmallDenseMap<Value, Value>& data_var_to_size_var,
203     llvm::StringMap<PartitionedCallStackOpsInfo>*
204         decomposed_partitioned_call_callees) {
205   auto then_func = if_op.then_function();
206   auto else_func = if_op.else_function();
207   llvm::SmallDenseMap<Value, Value> then_map;
208   llvm::SmallDenseMap<Value, Value> else_map;
209 
210   auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional<Type> {
211     auto it = data_var_to_size_var.find(if_op.getOperand(index + 1));
212     if (it == data_var_to_size_var.end()) return llvm::None;
213     return it->getFirst().getType();
214   };
215   ModifyFunctionSignature(then_func, &then_map, find_arg_stack_type);
216   ModifyFunctionSignature(else_func, &else_map, find_arg_stack_type);
217   const bool signature_change = !then_map.empty() || !else_map.empty();
218   if (failed(DecomposeStackOpsInternal(&then_func.front(), module, &then_map,
219                                        decomposed_partitioned_call_callees)) ||
220       failed(DecomposeStackOpsInternal(&else_func.front(), module, &else_map,
221                                        decomposed_partitioned_call_callees))) {
222     return failure();
223   }
224   if (!signature_change) return success();
225   auto new_if_operands = llvm::to_vector<8>(if_op.getOperands());
226   for (auto operand : if_op.getOperands()) {
227     auto it = data_var_to_size_var.find(operand);
228     if (it == data_var_to_size_var.end()) continue;
229     new_if_operands.push_back(it->getSecond());
230   }
231   auto new_if = OpBuilder(if_op).create<TF::IfOp>(
232       if_op.getLoc(), then_func.getFunctionType().getResults(), new_if_operands,
233       if_op->getAttrs());
234   for (auto result : if_op.getResults()) {
235     if (!getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) {
236       continue;
237     }
238     int64_t then_aliased_input =
239         FindAliasedInput(then_func, result.getResultNumber());
240     int64_t else_aliased_input =
241         FindAliasedInput(else_func, result.getResultNumber());
242     if (then_aliased_input >= 0 && then_aliased_input == else_aliased_input) {
243       // Replace aliased stack output uses with input.
244       result.replaceAllUsesWith(if_op.getOperand(then_aliased_input + 1));
245     }
246   }
247   if_op.replaceAllUsesWith(new_if);
248   if_op.erase();
249   return success();
250 }
251 
252 // Handles stack usage by a tf.StatefulPartitionedCall or a tf.PartitionedCall.
253 // It will first check if the callee was previously handled, and try to reuse
254 // that result if so. Otherwise, it will clone and convert the callee function,
255 // and performs stack ops decomposition on it.
256 template <typename CallOp>
HandlePartitionedCallOp(CallOp call,func::FuncOp callee,ModuleOp module,const llvm::SmallDenseMap<Value,Value> & data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)257 LogicalResult HandlePartitionedCallOp(
258     CallOp call, func::FuncOp callee, ModuleOp module,
259     const llvm::SmallDenseMap<Value, Value>& data_var_to_size_var,
260     llvm::StringMap<PartitionedCallStackOpsInfo>*
261         decomposed_partitioned_call_callees) {
262   auto emplace_res = decomposed_partitioned_call_callees->try_emplace(
263       callee.getName(), PartitionedCallStackOpsInfo());
264   auto& info = emplace_res.first->second;
265   // Recreate the call op with info.
266   auto recreate_caller = [&] {
267     auto new_operands = llvm::to_vector<8>(call.getOperands());
268     for (int64_t i = 0; i < call.getNumOperands(); ++i) {
269       auto arg_it = info.stack_var_arg_to_size_arg.find(i);
270       if (arg_it == info.stack_var_arg_to_size_arg.end()) continue;
271       auto it = data_var_to_size_var.find(call.getOperand(i));
272       if (it == data_var_to_size_var.end()) {
273         call.emitOpError("unknown stack");
274         return failure();
275       }
276       assert(arg_it->second == new_operands.size());
277       new_operands.push_back(it->getSecond());
278     }
279     OpBuilder builder(call);
280     auto new_call = builder.create<CallOp>(
281         call.getLoc(), info.decomposed_callee.getFunctionType().getResults(),
282         new_operands, call->getAttrs());
283     new_call->setAttr(
284         "f", SymbolRefAttr::get(
285                  builder.getContext(),
286                  const_cast<func::FuncOp&>(info.decomposed_callee).getName()));
287     for (int64_t i = 0; i < call.getNumResults(); ++i) {
288       auto result = call.getResult(i);
289       if (!getElementTypeOrSelf(result.getType())
290                .template isa<TF::ResourceType>()) {
291         continue;
292       }
293       int64_t aliased_input = FindAliasedInput(info.decomposed_callee, i);
294       if (aliased_input >= 0) {
295         // Replace aliased stack output uses with input.
296         result.replaceAllUsesWith(call.getOperand(aliased_input));
297       }
298     }
299     call.replaceAllUsesWith(new_call);
300     call.erase();
301     return success();
302   };
303   if (!emplace_res.second) {
304     // This callee was handled before.
305     if (!info.signature_change) return success();
306     return recreate_caller();
307   }
308   llvm::SmallDenseMap<Value, Value> callee_map;
309   func::FuncOp lowered_callee = callee;
310   if (!callee.isPrivate()) {
311     // Clone non-private callee in case of signature change.
312     lowered_callee = callee.clone();
313     lowered_callee.setPrivate();
314   }
315   auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional<Type> {
316     auto it = data_var_to_size_var.find(call.getOperand(index));
317     if (it == data_var_to_size_var.end()) return llvm::None;
318     return it->getFirst().getType();
319   };
320   ModifyFunctionSignature(lowered_callee, &callee_map, find_arg_stack_type);
321   info.signature_change = !callee_map.empty();
322   if (!info.signature_change) {
323     // Signature is not modified. We do not need the clone.
324     if (lowered_callee != callee) {
325       lowered_callee.erase();
326     }
327   } else {
328     info.decomposed_callee = lowered_callee;
329     for (auto& entry : callee_map) {
330       info.stack_var_arg_to_size_arg
331           [entry.getFirst().cast<BlockArgument>().getArgNumber()] =
332           entry.getSecond().cast<BlockArgument>().getArgNumber();
333     }
334     if (lowered_callee != callee) {
335       // Add the clone with a new name.
336       lowered_callee.setName(StringAttr::get(
337           callee->getContext(),
338           llvm::formatv("{0}_stack_decomposed", callee.getName()).str()));
339       SymbolTable(module).insert(lowered_callee);
340       callee = lowered_callee;
341     }
342   }
343   if (failed(DecomposeStackOpsInternal(&callee.front(), module, &callee_map,
344                                        decomposed_partitioned_call_callees))) {
345     return failure();
346   }
347   if (info.signature_change) return recreate_caller();
348   return success();
349 }
350 
HandleStackV2Op(TF::StackV2Op stack,ModuleOp module,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var)351 LogicalResult HandleStackV2Op(
352     TF::StackV2Op stack, ModuleOp module,
353     llvm::SmallDenseMap<Value, Value>* data_var_to_size_var) {
354   // Create a buffer variable and a size variable to replace the stack.
355   auto elem_type = cutil::GetElementTypeFromAccess(
356       stack.handle(), module, [](Operation* user) -> llvm::Optional<Type> {
357         auto push = llvm::dyn_cast<TF::StackPushV2Op>(user);
358         if (!push) return llvm::None;
359         return push.elem().getType();
360       });
361   if (!elem_type.has_value()) {
362     return stack.emitOpError("cannot infer element shape of stack");
363   }
364   OpBuilder builder(stack);
365   Value buffer;
366   if (failed(cutil::CreateInitBufferValue(
367           elem_type->getShape(), stack.max_size(), stack,
368           elem_type->getElementType(), builder, &buffer))) {
369     return failure();
370   }
371   auto size_var_type = GetSizeVarType(builder);
372   auto var_type = RankedTensorType::get(
373       {}, TF::ResourceType::get(
374               ArrayRef<TensorType>{buffer.getType().cast<TensorType>()},
375               stack.getContext()));
376   auto local_var = builder.create<TF::MlirLocalVarOp>(
377       stack.getLoc(), ArrayRef<Type>{var_type}, ArrayRef<Value>{});
378   auto local_size_var = builder.create<TF::MlirLocalVarOp>(
379       stack.getLoc(), ArrayRef<Type>{size_var_type}, ArrayRef<Value>{});
380   // Zero-initialize the local vars.
381   cutil::WriteLocalVariable(local_size_var,
382                             cutil::GetR1Const({0LL}, builder, stack.getLoc()),
383                             builder, stack.getLoc());
384   cutil::WriteLocalVariable(local_var, buffer, builder, stack.getLoc());
385   stack.handle().replaceAllUsesWith(local_var);
386   (*data_var_to_size_var)[local_var] = local_size_var;
387   stack.erase();
388   return success();
389 }
390 
HandleStackPushV2Op(TF::StackPushV2Op push,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var)391 LogicalResult HandleStackPushV2Op(
392     TF::StackPushV2Op push,
393     llvm::SmallDenseMap<Value, Value>* data_var_to_size_var) {
394   auto it = data_var_to_size_var->find(push.handle());
395   if (it == data_var_to_size_var->end()) {
396     return push.emitOpError("unknown stack");
397   }
398   // Push output simply forward the input element.
399   push.replaceAllUsesWith(push.elem());
400   OpBuilder builder(push);
401   // Read the current buffer and size.
402   auto stack_val =
403       cutil::ReadLocalVariable(push.handle(), builder, push.getLoc());
404   auto index =
405       cutil::ReadLocalVariable(it->getSecond(), builder, push.getLoc());
406   stack_val =
407       cutil::SetElement(index, stack_val, push.elem(), builder, push.getLoc());
408   // Assign the new buffer and size.
409   cutil::WriteLocalVariable(push.handle(), stack_val, builder, push.getLoc());
410   index = builder.create<TF::AddV2Op>(
411       push.getLoc(), ArrayRef<Type>{index.getType()},
412       ArrayRef<Value>{index, cutil::GetR1Const({1}, builder, push.getLoc())});
413   cutil::WriteLocalVariable(it->getSecond(), index, builder, push.getLoc());
414   push.erase();
415   return success();
416 }
417 
HandleStackPopV2Op(TF::StackPopV2Op pop,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var)418 LogicalResult HandleStackPopV2Op(
419     TF::StackPopV2Op pop,
420     llvm::SmallDenseMap<Value, Value>* data_var_to_size_var) {
421   auto it = data_var_to_size_var->find(pop.handle());
422   if (it == data_var_to_size_var->end()) {
423     return pop.emitOpError("unknown stack");
424   }
425   OpBuilder builder(pop);
426   // Read the current buffer and size.
427   auto stack_val =
428       cutil::ReadLocalVariable(pop.handle(), builder, pop.getLoc());
429   auto size = cutil::ReadLocalVariable(it->getSecond(), builder, pop.getLoc());
430   auto new_size = builder.create<TF::SubOp>(
431       pop.getLoc(), ArrayRef<Type>{size.getType()},
432       ArrayRef<Value>{size, cutil::GetR1Const({1}, builder, pop.getLoc())});
433   auto pop_val = cutil::GetElement(new_size, stack_val, builder, pop.getLoc());
434   pop.replaceAllUsesWith(pop_val);
435   // Update the size.
436   cutil::WriteLocalVariable(it->getSecond(), new_size, builder, pop.getLoc());
437   pop.erase();
438   return success();
439 }
440 
HandleRegionControlFlowOps(Operation & op,ModuleOp module,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)441 LogicalResult HandleRegionControlFlowOps(
442     Operation& op, ModuleOp module,
443     llvm::SmallDenseMap<Value, Value>* data_var_to_size_var,
444     llvm::StringMap<PartitionedCallStackOpsInfo>*
445         decomposed_partitioned_call_callees) {
446   for (OpOperand& operand : op.getOpOperands()) {
447     if (getElementTypeOrSelf(operand.get().getType()).isa<TF::ResourceType>()) {
448       return op.emitOpError()
449              << "found unexpected type " << operand.get().getType()
450              << " of operand #" << operand.getOperandNumber()
451              << ", resource type operands are expected to have been "
452                 "canonicalized away for region based control flow ops";
453     }
454   }
455   for (OpResult result : op.getResults()) {
456     if (getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) {
457       return op.emitOpError()
458              << "found unexpected type " << result.getType() << " of result #"
459              << result.getResultNumber()
460              << ", resource type results are expected to have been "
461                 "canonicalized away for region based control flow ops";
462     }
463   }
464   for (Region& region : op.getRegions()) {
465     if (failed(DecomposeStackOpsInternal(&region.front(), module,
466                                          data_var_to_size_var,
467                                          decomposed_partitioned_call_callees)))
468       return failure();
469   }
470   return success();
471 }
472 
473 // Decomposes stack ops on a region and recursively decomposes called functions.
474 // data_var_to_size_var: a mapping from stacks' buffer local variables to size
475 // local variables.
476 // decomposed_partitioned_call_callees: cache for partitioned call ops' callee
477 // function handling.
DecomposeStackOpsInternal(Block * block,ModuleOp module,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)478 LogicalResult DecomposeStackOpsInternal(
479     Block* block, ModuleOp module,
480     llvm::SmallDenseMap<Value, Value>* data_var_to_size_var,
481     llvm::StringMap<PartitionedCallStackOpsInfo>*
482         decomposed_partitioned_call_callees) {
483   for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
484     if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
485       // Removes identity nodes in the block. The device computation does not
486       // need such nodes to carry information.
487       op.replaceAllUsesWith(op.getOperands());
488       op.erase();
489     } else if (auto stack = llvm::dyn_cast<TF::StackV2Op>(&op)) {
490       if (failed(HandleStackV2Op(stack, module, data_var_to_size_var))) {
491         return failure();
492       }
493     } else if (auto push = llvm::dyn_cast<TF::StackPushV2Op>(&op)) {
494       if (failed(HandleStackPushV2Op(push, data_var_to_size_var))) {
495         return failure();
496       }
497     } else if (auto pop = llvm::dyn_cast<TF::StackPopV2Op>(&op)) {
498       if (failed(HandleStackPopV2Op(pop, data_var_to_size_var))) {
499         return failure();
500       }
501     } else if (auto close = llvm::dyn_cast<TF::StackCloseV2Op>(&op)) {
502       data_var_to_size_var->erase(close.handle());
503       close.erase();
504     } else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
505       if (failed(HandleWhileOp(while_op, module, *data_var_to_size_var,
506                                decomposed_partitioned_call_callees))) {
507         return failure();
508       }
509     } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
510       if (failed(HandleIfOp(if_op, module, *data_var_to_size_var,
511                             decomposed_partitioned_call_callees))) {
512         return failure();
513       }
514     } else if (llvm::isa<TF::WhileRegionOp>(op) ||
515                llvm::isa<TF::IfRegionOp>(op) ||
516                llvm::isa<TF::CaseRegionOp>(op)) {
517       if (failed(
518               HandleRegionControlFlowOps(op, module, data_var_to_size_var,
519                                          decomposed_partitioned_call_callees)))
520         return failure();
521     } else if (auto pcall = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
522       if (!pcall.func()) {
523         return pcall.emitOpError(
524             "stack decomposition does not support call with nested references");
525       }
526       if (failed(HandlePartitionedCallOp(
527               pcall, pcall.func(), module, *data_var_to_size_var,
528               decomposed_partitioned_call_callees))) {
529         return failure();
530       }
531     } else if (auto spcall =
532                    llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
533       if (failed(HandlePartitionedCallOp(
534               spcall, spcall.func(), module, *data_var_to_size_var,
535               decomposed_partitioned_call_callees))) {
536         return failure();
537       }
538     }
539   }
540   return success();
541 }
542 
DecomposeStackOps(Block * block,ModuleOp module)543 LogicalResult DecomposeStackOps(Block* block, ModuleOp module) {
544   llvm::SmallDenseMap<Value, Value> data_var_to_size_var;
545   llvm::StringMap<PartitionedCallStackOpsInfo>
546       decomposed_partitioned_call_callees;
547   return DecomposeStackOpsInternal(block, module, &data_var_to_size_var,
548                                    &decomposed_partitioned_call_callees);
549 }
550 
runOnOperation()551 void StackOpsDecompositionPass::runOnOperation() {
552   auto module = getOperation();
553   auto main = module.lookupSymbol<func::FuncOp>("main");
554   if (!main) return;
555   if (failed(DecomposeStackOps(&main.front(), module))) {
556     signalPassFailure();
557   }
558 }
559 
560 }  // namespace
561 
562 namespace TF {
CreateStackOpsDecompositionPass()563 std::unique_ptr<OperationPass<ModuleOp>> CreateStackOpsDecompositionPass() {
564   return std::make_unique<StackOpsDecompositionPass>();
565 }
566 
567 }  // namespace TF
568 }  // namespace mlir
569