1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/None.h"
18 #include "llvm/ADT/Optional.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/StringMap.h"
23 #include "llvm/ADT/StringSet.h"
24 #include "llvm/Support/Casting.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
27 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
28 #include "mlir/IR/Attributes.h"  // from @llvm-project
29 #include "mlir/IR/Builders.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
32 #include "mlir/IR/Location.h"  // from @llvm-project
33 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
34 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
35 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
36 #include "mlir/IR/Types.h"  // from @llvm-project
37 #include "mlir/IR/Value.h"  // from @llvm-project
38 #include "mlir/Pass/Pass.h"  // from @llvm-project
39 #include "mlir/Support/LLVM.h"  // from @llvm-project
40 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
43 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
44 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
47 #include "tensorflow/core/framework/tensor.h"
48 #include "tensorflow/core/framework/tensor_shape.h"
49 #include "tensorflow/core/framework/tensor_shape.pb.h"
50 #include "tensorflow/core/framework/types.pb.h"
51 #include "tensorflow/core/platform/types.h"
52 
53 namespace mlir {
54 namespace {
55 
56 namespace cutil = TF::collection_ops_util;
57 
58 using std::string;
59 
60 // A pass that converts tensor array operations to tensor operations and
61 // read/assign ops on local variables. A later resource lifting pass can further
62 // remove the local variables.
63 //
64 // This pass requires that the full shape of the tensor array can be inferred:
65 // 1) the size needs to be a constant, 2) it specifies the full element shape,
66 // or that can be inferred from a later write, and 3) all elements have the same
67 // shape.
68 //
69 struct TensorArrayOpsDecompositionPass
70     : public TF::TensorArrayOpsDecompositionPassBase<
71           TensorArrayOpsDecompositionPass> {
72   void runOnOperation() override;
73 };
74 
75 // Infers the element type and count for a TensorArraySplitV3Op. Requires
76 // constant lengths and static shape on the input value.
GetSplitElementTypeAndCount(TF::TensorArraySplitV3Op split,RankedTensorType * elem_type,int64_t * count)77 LogicalResult GetSplitElementTypeAndCount(TF::TensorArraySplitV3Op split,
78                                           RankedTensorType* elem_type,
79                                           int64_t* count) {
80   auto lengths_const =
81       llvm::dyn_cast_or_null<TF::ConstOp>(split.lengths().getDefiningOp());
82   if (!lengths_const) return split.emitOpError("non-constant split lengths");
83   *count = lengths_const.value().getNumElements();
84   if (*count <= 0) return split.emitOpError("non-positive split count");
85   auto buffer_type = split.value().getType().dyn_cast<RankedTensorType>();
86   if (!buffer_type || !buffer_type.hasStaticShape() ||
87       buffer_type.getRank() < 1) {
88     return split.emitOpError("unknown or invalid split tensor shape");
89   }
90   int64_t length = buffer_type.getDimSize(0) / *count;
91   for (const auto& len : lengths_const.value().getValues<APInt>()) {
92     if (length == len.getSExtValue()) continue;
93     return split.emitOpError("different split lengths are not supported");
94   }
95   llvm::SmallVector<int64_t, 8> elem_shape;
96   elem_shape.push_back(length);
97   for (int64_t dim : buffer_type.getShape().drop_front()) {
98     elem_shape.push_back(dim);
99   }
100   *elem_type = RankedTensorType::get(elem_shape, buffer_type.getElementType());
101   return success();
102 }
103 
104 // Tries to infer the tensor array element shape.
GetTensorArrayElementShape(TF::TensorArrayV3Op ta,ModuleOp module)105 llvm::Optional<llvm::SmallVector<int64_t, 8>> GetTensorArrayElementShape(
106     TF::TensorArrayV3Op ta, ModuleOp module) {
107   auto element_shape = ta.element_shapeAttr().cast<mlir::TF::ShapeAttr>();
108   if (element_shape.hasStaticShape()) {
109     auto shape = element_shape.getShape();
110     // Convert int64 to int64_t.
111     llvm::SmallVector<int64_t, 8> dims(shape.begin(), shape.end());
112     return dims;
113   }
114 
115   bool has_failure = false;
116   auto elem_type = cutil::GetElementTypeFromAccess(
117       ta.handle(), module, [&](Operation* user) -> llvm::Optional<Type> {
118         if (has_failure) return llvm::None;
119         if (auto write = llvm::dyn_cast<TF::TensorArrayWriteV3Op>(user)) {
120           return write.value().getType();
121         } else if (auto split =
122                        llvm::dyn_cast<TF::TensorArraySplitV3Op>(user)) {
123           if (!split.lengths().getDefiningOp() ||
124               !llvm::isa<TF::ConstOp>(split.lengths().getDefiningOp())) {
125             return llvm::None;
126           }
127           RankedTensorType t;
128           int64_t count;
129           if (failed(GetSplitElementTypeAndCount(split, &t, &count))) {
130             has_failure = true;
131             return llvm::None;
132           }
133           return t;
134         } else if (auto scatter =
135                        llvm::dyn_cast<TF::TensorArrayScatterV3Op>(user)) {
136           // TensorArrayScatter writes vector of tensors to TensorArray. We can
137           // deduce the shape of TensorArray by dropping the 0th dim of
138           // TensorArrayScatter `value`.
139           auto t = scatter.value().getType().dyn_cast<RankedTensorType>();
140           if (!t || t.getShape().empty()) return llvm::None;
141           return RankedTensorType::get(t.getShape().drop_front(),
142                                        t.getElementType());
143         } else if (auto gather =
144                        llvm::dyn_cast<TF::TensorArrayGatherV3Op>(user)) {
145           // Try to infer from result type of gather.
146           auto t = gather.value().getType().dyn_cast<RankedTensorType>();
147           if (t && !t.getShape().empty())
148             return RankedTensorType::get(t.getShape().drop_front(),
149                                          t.getElementType());
150           // Try to infer from `element_shape` attribute of gather.
151           auto element_shape = gather.element_shapeAttr()
152                                    .dyn_cast_or_null<mlir::TF::ShapeAttr>();
153           if (element_shape && element_shape.hasStaticShape()) {
154             return RankedTensorType::get(element_shape.getShape(),
155                                          gather.dtype());
156           }
157         }
158         return llvm::None;
159       });
160   if (!elem_type) return llvm::None;
161   return llvm::to_vector<8>(elem_type->getShape());
162 }
163 
ReplaceAllUsesWithCast(Value old_val,Value new_val)164 void ReplaceAllUsesWithCast(Value old_val, Value new_val) {
165   if (old_val.use_empty()) return;
166   auto cast_op =
167       OpBuilder(old_val.getDefiningOp())
168           .create<tensor::CastOp>(old_val.getLoc(), old_val.getType(), new_val);
169   old_val.replaceAllUsesWith(cast_op);
170 }
171 
ReplaceAllUsesExceptTerminator(Value old_val,Value new_val)172 void ReplaceAllUsesExceptTerminator(Value old_val, Value new_val) {
173   if (old_val.getType() == new_val.getType()) {
174     old_val.replaceAllUsesWith(new_val);
175     return;
176   }
177   Operation* old_op = old_val.getDefiningOp();
178   Operation* terminator_op =
179       old_op->getParentOfType<func::FuncOp>().front().getTerminator();
180   llvm::SmallPtrSet<mlir::Operation*, 1> exceptions = {terminator_op};
181   old_val.replaceAllUsesExcept(new_val, exceptions);
182 }
183 
184 struct TensorArrayStats {
185   // Whether a write op should accumulate with the old value. Set to true if
186   // this is a gradient.
187   bool accumulate_on_write;
188   // Maps from a gradient source string to the local variable to the gradient.
189   llvm::StringMap<Value> grads;
190 };
191 
HandleTensorArrayV3Op(TF::TensorArrayV3Op ta,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats)192 LogicalResult HandleTensorArrayV3Op(
193     TF::TensorArrayV3Op ta, ModuleOp module,
194     llvm::SmallDenseMap<Value, TensorArrayStats>* stats) {
195   auto elem_shape = GetTensorArrayElementShape(ta, module);
196   if (!elem_shape) return ta.emitOpError("unknown element shape");
197   if (ta.dynamic_size()) {
198     return ta.emitOpError("dynamic tensor array size is unsupported");
199   }
200   Value buffer;
201   OpBuilder builder(ta);
202   if (failed(cutil::CreateInitBufferValue(*elem_shape, ta.size(), ta,
203                                           ta.dtype(), builder, &buffer))) {
204     return failure();
205   }
206   auto var_type = RankedTensorType::get(
207       {}, TF::ResourceType::get(
208               ArrayRef<TensorType>{buffer.getType().cast<TensorType>()},
209               ta.getContext()));
210   auto local_var = builder.create<TF::MlirLocalVarOp>(
211       ta.getLoc(), ArrayRef<Type>{var_type}, ArrayRef<Value>{});
212   cutil::WriteLocalVariable(local_var, buffer, builder, ta.getLoc());
213   ta.handle().replaceAllUsesWith(local_var);
214   // The flow output is just a way for the front end to enforce ordering among
215   // tensor array ops, but in the MLIR TF dialect they have sequential ordering.
216   // Just create a constant to replace its uses.
217   tensorflow::Tensor scalar_tensor(tensorflow::DT_FLOAT, {});
218   scalar_tensor.scalar<float>()() = 0.0f;
219   auto flow = builder.create<TF::ConstOp>(
220       ta.getLoc(),
221       tensorflow::ConvertTensor(scalar_tensor, &builder).ValueOrDie());
222   ta.flow().replaceAllUsesWith(flow);
223   ta.erase();
224   (*stats)[local_var].accumulate_on_write = false;
225   return success();
226 }
227 
HandleTensorArrayReadV3Op(TF::TensorArrayReadV3Op read,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)228 LogicalResult HandleTensorArrayReadV3Op(
229     TF::TensorArrayReadV3Op read,
230     const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
231   auto local_var = read.handle();
232   if (stats.count(local_var) == 0) {
233     return read.emitOpError("unknown tensor array");
234   }
235   OpBuilder builder(read);
236   auto buffer = cutil::ReadLocalVariable(local_var, builder, read.getLoc());
237   auto index_reshape =
238       cutil::ReshapeScalarToSizeType(builder, read.index(), read.getLoc());
239   auto elem = cutil::GetElement(index_reshape, buffer, builder, read.getLoc());
240   ReplaceAllUsesExceptTerminator(read.value(), elem);
241   ReplaceAllUsesWithCast(read.value(), elem);
242   read.erase();
243   // The clear_after_read attribute does not mean setting the tensor to 0 after
244   // read; instead it does not allow a second read before the next write. We
245   // follow the old bridge's implementation not to do anything here.
246   return success();
247 }
248 
HandleTensorArrayWriteV3Op(TF::TensorArrayWriteV3Op write,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)249 LogicalResult HandleTensorArrayWriteV3Op(
250     TF::TensorArrayWriteV3Op write,
251     const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
252   auto local_var = write.handle();
253   auto stat_it = stats.find(local_var);
254   if (stat_it == stats.end()) return write.emitOpError("unknown tensor array");
255   OpBuilder builder(write);
256   auto buffer = cutil::ReadLocalVariable(local_var, builder, write.getLoc());
257   auto index_reshape =
258       cutil::ReshapeScalarToSizeType(builder, write.index(), write.getLoc());
259   auto elem = write.value();
260   if (stat_it->getSecond().accumulate_on_write) {
261     // Get the old slice, and accumulate with it. We set keep_slice_shape
262     // (keeping the leading size-1 dimension) because it avoids reshape back and
263     // forth.
264     auto original_elem =
265         cutil::GetElement(index_reshape, buffer, builder, write.getLoc(),
266                           /*keep_slice_shape=*/true);
267     // Add a size-1 leading dimension to elem.
268     auto slice_type = original_elem.getType().cast<RankedTensorType>();
269     elem = builder.create<TF::ReshapeOp>(
270         write.getLoc(), ArrayRef<Type>{slice_type},
271         ArrayRef<Value>{elem, cutil::GetR1Const(slice_type.getShape(), builder,
272                                                 write.getLoc())});
273     elem =
274         cutil::AccumulateBuffers(elem, original_elem, builder, write.getLoc());
275   }
276   buffer =
277       cutil::SetElement(index_reshape, buffer, elem, builder, write.getLoc());
278   cutil::WriteLocalVariable(local_var, buffer, builder, write.getLoc());
279   write.flow_out().replaceAllUsesWith(write.flow_in());
280   write.erase();
281   return success();
282 }
283 
HandleTensorArrayConcatV3Op(TF::TensorArrayConcatV3Op concat,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)284 LogicalResult HandleTensorArrayConcatV3Op(
285     TF::TensorArrayConcatV3Op concat,
286     const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
287   auto local_var = concat.handle();
288   if (stats.count(local_var) == 0) {
289     return concat.emitOpError("unknown tensor array");
290   }
291   OpBuilder builder(concat);
292   auto buffer = cutil::ReadLocalVariable(local_var, builder, concat.getLoc());
293   auto buffer_type = buffer.getType().cast<RankedTensorType>();
294   if (buffer_type.getShape().size() <= 1) {
295     return concat.emitOpError("cannot concat on scalar-element tensor array");
296   }
297   // Merget he first two dimensions.
298   auto shape = llvm::to_vector<8>(buffer_type.getShape().drop_front());
299   shape[0] *= buffer_type.getDimSize(0);
300   buffer = builder.create<TF::ReshapeOp>(
301       concat.getLoc(),
302       ArrayRef<Type>{
303           RankedTensorType::get(shape, buffer_type.getElementType())},
304       ArrayRef<Value>{buffer,
305                       cutil::GetR1Const(shape, builder, concat.getLoc())});
306   ReplaceAllUsesExceptTerminator(concat.value(), buffer);
307   ReplaceAllUsesWithCast(concat.value(), buffer);
308 
309   // Create the lengths as a list of the same value (element size).
310   tensorflow::Tensor lengths_tensor(tensorflow::DT_INT64,
311                                     {buffer_type.getDimSize(0)});
312   for (int64_t i = 0; i < buffer_type.getDimSize(0); ++i) {
313     lengths_tensor.vec<int64_t>()(i) = buffer_type.getDimSize(1);
314   }
315   concat.lengths().replaceAllUsesWith(builder.create<TF::ConstOp>(
316       concat.getLoc(),
317       tensorflow::ConvertTensor(lengths_tensor, &builder).ValueOrDie()));
318   concat.erase();
319   return success();
320 }
321 
HandleTensorArraySplitV3Op(TF::TensorArraySplitV3Op split,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)322 LogicalResult HandleTensorArraySplitV3Op(
323     TF::TensorArraySplitV3Op split,
324     const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
325   auto local_var = split.handle();
326   if (stats.count(local_var) == 0) {
327     return split.emitOpError("unknown tensor array");
328   }
329   OpBuilder builder(split);
330   int64_t count;
331   RankedTensorType elem_type;
332   if (failed(GetSplitElementTypeAndCount(split, &elem_type, &count))) {
333     return failure();
334   }
335   llvm::SmallVector<int64_t, 8> buffer_shape;
336   buffer_shape.push_back(count);
337   for (int64_t dim : elem_type.getShape()) buffer_shape.push_back(dim);
338   // Reshape the input to match the buffer of the tensor array.
339   auto buffer = builder
340                     .create<TF::ReshapeOp>(
341                         split.getLoc(),
342                         ArrayRef<Type>{RankedTensorType::get(
343                             buffer_shape, elem_type.getElementType())},
344                         ArrayRef<Value>{split.value(),
345                                         cutil::GetR1Const(buffer_shape, builder,
346                                                           split.getLoc())})
347                     .output();
348   // Accumulate with the old buffer.
349   auto old_buffer =
350       cutil::ReadLocalVariable(local_var, builder, split.getLoc());
351   buffer =
352       cutil::AccumulateBuffers(old_buffer, buffer, builder, split.getLoc());
353   cutil::WriteLocalVariable(local_var, buffer, builder, split.getLoc());
354   split.flow_out().replaceAllUsesWith(split.flow_in());
355   split.erase();
356   return success();
357 }
358 
HandleTensorArraySizeV3Op(TF::TensorArraySizeV3Op size,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)359 LogicalResult HandleTensorArraySizeV3Op(
360     TF::TensorArraySizeV3Op size,
361     const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
362   auto local_var = size.handle();
363   if (stats.count(local_var) == 0) {
364     return size.emitOpError("unknown tensor array");
365   }
366   auto buffer_type = getElementTypeOrSelf(local_var.getType())
367                          .cast<TF::ResourceType>()
368                          .getSubtypes()[0]
369                          .cast<RankedTensorType>();
370   OpBuilder builder(size);
371   auto result = cutil::CreateScalarConst(buffer_type.getDimSize(0), builder,
372                                          size.getLoc());
373   size.size().replaceAllUsesWith(result);
374   size.erase();
375   return success();
376 }
377 
CreateAndInitializeGradVariable(Type local_var_type,Operation * op,Value * var)378 LogicalResult CreateAndInitializeGradVariable(Type local_var_type,
379                                               Operation* op, Value* var) {
380   OpBuilder builder(op);
381   *var = builder.create<TF::MlirLocalVarOp>(
382       op->getLoc(), ArrayRef<Type>{local_var_type}, ArrayRef<Value>{});
383   Value buffer;
384   auto buffer_type = getElementTypeOrSelf(local_var_type)
385                          .cast<TF::ResourceType>()
386                          .getSubtypes()[0]
387                          .cast<RankedTensorType>();
388   if (failed(cutil::CreateInitBufferValue(
389           buffer_type.getShape().drop_front(), buffer_type.getDimSize(0), op,
390           buffer_type.getElementType(), builder, &buffer))) {
391     return failure();
392   }
393   cutil::WriteLocalVariable(*var, buffer, builder, op->getLoc());
394   return success();
395 }
396 
HandleTensorArrayGradV3Op(TF::TensorArrayGradV3Op grad,llvm::SmallDenseMap<Value,TensorArrayStats> * stats)397 LogicalResult HandleTensorArrayGradV3Op(
398     TF::TensorArrayGradV3Op grad,
399     llvm::SmallDenseMap<Value, TensorArrayStats>* stats) {
400   auto local_var = grad.handle();
401   OpBuilder builder(grad);
402   Value grad_var;
403   auto sit = stats->find(local_var);
404   if (sit == stats->end()) return grad.emitOpError("unknown tensor array");
405   auto emplace_res =
406       sit->getSecond().grads.try_emplace(grad.source().str(), Value());
407   if (!emplace_res.second) {
408     // If the source has been assigned a grad, use it.
409     grad_var = emplace_res.first->second;
410   } else {
411     if (failed(CreateAndInitializeGradVariable(local_var.getType(), grad,
412                                                &grad_var))) {
413       return failure();
414     }
415     emplace_res.first->second = grad_var;
416     // Write to a grad accumulates with previous writes.
417     (*stats)[grad_var].accumulate_on_write = true;
418   }
419   grad.flow_out().replaceAllUsesWith(grad.flow_in());
420   grad.grad_handle().replaceAllUsesWith(grad_var);
421   grad.erase();
422   return success();
423 }
424 
HandleTensorArrayGatherV3Op(TF::TensorArrayGatherV3Op gather,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)425 LogicalResult HandleTensorArrayGatherV3Op(
426     TF::TensorArrayGatherV3Op gather,
427     const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
428   auto local_var = gather.handle();
429   if (stats.count(local_var) == 0) {
430     return gather.emitOpError("unknown tensor array");
431   }
432   OpBuilder builder(gather);
433   auto buffer = cutil::ReadLocalVariable(local_var, builder, gather.getLoc());
434   auto result =
435       cutil::GatherElements(gather.indices(), buffer, builder, gather.getLoc());
436   ReplaceAllUsesExceptTerminator(gather.value(), result);
437   ReplaceAllUsesWithCast(gather.value(), result);
438   gather.erase();
439   return success();
440 }
441 
HandleTensorArrayScatterV3Op(TF::TensorArrayScatterV3Op scatter,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)442 LogicalResult HandleTensorArrayScatterV3Op(
443     TF::TensorArrayScatterV3Op scatter,
444     const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
445   auto local_var = scatter.handle();
446   if (stats.count(local_var) == 0) {
447     return scatter.emitOpError("unknown tensor array");
448   }
449   OpBuilder builder(scatter);
450   auto buffer = cutil::ReadLocalVariable(local_var, builder, scatter.getLoc());
451   buffer = cutil::ScatterAccumulateElements(scatter.indices(), scatter.value(),
452                                             buffer, builder, scatter.getLoc());
453   cutil::WriteLocalVariable(local_var, buffer, builder, scatter.getLoc());
454   scatter.flow_out().replaceAllUsesWith(scatter.flow_in());
455   scatter.erase();
456   return success();
457 }
458 
459 // Updates func's type according to its current arguments and return values.
UpdateFuncType(func::FuncOp func)460 void UpdateFuncType(func::FuncOp func) {
461   llvm::SmallVector<Type, 8> arg_types;
462   for (auto arg : func.getArguments()) arg_types.push_back(arg.getType());
463   func.setType(
464       FunctionType::get(func.getContext(), arg_types,
465                         func.front().getTerminator()->getOperandTypes()));
466 }
467 
468 // Finds the accessed gradient sources for each tensor array argument.
AccessedGradients(ArrayRef<func::FuncOp> funcs,ModuleOp module)469 llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>> AccessedGradients(
470     ArrayRef<func::FuncOp> funcs, ModuleOp module) {
471   llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>> result;
472   llvm::SmallDenseMap<int64_t, llvm::StringSet<>> result_sets;
473   auto insert = [&](Value v, const string& source, const Block& func_block) {
474     auto arg = v.dyn_cast<BlockArgument>();
475     if (!arg || arg.getOwner() != &func_block) return;
476     auto insert_res = result_sets[arg.getArgNumber()].insert(source);
477     if (!insert_res.second) return;
478     result[arg.getArgNumber()].push_back(source);
479   };
480   for (func::FuncOp func : funcs) {
481     const Block& func_block = func.front();
482     // Walk all operations and nested regions to find accessed gradient sources
483     // for function arguments.
484     func.walk([&](Operation* op) {
485       if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(op)) {
486         op->replaceAllUsesWith(op->getOperands());
487         return;
488       }
489       if (auto grad = llvm::dyn_cast<TF::TensorArrayGradV3Op>(op)) {
490         insert(grad.handle(), grad.source().str(), func_block);
491       } else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(op)) {
492         for (const auto& entry : AccessedGradients(
493                  {while_op.body_function(), while_op.cond_function()}, module))
494           for (const string& source : entry.getSecond())
495             insert(while_op.getOperand(entry.getFirst()), source, func_block);
496       } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(op)) {
497         for (const auto& entry : AccessedGradients(
498                  {if_op.then_function(), if_op.else_function()}, module))
499           for (const string& source : entry.getSecond())
500             insert(if_op.getOperand(entry.getFirst() + 1), source, func_block);
501       } else if (auto call = llvm::dyn_cast<CallOpInterface>(op)) {
502         auto callee = dyn_cast<func::FuncOp>(call.resolveCallable());
503         for (const auto& entry : AccessedGradients({callee}, module))
504           for (const string& source : entry.getSecond())
505             insert(call.getArgOperands()[entry.getFirst()], source, func_block);
506       }
507     });
508   }
509   return result;
510 }
511 
512 // Contains cached information for decomposed callee functions for (stateful)
513 // partitioned call ops.
514 struct PartitionedCallTensorArrayOpsInfo {
515   bool signature_change;
516   func::FuncOp decomposed_callee;
517   llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<string, 4>>, 4>
518       arg_grads;
519   llvm::SmallVector<std::pair<int64_t, int64_t>, 4> ret_forward_input;
520 };
521 
522 // Updates a called function's input signature by adjusting resource types, and
523 // adding required gradient arguments.
ChangeFunctionInputSignature(func::FuncOp func,const llvm::SmallDenseMap<int64_t,llvm::SmallVector<string,4>> & grads,llvm::function_ref<Type (int64_t)> ta_arg_buffer_type,llvm::function_ref<bool (int64_t)> ta_accumulate_on_write,llvm::SmallDenseMap<Value,TensorArrayStats> * stats)524 void ChangeFunctionInputSignature(
525     func::FuncOp func,
526     const llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>>& grads,
527     llvm::function_ref<Type(int64_t)> ta_arg_buffer_type,
528     llvm::function_ref<bool(int64_t)> ta_accumulate_on_write,
529     llvm::SmallDenseMap<Value, TensorArrayStats>* stats) {
530   int64_t original_args = func.getNumArguments();
531   for (int64_t argnum = 0; argnum < original_args; ++argnum) {
532     auto arg = func.getArgument(argnum);
533     Type t = ta_arg_buffer_type(argnum);
534     if (!t) continue;
535     arg.setType(t);
536     auto grad_it = grads.find(argnum);
537     if (grad_it == grads.end()) continue;
538     llvm::StringMap<Value> grads_map;
539     for (const string& source : grad_it->getSecond()) {
540       auto g = func.front().addArgument(t, func.getLoc());
541       (*stats)[g].accumulate_on_write = true;
542       grads_map[source] = g;
543     }
544     auto& stat = (*stats)[arg];
545     stat.accumulate_on_write = ta_accumulate_on_write(argnum);
546     stat.grads = std::move(grads_map);
547   }
548   UpdateFuncType(func);
549 }
550 
551 LogicalResult DecomposeTensorArrayOps(
552     Block*, ModuleOp, llvm::SmallDenseMap<Value, TensorArrayStats>*,
553     llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*);
554 
HandleWhileOp(TF::WhileOp while_op,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats,llvm::StringMap<PartitionedCallTensorArrayOpsInfo> * decomposed_partitioned_call_callees)555 LogicalResult HandleWhileOp(TF::WhileOp while_op, ModuleOp module,
556                             llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
557                             llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
558                                 decomposed_partitioned_call_callees) {
559   auto body = while_op.body_function();
560   auto cond = while_op.cond_function();
561   auto grads = AccessedGradients({body, cond}, module);
562   auto ta_arg_buffer_type = [&](int64_t index) -> Type {
563     auto it = stats->find(while_op.getOperand(index));
564     if (it == stats->end()) return nullptr;
565     return it->getFirst().getType();
566   };
567   auto ta_accumulate_on_write = [&](int64_t index) {
568     auto it = stats->find(while_op.getOperand(index));
569     if (it == stats->end()) return false;
570     return it->getSecond().accumulate_on_write;
571   };
572   llvm::SmallDenseMap<Value, TensorArrayStats> body_stats;
573   ChangeFunctionInputSignature(body, grads, ta_arg_buffer_type,
574                                ta_accumulate_on_write, &body_stats);
575   llvm::SmallDenseMap<Value, TensorArrayStats> cond_stats;
576   ChangeFunctionInputSignature(cond, grads, ta_arg_buffer_type,
577                                ta_accumulate_on_write, &cond_stats);
578   if (failed(DecomposeTensorArrayOps(&body.front(), module, &body_stats,
579                                      decomposed_partitioned_call_callees)) ||
580       failed(DecomposeTensorArrayOps(&cond.front(), module, &cond_stats,
581                                      decomposed_partitioned_call_callees))) {
582     return failure();
583   }
584   if (body_stats.empty() && cond_stats.empty()) return success();
585   auto old_body_ret = body.front().getTerminator();
586   auto new_retvals = llvm::to_vector<8>(old_body_ret->getOperands());
587   for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
588     if (!ta_arg_buffer_type(i)) continue;
589     auto retval = old_body_ret->getOperand(i);
590     auto arg = retval.dyn_cast<BlockArgument>();
591     if (!arg) {
592       return while_op.emitOpError(
593           "output tensor array does not alias input in a while loop");
594     }
595     for (const string& source : grads[i]) {
596       new_retvals.push_back(body_stats[arg].grads[source]);
597     }
598   }
599   OpBuilder(old_body_ret)
600       .create<func::ReturnOp>(old_body_ret->getLoc(), new_retvals);
601   old_body_ret->erase();
602   UpdateFuncType(body);
603   // Recreate the while op.
604   auto operands = llvm::to_vector<8>(while_op.getOperands());
605   for (int64_t i = 0; i < while_op.getNumOperands(); ++i) {
606     auto grad_it = grads.find(i);
607     auto& stat = (*stats)[operands[i]];
608     if (grad_it == grads.end()) continue;
609     for (const string& source : grad_it->getSecond()) {
610       auto it = stat.grads.find(source);
611       if (it != stat.grads.end()) {
612         operands.push_back(it->second);
613       } else {
614         Value grad_var;
615         if (failed(CreateAndInitializeGradVariable(operands[i].getType(),
616                                                    while_op, &grad_var))) {
617           return failure();
618         }
619         stat.grads[source] = grad_var;
620         operands.push_back(grad_var);
621         (*stats)[grad_var].accumulate_on_write = true;
622       }
623     }
624   }
625   OpBuilder builder(while_op);
626   auto new_while = builder.create<TF::WhileOp>(
627       while_op.getLoc(), body.getFunctionType().getInputs(), operands,
628       while_op->getAttrs());
629   for (int64_t i = 0; i < while_op.getNumOperands(); ++i) {
630     if (ta_arg_buffer_type(i)) {
631       while_op.getResult(i).replaceAllUsesWith(while_op.getOperand(i));
632     } else {
633       while_op.getResult(i).replaceAllUsesWith(new_while.getResult(i));
634     }
635   }
636   while_op.erase();
637   return success();
638 }
639 
HandleIfOp(TF::IfOp if_op,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats,llvm::StringMap<PartitionedCallTensorArrayOpsInfo> * decomposed_partitioned_call_callees)640 LogicalResult HandleIfOp(TF::IfOp if_op, ModuleOp module,
641                          llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
642                          llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
643                              decomposed_partitioned_call_callees) {
644   auto then_branch = if_op.then_function();
645   auto else_branch = if_op.else_function();
646   auto grads = AccessedGradients({then_branch, else_branch}, module);
647   auto ta_arg_buffer_type = [&](int64_t index) -> Type {
648     auto it = stats->find(if_op.getOperand(index + 1));
649     if (it == stats->end()) return nullptr;
650     return it->getFirst().getType();
651   };
652   auto ta_accumulate_on_write = [&](int64_t index) {
653     auto it = stats->find(if_op.getOperand(index + 1));
654     if (it == stats->end()) return false;
655     return it->getSecond().accumulate_on_write;
656   };
657   llvm::SmallDenseMap<Value, TensorArrayStats> then_stats;
658   ChangeFunctionInputSignature(then_branch, grads, ta_arg_buffer_type,
659                                ta_accumulate_on_write, &then_stats);
660   llvm::SmallDenseMap<Value, TensorArrayStats> else_stats;
661   ChangeFunctionInputSignature(else_branch, grads, ta_arg_buffer_type,
662                                ta_accumulate_on_write, &else_stats);
663   if (failed(DecomposeTensorArrayOps(&then_branch.front(), module, &then_stats,
664                                      decomposed_partitioned_call_callees)) ||
665       failed(DecomposeTensorArrayOps(&else_branch.front(), module, &else_stats,
666                                      decomposed_partitioned_call_callees))) {
667     return failure();
668   }
669   if (then_stats.empty() && else_stats.empty()) return success();
670   // Recreate the if op.
671   auto operands = llvm::to_vector<8>(if_op.getOperands());
672   for (int64_t i = 0; i < if_op.getNumOperands() - 1; ++i) {
673     auto grad_it = grads.find(i);
674     auto& stat = (*stats)[operands[i + 1]];
675     if (grad_it == grads.end()) continue;
676     for (const string& source : grad_it->getSecond()) {
677       auto it = stat.grads.find(source);
678       if (it != stat.grads.end()) {
679         operands.push_back(it->second);
680       } else {
681         Value grad_var;
682         if (failed(CreateAndInitializeGradVariable(operands[i + 1].getType(),
683                                                    if_op, &grad_var))) {
684           return failure();
685         }
686         stat.grads[source] = grad_var;
687         operands.push_back(grad_var);
688         (*stats)[grad_var].accumulate_on_write = true;
689       }
690     }
691   }
692   OpBuilder builder(if_op);
693   auto new_if = builder.create<TF::IfOp>(
694       if_op.getLoc(), then_branch.getFunctionType().getResults(), operands,
695       if_op->getAttrs());
696   auto ret_forwards_input = [](func::FuncOp f, int64_t ret_ind) -> int64_t {
697     auto retval = f.front().getTerminator()->getOperand(ret_ind);
698     auto arg = retval.dyn_cast<BlockArgument>();
699     if (!arg) return -1;
700     return arg.getArgNumber();
701   };
702   for (int64_t i = 0; i < if_op.getNumResults(); ++i) {
703     if (!getElementTypeOrSelf(if_op.getResult(i).getType())
704              .isa<TF::ResourceType>()) {
705       if_op.getResult(i).replaceAllUsesWith(new_if.getResult(i));
706       continue;
707     }
708     int64_t then_forward_input = ret_forwards_input(then_branch, i);
709     int64_t else_foward_input = ret_forwards_input(else_branch, i);
710     if (then_forward_input != else_foward_input || then_forward_input < 0) {
711       return if_op.emitOpError(
712           "branches do not forward the same input resource");
713     }
714     if_op.getResult(i).replaceAllUsesWith(
715         if_op.getOperand(then_forward_input + 1));
716   }
717   if_op.erase();
718   return success();
719 }
720 
721 template <typename CallOp>
HandlePartitionedCallOp(CallOp call,func::FuncOp callee,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats,llvm::StringMap<PartitionedCallTensorArrayOpsInfo> * decomposed_partitioned_call_callees)722 LogicalResult HandlePartitionedCallOp(
723     CallOp call, func::FuncOp callee, ModuleOp module,
724     llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
725     llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
726         decomposed_partitioned_call_callees) {
727   auto emplace_res = decomposed_partitioned_call_callees->try_emplace(
728       callee.getName(), PartitionedCallTensorArrayOpsInfo());
729   auto& info = emplace_res.first->second;
730   // Recreates the call op with info.
731   auto recreate_caller = [&]() -> LogicalResult {
732     auto new_operands = llvm::to_vector<8>(call.getOperands());
733     for (const auto& entry : info.arg_grads) {
734       auto it = stats->find(call.getOperand(entry.first));
735       if (it == stats->end()) return call.emitOpError("unknown tensor array");
736       for (const string& source : entry.second) {
737         auto grad_it = it->getSecond().grads.find(source);
738         if (grad_it != it->getSecond().grads.end()) {
739           new_operands.push_back(grad_it->second);
740         } else {
741           Value grad_var;
742           if (failed(CreateAndInitializeGradVariable(it->getFirst().getType(),
743                                                      call, &grad_var))) {
744             return failure();
745           }
746           it->getSecond().grads[source] = grad_var;
747           new_operands.push_back(grad_var);
748         }
749       }
750     }
751     OpBuilder builder(call);
752     auto new_call = builder.create<CallOp>(
753         call.getLoc(), info.decomposed_callee.getFunctionType().getResults(),
754         new_operands, call->getAttrs());
755     new_call->setAttr(
756         "f", SymbolRefAttr::get(
757                  builder.getContext(),
758                  const_cast<func::FuncOp&>(info.decomposed_callee).getName()));
759     for (const auto& entry : info.ret_forward_input) {
760       call.getResult(entry.first)
761           .replaceAllUsesWith(call.getOperand(entry.second));
762     }
763     call.replaceAllUsesWith(new_call);
764     call.erase();
765     return success();
766   };
767   if (!emplace_res.second) {
768     // This callee was handled before.
769     if (!info.signature_change) return success();
770     return recreate_caller();
771   }
772   // Rewrite the callee.
773   info.signature_change = false;
774   auto ta_arg_buffer_type = [&](int64_t index) -> Type {
775     auto it = stats->find(call.getOperand(index));
776     if (it == stats->end()) return nullptr;
777     info.signature_change = true;
778     return it->getFirst().getType();
779   };
780   auto ta_accumulate_on_write = [&](int64_t index) {
781     auto it = stats->find(call.getOperand(index));
782     if (it == stats->end()) return false;
783     return it->getSecond().accumulate_on_write;
784   };
785   func::FuncOp lowered_callee = callee;
786   if (!callee.isPrivate()) {
787     // Clone non-private callee in case of signature change.
788     lowered_callee = callee.clone();
789     lowered_callee.setPrivate();
790   }
791   auto grads = AccessedGradients({lowered_callee}, module);
792   for (int64_t i = 0; i < lowered_callee.getNumArguments(); ++i) {
793     auto it = grads.find(i);
794     if (it == grads.end()) continue;
795     info.arg_grads.emplace_back(i, it->getSecond());
796   }
797   llvm::SmallDenseMap<Value, TensorArrayStats> callee_stats;
798   ChangeFunctionInputSignature(lowered_callee, grads, ta_arg_buffer_type,
799                                ta_accumulate_on_write, &callee_stats);
800   if (failed(DecomposeTensorArrayOps(&lowered_callee.front(), module,
801                                      &callee_stats,
802                                      decomposed_partitioned_call_callees))) {
803     return failure();
804   }
805   for (int64_t i = 0; i < call.getNumResults(); ++i) {
806     auto ret = lowered_callee.front().getTerminator()->getOperand(i);
807     if (!getElementTypeOrSelf(ret.getType()).isa<TF::ResourceType>()) continue;
808     auto arg = ret.dyn_cast<BlockArgument>();
809     if (!arg) continue;
810     info.ret_forward_input.emplace_back(i, arg.getArgNumber());
811   }
812 
813   info.decomposed_callee = lowered_callee;
814   if (lowered_callee != callee) {
815     if (!info.signature_change) {
816       // Signature is not modified. We do not need to keep two copies.
817       lowered_callee.setName(
818           StringAttr::get(callee->getContext(), callee.getName()));
819       callee.erase();
820     } else {
821       // Add the clone with a new name.
822       lowered_callee.setName(StringAttr::get(
823           callee->getContext(),
824           llvm::formatv("{0}_tensorarray_decomposed", callee.getName()).str()));
825     }
826     SymbolTable(module).insert(lowered_callee);
827   }
828   if (info.signature_change) return recreate_caller();
829   return success();
830 }
831 
HandleRegionControlFlowOps(Operation & op,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats,llvm::StringMap<PartitionedCallTensorArrayOpsInfo> * decomposed_partitioned_call_callees)832 LogicalResult HandleRegionControlFlowOps(
833     Operation& op, ModuleOp module,
834     llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
835     llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
836         decomposed_partitioned_call_callees) {
837   for (OpOperand& operand : op.getOpOperands()) {
838     if (getElementTypeOrSelf(operand.get().getType()).isa<TF::ResourceType>()) {
839       return op.emitOpError()
840              << "found unexpected type " << operand.get().getType()
841              << " of operand #" << operand.getOperandNumber()
842              << ", resource type operands are expected to have been "
843                 "canonicalized away for region based control flow ops";
844     }
845   }
846   for (OpResult result : op.getResults()) {
847     if (getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) {
848       return op.emitOpError()
849              << "found unexpected type " << result.getType() << " of result #"
850              << result.getResultNumber()
851              << ", resource type results are expected to have been "
852                 "canonicalized away for region based control flow ops";
853     }
854   }
855 
856   for (Region& region : op.getRegions()) {
857     if (failed(DecomposeTensorArrayOps(&region.front(), module, stats,
858                                        decomposed_partitioned_call_callees)))
859       return failure();
860   }
861   return success();
862 }
863 
DecomposeTensorArrayOps(Block * block,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats,llvm::StringMap<PartitionedCallTensorArrayOpsInfo> * decomposed_partitioned_call_callees)864 LogicalResult DecomposeTensorArrayOps(
865     Block* block, ModuleOp module,
866     llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
867     llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
868         decomposed_partitioned_call_callees) {
869   for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
870     if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
871       op.replaceAllUsesWith(op.getOperands());
872       op.erase();
873     } else if (auto ta = llvm::dyn_cast<TF::TensorArrayV3Op>(&op)) {
874       if (failed(HandleTensorArrayV3Op(ta, module, stats))) {
875         return failure();
876       }
877     } else if (auto read = llvm::dyn_cast<TF::TensorArrayReadV3Op>(&op)) {
878       if (failed(HandleTensorArrayReadV3Op(read, *stats))) return failure();
879     } else if (auto write = llvm::dyn_cast<TF::TensorArrayWriteV3Op>(&op)) {
880       if (failed(HandleTensorArrayWriteV3Op(write, *stats))) return failure();
881     } else if (auto concat = llvm::dyn_cast<TF::TensorArrayConcatV3Op>(&op)) {
882       if (failed(HandleTensorArrayConcatV3Op(concat, *stats))) return failure();
883     } else if (auto split = llvm::dyn_cast<TF::TensorArraySplitV3Op>(&op)) {
884       if (failed(HandleTensorArraySplitV3Op(split, *stats))) return failure();
885     } else if (auto size = llvm::dyn_cast<TF::TensorArraySizeV3Op>(&op)) {
886       if (failed(HandleTensorArraySizeV3Op(size, *stats))) return failure();
887     } else if (auto grad = llvm::dyn_cast<TF::TensorArrayGradV3Op>(&op)) {
888       if (failed(HandleTensorArrayGradV3Op(grad, stats))) return failure();
889     } else if (auto gather = llvm::dyn_cast<TF::TensorArrayGatherV3Op>(&op)) {
890       if (failed(HandleTensorArrayGatherV3Op(gather, *stats))) return failure();
891     } else if (auto scatter = llvm::dyn_cast<TF::TensorArrayScatterV3Op>(&op)) {
892       if (failed(HandleTensorArrayScatterV3Op(scatter, *stats))) {
893         return failure();
894       }
895     } else if (auto close = llvm::dyn_cast<TF::TensorArrayCloseV3Op>(&op)) {
896       close.erase();
897     } else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
898       if (failed(HandleWhileOp(while_op, module, stats,
899                                decomposed_partitioned_call_callees))) {
900         return failure();
901       }
902     } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
903       if (failed(HandleIfOp(if_op, module, stats,
904                             decomposed_partitioned_call_callees))) {
905         return failure();
906       }
907     } else if (llvm::isa<TF::CaseRegionOp>(op) ||
908                llvm::isa<TF::IfRegionOp>(op) ||
909                llvm::isa<TF::WhileRegionOp>(op)) {
910       if (failed(HandleRegionControlFlowOps(
911               op, module, stats, decomposed_partitioned_call_callees)))
912         return failure();
913     } else if (auto pcall = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
914       auto callee = pcall.func();
915       if (!callee)
916         return pcall.emitOpError(
917             "TensorArray decomposition does not support call with nested "
918             "references.");
919 
920       if (failed(
921               HandlePartitionedCallOp(pcall, callee, module, stats,
922                                       decomposed_partitioned_call_callees))) {
923         return failure();
924       }
925     } else if (auto spcall =
926                    llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
927       if (failed(
928               HandlePartitionedCallOp(spcall, spcall.func(), module, stats,
929                                       decomposed_partitioned_call_callees))) {
930         return failure();
931       }
932     }
933   }
934   return success();
935 }
936 
runOnOperation()937 void TensorArrayOpsDecompositionPass::runOnOperation() {
938   auto module = getOperation();
939   auto main = module.lookupSymbol<func::FuncOp>("main");
940   if (!main) return;
941   llvm::SmallDenseMap<Value, TensorArrayStats> stats;
942   llvm::StringMap<PartitionedCallTensorArrayOpsInfo>
943       decomposed_partitioned_call_callees;
944   if (failed(DecomposeTensorArrayOps(&main.front(), module, &stats,
945                                      &decomposed_partitioned_call_callees))) {
946     signalPassFailure();
947   }
948 }
949 
950 }  // namespace
951 
952 namespace TF {
953 
954 std::unique_ptr<OperationPass<ModuleOp>>
CreateTensorArrayOpsDecompositionPass()955 CreateTensorArrayOpsDecompositionPass() {
956   return std::make_unique<TensorArrayOpsDecompositionPass>();
957 }
958 
959 }  // namespace TF
960 }  // namespace mlir
961