1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/None.h"
18 #include "llvm/ADT/Optional.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/StringMap.h"
23 #include "llvm/ADT/StringSet.h"
24 #include "llvm/Support/Casting.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
27 #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
28 #include "mlir/IR/Attributes.h" // from @llvm-project
29 #include "mlir/IR/Builders.h" // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
32 #include "mlir/IR/Location.h" // from @llvm-project
33 #include "mlir/IR/MLIRContext.h" // from @llvm-project
34 #include "mlir/IR/SymbolTable.h" // from @llvm-project
35 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
36 #include "mlir/IR/Types.h" // from @llvm-project
37 #include "mlir/IR/Value.h" // from @llvm-project
38 #include "mlir/Pass/Pass.h" // from @llvm-project
39 #include "mlir/Support/LLVM.h" // from @llvm-project
40 #include "mlir/Support/LogicalResult.h" // from @llvm-project
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
43 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
44 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
47 #include "tensorflow/core/framework/tensor.h"
48 #include "tensorflow/core/framework/tensor_shape.h"
49 #include "tensorflow/core/framework/tensor_shape.pb.h"
50 #include "tensorflow/core/framework/types.pb.h"
51 #include "tensorflow/core/platform/types.h"
52
53 namespace mlir {
54 namespace {
55
56 namespace cutil = TF::collection_ops_util;
57
58 using std::string;
59
60 // A pass that converts tensor array operations to tensor operations and
61 // read/assign ops on local variables. A later resource lifting pass can further
62 // remove the local variables.
63 //
64 // This pass requires that the full shape of the tensor array can be inferred:
65 // 1) the size needs to be a constant, 2) it specifies the full element shape,
66 // or that can be inferred from a later write, and 3) all elements have the same
67 // shape.
68 //
69 struct TensorArrayOpsDecompositionPass
70 : public TF::TensorArrayOpsDecompositionPassBase<
71 TensorArrayOpsDecompositionPass> {
72 void runOnOperation() override;
73 };
74
75 // Infers the element type and count for a TensorArraySplitV3Op. Requires
76 // constant lengths and static shape on the input value.
GetSplitElementTypeAndCount(TF::TensorArraySplitV3Op split,RankedTensorType * elem_type,int64_t * count)77 LogicalResult GetSplitElementTypeAndCount(TF::TensorArraySplitV3Op split,
78 RankedTensorType* elem_type,
79 int64_t* count) {
80 auto lengths_const =
81 llvm::dyn_cast_or_null<TF::ConstOp>(split.lengths().getDefiningOp());
82 if (!lengths_const) return split.emitOpError("non-constant split lengths");
83 *count = lengths_const.value().getNumElements();
84 if (*count <= 0) return split.emitOpError("non-positive split count");
85 auto buffer_type = split.value().getType().dyn_cast<RankedTensorType>();
86 if (!buffer_type || !buffer_type.hasStaticShape() ||
87 buffer_type.getRank() < 1) {
88 return split.emitOpError("unknown or invalid split tensor shape");
89 }
90 int64_t length = buffer_type.getDimSize(0) / *count;
91 for (const auto& len : lengths_const.value().getValues<APInt>()) {
92 if (length == len.getSExtValue()) continue;
93 return split.emitOpError("different split lengths are not supported");
94 }
95 llvm::SmallVector<int64_t, 8> elem_shape;
96 elem_shape.push_back(length);
97 for (int64_t dim : buffer_type.getShape().drop_front()) {
98 elem_shape.push_back(dim);
99 }
100 *elem_type = RankedTensorType::get(elem_shape, buffer_type.getElementType());
101 return success();
102 }
103
104 // Tries to infer the tensor array element shape.
GetTensorArrayElementShape(TF::TensorArrayV3Op ta,ModuleOp module)105 llvm::Optional<llvm::SmallVector<int64_t, 8>> GetTensorArrayElementShape(
106 TF::TensorArrayV3Op ta, ModuleOp module) {
107 auto element_shape = ta.element_shapeAttr().cast<mlir::TF::ShapeAttr>();
108 if (element_shape.hasStaticShape()) {
109 auto shape = element_shape.getShape();
110 // Convert int64 to int64_t.
111 llvm::SmallVector<int64_t, 8> dims(shape.begin(), shape.end());
112 return dims;
113 }
114
115 bool has_failure = false;
116 auto elem_type = cutil::GetElementTypeFromAccess(
117 ta.handle(), module, [&](Operation* user) -> llvm::Optional<Type> {
118 if (has_failure) return llvm::None;
119 if (auto write = llvm::dyn_cast<TF::TensorArrayWriteV3Op>(user)) {
120 return write.value().getType();
121 } else if (auto split =
122 llvm::dyn_cast<TF::TensorArraySplitV3Op>(user)) {
123 if (!split.lengths().getDefiningOp() ||
124 !llvm::isa<TF::ConstOp>(split.lengths().getDefiningOp())) {
125 return llvm::None;
126 }
127 RankedTensorType t;
128 int64_t count;
129 if (failed(GetSplitElementTypeAndCount(split, &t, &count))) {
130 has_failure = true;
131 return llvm::None;
132 }
133 return t;
134 } else if (auto scatter =
135 llvm::dyn_cast<TF::TensorArrayScatterV3Op>(user)) {
136 // TensorArrayScatter writes vector of tensors to TensorArray. We can
137 // deduce the shape of TensorArray by dropping the 0th dim of
138 // TensorArrayScatter `value`.
139 auto t = scatter.value().getType().dyn_cast<RankedTensorType>();
140 if (!t || t.getShape().empty()) return llvm::None;
141 return RankedTensorType::get(t.getShape().drop_front(),
142 t.getElementType());
143 } else if (auto gather =
144 llvm::dyn_cast<TF::TensorArrayGatherV3Op>(user)) {
145 // Try to infer from result type of gather.
146 auto t = gather.value().getType().dyn_cast<RankedTensorType>();
147 if (t && !t.getShape().empty())
148 return RankedTensorType::get(t.getShape().drop_front(),
149 t.getElementType());
150 // Try to infer from `element_shape` attribute of gather.
151 auto element_shape = gather.element_shapeAttr()
152 .dyn_cast_or_null<mlir::TF::ShapeAttr>();
153 if (element_shape && element_shape.hasStaticShape()) {
154 return RankedTensorType::get(element_shape.getShape(),
155 gather.dtype());
156 }
157 }
158 return llvm::None;
159 });
160 if (!elem_type) return llvm::None;
161 return llvm::to_vector<8>(elem_type->getShape());
162 }
163
ReplaceAllUsesWithCast(Value old_val,Value new_val)164 void ReplaceAllUsesWithCast(Value old_val, Value new_val) {
165 if (old_val.use_empty()) return;
166 auto cast_op =
167 OpBuilder(old_val.getDefiningOp())
168 .create<tensor::CastOp>(old_val.getLoc(), old_val.getType(), new_val);
169 old_val.replaceAllUsesWith(cast_op);
170 }
171
ReplaceAllUsesExceptTerminator(Value old_val,Value new_val)172 void ReplaceAllUsesExceptTerminator(Value old_val, Value new_val) {
173 if (old_val.getType() == new_val.getType()) {
174 old_val.replaceAllUsesWith(new_val);
175 return;
176 }
177 Operation* old_op = old_val.getDefiningOp();
178 Operation* terminator_op =
179 old_op->getParentOfType<func::FuncOp>().front().getTerminator();
180 llvm::SmallPtrSet<mlir::Operation*, 1> exceptions = {terminator_op};
181 old_val.replaceAllUsesExcept(new_val, exceptions);
182 }
183
184 struct TensorArrayStats {
185 // Whether a write op should accumulate with the old value. Set to true if
186 // this is a gradient.
187 bool accumulate_on_write;
188 // Maps from a gradient source string to the local variable to the gradient.
189 llvm::StringMap<Value> grads;
190 };
191
HandleTensorArrayV3Op(TF::TensorArrayV3Op ta,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats)192 LogicalResult HandleTensorArrayV3Op(
193 TF::TensorArrayV3Op ta, ModuleOp module,
194 llvm::SmallDenseMap<Value, TensorArrayStats>* stats) {
195 auto elem_shape = GetTensorArrayElementShape(ta, module);
196 if (!elem_shape) return ta.emitOpError("unknown element shape");
197 if (ta.dynamic_size()) {
198 return ta.emitOpError("dynamic tensor array size is unsupported");
199 }
200 Value buffer;
201 OpBuilder builder(ta);
202 if (failed(cutil::CreateInitBufferValue(*elem_shape, ta.size(), ta,
203 ta.dtype(), builder, &buffer))) {
204 return failure();
205 }
206 auto var_type = RankedTensorType::get(
207 {}, TF::ResourceType::get(
208 ArrayRef<TensorType>{buffer.getType().cast<TensorType>()},
209 ta.getContext()));
210 auto local_var = builder.create<TF::MlirLocalVarOp>(
211 ta.getLoc(), ArrayRef<Type>{var_type}, ArrayRef<Value>{});
212 cutil::WriteLocalVariable(local_var, buffer, builder, ta.getLoc());
213 ta.handle().replaceAllUsesWith(local_var);
214 // The flow output is just a way for the front end to enforce ordering among
215 // tensor array ops, but in the MLIR TF dialect they have sequential ordering.
216 // Just create a constant to replace its uses.
217 tensorflow::Tensor scalar_tensor(tensorflow::DT_FLOAT, {});
218 scalar_tensor.scalar<float>()() = 0.0f;
219 auto flow = builder.create<TF::ConstOp>(
220 ta.getLoc(),
221 tensorflow::ConvertTensor(scalar_tensor, &builder).ValueOrDie());
222 ta.flow().replaceAllUsesWith(flow);
223 ta.erase();
224 (*stats)[local_var].accumulate_on_write = false;
225 return success();
226 }
227
HandleTensorArrayReadV3Op(TF::TensorArrayReadV3Op read,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)228 LogicalResult HandleTensorArrayReadV3Op(
229 TF::TensorArrayReadV3Op read,
230 const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
231 auto local_var = read.handle();
232 if (stats.count(local_var) == 0) {
233 return read.emitOpError("unknown tensor array");
234 }
235 OpBuilder builder(read);
236 auto buffer = cutil::ReadLocalVariable(local_var, builder, read.getLoc());
237 auto index_reshape =
238 cutil::ReshapeScalarToSizeType(builder, read.index(), read.getLoc());
239 auto elem = cutil::GetElement(index_reshape, buffer, builder, read.getLoc());
240 ReplaceAllUsesExceptTerminator(read.value(), elem);
241 ReplaceAllUsesWithCast(read.value(), elem);
242 read.erase();
243 // The clear_after_read attribute does not mean setting the tensor to 0 after
244 // read; instead it does not allow a second read before the next write. We
245 // follow the old bridge's implementation not to do anything here.
246 return success();
247 }
248
HandleTensorArrayWriteV3Op(TF::TensorArrayWriteV3Op write,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)249 LogicalResult HandleTensorArrayWriteV3Op(
250 TF::TensorArrayWriteV3Op write,
251 const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
252 auto local_var = write.handle();
253 auto stat_it = stats.find(local_var);
254 if (stat_it == stats.end()) return write.emitOpError("unknown tensor array");
255 OpBuilder builder(write);
256 auto buffer = cutil::ReadLocalVariable(local_var, builder, write.getLoc());
257 auto index_reshape =
258 cutil::ReshapeScalarToSizeType(builder, write.index(), write.getLoc());
259 auto elem = write.value();
260 if (stat_it->getSecond().accumulate_on_write) {
261 // Get the old slice, and accumulate with it. We set keep_slice_shape
262 // (keeping the leading size-1 dimension) because it avoids reshape back and
263 // forth.
264 auto original_elem =
265 cutil::GetElement(index_reshape, buffer, builder, write.getLoc(),
266 /*keep_slice_shape=*/true);
267 // Add a size-1 leading dimension to elem.
268 auto slice_type = original_elem.getType().cast<RankedTensorType>();
269 elem = builder.create<TF::ReshapeOp>(
270 write.getLoc(), ArrayRef<Type>{slice_type},
271 ArrayRef<Value>{elem, cutil::GetR1Const(slice_type.getShape(), builder,
272 write.getLoc())});
273 elem =
274 cutil::AccumulateBuffers(elem, original_elem, builder, write.getLoc());
275 }
276 buffer =
277 cutil::SetElement(index_reshape, buffer, elem, builder, write.getLoc());
278 cutil::WriteLocalVariable(local_var, buffer, builder, write.getLoc());
279 write.flow_out().replaceAllUsesWith(write.flow_in());
280 write.erase();
281 return success();
282 }
283
HandleTensorArrayConcatV3Op(TF::TensorArrayConcatV3Op concat,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)284 LogicalResult HandleTensorArrayConcatV3Op(
285 TF::TensorArrayConcatV3Op concat,
286 const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
287 auto local_var = concat.handle();
288 if (stats.count(local_var) == 0) {
289 return concat.emitOpError("unknown tensor array");
290 }
291 OpBuilder builder(concat);
292 auto buffer = cutil::ReadLocalVariable(local_var, builder, concat.getLoc());
293 auto buffer_type = buffer.getType().cast<RankedTensorType>();
294 if (buffer_type.getShape().size() <= 1) {
295 return concat.emitOpError("cannot concat on scalar-element tensor array");
296 }
297 // Merget he first two dimensions.
298 auto shape = llvm::to_vector<8>(buffer_type.getShape().drop_front());
299 shape[0] *= buffer_type.getDimSize(0);
300 buffer = builder.create<TF::ReshapeOp>(
301 concat.getLoc(),
302 ArrayRef<Type>{
303 RankedTensorType::get(shape, buffer_type.getElementType())},
304 ArrayRef<Value>{buffer,
305 cutil::GetR1Const(shape, builder, concat.getLoc())});
306 ReplaceAllUsesExceptTerminator(concat.value(), buffer);
307 ReplaceAllUsesWithCast(concat.value(), buffer);
308
309 // Create the lengths as a list of the same value (element size).
310 tensorflow::Tensor lengths_tensor(tensorflow::DT_INT64,
311 {buffer_type.getDimSize(0)});
312 for (int64_t i = 0; i < buffer_type.getDimSize(0); ++i) {
313 lengths_tensor.vec<int64_t>()(i) = buffer_type.getDimSize(1);
314 }
315 concat.lengths().replaceAllUsesWith(builder.create<TF::ConstOp>(
316 concat.getLoc(),
317 tensorflow::ConvertTensor(lengths_tensor, &builder).ValueOrDie()));
318 concat.erase();
319 return success();
320 }
321
HandleTensorArraySplitV3Op(TF::TensorArraySplitV3Op split,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)322 LogicalResult HandleTensorArraySplitV3Op(
323 TF::TensorArraySplitV3Op split,
324 const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
325 auto local_var = split.handle();
326 if (stats.count(local_var) == 0) {
327 return split.emitOpError("unknown tensor array");
328 }
329 OpBuilder builder(split);
330 int64_t count;
331 RankedTensorType elem_type;
332 if (failed(GetSplitElementTypeAndCount(split, &elem_type, &count))) {
333 return failure();
334 }
335 llvm::SmallVector<int64_t, 8> buffer_shape;
336 buffer_shape.push_back(count);
337 for (int64_t dim : elem_type.getShape()) buffer_shape.push_back(dim);
338 // Reshape the input to match the buffer of the tensor array.
339 auto buffer = builder
340 .create<TF::ReshapeOp>(
341 split.getLoc(),
342 ArrayRef<Type>{RankedTensorType::get(
343 buffer_shape, elem_type.getElementType())},
344 ArrayRef<Value>{split.value(),
345 cutil::GetR1Const(buffer_shape, builder,
346 split.getLoc())})
347 .output();
348 // Accumulate with the old buffer.
349 auto old_buffer =
350 cutil::ReadLocalVariable(local_var, builder, split.getLoc());
351 buffer =
352 cutil::AccumulateBuffers(old_buffer, buffer, builder, split.getLoc());
353 cutil::WriteLocalVariable(local_var, buffer, builder, split.getLoc());
354 split.flow_out().replaceAllUsesWith(split.flow_in());
355 split.erase();
356 return success();
357 }
358
HandleTensorArraySizeV3Op(TF::TensorArraySizeV3Op size,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)359 LogicalResult HandleTensorArraySizeV3Op(
360 TF::TensorArraySizeV3Op size,
361 const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
362 auto local_var = size.handle();
363 if (stats.count(local_var) == 0) {
364 return size.emitOpError("unknown tensor array");
365 }
366 auto buffer_type = getElementTypeOrSelf(local_var.getType())
367 .cast<TF::ResourceType>()
368 .getSubtypes()[0]
369 .cast<RankedTensorType>();
370 OpBuilder builder(size);
371 auto result = cutil::CreateScalarConst(buffer_type.getDimSize(0), builder,
372 size.getLoc());
373 size.size().replaceAllUsesWith(result);
374 size.erase();
375 return success();
376 }
377
CreateAndInitializeGradVariable(Type local_var_type,Operation * op,Value * var)378 LogicalResult CreateAndInitializeGradVariable(Type local_var_type,
379 Operation* op, Value* var) {
380 OpBuilder builder(op);
381 *var = builder.create<TF::MlirLocalVarOp>(
382 op->getLoc(), ArrayRef<Type>{local_var_type}, ArrayRef<Value>{});
383 Value buffer;
384 auto buffer_type = getElementTypeOrSelf(local_var_type)
385 .cast<TF::ResourceType>()
386 .getSubtypes()[0]
387 .cast<RankedTensorType>();
388 if (failed(cutil::CreateInitBufferValue(
389 buffer_type.getShape().drop_front(), buffer_type.getDimSize(0), op,
390 buffer_type.getElementType(), builder, &buffer))) {
391 return failure();
392 }
393 cutil::WriteLocalVariable(*var, buffer, builder, op->getLoc());
394 return success();
395 }
396
HandleTensorArrayGradV3Op(TF::TensorArrayGradV3Op grad,llvm::SmallDenseMap<Value,TensorArrayStats> * stats)397 LogicalResult HandleTensorArrayGradV3Op(
398 TF::TensorArrayGradV3Op grad,
399 llvm::SmallDenseMap<Value, TensorArrayStats>* stats) {
400 auto local_var = grad.handle();
401 OpBuilder builder(grad);
402 Value grad_var;
403 auto sit = stats->find(local_var);
404 if (sit == stats->end()) return grad.emitOpError("unknown tensor array");
405 auto emplace_res =
406 sit->getSecond().grads.try_emplace(grad.source().str(), Value());
407 if (!emplace_res.second) {
408 // If the source has been assigned a grad, use it.
409 grad_var = emplace_res.first->second;
410 } else {
411 if (failed(CreateAndInitializeGradVariable(local_var.getType(), grad,
412 &grad_var))) {
413 return failure();
414 }
415 emplace_res.first->second = grad_var;
416 // Write to a grad accumulates with previous writes.
417 (*stats)[grad_var].accumulate_on_write = true;
418 }
419 grad.flow_out().replaceAllUsesWith(grad.flow_in());
420 grad.grad_handle().replaceAllUsesWith(grad_var);
421 grad.erase();
422 return success();
423 }
424
HandleTensorArrayGatherV3Op(TF::TensorArrayGatherV3Op gather,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)425 LogicalResult HandleTensorArrayGatherV3Op(
426 TF::TensorArrayGatherV3Op gather,
427 const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
428 auto local_var = gather.handle();
429 if (stats.count(local_var) == 0) {
430 return gather.emitOpError("unknown tensor array");
431 }
432 OpBuilder builder(gather);
433 auto buffer = cutil::ReadLocalVariable(local_var, builder, gather.getLoc());
434 auto result =
435 cutil::GatherElements(gather.indices(), buffer, builder, gather.getLoc());
436 ReplaceAllUsesExceptTerminator(gather.value(), result);
437 ReplaceAllUsesWithCast(gather.value(), result);
438 gather.erase();
439 return success();
440 }
441
HandleTensorArrayScatterV3Op(TF::TensorArrayScatterV3Op scatter,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)442 LogicalResult HandleTensorArrayScatterV3Op(
443 TF::TensorArrayScatterV3Op scatter,
444 const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
445 auto local_var = scatter.handle();
446 if (stats.count(local_var) == 0) {
447 return scatter.emitOpError("unknown tensor array");
448 }
449 OpBuilder builder(scatter);
450 auto buffer = cutil::ReadLocalVariable(local_var, builder, scatter.getLoc());
451 buffer = cutil::ScatterAccumulateElements(scatter.indices(), scatter.value(),
452 buffer, builder, scatter.getLoc());
453 cutil::WriteLocalVariable(local_var, buffer, builder, scatter.getLoc());
454 scatter.flow_out().replaceAllUsesWith(scatter.flow_in());
455 scatter.erase();
456 return success();
457 }
458
459 // Updates func's type according to its current arguments and return values.
UpdateFuncType(func::FuncOp func)460 void UpdateFuncType(func::FuncOp func) {
461 llvm::SmallVector<Type, 8> arg_types;
462 for (auto arg : func.getArguments()) arg_types.push_back(arg.getType());
463 func.setType(
464 FunctionType::get(func.getContext(), arg_types,
465 func.front().getTerminator()->getOperandTypes()));
466 }
467
468 // Finds the accessed gradient sources for each tensor array argument.
AccessedGradients(ArrayRef<func::FuncOp> funcs,ModuleOp module)469 llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>> AccessedGradients(
470 ArrayRef<func::FuncOp> funcs, ModuleOp module) {
471 llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>> result;
472 llvm::SmallDenseMap<int64_t, llvm::StringSet<>> result_sets;
473 auto insert = [&](Value v, const string& source, const Block& func_block) {
474 auto arg = v.dyn_cast<BlockArgument>();
475 if (!arg || arg.getOwner() != &func_block) return;
476 auto insert_res = result_sets[arg.getArgNumber()].insert(source);
477 if (!insert_res.second) return;
478 result[arg.getArgNumber()].push_back(source);
479 };
480 for (func::FuncOp func : funcs) {
481 const Block& func_block = func.front();
482 // Walk all operations and nested regions to find accessed gradient sources
483 // for function arguments.
484 func.walk([&](Operation* op) {
485 if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(op)) {
486 op->replaceAllUsesWith(op->getOperands());
487 return;
488 }
489 if (auto grad = llvm::dyn_cast<TF::TensorArrayGradV3Op>(op)) {
490 insert(grad.handle(), grad.source().str(), func_block);
491 } else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(op)) {
492 for (const auto& entry : AccessedGradients(
493 {while_op.body_function(), while_op.cond_function()}, module))
494 for (const string& source : entry.getSecond())
495 insert(while_op.getOperand(entry.getFirst()), source, func_block);
496 } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(op)) {
497 for (const auto& entry : AccessedGradients(
498 {if_op.then_function(), if_op.else_function()}, module))
499 for (const string& source : entry.getSecond())
500 insert(if_op.getOperand(entry.getFirst() + 1), source, func_block);
501 } else if (auto call = llvm::dyn_cast<CallOpInterface>(op)) {
502 auto callee = dyn_cast<func::FuncOp>(call.resolveCallable());
503 for (const auto& entry : AccessedGradients({callee}, module))
504 for (const string& source : entry.getSecond())
505 insert(call.getArgOperands()[entry.getFirst()], source, func_block);
506 }
507 });
508 }
509 return result;
510 }
511
512 // Contains cached information for decomposed callee functions for (stateful)
513 // partitioned call ops.
514 struct PartitionedCallTensorArrayOpsInfo {
515 bool signature_change;
516 func::FuncOp decomposed_callee;
517 llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<string, 4>>, 4>
518 arg_grads;
519 llvm::SmallVector<std::pair<int64_t, int64_t>, 4> ret_forward_input;
520 };
521
522 // Updates a called function's input signature by adjusting resource types, and
523 // adding required gradient arguments.
ChangeFunctionInputSignature(func::FuncOp func,const llvm::SmallDenseMap<int64_t,llvm::SmallVector<string,4>> & grads,llvm::function_ref<Type (int64_t)> ta_arg_buffer_type,llvm::function_ref<bool (int64_t)> ta_accumulate_on_write,llvm::SmallDenseMap<Value,TensorArrayStats> * stats)524 void ChangeFunctionInputSignature(
525 func::FuncOp func,
526 const llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>>& grads,
527 llvm::function_ref<Type(int64_t)> ta_arg_buffer_type,
528 llvm::function_ref<bool(int64_t)> ta_accumulate_on_write,
529 llvm::SmallDenseMap<Value, TensorArrayStats>* stats) {
530 int64_t original_args = func.getNumArguments();
531 for (int64_t argnum = 0; argnum < original_args; ++argnum) {
532 auto arg = func.getArgument(argnum);
533 Type t = ta_arg_buffer_type(argnum);
534 if (!t) continue;
535 arg.setType(t);
536 auto grad_it = grads.find(argnum);
537 if (grad_it == grads.end()) continue;
538 llvm::StringMap<Value> grads_map;
539 for (const string& source : grad_it->getSecond()) {
540 auto g = func.front().addArgument(t, func.getLoc());
541 (*stats)[g].accumulate_on_write = true;
542 grads_map[source] = g;
543 }
544 auto& stat = (*stats)[arg];
545 stat.accumulate_on_write = ta_accumulate_on_write(argnum);
546 stat.grads = std::move(grads_map);
547 }
548 UpdateFuncType(func);
549 }
550
551 LogicalResult DecomposeTensorArrayOps(
552 Block*, ModuleOp, llvm::SmallDenseMap<Value, TensorArrayStats>*,
553 llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*);
554
HandleWhileOp(TF::WhileOp while_op,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats,llvm::StringMap<PartitionedCallTensorArrayOpsInfo> * decomposed_partitioned_call_callees)555 LogicalResult HandleWhileOp(TF::WhileOp while_op, ModuleOp module,
556 llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
557 llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
558 decomposed_partitioned_call_callees) {
559 auto body = while_op.body_function();
560 auto cond = while_op.cond_function();
561 auto grads = AccessedGradients({body, cond}, module);
562 auto ta_arg_buffer_type = [&](int64_t index) -> Type {
563 auto it = stats->find(while_op.getOperand(index));
564 if (it == stats->end()) return nullptr;
565 return it->getFirst().getType();
566 };
567 auto ta_accumulate_on_write = [&](int64_t index) {
568 auto it = stats->find(while_op.getOperand(index));
569 if (it == stats->end()) return false;
570 return it->getSecond().accumulate_on_write;
571 };
572 llvm::SmallDenseMap<Value, TensorArrayStats> body_stats;
573 ChangeFunctionInputSignature(body, grads, ta_arg_buffer_type,
574 ta_accumulate_on_write, &body_stats);
575 llvm::SmallDenseMap<Value, TensorArrayStats> cond_stats;
576 ChangeFunctionInputSignature(cond, grads, ta_arg_buffer_type,
577 ta_accumulate_on_write, &cond_stats);
578 if (failed(DecomposeTensorArrayOps(&body.front(), module, &body_stats,
579 decomposed_partitioned_call_callees)) ||
580 failed(DecomposeTensorArrayOps(&cond.front(), module, &cond_stats,
581 decomposed_partitioned_call_callees))) {
582 return failure();
583 }
584 if (body_stats.empty() && cond_stats.empty()) return success();
585 auto old_body_ret = body.front().getTerminator();
586 auto new_retvals = llvm::to_vector<8>(old_body_ret->getOperands());
587 for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
588 if (!ta_arg_buffer_type(i)) continue;
589 auto retval = old_body_ret->getOperand(i);
590 auto arg = retval.dyn_cast<BlockArgument>();
591 if (!arg) {
592 return while_op.emitOpError(
593 "output tensor array does not alias input in a while loop");
594 }
595 for (const string& source : grads[i]) {
596 new_retvals.push_back(body_stats[arg].grads[source]);
597 }
598 }
599 OpBuilder(old_body_ret)
600 .create<func::ReturnOp>(old_body_ret->getLoc(), new_retvals);
601 old_body_ret->erase();
602 UpdateFuncType(body);
603 // Recreate the while op.
604 auto operands = llvm::to_vector<8>(while_op.getOperands());
605 for (int64_t i = 0; i < while_op.getNumOperands(); ++i) {
606 auto grad_it = grads.find(i);
607 auto& stat = (*stats)[operands[i]];
608 if (grad_it == grads.end()) continue;
609 for (const string& source : grad_it->getSecond()) {
610 auto it = stat.grads.find(source);
611 if (it != stat.grads.end()) {
612 operands.push_back(it->second);
613 } else {
614 Value grad_var;
615 if (failed(CreateAndInitializeGradVariable(operands[i].getType(),
616 while_op, &grad_var))) {
617 return failure();
618 }
619 stat.grads[source] = grad_var;
620 operands.push_back(grad_var);
621 (*stats)[grad_var].accumulate_on_write = true;
622 }
623 }
624 }
625 OpBuilder builder(while_op);
626 auto new_while = builder.create<TF::WhileOp>(
627 while_op.getLoc(), body.getFunctionType().getInputs(), operands,
628 while_op->getAttrs());
629 for (int64_t i = 0; i < while_op.getNumOperands(); ++i) {
630 if (ta_arg_buffer_type(i)) {
631 while_op.getResult(i).replaceAllUsesWith(while_op.getOperand(i));
632 } else {
633 while_op.getResult(i).replaceAllUsesWith(new_while.getResult(i));
634 }
635 }
636 while_op.erase();
637 return success();
638 }
639
HandleIfOp(TF::IfOp if_op,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats,llvm::StringMap<PartitionedCallTensorArrayOpsInfo> * decomposed_partitioned_call_callees)640 LogicalResult HandleIfOp(TF::IfOp if_op, ModuleOp module,
641 llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
642 llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
643 decomposed_partitioned_call_callees) {
644 auto then_branch = if_op.then_function();
645 auto else_branch = if_op.else_function();
646 auto grads = AccessedGradients({then_branch, else_branch}, module);
647 auto ta_arg_buffer_type = [&](int64_t index) -> Type {
648 auto it = stats->find(if_op.getOperand(index + 1));
649 if (it == stats->end()) return nullptr;
650 return it->getFirst().getType();
651 };
652 auto ta_accumulate_on_write = [&](int64_t index) {
653 auto it = stats->find(if_op.getOperand(index + 1));
654 if (it == stats->end()) return false;
655 return it->getSecond().accumulate_on_write;
656 };
657 llvm::SmallDenseMap<Value, TensorArrayStats> then_stats;
658 ChangeFunctionInputSignature(then_branch, grads, ta_arg_buffer_type,
659 ta_accumulate_on_write, &then_stats);
660 llvm::SmallDenseMap<Value, TensorArrayStats> else_stats;
661 ChangeFunctionInputSignature(else_branch, grads, ta_arg_buffer_type,
662 ta_accumulate_on_write, &else_stats);
663 if (failed(DecomposeTensorArrayOps(&then_branch.front(), module, &then_stats,
664 decomposed_partitioned_call_callees)) ||
665 failed(DecomposeTensorArrayOps(&else_branch.front(), module, &else_stats,
666 decomposed_partitioned_call_callees))) {
667 return failure();
668 }
669 if (then_stats.empty() && else_stats.empty()) return success();
670 // Recreate the if op.
671 auto operands = llvm::to_vector<8>(if_op.getOperands());
672 for (int64_t i = 0; i < if_op.getNumOperands() - 1; ++i) {
673 auto grad_it = grads.find(i);
674 auto& stat = (*stats)[operands[i + 1]];
675 if (grad_it == grads.end()) continue;
676 for (const string& source : grad_it->getSecond()) {
677 auto it = stat.grads.find(source);
678 if (it != stat.grads.end()) {
679 operands.push_back(it->second);
680 } else {
681 Value grad_var;
682 if (failed(CreateAndInitializeGradVariable(operands[i + 1].getType(),
683 if_op, &grad_var))) {
684 return failure();
685 }
686 stat.grads[source] = grad_var;
687 operands.push_back(grad_var);
688 (*stats)[grad_var].accumulate_on_write = true;
689 }
690 }
691 }
692 OpBuilder builder(if_op);
693 auto new_if = builder.create<TF::IfOp>(
694 if_op.getLoc(), then_branch.getFunctionType().getResults(), operands,
695 if_op->getAttrs());
696 auto ret_forwards_input = [](func::FuncOp f, int64_t ret_ind) -> int64_t {
697 auto retval = f.front().getTerminator()->getOperand(ret_ind);
698 auto arg = retval.dyn_cast<BlockArgument>();
699 if (!arg) return -1;
700 return arg.getArgNumber();
701 };
702 for (int64_t i = 0; i < if_op.getNumResults(); ++i) {
703 if (!getElementTypeOrSelf(if_op.getResult(i).getType())
704 .isa<TF::ResourceType>()) {
705 if_op.getResult(i).replaceAllUsesWith(new_if.getResult(i));
706 continue;
707 }
708 int64_t then_forward_input = ret_forwards_input(then_branch, i);
709 int64_t else_foward_input = ret_forwards_input(else_branch, i);
710 if (then_forward_input != else_foward_input || then_forward_input < 0) {
711 return if_op.emitOpError(
712 "branches do not forward the same input resource");
713 }
714 if_op.getResult(i).replaceAllUsesWith(
715 if_op.getOperand(then_forward_input + 1));
716 }
717 if_op.erase();
718 return success();
719 }
720
721 template <typename CallOp>
HandlePartitionedCallOp(CallOp call,func::FuncOp callee,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats,llvm::StringMap<PartitionedCallTensorArrayOpsInfo> * decomposed_partitioned_call_callees)722 LogicalResult HandlePartitionedCallOp(
723 CallOp call, func::FuncOp callee, ModuleOp module,
724 llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
725 llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
726 decomposed_partitioned_call_callees) {
727 auto emplace_res = decomposed_partitioned_call_callees->try_emplace(
728 callee.getName(), PartitionedCallTensorArrayOpsInfo());
729 auto& info = emplace_res.first->second;
730 // Recreates the call op with info.
731 auto recreate_caller = [&]() -> LogicalResult {
732 auto new_operands = llvm::to_vector<8>(call.getOperands());
733 for (const auto& entry : info.arg_grads) {
734 auto it = stats->find(call.getOperand(entry.first));
735 if (it == stats->end()) return call.emitOpError("unknown tensor array");
736 for (const string& source : entry.second) {
737 auto grad_it = it->getSecond().grads.find(source);
738 if (grad_it != it->getSecond().grads.end()) {
739 new_operands.push_back(grad_it->second);
740 } else {
741 Value grad_var;
742 if (failed(CreateAndInitializeGradVariable(it->getFirst().getType(),
743 call, &grad_var))) {
744 return failure();
745 }
746 it->getSecond().grads[source] = grad_var;
747 new_operands.push_back(grad_var);
748 }
749 }
750 }
751 OpBuilder builder(call);
752 auto new_call = builder.create<CallOp>(
753 call.getLoc(), info.decomposed_callee.getFunctionType().getResults(),
754 new_operands, call->getAttrs());
755 new_call->setAttr(
756 "f", SymbolRefAttr::get(
757 builder.getContext(),
758 const_cast<func::FuncOp&>(info.decomposed_callee).getName()));
759 for (const auto& entry : info.ret_forward_input) {
760 call.getResult(entry.first)
761 .replaceAllUsesWith(call.getOperand(entry.second));
762 }
763 call.replaceAllUsesWith(new_call);
764 call.erase();
765 return success();
766 };
767 if (!emplace_res.second) {
768 // This callee was handled before.
769 if (!info.signature_change) return success();
770 return recreate_caller();
771 }
772 // Rewrite the callee.
773 info.signature_change = false;
774 auto ta_arg_buffer_type = [&](int64_t index) -> Type {
775 auto it = stats->find(call.getOperand(index));
776 if (it == stats->end()) return nullptr;
777 info.signature_change = true;
778 return it->getFirst().getType();
779 };
780 auto ta_accumulate_on_write = [&](int64_t index) {
781 auto it = stats->find(call.getOperand(index));
782 if (it == stats->end()) return false;
783 return it->getSecond().accumulate_on_write;
784 };
785 func::FuncOp lowered_callee = callee;
786 if (!callee.isPrivate()) {
787 // Clone non-private callee in case of signature change.
788 lowered_callee = callee.clone();
789 lowered_callee.setPrivate();
790 }
791 auto grads = AccessedGradients({lowered_callee}, module);
792 for (int64_t i = 0; i < lowered_callee.getNumArguments(); ++i) {
793 auto it = grads.find(i);
794 if (it == grads.end()) continue;
795 info.arg_grads.emplace_back(i, it->getSecond());
796 }
797 llvm::SmallDenseMap<Value, TensorArrayStats> callee_stats;
798 ChangeFunctionInputSignature(lowered_callee, grads, ta_arg_buffer_type,
799 ta_accumulate_on_write, &callee_stats);
800 if (failed(DecomposeTensorArrayOps(&lowered_callee.front(), module,
801 &callee_stats,
802 decomposed_partitioned_call_callees))) {
803 return failure();
804 }
805 for (int64_t i = 0; i < call.getNumResults(); ++i) {
806 auto ret = lowered_callee.front().getTerminator()->getOperand(i);
807 if (!getElementTypeOrSelf(ret.getType()).isa<TF::ResourceType>()) continue;
808 auto arg = ret.dyn_cast<BlockArgument>();
809 if (!arg) continue;
810 info.ret_forward_input.emplace_back(i, arg.getArgNumber());
811 }
812
813 info.decomposed_callee = lowered_callee;
814 if (lowered_callee != callee) {
815 if (!info.signature_change) {
816 // Signature is not modified. We do not need to keep two copies.
817 lowered_callee.setName(
818 StringAttr::get(callee->getContext(), callee.getName()));
819 callee.erase();
820 } else {
821 // Add the clone with a new name.
822 lowered_callee.setName(StringAttr::get(
823 callee->getContext(),
824 llvm::formatv("{0}_tensorarray_decomposed", callee.getName()).str()));
825 }
826 SymbolTable(module).insert(lowered_callee);
827 }
828 if (info.signature_change) return recreate_caller();
829 return success();
830 }
831
HandleRegionControlFlowOps(Operation & op,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats,llvm::StringMap<PartitionedCallTensorArrayOpsInfo> * decomposed_partitioned_call_callees)832 LogicalResult HandleRegionControlFlowOps(
833 Operation& op, ModuleOp module,
834 llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
835 llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
836 decomposed_partitioned_call_callees) {
837 for (OpOperand& operand : op.getOpOperands()) {
838 if (getElementTypeOrSelf(operand.get().getType()).isa<TF::ResourceType>()) {
839 return op.emitOpError()
840 << "found unexpected type " << operand.get().getType()
841 << " of operand #" << operand.getOperandNumber()
842 << ", resource type operands are expected to have been "
843 "canonicalized away for region based control flow ops";
844 }
845 }
846 for (OpResult result : op.getResults()) {
847 if (getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) {
848 return op.emitOpError()
849 << "found unexpected type " << result.getType() << " of result #"
850 << result.getResultNumber()
851 << ", resource type results are expected to have been "
852 "canonicalized away for region based control flow ops";
853 }
854 }
855
856 for (Region& region : op.getRegions()) {
857 if (failed(DecomposeTensorArrayOps(®ion.front(), module, stats,
858 decomposed_partitioned_call_callees)))
859 return failure();
860 }
861 return success();
862 }
863
DecomposeTensorArrayOps(Block * block,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats,llvm::StringMap<PartitionedCallTensorArrayOpsInfo> * decomposed_partitioned_call_callees)864 LogicalResult DecomposeTensorArrayOps(
865 Block* block, ModuleOp module,
866 llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
867 llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
868 decomposed_partitioned_call_callees) {
869 for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
870 if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
871 op.replaceAllUsesWith(op.getOperands());
872 op.erase();
873 } else if (auto ta = llvm::dyn_cast<TF::TensorArrayV3Op>(&op)) {
874 if (failed(HandleTensorArrayV3Op(ta, module, stats))) {
875 return failure();
876 }
877 } else if (auto read = llvm::dyn_cast<TF::TensorArrayReadV3Op>(&op)) {
878 if (failed(HandleTensorArrayReadV3Op(read, *stats))) return failure();
879 } else if (auto write = llvm::dyn_cast<TF::TensorArrayWriteV3Op>(&op)) {
880 if (failed(HandleTensorArrayWriteV3Op(write, *stats))) return failure();
881 } else if (auto concat = llvm::dyn_cast<TF::TensorArrayConcatV3Op>(&op)) {
882 if (failed(HandleTensorArrayConcatV3Op(concat, *stats))) return failure();
883 } else if (auto split = llvm::dyn_cast<TF::TensorArraySplitV3Op>(&op)) {
884 if (failed(HandleTensorArraySplitV3Op(split, *stats))) return failure();
885 } else if (auto size = llvm::dyn_cast<TF::TensorArraySizeV3Op>(&op)) {
886 if (failed(HandleTensorArraySizeV3Op(size, *stats))) return failure();
887 } else if (auto grad = llvm::dyn_cast<TF::TensorArrayGradV3Op>(&op)) {
888 if (failed(HandleTensorArrayGradV3Op(grad, stats))) return failure();
889 } else if (auto gather = llvm::dyn_cast<TF::TensorArrayGatherV3Op>(&op)) {
890 if (failed(HandleTensorArrayGatherV3Op(gather, *stats))) return failure();
891 } else if (auto scatter = llvm::dyn_cast<TF::TensorArrayScatterV3Op>(&op)) {
892 if (failed(HandleTensorArrayScatterV3Op(scatter, *stats))) {
893 return failure();
894 }
895 } else if (auto close = llvm::dyn_cast<TF::TensorArrayCloseV3Op>(&op)) {
896 close.erase();
897 } else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
898 if (failed(HandleWhileOp(while_op, module, stats,
899 decomposed_partitioned_call_callees))) {
900 return failure();
901 }
902 } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
903 if (failed(HandleIfOp(if_op, module, stats,
904 decomposed_partitioned_call_callees))) {
905 return failure();
906 }
907 } else if (llvm::isa<TF::CaseRegionOp>(op) ||
908 llvm::isa<TF::IfRegionOp>(op) ||
909 llvm::isa<TF::WhileRegionOp>(op)) {
910 if (failed(HandleRegionControlFlowOps(
911 op, module, stats, decomposed_partitioned_call_callees)))
912 return failure();
913 } else if (auto pcall = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
914 auto callee = pcall.func();
915 if (!callee)
916 return pcall.emitOpError(
917 "TensorArray decomposition does not support call with nested "
918 "references.");
919
920 if (failed(
921 HandlePartitionedCallOp(pcall, callee, module, stats,
922 decomposed_partitioned_call_callees))) {
923 return failure();
924 }
925 } else if (auto spcall =
926 llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
927 if (failed(
928 HandlePartitionedCallOp(spcall, spcall.func(), module, stats,
929 decomposed_partitioned_call_callees))) {
930 return failure();
931 }
932 }
933 }
934 return success();
935 }
936
runOnOperation()937 void TensorArrayOpsDecompositionPass::runOnOperation() {
938 auto module = getOperation();
939 auto main = module.lookupSymbol<func::FuncOp>("main");
940 if (!main) return;
941 llvm::SmallDenseMap<Value, TensorArrayStats> stats;
942 llvm::StringMap<PartitionedCallTensorArrayOpsInfo>
943 decomposed_partitioned_call_callees;
944 if (failed(DecomposeTensorArrayOps(&main.front(), module, &stats,
945 &decomposed_partitioned_call_callees))) {
946 signalPassFailure();
947 }
948 }
949
950 } // namespace
951
952 namespace TF {
953
954 std::unique_ptr<OperationPass<ModuleOp>>
CreateTensorArrayOpsDecompositionPass()955 CreateTensorArrayOpsDecompositionPass() {
956 return std::make_unique<TensorArrayOpsDecompositionPass>();
957 }
958
959 } // namespace TF
960 } // namespace mlir
961