1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass prepares for legalization to the TFLite dialect by
17 // converting Tensorlist operations in TensorFlow dialect into operations that
18 // can be legalized to TensorFlow Lite dialect with simple replacements. The
19 // newly created operations are in the TensorFlow dialect if the operation can
20 // be represented using a TensorFlow op. Otherwise, TensorFlow Lite dialect op
21 // is used.
22
23 #include <climits>
24 #include <cstdint>
25 #include <utility>
26
27 #include "absl/container/inlined_vector.h"
28 #include "llvm/ADT/ArrayRef.h"
29 #include "llvm/ADT/DenseMap.h"
30 #include "llvm/ADT/None.h"
31 #include "llvm/ADT/Optional.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/SmallSet.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/ADT/StringSwitch.h"
36 #include "llvm/Support/Casting.h"
37 #include "llvm/Support/Debug.h"
38 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" // from @llvm-project
39 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
40 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
41 #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
42 #include "mlir/IR/Attributes.h" // from @llvm-project
43 #include "mlir/IR/Block.h" // from @llvm-project
44 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
45 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
46 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
47 #include "mlir/IR/MLIRContext.h" // from @llvm-project
48 #include "mlir/IR/Matchers.h" // from @llvm-project
49 #include "mlir/IR/Operation.h" // from @llvm-project
50 #include "mlir/IR/OperationSupport.h" // from @llvm-project
51 #include "mlir/IR/PatternMatch.h" // from @llvm-project
52 #include "mlir/IR/SymbolTable.h" // from @llvm-project
53 #include "mlir/IR/TypeRange.h" // from @llvm-project
54 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
55 #include "mlir/IR/Types.h" // from @llvm-project
56 #include "mlir/IR/UseDefLists.h" // from @llvm-project
57 #include "mlir/IR/Value.h" // from @llvm-project
58 #include "mlir/Pass/Pass.h" // from @llvm-project
59 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
60 #include "mlir/Support/LLVM.h" // from @llvm-project
61 #include "mlir/Support/LogicalResult.h" // from @llvm-project
62 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
63 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
64 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
65 #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
66 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
67 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
68 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
69 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
70 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
71 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
72 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
73 #include "tensorflow/core/framework/tensor.h"
74 #include "tensorflow/core/framework/types.pb.h"
75 #include "tensorflow/core/kernels/tensor_list.h"
76
77 #define DEBUG_TYPE "tf-tfl-legalization"
78
79 //===----------------------------------------------------------------------===//
80 // The actual LowerStaticTensorList Pass.
81 //
82 namespace mlir {
83 namespace {
84 #define GEN_PASS_CLASSES
85 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
86
87 /// Lower TensorList ops in functions for subsequent legalization.
88 struct LowerStaticTensorListPass
89 : public LowerStaticTensorListPassBase<LowerStaticTensorListPass> {
90 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerStaticTensorListPass)
91
92 LowerStaticTensorListPass() = default;
LowerStaticTensorListPassmlir::__anonaaa99d6c0111::LowerStaticTensorListPass93 LowerStaticTensorListPass(const LowerStaticTensorListPass &) {}
LowerStaticTensorListPassmlir::__anonaaa99d6c0111::LowerStaticTensorListPass94 explicit LowerStaticTensorListPass(bool allow_tensorlist_pass_through,
95 bool default_to_single_batch,
96 bool enable_dynamic_update_slice) {
97 this->allow_tensorlist_pass_through_ = allow_tensorlist_pass_through;
98 this->default_to_single_batch_ = default_to_single_batch;
99 this->enable_dynamic_update_slice_ = enable_dynamic_update_slice;
100 }
101
102 void runOnOperation() override;
103 };
104
CreateI32SplatConst(Location loc,PatternRewriter * rewriter,ArrayRef<int64_t> shape,int32_t val)105 Value CreateI32SplatConst(Location loc, PatternRewriter *rewriter,
106 ArrayRef<int64_t> shape, int32_t val) {
107 RankedTensorType type =
108 RankedTensorType::get(shape, rewriter->getIntegerType(32));
109 DenseElementsAttr attr =
110 DenseElementsAttr::get(type, rewriter->getI32IntegerAttr(val));
111 return rewriter->create<arith::ConstantOp>(loc, type, attr);
112 }
113
CreateI64SplatConst(Location loc,PatternRewriter * rewriter,ArrayRef<int64_t> shape,int64_t val)114 Value CreateI64SplatConst(Location loc, PatternRewriter *rewriter,
115 ArrayRef<int64_t> shape, int64_t val) {
116 RankedTensorType type =
117 RankedTensorType::get(shape, rewriter->getIntegerType(64));
118 DenseElementsAttr attr =
119 DenseElementsAttr::get(type, rewriter->getI64IntegerAttr(val));
120 return rewriter->create<arith::ConstantOp>(loc, type, attr);
121 }
122
CreateI32SplatTensor(Location loc,PatternRewriter * rewriter,Value shape_tensor,int32_t val)123 Value CreateI32SplatTensor(Location loc, PatternRewriter *rewriter,
124 Value shape_tensor, int32_t val) {
125 Value scalar_val = CreateI32SplatConst(loc, rewriter, {}, val);
126 return rewriter->create<TF::FillOp>(
127 loc, RankedTensorType::get({-1}, rewriter->getIntegerType(32)),
128 shape_tensor, scalar_val);
129 }
130
131 // Returns a new type by prepending the specified dimension to the shape of
132 // the given type if it is a ranked type.
PrependLeadingDimIfRanked(int64_t dim,Type type,PatternRewriter * rewriter)133 Type PrependLeadingDimIfRanked(int64_t dim, Type type,
134 PatternRewriter *rewriter) {
135 Type dtype = getElementTypeOrSelf(type);
136 if (RankedTensorType ty = type.dyn_cast<RankedTensorType>()) {
137 llvm::SmallVector<int64_t, 4> shape = {dim};
138 shape.append(ty.getShape().begin(), ty.getShape().end());
139 return RankedTensorType::get(shape, dtype);
140 }
141 return type;
142 }
143
GetTensorTypeForTensorList(Type element_type,TF::VariantType handle_dtype,PatternRewriter * rewriter)144 Type GetTensorTypeForTensorList(Type element_type, TF::VariantType handle_dtype,
145 PatternRewriter *rewriter) {
146 // If the variant type in the output handle has item shape available, use it
147 // to derive the output shape by setting unknown leading dimension.
148 // Otherwise, result type will be of unranked type.
149 if (handle_dtype.getSubtypes().empty()) {
150 return UnrankedTensorType::get(element_type);
151 }
152 return PrependLeadingDimIfRanked(-1, handle_dtype.getSubtypes()[0], rewriter);
153 }
154
155 // Gets the index of tensorlist arguments which size might get changed by the
156 // function.
GetResizedTensorListIndexes(func::FuncOp func,const llvm::SmallSet<int,4> & tensor_list_args)157 llvm::SmallSet<int, 4> GetResizedTensorListIndexes(
158 func::FuncOp func, const llvm::SmallSet<int, 4> &tensor_list_args) {
159 // `indexes` stores the argument index of tensorlists which size may get
160 // updated in the function.
161 llvm::SmallSet<int, 4> indexes;
162 for (BlockArgument &arg : func.getArguments()) {
163 if (tensor_list_args.contains(arg.getArgNumber())) {
164 for (const mlir::OpOperand &use : arg.getUses()) {
165 mlir::Operation *op = use.getOwner();
166 // Currently we only check if the tensorlist argument is consumed by
167 // `TensorListPushBack` or `TensorListResize`, since those are the only
168 // length-mutating ops supported in this pass.
169 if (llvm::isa<TF::TensorListPushBackOp>(op) ||
170 llvm::isa<TF::TensorListResizeOp>(op)) {
171 indexes.insert(arg.getArgNumber());
172 }
173 }
174 }
175 }
176 return indexes;
177 }
178
179 // Creates a slice of the tensorlist `input_list`, starting from
180 // [start_index, 0, ...0], with size [size, -1, ...-1].
181 //
182 // Requires that `start_index` and `size` are scalar tensors and
183 // `item_position_shape` is a 1-D tensor with only one element equal to the rank
184 // of an item in the tensorlist.
CreateSliceOpForTensorList(Location loc,Value input_list,Value start_index,Value size,Value item_rank,Type result_type,PatternRewriter * rewriter)185 TF::SliceOp CreateSliceOpForTensorList(Location loc, Value input_list,
186 Value start_index, Value size,
187 Value item_rank, Type result_type,
188 PatternRewriter *rewriter) {
189 // Create the start position of slice. This is done by concatenating
190 // `start_index` and `partial_start_position` together.
191 IntegerType shape_dtype = rewriter->getIntegerType(32);
192 RankedTensorType position_type = RankedTensorType::get({-1}, shape_dtype);
193 Value partial_start_position =
194 CreateI32SplatTensor(loc, rewriter, item_rank, 0);
195 Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
196 RankedTensorType vector_type = RankedTensorType::get({1}, shape_dtype);
197 auto expanded_start_index = rewriter->create<TF::ExpandDimsOp>(
198 loc, vector_type, start_index, scalar_zero);
199 auto start_position = rewriter->create<TF::ConcatOp>(
200 loc, position_type, scalar_zero,
201 ArrayRef<Value>({expanded_start_index, partial_start_position}));
202
203 // Create the slice size tensor. This is done by concatenating `size` and
204 // `partial_size`.
205 auto size_leading_dim =
206 rewriter->create<TF::ExpandDimsOp>(loc, vector_type, size, scalar_zero);
207 Value partial_size = CreateI32SplatTensor(loc, rewriter, item_rank, -1);
208 auto slice_size = rewriter->create<TF::ConcatOp>(
209 loc, position_type, scalar_zero,
210 ArrayRef<Value>({size_leading_dim, partial_size}));
211
212 return rewriter->create<TF::SliceOp>(loc, result_type, input_list,
213 start_position, slice_size);
214 }
215
216 template <typename OpT>
217 class TensorListOpConverterBase : public OpConversionPattern<OpT> {
218 public:
TensorListOpConverterBase(MLIRContext * context,bool allow_tensorlist_pass_through,bool default_to_single_batch)219 explicit TensorListOpConverterBase<OpT>(MLIRContext *context,
220 bool allow_tensorlist_pass_through,
221 bool default_to_single_batch)
222 : OpConversionPattern<OpT>::OpConversionPattern(context),
223 allow_tensorlist_pass_through_(allow_tensorlist_pass_through),
224 default_to_single_batch_(default_to_single_batch) {}
225
226 protected:
227 // This flag will control the behavior of error emitting during rewrite:
228 // 1) If it's true, then patterns will only emit errors during debug or
229 // tracing mode. 2) If it's false, then patterns will emit standard errors
230 // when there is a rewrite failure.
231 bool allow_tensorlist_pass_through_;
232
233 // This flag will control the behavior of setting the batch size one when the
234 // given batch size is None in order to force to proceed the tensor list op
235 // lowerings.
236 bool default_to_single_batch_;
237 };
238
239 // Converts tf.Const containing variant of type TensorList to a tensor of
240 // primitive element types. Each of the individual tensor in the list is
241 // converted to an ElementsAttr and then those are packed together using
242 // tf.Pack op.
243 struct ConvertConst : public OpConversionPattern<TF::ConstOp> {
244 using OpConversionPattern::OpConversionPattern;
245
matchAndRewritemlir::__anonaaa99d6c0111::ConvertConst246 LogicalResult matchAndRewrite(
247 TF::ConstOp op, OpAdaptor adaptor,
248 ConversionPatternRewriter &rewriter) const override {
249 // Verify that the tensor proto contains tensor of type variant and scalar
250 // shape. The variant type should hold a TensorList.
251 auto proto_attr = op.value().dyn_cast<TF::TensorProtoAttr>();
252 if (!proto_attr) return failure();
253 tensorflow::Tensor tensor;
254 if (!tensorflow::ConvertToTensor(proto_attr, &tensor).ok())
255 return failure();
256 if (tensor.dtype() != tensorflow::DT_VARIANT) return failure();
257 if (!tensorflow::TensorShapeUtils::IsScalar(tensor.shape()))
258 return failure();
259
260 const tensorflow::TensorList *list =
261 tensor.scalar<tensorflow::Variant>()().get<tensorflow::TensorList>();
262 if (!list) return failure();
263
264 // Verify output type is variant and contains exactly one ranked subtypes.
265 auto variant_ty =
266 getElementTypeOrSelf(op.getType()).dyn_cast<TF::VariantType>();
267 if (!variant_ty) return failure();
268 ArrayRef<TensorType> subtypes = variant_ty.getSubtypes();
269 if (subtypes.size() != 1) return failure();
270 RankedTensorType list_element_ty =
271 subtypes.front().dyn_cast<RankedTensorType>();
272 if (!list_element_ty) return failure();
273
274 // Extract tensor elements for the TensorList and construct result type
275 // based on the number of elements and element shape.
276 const std::vector<tensorflow::Tensor> &tensors = list->tensors();
277 llvm::SmallVector<int64_t, 4> result_shape = {
278 static_cast<int64_t>(tensors.size())};
279 result_shape.append(list_element_ty.getShape().begin(),
280 list_element_ty.getShape().end());
281 auto result_ty =
282 RankedTensorType::get(result_shape, list_element_ty.getElementType());
283
284 // If the list is empty, directly create the final result instead of
285 // creating the tf.Pack op. tf.Pack op requires at least one operand.
286 if (tensors.empty()) {
287 tensorflow::Tensor tensor(list->element_dtype,
288 tensorflow::TensorShape(result_shape));
289 auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter);
290 if (!attr_or.ok()) return failure();
291 rewriter.replaceOpWithNewOp<TF::ConstOp>(op, attr_or.ValueOrDie());
292 return success();
293 }
294
295 // Extract individual tensor list element and combine them using the tf.Pack
296 // op.
297 Location loc = op.getLoc();
298 llvm::SmallVector<Value, 4> values;
299 values.reserve(tensors.size());
300 for (const tensorflow::Tensor &tensor : tensors) {
301 auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter);
302 if (!attr_or.ok()) return failure();
303
304 auto value = rewriter.create<TF::ConstOp>(loc, attr_or.ValueOrDie());
305 values.push_back(value);
306 }
307 rewriter.replaceOpWithNewOp<TF::PackOp>(
308 op, result_ty, values, /*axis=*/rewriter.getI64IntegerAttr(0));
309 return success();
310 }
311 };
312
313 struct ConvertTensorListSetItem
314 : public OpConversionPattern<TF::TensorListSetItemOp> {
ConvertTensorListSetItemmlir::__anonaaa99d6c0111::ConvertTensorListSetItem315 explicit ConvertTensorListSetItem(MLIRContext *context,
316 bool enable_dynamic_update_slice = false)
317 : OpConversionPattern<TF::TensorListSetItemOp>(context),
318 enable_dynamic_update_slice(enable_dynamic_update_slice) {}
319
matchAndRewritemlir::__anonaaa99d6c0111::ConvertTensorListSetItem320 LogicalResult matchAndRewrite(
321 TF::TensorListSetItemOp op, OpAdaptor adaptor,
322 ConversionPatternRewriter &rewriter) const override {
323 if (enable_dynamic_update_slice) {
324 return matchAndRewriteImplWithDynamicUpdateSlice(op, adaptor, rewriter);
325 } else {
326 return matchAndRewriteImplWithSliceAndConcat(op, adaptor, rewriter);
327 }
328 }
329
330 // This function rewrites the original op into a series of slice and concat op
331 // to produce the same result. It first slices the first `$index` rows. Then
332 // expands the dimension of the `$item`, followed by another slice of the
333 // remaining rows starting from `$index` + 1. Lastly it concatenates the
334 // three parts together.
335 // On a high level, it's doing something like:
336 // def : Pat<(TF_TensorListSetItemOp $input, $index, $item),
337 // (Concat
338 // concat_dim = 0,
339 // (Slice $input, [0, 0, ...], (Concat (ExpandDims $index, expand_dim =
340 // 0), [-1, -1, ...])), (ExpandDims $item, expand_dim = 0), (Slice
341 // $input, [$index + 1, 0, 0, ...], [-1, -1, ...]))>;
matchAndRewriteImplWithSliceAndConcatmlir::__anonaaa99d6c0111::ConvertTensorListSetItem342 LogicalResult matchAndRewriteImplWithSliceAndConcat(
343 TF::TensorListSetItemOp op, OpAdaptor adaptor,
344 ConversionPatternRewriter &rewriter) const {
345 Location loc = op.getLoc();
346 Value input = adaptor.getOperands()[0];
347 Value index = adaptor.getOperands()[1];
348 Value item = adaptor.getOperands()[2];
349
350 IntegerType shape_dtype = rewriter.getIntegerType(32);
351 auto item_rank = rewriter.create<TF::RankOp>(
352 loc, RankedTensorType::get({}, shape_dtype), item);
353 Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
354
355 // Calculate `index` + 1, which is used to generate the start position for
356 // the second slice op.
357 auto suffix_start =
358 rewriter.create<TF::AddOp>(loc, index.getType(), index,
359 CreateI32SplatConst(loc, &rewriter, {}, 1));
360
361 auto item_position_shape = rewriter.create<TF::ExpandDimsOp>(
362 loc, RankedTensorType::get({1}, shape_dtype), item_rank, scalar_zero);
363 // Create two slice ops.
364 Type element_type = input.getType().cast<TensorType>().getElementType();
365 UnrankedTensorType unranked_tensor = UnrankedTensorType::get(element_type);
366 Value scalar_minus_one = CreateI32SplatConst(loc, &rewriter, {}, -1);
367 TF::SliceOp slice1 =
368 CreateSliceOpForTensorList(loc, /*input_list=*/input,
369 /*start_index=*/scalar_zero,
370 /*size=*/index,
371 /*item_rank=*/item_position_shape,
372 /*result_type=*/unranked_tensor, &rewriter);
373 TF::SliceOp slice2 =
374 CreateSliceOpForTensorList(loc, /*input_list=*/input,
375 /*start_index=*/suffix_start,
376 /*size=*/scalar_minus_one,
377 /*item_rank=*/item_position_shape,
378 /*result_type=*/unranked_tensor, &rewriter);
379
380 // Expand the dimension of item so that it will have the same rank with
381 // input.
382 auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
383 op.getLoc(), unranked_tensor, item, scalar_zero);
384
385 // Concatenate three parts together to generate the final result.
386 rewriter.replaceOpWithNewOp<TF::ConcatOp>(
387 op, input.getType(), scalar_zero,
388 ArrayRef<Value>({slice1, expanded_item, slice2}));
389 return success();
390 }
391
392 // This function rewrites the original op into a XLA DynamicUpdateSlice op.
393 // |item| is expanded to have the same dimension as input_handle and
394 // |index| is expanded to [index, 0, 0, ...] as the indices to input_handle.
395 // On a high level, it's doing something like:
396 // def : Pat<(TensorListSetItem($input_handle, $index, $item)),
397 // (XlaDynamicUpdateSlice($input_handle, ExpandDims($item, 0),
398 // Concat(ExpandDims($index, 0), [0, 0, 0, ...])))>
matchAndRewriteImplWithDynamicUpdateSlicemlir::__anonaaa99d6c0111::ConvertTensorListSetItem399 LogicalResult matchAndRewriteImplWithDynamicUpdateSlice(
400 TF::TensorListSetItemOp op, OpAdaptor adaptor,
401 ConversionPatternRewriter &rewriter) const {
402 Location loc = op.getLoc();
403 Value input = adaptor.getOperands()[0];
404 Value index = adaptor.getOperands()[1];
405 Value item = adaptor.getOperands()[2];
406
407 IntegerType shape_dtype = rewriter.getIntegerType(32);
408 auto item_rank = rewriter.create<TF::RankOp>(
409 loc, RankedTensorType::get({}, shape_dtype), item);
410 Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
411
412 // Concat(ExpandDims(index, 0), [0, 0, 0, ...])
413 RankedTensorType position_type = RankedTensorType::get({-1}, shape_dtype);
414 auto item_position_shape = rewriter.create<TF::ExpandDimsOp>(
415 loc, RankedTensorType::get({1}, shape_dtype), item_rank, scalar_zero);
416 Value partial_index =
417 CreateI32SplatTensor(loc, &rewriter, item_position_shape, 0);
418 RankedTensorType vector_type = RankedTensorType::get({1}, shape_dtype);
419 auto expanded_index =
420 rewriter.create<TF::ExpandDimsOp>(loc, vector_type, index, scalar_zero);
421 auto start_position = rewriter.create<TF::ConcatOp>(
422 loc, position_type, scalar_zero,
423 ArrayRef<Value>({expanded_index, partial_index}));
424
425 // Expand the dimension of item so that it will have the same rank with
426 // input.
427 // ExpandDims(item, 0)
428 Type element_type = input.getType().cast<TensorType>().getElementType();
429 UnrankedTensorType unranked_tensor = UnrankedTensorType::get(element_type);
430 auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
431 op.getLoc(), unranked_tensor, item, scalar_zero);
432
433 // Update the element with XlaDynamicUpdateSliceOp.
434 rewriter.replaceOpWithNewOp<TF::XlaDynamicUpdateSliceOp>(
435 op, input.getType(), input, expanded_item, start_position);
436 return success();
437 }
438
439 bool enable_dynamic_update_slice;
440 };
441
442 // Rewrites op of the template type initializing a TensorList with a list of ops
443 // to generate an equivalent raw tensor. Derived classes are required to
444 // override GetNumElements method.
445 template <typename OpT>
446 struct ConvertTensorListInitOp : public TensorListOpConverterBase<OpT> {
447 using TensorListOpConverterBase<OpT>::TensorListOpConverterBase;
448 using TensorListOpConverterBase<OpT>::allow_tensorlist_pass_through_;
449 using TensorListOpConverterBase<OpT>::default_to_single_batch_;
450
451 // Create and return a 1-d tensor with exactly one element equal to the number
452 // of list elements to initialize the output tensor list with.
453 virtual Value GetNumElements(OpT op, ValueRange operands,
454 PatternRewriter *rewriter) const = 0;
455
456 // Rewrites the original op into `tf.fill`. The result tensor shape is
457 // [num_element, element_shape]. All the values in the result tensor will be
458 // initialized to 0.
matchAndRewritemlir::__anonaaa99d6c0111::ConvertTensorListInitOp459 LogicalResult matchAndRewrite(
460 OpT op, typename OpT::Adaptor adaptor,
461 ConversionPatternRewriter &rewriter) const override {
462 Type dtype = op.element_dtype();
463 if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() ||
464 dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) ||
465 dtype.isInteger(32) || dtype.isInteger(64))) {
466 const char *error_info =
467 "requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
468 "integer or 16-bit/32-bit/64-bit float type during TF Lite "
469 "transformation pass";
470 return allow_tensorlist_pass_through_
471 ? rewriter.notifyMatchFailure(op, error_info)
472 : op.emitOpError(error_info);
473 }
474
475 Value element_shape = adaptor.getOperands()[0];
476 Type shape_dtype = getElementTypeOrSelf(element_shape.getType());
477 // If the `element_shape` is a scalar, we try to acquire its shape by
478 // looking at the first `TensorListSetItemOp` writing to this tensor list.
479 // Here we assume that the element_shape won't be changed before calling
480 // the first `TensorListSetItemOp`.
481 if (auto shaped_type = element_shape.getType().dyn_cast<ShapedType>()) {
482 if (shaped_type.hasRank() && shaped_type.getRank() == 0) {
483 bool element_shape_acquired = false;
484 auto uses = op.getResult().getUses();
485 for (auto &use : llvm::make_early_inc_range(uses)) {
486 if (TF::TensorListSetItemOp set_op =
487 llvm::dyn_cast<TF::TensorListSetItemOp>(use.getOwner())) {
488 element_shape = rewriter.create<TF::ShapeOp>(
489 op.getLoc(), RankedTensorType::get({-1}, shape_dtype),
490 set_op.item());
491 element_shape_acquired = true;
492 } else if (TF::WhileOp while_op =
493 llvm::dyn_cast<TF::WhileOp>(use.getOwner())) {
494 // Tensorlist is passed into a while loop, check inside the body
495 // function.
496 auto inside_uses = while_op.body_function()
497 .getArgument(use.getOperandNumber())
498 .getUses();
499 for (auto &inside_use : llvm::make_early_inc_range(inside_uses)) {
500 if (TF::TensorListSetItemOp set_op =
501 llvm::dyn_cast<TF::TensorListSetItemOp>(
502 inside_use.getOwner())) {
503 if (auto shaped_type =
504 set_op.item().getType().dyn_cast<ShapedType>()) {
505 if (shaped_type.hasStaticShape()) {
506 RankedTensorType type = RankedTensorType::get(
507 {shaped_type.getRank()}, rewriter.getIntegerType(32));
508 SmallVector<Attribute, 4> shape_attr;
509 for (int64_t dim : shaped_type.getShape()) {
510 shape_attr.push_back(rewriter.getI32IntegerAttr(dim));
511 }
512 DenseElementsAttr attr =
513 DenseElementsAttr::get(type, shape_attr);
514 element_shape = rewriter.create<arith::ConstantOp>(
515 op.getLoc(), type, attr);
516 element_shape_acquired = true;
517 break;
518 }
519 }
520 }
521 }
522 }
523 if (element_shape_acquired) break;
524 }
525 if (!element_shape_acquired) {
526 const char *error_info =
527 "requires element_shape to be 1D tensor during TF Lite "
528 "transformation pass";
529 return allow_tensorlist_pass_through_
530 ? rewriter.notifyMatchFailure(op, error_info)
531 : op.emitOpError(error_info);
532 }
533 }
534 }
535
536 DenseIntElementsAttr dense_elem_attr;
537 if (matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
538 // Note: It's technically unsafe to rewrite
539 // TensorListReserve(num_element, element_shape)
540 // to
541 // Fill(Concat(num_element, element_shape), 0)
542 // because element_shape may contain -1 to represent unknown dimension.
543 //
544 // In real world use cases (e.g. Keras RNN), `element_shape` is usually
545 // a constant, and the first dimension of `element_shape` is usually
546 // batch dimension. Currently TFLiteConverter always rewrite unknown
547 // batch dimension to 1, therefore we also rewrite unknown dimension in
548 // `element_shape` to 1 here.
549 //
550 // This workaround enables converting Keras RNN without specifying batch
551 // dimension. This isn't guaranteed to work, but it doesn't break any
552 // non-broken cases either (since it's already broken if `element_shape`
553 // contains -1).
554 // TODO(b/142096690): Support dynamic element shape and remove the
555 // workaround.
556 SmallVector<int32_t, 4> new_element_shape_values;
557
558 auto int_values = dense_elem_attr.getValues<APInt>();
559 for (auto it = int_values.begin(); it != int_values.end(); ++it) {
560 auto dim_value = (*it).getSExtValue();
561 if (it == int_values.begin() && dim_value == -1) {
562 if (!default_to_single_batch_) {
563 const char *error_info =
564 "requires element_shape to be static during TF Lite "
565 "transformation pass";
566 return allow_tensorlist_pass_through_
567 ? rewriter.notifyMatchFailure(op, error_info)
568 : op.emitOpError(error_info);
569 }
570 dim_value = 1;
571 }
572 new_element_shape_values.push_back(dim_value);
573 }
574
575 auto attr = DenseIntElementsAttr::get(
576 element_shape.getType().cast<ShapedType>(), new_element_shape_values);
577 auto new_element_shape = rewriter.create<arith::ConstantOp>(
578 op.getLoc(), element_shape.getType(), attr);
579 element_shape = new_element_shape;
580 }
581
582 int64_t result_rank = -1; // -1 means unknown result rank.
583 Type element_dtype = op.element_dtype();
584 Type result_type = UnrankedTensorType::get(element_dtype);
585 Value leading_dim = GetNumElements(op, adaptor.getOperands(), &rewriter);
586 if (auto element_type =
587 op.element_type().template dyn_cast<RankedTensorType>()) {
588 result_rank = element_type.getRank() + 1;
589 int64_t leading_dim_v = -1;
590 ElementsAttr element_attr;
591 if (matchPattern(leading_dim, m_Constant(&element_attr))) {
592 leading_dim_v = element_attr.getValues<APInt>()[0].getSExtValue();
593 }
594 SmallVector<int64_t, 4> result_shape = {leading_dim_v};
595 ArrayRef<int64_t> shape = element_type.getShape();
596 result_shape.append(shape.begin(), shape.end());
597 result_type = RankedTensorType::get(result_shape, element_dtype);
598 }
599
600 // Create a 1-D RankedTensorType for result's shape. Number of elements in
601 // it is equal to the rank of the result, if known. Otherwise, the number of
602 // elements are unknown and represented with -1. In both cases, we can
603 // specify dimension using rank of the result.
604 Type shape_type = RankedTensorType::get({result_rank}, shape_dtype);
605
606 Location loc = op.getLoc();
607 // Add number of elements as the prefix to the element shape to get shape of
608 // the output tensor.
609 Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
610 auto list_shape = rewriter.create<TF::ConcatOp>(
611 loc, shape_type, scalar_zero,
612 ArrayRef<Value>({leading_dim, element_shape}));
613
614 // Create a zero-initialized constant tensor that has the same type
615 // as specified by element_dtype.
616 RankedTensorType zero_type = RankedTensorType::get({}, element_dtype);
617 Attribute zero_attr = rewriter.getZeroAttr(zero_type);
618 auto zero = rewriter.create<arith::ConstantOp>(loc, zero_type, zero_attr);
619
620 rewriter.replaceOpWithNewOp<TF::FillOp>(op, result_type, list_shape, zero);
621 return success();
622 }
623 };
624
625 struct ConvertTensorListReserve
626 : public ConvertTensorListInitOp<TF::TensorListReserveOp> {
ConvertTensorListReservemlir::__anonaaa99d6c0111::ConvertTensorListReserve627 explicit ConvertTensorListReserve(MLIRContext *context,
628 bool allow_tensorlist_pass_through,
629 bool default_to_single_batch)
630 : ConvertTensorListInitOp(context, allow_tensorlist_pass_through,
631 default_to_single_batch) {}
632
GetNumElementsmlir::__anonaaa99d6c0111::ConvertTensorListReserve633 Value GetNumElements(TF::TensorListReserveOp op, ValueRange operands,
634 PatternRewriter *rewriter) const override {
635 Value scalar_zero = CreateI32SplatConst(op.getLoc(), rewriter, {}, 0);
636 Type shape_dtype = getElementTypeOrSelf(op.element_shape().getType());
637 Value num_elements = operands[1];
638 IntegerAttr attr;
639 if (matchPattern(num_elements, m_Constant(&attr))) {
640 return CreateI32SplatConst(op.getLoc(), rewriter, {1}, attr.getInt());
641 }
642 if (auto const_op = num_elements.getDefiningOp<TF::ConstOp>()) {
643 return CreateI32SplatConst(op->getLoc(), rewriter, {1},
644 (*const_op.value()
645 .cast<DenseElementsAttr>()
646 .getValues<APInt>()
647 .begin())
648 .getSExtValue());
649 }
650 return rewriter->create<TF::ExpandDimsOp>(
651 op.getLoc(), RankedTensorType::get({1}, shape_dtype), num_elements,
652 scalar_zero);
653 }
654 };
655
656 // Note that we ignore the second operand `max_num_elements` as we don't have
657 // any restrictions on the number of elements we can support. So this may
658 // have a different behavior compared to TensorFlow in case of errors.
659 struct ConvertEmptyTensorList
660 : public ConvertTensorListInitOp<TF::EmptyTensorListOp> {
ConvertEmptyTensorListmlir::__anonaaa99d6c0111::ConvertEmptyTensorList661 explicit ConvertEmptyTensorList(MLIRContext *context,
662 bool allow_tensorlist_pass_through,
663 bool default_to_single_batch)
664 : ConvertTensorListInitOp(context, allow_tensorlist_pass_through,
665 default_to_single_batch) {}
666
GetNumElementsmlir::__anonaaa99d6c0111::ConvertEmptyTensorList667 Value GetNumElements(TF::EmptyTensorListOp op, ValueRange operands,
668 PatternRewriter *rewriter) const override {
669 return CreateI32SplatConst(op.getLoc(), rewriter, {1}, 0);
670 }
671 };
672
673 struct ConvertTensorListPushBack
674 : public OpConversionPattern<TF::TensorListPushBackOp> {
675 using OpConversionPattern::OpConversionPattern;
676
matchAndRewritemlir::__anonaaa99d6c0111::ConvertTensorListPushBack677 LogicalResult matchAndRewrite(
678 TF::TensorListPushBackOp op, OpAdaptor adaptor,
679 ConversionPatternRewriter &rewriter) const override {
680 Value input_handle = adaptor.getOperands()[0];
681 Value item = adaptor.getOperands()[1];
682
683 // Expand the shape of the item so that it will have rank same as the input
684 // tensor and it is compatible for the Concat Op.
685 Type expanded_item_type =
686 PrependLeadingDimIfRanked(1, item.getType(), &rewriter);
687 Location loc = op.getLoc();
688 Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
689 auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
690 loc, expanded_item_type, item, scalar_zero);
691
692 Type elem_type = getElementTypeOrSelf(item);
693 auto handle_dtype = getElementTypeOrSelf(op.output_handle().getType())
694 .cast<TF::VariantType>();
695 Type result_type =
696 GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
697
698 // Concatenate tensor stored in the input handle with the expanded item to
699 // get a tensor equivalent to the TensorList generated by this op.
700 rewriter.replaceOpWithNewOp<TF::ConcatOp>(
701 op, result_type, scalar_zero,
702 ArrayRef<Value>({input_handle, expanded_item}));
703 return success();
704 }
705 };
706
707 // Rewrites `TensorListResize` op into a functional `If` op and several basic
708 // TF ops to match the op semantics of Tensorflow. Basically, it does:
709 // 1) If the requested size is smaller or equal than the input tensorlist's
710 // size, rewrite it to a Slice op so that only the first 'size' rows are
711 // returned. 2) If the requested size is larger than the input tensorlist's
712 // size. We need to create an additional tensorlist with 'size - input_size'
713 // elements, and append it to the end of the input tensorlist.
714 struct ConvertTensorListResize
715 : public OpConversionPattern<TF::TensorListResizeOp> {
716 using OpConversionPattern::OpConversionPattern;
717
matchAndRewritemlir::__anonaaa99d6c0111::ConvertTensorListResize718 LogicalResult matchAndRewrite(
719 TF::TensorListResizeOp op, OpAdaptor adaptor,
720 ConversionPatternRewriter &rewriter) const override {
721 Value input_handle = adaptor.getOperands()[0];
722 Value size = adaptor.getOperands()[1];
723
724 Location loc = op.getLoc();
725 Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
726
727 // Compute the input tensorlist's length and store it in `input_size`.
728 IntegerType shape_dtype = rewriter.getIntegerType(32);
729 auto input_size = rewriter.create<TF::TensorListLengthOp>(
730 loc, RankedTensorType::get({}, shape_dtype), op.getOperand(0));
731
732 // Infer result type of this op based on TF's shape inference result.
733 Type elem_type = getElementTypeOrSelf(input_handle);
734 auto handle_dtype = getElementTypeOrSelf(op.output_handle().getType())
735 .cast<TF::VariantType>();
736 Type result_type =
737 GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
738
739 // Compute the difference of `size` and `input_size`, and store it in
740 // `size_diff`, which is then consumed by `if_cond`.
741 auto size_diff = rewriter.create<TF::SubOp>(
742 loc, RankedTensorType::get({}, shape_dtype), size, input_size);
743 auto if_cond = rewriter.create<TF::GreaterOp>(
744 loc, RankedTensorType::get({}, rewriter.getI1Type()), size_diff,
745 scalar_zero);
746
747 // Build the argument/result types for if branch function.
748 auto input_shape = rewriter.create<TF::ShapeOp>(
749 loc, RankedTensorType::get({-1}, shape_dtype), input_handle);
750
751 Type branch_args_type[] = {input_handle.getType(), input_shape.getType(),
752 size_diff.getType(), size.getType()};
753 Type branch_result_type[] = {result_type};
754 auto func_type = FunctionType::get(rewriter.getContext(), branch_args_type,
755 branch_result_type);
756
757 // Create functions in a higher scope before restoring the insertion point.
758 // Additionally, create the SymbolTable before further modifying the module.
759 auto original_point = rewriter.saveInsertionPoint();
760 rewriter.setInsertionPointAfter(op->getParentOfType<func::FuncOp>());
761 SymbolTable manager(op->getParentOfType<ModuleOp>());
762
763 // Constructs `then_branch`, which is executed when `if_cond` evaluates to
764 // true.
765 auto then_branch_op =
766 rewriter.create<func::FuncOp>(loc, "cond_true", func_type);
767 CreateCondTrueBranch(op, shape_dtype, result_type, then_branch_op,
768 &rewriter);
769 then_branch_op.setVisibility(func::FuncOp::Visibility::Private);
770
771 // Constructs `else_branch`, which is executed when `if_cond` evaluates to
772 // false.
773 auto else_branch_op =
774 rewriter.create<func::FuncOp>(loc, "cond_false", func_type);
775 CreateCondFalseBranch(loc, shape_dtype, result_type, else_branch_op,
776 &rewriter);
777 else_branch_op.setVisibility(func::FuncOp::Visibility::Private);
778
779 // Inserts the two blocks' names into the symbol table held by the module.
780 // Using SymbolTable will ensure that the inserted symbol names are
781 // unique.
782 manager.insert(then_branch_op);
783 manager.insert(else_branch_op);
784
785 rewriter.restoreInsertionPoint(original_point);
786 rewriter.replaceOpWithNewOp<TF::IfOp>(
787 op, result_type, if_cond,
788 /*input=*/
789 ArrayRef<Value>({input_handle, input_shape, size_diff, size}),
790 /*then_branch=*/
791 mlir::SymbolRefAttr::get(then_branch_op),
792 /*else_branch=*/
793 mlir::SymbolRefAttr::get(else_branch_op),
794 /*is_stateless=*/rewriter.getBoolAttr(true));
795 return success();
796 }
797
798 private:
799 // When the input tensorlist's size is smaller than the requested size,
800 // then branch is executed.
801 // Create a new tensorlist of size 'size - input_size' and concat it
802 // with the input tensorlist.
CreateCondTrueBranchmlir::__anonaaa99d6c0111::ConvertTensorListResize803 void CreateCondTrueBranch(TF::TensorListResizeOp resize_op, Type shape_dtype,
804 Type result_type, func::FuncOp branch_func,
805 ConversionPatternRewriter *rewriter) const {
806 auto guard = OpBuilder::InsertionGuard(*rewriter);
807 auto inputs = branch_func.getFunctionType().getInputs();
808 Block *block = rewriter->createBlock(
809 &branch_func.getBody(), branch_func.begin(), inputs,
810 SmallVector<Location>(inputs.size(), branch_func.getLoc()));
811
812 auto input_shape = block->getArgument(1);
813 auto size_diff = block->getArgument(2);
814 auto input = block->getArgument(0);
815
816 Location loc = resize_op.getLoc();
817 // Get the element shape by slicing from index 1 in the input shape.
818 Value slice_size = CreateI32SplatConst(loc, rewriter, {1}, -1);
819 Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
820 Value slice_start = CreateI32SplatConst(loc, rewriter, {1}, 1);
821 auto elem_shape = rewriter->create<TF::SliceOp>(
822 loc, RankedTensorType::get({-1}, shape_dtype), input_shape, slice_start,
823 slice_size);
824 auto extended_part = rewriter->create<TF::TensorListReserveOp>(
825 loc, resize_op.output_handle().getType(), elem_shape, size_diff);
826 // `ConcatOp` expects non-variant-typed input. Insert a
827 // `TensorListStackOp` here to convert type from variant to non-variant.
828 // Note that we are using the same `result_type` for both the
829 // `TensorListStackOp` and `ConcatOp`, since the first dimension of the
830 // shape specified by `result_type` is -1.
831 auto stacked_extended_part = rewriter->create<TF::TensorListStackOp>(
832 loc, result_type, extended_part,
833 /*element_shape=*/CreateI32SplatConst(loc, rewriter, {}, -1),
834 /*num_elements=*/rewriter->getI32IntegerAttr(-1));
835 auto concat_op = rewriter->create<TF::ConcatOp>(
836 loc, result_type, scalar_zero,
837 ArrayRef<Value>({input, stacked_extended_part}));
838 rewriter->create<func::ReturnOp>(loc, ArrayRef<Value>({concat_op}));
839 }
840
CreateCondFalseBranchmlir::__anonaaa99d6c0111::ConvertTensorListResize841 void CreateCondFalseBranch(Location loc, Type shape_dtype, Type result_type,
842 func::FuncOp branch_func,
843 ConversionPatternRewriter *rewriter) const {
844 // When the input tensorlist's size is larger or equal than the requested
845 // size, the else branch is executed.
846 // Slice the first 'size' rows from the input tensorlist.
847 auto guard = OpBuilder::InsertionGuard(*rewriter);
848 auto inputs = branch_func.getFunctionType().getInputs();
849 Block *block = rewriter->createBlock(
850 &branch_func.getBody(), branch_func.begin(), inputs,
851 SmallVector<Location>(inputs.size(), branch_func.getLoc()));
852
853 Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
854 Value vector_one = CreateI32SplatConst(loc, rewriter, {1}, 1);
855 auto input = block->getArgument(0);
856 auto size = block->getArgument(3);
857
858 // Subtract `input_rank` by 1 to get the item's rank, which is used as
859 // `partial_position_shape`.
860 auto input_rank = rewriter->create<TF::RankOp>(
861 loc, RankedTensorType::get({}, shape_dtype), input);
862 auto partial_position_shape = rewriter->create<TF::SubOp>(
863 loc, RankedTensorType::get({1}, shape_dtype), input_rank, vector_one);
864 auto slice_op =
865 CreateSliceOpForTensorList(loc, /*input_list=*/input,
866 /*start_index=*/scalar_zero, /*size=*/size,
867 /*item_rank=*/partial_position_shape,
868 /*result_type=*/result_type, rewriter);
869 rewriter->create<func::ReturnOp>(loc, ArrayRef<Value>({slice_op}));
870 }
871 };
872
873 struct ConvertTensorListGetItem
874 : public OpConversionPattern<TF::TensorListGetItemOp> {
875 using OpConversionPattern::OpConversionPattern;
876
matchAndRewritemlir::__anonaaa99d6c0111::ConvertTensorListGetItem877 LogicalResult matchAndRewrite(
878 TF::TensorListGetItemOp op, OpAdaptor adaptor,
879 ConversionPatternRewriter &rewriter) const override {
880 Value input = adaptor.getOperands()[0];
881 Value index = adaptor.getOperands()[1];
882 rewriter.replaceOpWithNewOp<TF::GatherOp>(op, op.getType(), input, index,
883 rewriter.getBoolAttr(true));
884 return success();
885 }
886 };
887
888 struct ConvertTensorListLength
889 : public OpConversionPattern<TF::TensorListLengthOp> {
890 using OpConversionPattern::OpConversionPattern;
891
matchAndRewritemlir::__anonaaa99d6c0111::ConvertTensorListLength892 LogicalResult matchAndRewrite(
893 TF::TensorListLengthOp op, OpAdaptor adaptor,
894 ConversionPatternRewriter &rewriter) const override {
895 Location loc = op.getLoc();
896 Value input_handle = adaptor.getOperands()[0];
897
898 BoolAttr true_attr = rewriter.getBoolAttr(true);
899 auto shape = rewriter.create<TF::ShapeOp>(loc, input_handle,
900 /*use_32bit=*/true_attr);
901 rewriter.replaceOpWithNewOp<TF::GatherOp>(
902 op, op.getType(), shape, CreateI32SplatConst(loc, &rewriter, {}, 0),
903 /*validate_indices=*/true_attr);
904 return success();
905 }
906 };
907
908 struct ConvertTensorListStack
909 : public OpConversionPattern<TF::TensorListStackOp> {
910 using OpConversionPattern::OpConversionPattern;
911
matchAndRewritemlir::__anonaaa99d6c0111::ConvertTensorListStack912 LogicalResult matchAndRewrite(
913 TF::TensorListStackOp op, OpAdaptor adaptor,
914 ConversionPatternRewriter &rewriter) const override {
915 Location loc = op.getLoc();
916 Value input = adaptor.getOperands()[0];
917 Value element_shape = adaptor.getOperands()[1];
918
919 // If the `element_shape` is a known constant (which is defined when calling
920 // `tensor_list_stack`) and also valid (not scalar), we rewrite this op to a
921 // trivial Reshape op (that doesn't actually change the input's shape) and
922 // also populate the shape info to the op result. The shape of the
923 // tensorlist is inferred from `num_elements` and `element_shape`.
924 auto ranked_type = element_shape.getType().dyn_cast<RankedTensorType>();
925 DenseIntElementsAttr dense_elem_attr;
926 if ((ranked_type && ranked_type.getRank() == 0) ||
927 !matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
928 // If no constant is spotted, just forward the operand.
929 rewriter.replaceOp(op, {input});
930 return success();
931 }
932
933 RankedTensorType shape_type =
934 RankedTensorType::get({-1}, rewriter.getIntegerType(32));
935 auto new_shape = rewriter.create<TF::ShapeOp>(loc, shape_type, input);
936 SmallVector<int64_t, 8> output_shape(/*Size=*/1, op.num_elements());
937 for (const auto &dim : dense_elem_attr.getValues<APInt>())
938 output_shape.push_back(dim.getSExtValue());
939 RankedTensorType result_type =
940 RankedTensorType::get(output_shape, getElementTypeOrSelf(input));
941 rewriter.replaceOpWithNewOp<TF::ReshapeOp>(op, result_type, input,
942 new_shape);
943 return success();
944 }
945 };
946
947 // Converts `TensorListConcatV2` into Unpack and Concat. First we unpack
948 // the input tensorlist along the first dimension, which results in N (where N
949 // is the first dim's size) tensors (each with shape [element_shape]). Then
950 // we concatenate all those tensors along the first dimension.
951 // The pattern will be rejected if either `element_shape` is not constant, or
952 // the first dimension of `input` is not known.
953 struct ConvertTensorListConcatV2
954 : public TensorListOpConverterBase<TF::TensorListConcatV2Op> {
955 using TensorListOpConverterBase<
956 TF::TensorListConcatV2Op>::TensorListOpConverterBase;
957 using TensorListOpConverterBase<
958 TF::TensorListConcatV2Op>::allow_tensorlist_pass_through_;
959
matchAndRewritemlir::__anonaaa99d6c0111::ConvertTensorListConcatV2960 LogicalResult matchAndRewrite(
961 TF::TensorListConcatV2Op op, OpAdaptor adaptor,
962 ConversionPatternRewriter &rewriter) const override {
963 Location loc = op.getLoc();
964 Value input = adaptor.getOperands()[0];
965 Value element_shape = adaptor.getOperands()[1];
966
967 // Only match when `element_shape` is a constant.
968 DenseIntElementsAttr dense_elem_attr;
969 if (!matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
970 const char *error_info = "requires element_shape to be a constant";
971 return allow_tensorlist_pass_through_
972 ? rewriter.notifyMatchFailure(op, error_info)
973 : op.emitOpError(error_info);
974 }
975 llvm::SmallVector<int64_t, 4> output_shape;
976 for (const auto &dim : dense_elem_attr.getValues<APInt>()) {
977 output_shape.push_back(dim.getSExtValue());
978 }
979
980 // First unpack the input tensor along the first dimension.
981 Type input_element_type = getElementTypeOrSelf(input);
982 int64_t num_unpacked = 0;
983 if (auto type = input.getType().dyn_cast<RankedTensorType>()) {
984 if (type.getDimSize(0) > 0) {
985 num_unpacked = type.getDimSize(0);
986 } else {
987 const char *error_info =
988 "requires the first dimension of input tensor to have > 0 "
989 "dimension";
990 return allow_tensorlist_pass_through_
991 ? rewriter.notifyMatchFailure(op, error_info)
992 : op.emitOpError(error_info);
993 }
994 }
995 llvm::SmallVector<Type, 1> unpack_output_type;
996 unpack_output_type.insert(
997 unpack_output_type.begin(), num_unpacked,
998 RankedTensorType::get(output_shape, input_element_type));
999 auto unpack = rewriter.create<TF::UnpackOp>(loc, unpack_output_type, input,
1000 /*axis=*/0);
1001
1002 // Concatenate the unpacked tensors along the first dimension.
1003 // Since we're concatenating along first dimension, change its dim size to
1004 // -1.
1005 output_shape[0] = -1;
1006 Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
1007 auto concat = rewriter.create<TF::ConcatOp>(
1008 loc, RankedTensorType::get(output_shape, input_element_type),
1009 scalar_zero, unpack->getResults());
1010 // `lengths` is only useful for computing gradient. For now we just return
1011 // a placeholder tensor.
1012 rewriter.replaceOp(
1013 op, {concat.getResult(), CreateI64SplatConst(loc, &rewriter, {0}, 0)});
1014 return success();
1015 }
1016 };
1017
1018 struct ConvertIdentity : public OpConversionPattern<TF::IdentityOp> {
1019 using OpConversionPattern::OpConversionPattern;
1020
matchAndRewritemlir::__anonaaa99d6c0111::ConvertIdentity1021 LogicalResult matchAndRewrite(
1022 TF::IdentityOp op, OpAdaptor adaptor,
1023 ConversionPatternRewriter &rewriter) const override {
1024 Value input = adaptor.getOperands()[0];
1025 rewriter.replaceOpWithNewOp<TF::IdentityOp>(
1026 op, input.getType(), adaptor.getOperands(), op->getAttrs());
1027 return success();
1028 }
1029 };
1030
1031 struct ConvertReturn : public OpConversionPattern<func::ReturnOp> {
1032 using OpConversionPattern::OpConversionPattern;
1033
matchAndRewritemlir::__anonaaa99d6c0111::ConvertReturn1034 LogicalResult matchAndRewrite(
1035 func::ReturnOp op, OpAdaptor adaptor,
1036 ConversionPatternRewriter &rewriter) const override {
1037 rewriter.updateRootInPlace(op,
1038 [&] { op->setOperands(adaptor.getOperands()); });
1039 return success();
1040 }
1041 };
1042
1043 struct ConvertYield : public OpConversionPattern<TF::YieldOp> {
1044 using OpConversionPattern::OpConversionPattern;
1045
matchAndRewritemlir::__anonaaa99d6c0111::ConvertYield1046 LogicalResult matchAndRewrite(
1047 TF::YieldOp op, OpAdaptor adaptor,
1048 ConversionPatternRewriter &rewriter) const override {
1049 rewriter.updateRootInPlace(op,
1050 [&] { op->setOperands(adaptor.getOperands()); });
1051 return success();
1052 }
1053 };
1054
1055 // Returns an unranked tensor type with an element of the same type as `value`
1056 // if `type` is a tensor of variant. Otherwise, returns `type` unmodified.
VariantToUnrankedTensorType(Type type,Value value)1057 Type VariantToUnrankedTensorType(Type type, Value value) {
1058 TF::VariantType variant_ty =
1059 getElementTypeOrSelf(type).dyn_cast<TF::VariantType>();
1060 if (!variant_ty) {
1061 return type;
1062 }
1063 if (!variant_ty.getSubtypes().empty()) {
1064 // Short-circut if the variant type has subtype info.
1065 return UnrankedTensorType::get(
1066 variant_ty.getSubtypes()[0].getElementType());
1067 }
1068 Type value_type = value.getType();
1069 Type element_type;
1070 variant_ty = value_type.dyn_cast<TF::VariantType>();
1071 if (variant_ty && !variant_ty.getSubtypes().empty()) {
1072 element_type = variant_ty.getSubtypes()[0].getElementType();
1073 } else {
1074 element_type = getElementTypeOrSelf(value_type);
1075 }
1076 return UnrankedTensorType::get(element_type);
1077 }
1078
1079 // Returns true if we can deduce the type is tensorlist.
IsTensorListType(Type type,llvm::Optional<Value> value)1080 bool IsTensorListType(Type type, llvm::Optional<Value> value) {
1081 TF::VariantType variant_ty =
1082 getElementTypeOrSelf(type).dyn_cast<TF::VariantType>();
1083 if (!variant_ty) {
1084 return false;
1085 }
1086 // Check there is only one subtype contained in the variant type. Note that
1087 // when `subtypes.size() == 1` does not always mean the type is actually
1088 // a tensorlist. We probably need some form of data flow analysis.
1089 if (variant_ty.getSubtypes().size() == 1) {
1090 return true;
1091 }
1092 // If subtype info is not available, check if the value is used by any of
1093 // the following TensorList operations.
1094 if (!value.has_value()) {
1095 return false;
1096 }
1097 for (const mlir::OpOperand &use : value.getValue().getUses()) {
1098 mlir::Operation *op = use.getOwner();
1099 if (llvm::isa<TF::TensorListGetItemOp>(op) ||
1100 llvm::isa<TF::TensorListLengthOp>(op) ||
1101 llvm::isa<TF::TensorListPushBackOp>(op) ||
1102 llvm::isa<TF::TensorListReserveOp>(op) ||
1103 llvm::isa<TF::TensorListSetItemOp>(op) ||
1104 llvm::isa<TF::TensorListStackOp>(op) ||
1105 llvm::isa<TF::TensorListResizeOp>(op)) {
1106 return true;
1107 }
1108 }
1109 return false;
1110 }
1111
1112 // Returns a set of integers that correspond to the tensorlist arguments in
1113 // the function.
GetTensorListArgumentsIndex(func::FuncOp func)1114 llvm::SmallSet<int, 4> GetTensorListArgumentsIndex(func::FuncOp func) {
1115 llvm::SmallSet<int, 4> set;
1116 for (const auto &arg_and_idx : llvm::enumerate(func.getArguments())) {
1117 if (IsTensorListType(arg_and_idx.value().getType(), arg_and_idx.value())) {
1118 set.insert(arg_and_idx.index());
1119 }
1120 }
1121 return set;
1122 }
1123
1124 // Returns a set of integers that correspond to the tensorlist results in the
1125 // function.
GetTensorListResultsIndex(func::FuncOp func)1126 llvm::SmallSet<int, 4> GetTensorListResultsIndex(func::FuncOp func) {
1127 llvm::SmallSet<int, 4> set;
1128
1129 for (const auto &result_and_idx :
1130 llvm::enumerate(func.getFunctionType().getResults())) {
1131 if (IsTensorListType(result_and_idx.value(), llvm::None)) {
1132 set.insert(result_and_idx.index());
1133 }
1134 }
1135 return set;
1136 }
1137
1138 // Updates the tensorlist types based on the input index. If the tensorlist's
1139 // size isn't changed(which is indicated by `resized_tensor_list_index`), then
1140 // we will use the original operand's type, otherwise update it with the
1141 // unranked tensor type.
1142 template <typename R>
UpdateTensorListTypes(const llvm::SmallSet<int,4> & tensor_list_index,const llvm::SmallSet<int,4> & resized_tensor_list_index,ArrayRef<Type> types,R && range,ValueRange operands,llvm::SmallVectorImpl<Type> * updated_types)1143 void UpdateTensorListTypes(
1144 const llvm::SmallSet<int, 4> &tensor_list_index,
1145 const llvm::SmallSet<int, 4> &resized_tensor_list_index,
1146 ArrayRef<Type> types, R &&range, ValueRange operands,
1147 llvm::SmallVectorImpl<Type> *updated_types) {
1148 int i = 0;
1149 for (const auto it : llvm::zip(types, range, operands)) {
1150 if (tensor_list_index.count(i)) {
1151 // Only change the tensorlist's type to unranked tensor if it has been
1152 // resized.
1153 if (resized_tensor_list_index.count(i)) {
1154 updated_types->push_back(
1155 VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
1156 } else {
1157 updated_types->push_back(std::get<2>(it).getType());
1158 }
1159 } else {
1160 updated_types->push_back(std::get<0>(it));
1161 }
1162 ++i;
1163 }
1164 }
1165
1166 // Updates the tensorlist types to unranked tensor types based on the input
1167 // index.
1168 template <typename R>
ChangeVariantToUnrankedTensorType(const llvm::SmallSet<int,4> & tensor_list_index,ArrayRef<Type> types,R && range,llvm::SmallVectorImpl<Type> * updated_types)1169 void ChangeVariantToUnrankedTensorType(
1170 const llvm::SmallSet<int, 4> &tensor_list_index, ArrayRef<Type> types,
1171 R &&range, llvm::SmallVectorImpl<Type> *updated_types) {
1172 int i = 0;
1173 for (const auto it : llvm::zip(types, range)) {
1174 if (tensor_list_index.count(i)) {
1175 updated_types->push_back(
1176 VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
1177 } else {
1178 updated_types->push_back(std::get<0>(it));
1179 }
1180 ++i;
1181 }
1182 }
1183
1184 // Updates the specified function's type and region signature.
UpdateFunctionAndRegionType(ConversionPatternRewriter & rewriter,func::FuncOp func,llvm::ArrayRef<Type> updated_argument_types,llvm::ArrayRef<Type> updated_result_types)1185 void UpdateFunctionAndRegionType(ConversionPatternRewriter &rewriter,
1186 func::FuncOp func,
1187 llvm::ArrayRef<Type> updated_argument_types,
1188 llvm::ArrayRef<Type> updated_result_types) {
1189 // Change `func`'s argument type to `unranked_argument_types`. If its
1190 // return types contain a `DT_VARIANT`, change it to the unranked type
1191 // derived from the corresponding argument.
1192 rewriter.updateRootInPlace(func, [&] {
1193 func.setType(FunctionType::get(func.getContext(), updated_argument_types,
1194 updated_result_types));
1195 });
1196 Region &entry = func.getRegion();
1197 TypeConverter::SignatureConversion signature_conversion(
1198 entry.getNumArguments());
1199 for (const BlockArgument &arg : entry.getArguments()) {
1200 signature_conversion.addInputs(arg.getArgNumber(),
1201 updated_argument_types[arg.getArgNumber()]);
1202 }
1203 rewriter.applySignatureConversion(&entry, signature_conversion);
1204 }
1205
1206 // Changes the function type of `cond_func` and `body_func` for the given While
1207 // op.
UpdateFunctionTypesForWhileOp(ConversionPatternRewriter & rewriter,TF::WhileOp op,ValueRange operands,const llvm::SmallSet<int,4> & tensor_list_args,const llvm::SmallSet<int,4> & resized_tensor_lists)1208 LogicalResult UpdateFunctionTypesForWhileOp(
1209 ConversionPatternRewriter &rewriter, TF::WhileOp op, ValueRange operands,
1210 const llvm::SmallSet<int, 4> &tensor_list_args,
1211 const llvm::SmallSet<int, 4> &resized_tensor_lists) {
1212 int func_index = 0;
1213 for (func::FuncOp func : {op.cond_function(), op.body_function()}) {
1214 ++func_index;
1215 if (!func) continue;
1216
1217 FunctionType func_type = func.getFunctionType();
1218 int num_inputs = func_type.getNumInputs();
1219 int num_results = func_type.getNumResults();
1220
1221 // For each argument type in function's arguments, change it to uranked
1222 // tensor type if it's a variant type.
1223 SmallVector<Type, 8> updated_argument_types;
1224 updated_argument_types.reserve(num_inputs);
1225 UpdateTensorListTypes<mlir::OperandRange>(
1226 tensor_list_args, resized_tensor_lists, func_type.getInputs(),
1227 op.getOperands(), operands, &updated_argument_types);
1228
1229 // Change all DT_VARIANT result types in function results to unranked tensor
1230 // type with element type derived from the corresponding input operand. This
1231 // is correct because while body's inputs and results have the same type.
1232 SmallVector<Type, 8> updated_result_types;
1233 updated_result_types.reserve(num_results);
1234 if (func_index == 1) {
1235 // We only update the result types for the body function.
1236 for (Type ty : func_type.getResults()) {
1237 updated_result_types.push_back(ty);
1238 }
1239 } else {
1240 UpdateTensorListTypes<mlir::OperandRange>(
1241 tensor_list_args, resized_tensor_lists, func_type.getResults(),
1242 op.getOperands(), operands, &updated_result_types);
1243 }
1244
1245 UpdateFunctionAndRegionType(rewriter, func, updated_argument_types,
1246 updated_result_types);
1247 }
1248 return success();
1249 }
1250
1251 // Changes the function type of `then_function` and `else_function` for the
1252 // given If op.
UpdateFunctionTypesForIfOp(ConversionPatternRewriter & rewriter,TF::IfOp op,ValueRange operands,const llvm::SmallSet<int,4> & tensor_list_args,const llvm::SmallSet<int,4> & resized_tensor_lists,llvm::ArrayRef<Type> updated_result_types)1253 LogicalResult UpdateFunctionTypesForIfOp(
1254 ConversionPatternRewriter &rewriter, TF::IfOp op, ValueRange operands,
1255 const llvm::SmallSet<int, 4> &tensor_list_args,
1256 const llvm::SmallSet<int, 4> &resized_tensor_lists,
1257 llvm::ArrayRef<Type> updated_result_types) {
1258 for (func::FuncOp func : {op.else_function(), op.then_function()}) {
1259 if (!func) continue;
1260
1261 FunctionType func_type = func.getFunctionType();
1262 int num_inputs = func_type.getNumInputs();
1263
1264 // Update the argument types of the function. If it's a tensorlist and
1265 // is not resized inside the function, we will use the corresponding
1266 // operand's type, otherwise change its type to unranked tensor type.
1267 SmallVector<Type, 8> updated_argument_types;
1268 updated_argument_types.reserve(num_inputs);
1269 UpdateTensorListTypes<mlir::OperandRange>(
1270 tensor_list_args, resized_tensor_lists, func_type.getInputs(),
1271 op.getOperands().drop_front(), operands.drop_front(),
1272 &updated_argument_types);
1273
1274 UpdateFunctionAndRegionType(rewriter, func, updated_argument_types,
1275 updated_result_types);
1276 }
1277 return success();
1278 }
1279
1280 // Returns a `llvm::DenseMap` which maps from the index of tensorlist in the
1281 // result, to the index of the same tensorlist in the arguments. For `If` op's
1282 // branch functions, the results and arguments are not usually matched 1-1. This
1283 // will let us konw which tensorlist result maps to which tensorlist in the
1284 // arguments. Once we know this info it will help us decide the types of the
1285 // result tensorlist based on the operand's of the `If` op.
MapTensorListResultToArgument(func::FuncOp func)1286 llvm::DenseMap<int, int> MapTensorListResultToArgument(func::FuncOp func) {
1287 // `map_fn` will trace upwards along the use-def chain of the ssa value. It
1288 // starts from the last ssa value (returned by the function), and check its
1289 // parent op iteratively. If the root ssa value appears in the function's
1290 // argument list, it will return the index of the corresponding argument,
1291 // otherwise it will return -1.
1292 auto map_fn = [](Value value) -> int {
1293 Value parent = value;
1294 while (true) {
1295 if (auto identity = parent.getDefiningOp<TF::IdentityOp>()) {
1296 parent = identity.input();
1297 } else if (auto set_item =
1298 parent.getDefiningOp<TF::TensorListSetItemOp>()) {
1299 parent = set_item.input_handle();
1300 } else {
1301 break;
1302 }
1303 }
1304 if (auto block_arg = parent.dyn_cast<mlir::BlockArgument>()) {
1305 return block_arg.getArgNumber();
1306 }
1307 // Returns -1 if we don't find which this result maps to.
1308 return -1;
1309 };
1310
1311 llvm::SmallVector<Value, 4> returns;
1312 for (auto res : func.getBody().back().getTerminator()->getOperands()) {
1313 returns.push_back(res);
1314 }
1315 llvm::DenseMap<int, int> result;
1316 for (const auto &result_and_idx : llvm::enumerate(returns)) {
1317 if (IsTensorListType(result_and_idx.value().getType(),
1318 result_and_idx.value())) {
1319 int arg_idx = map_fn(result_and_idx.value());
1320 if (arg_idx != -1) {
1321 result.insert({result_and_idx.index(), arg_idx});
1322 }
1323 }
1324 }
1325 return result;
1326 }
1327
1328 // Updates the tensorlist result types for the `If` Op. If the tensorlist result
1329 // maps to a specific argument (indicated by `tensor_list_map`), and also that
1330 // tensorlist argument's shape isn't changed (indicated by
1331 // `resized_tensor_list_index`), we will update this tensorlist's result type to
1332 // the corresponding operand's type. In all other cases we change the
1333 // tensorlist's type to unranked tensor type.
1334 template <typename R>
UpdateTensorListResultTypesForIf(const llvm::SmallSet<int,4> & tensor_list_index,const llvm::SmallSet<int,4> & resized_tensor_list_index,const llvm::DenseMap<int,int> & tensor_list_map,ArrayRef<Type> types,R && range,ValueRange operands,llvm::SmallVectorImpl<Type> * updated_types)1335 void UpdateTensorListResultTypesForIf(
1336 const llvm::SmallSet<int, 4> &tensor_list_index,
1337 const llvm::SmallSet<int, 4> &resized_tensor_list_index,
1338 const llvm::DenseMap<int, int> &tensor_list_map, ArrayRef<Type> types,
1339 R &&range, ValueRange operands,
1340 llvm::SmallVectorImpl<Type> *updated_types) {
1341 int i = 0;
1342 for (const auto it : llvm::zip(types, range)) {
1343 if (!tensor_list_index.count(i)) {
1344 updated_types->push_back(std::get<0>(it));
1345 ++i;
1346 continue;
1347 }
1348 auto iter = tensor_list_map.find(i);
1349 if (iter != tensor_list_map.end()) {
1350 int arg_idx = iter->second;
1351 if (!resized_tensor_list_index.count(arg_idx)) {
1352 // If the mapped tensorlist argument's size isn't changed, we will
1353 // use the corresponding `operand` type.
1354 updated_types->push_back(operands[arg_idx].getType());
1355 ++i;
1356 continue;
1357 }
1358 }
1359 updated_types->push_back(
1360 VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
1361 ++i;
1362 }
1363 }
1364
1365 struct ConvertIf : public OpConversionPattern<TF::IfOp> {
1366 using OpConversionPattern::OpConversionPattern;
1367
matchAndRewritemlir::__anonaaa99d6c0111::ConvertIf1368 LogicalResult matchAndRewrite(
1369 TF::IfOp op, OpAdaptor adaptor,
1370 ConversionPatternRewriter &rewriter) const override {
1371 // Find all Tensor List arugments.
1372 auto tensor_list_args = GetTensorListArgumentsIndex(op.else_function());
1373 auto tensor_list_results = GetTensorListResultsIndex(op.else_function());
1374 auto tensor_list_map = MapTensorListResultToArgument(op.else_function());
1375 llvm::SmallSet<int, 4> resized_tensor_lists =
1376 GetResizedTensorListIndexes(op.else_function(), tensor_list_args);
1377
1378 llvm::SmallVector<Type, 8> result_types;
1379 result_types.reserve(op.getNumResults());
1380 llvm::SmallVector<Type, 4> op_result_types;
1381 for (Type ty : op.getResultTypes()) {
1382 op_result_types.push_back(ty);
1383 }
1384
1385 UpdateTensorListResultTypesForIf<mlir::ResultRange>(
1386 tensor_list_results, resized_tensor_lists, tensor_list_map,
1387 op_result_types, op->getResults(), adaptor.getOperands().drop_front(),
1388 &result_types);
1389
1390 // Create a new if op with new operands and updated result types.
1391 auto converted = rewriter.create<TF::IfOp>(
1392 op.getLoc(), result_types, adaptor.getOperands(), op->getAttrs());
1393 converted->removeAttr("T");
1394 (void)UpdateFunctionTypesForIfOp(rewriter, converted, adaptor.getOperands(),
1395 tensor_list_args, resized_tensor_lists,
1396 result_types);
1397 rewriter.replaceOp(op, converted.getResults());
1398 return success();
1399 }
1400 };
1401
1402 struct ConvertWhile : public OpConversionPattern<TF::WhileOp> {
1403 using OpConversionPattern::OpConversionPattern;
1404
matchAndRewritemlir::__anonaaa99d6c0111::ConvertWhile1405 LogicalResult matchAndRewrite(
1406 TF::WhileOp op, OpAdaptor adaptor,
1407 ConversionPatternRewriter &rewriter) const override {
1408 // Find all Tensor List arugments.
1409 auto tensor_list_args = GetTensorListArgumentsIndex(op.body_function());
1410
1411 llvm::SmallVector<Type, 8> result_types;
1412 result_types.reserve(op.getNumOperands());
1413 // Change all DT_VARIANT result types to unranked tensor type.
1414 llvm::SmallVector<Type, 4> op_result_types;
1415 for (Type ty : op.getResultTypes()) {
1416 op_result_types.push_back(ty);
1417 }
1418
1419 llvm::SmallSet<int, 4> resized_tensor_lists =
1420 GetResizedTensorListIndexes(op.body_function(), tensor_list_args);
1421 UpdateTensorListTypes<mlir::OperandRange>(
1422 tensor_list_args, resized_tensor_lists, op_result_types,
1423 op.getOperands(), adaptor.getOperands(), &result_types);
1424
1425 // Create a new while op with new operands and updated result types.
1426 auto converted = rewriter.create<TF::WhileOp>(
1427 op.getLoc(), result_types, adaptor.getOperands(), op->getAttrs());
1428 converted->removeAttr("T");
1429 (void)UpdateFunctionTypesForWhileOp(rewriter, converted,
1430 adaptor.getOperands(), tensor_list_args,
1431 resized_tensor_lists);
1432
1433 rewriter.replaceOp(op, converted.getResults());
1434 return success();
1435 }
1436 };
1437
1438 struct ConvertWhileRegion : public OpConversionPattern<TF::WhileRegionOp> {
1439 using OpConversionPattern::OpConversionPattern;
1440
matchAndRewritemlir::__anonaaa99d6c0111::ConvertWhileRegion1441 LogicalResult matchAndRewrite(
1442 TF::WhileRegionOp op, OpAdaptor adaptor,
1443 ConversionPatternRewriter &rewriter) const override {
1444 llvm::SmallVector<Type, 8> result_types;
1445 result_types.reserve(op.getNumOperands());
1446 // Change all DT_VARIANT result types to unranked tensor type.
1447 for (auto it : llvm::zip(op.getResultTypes(), adaptor.getOperands()))
1448 result_types.push_back(
1449 VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
1450
1451 // Create a new while op with new operands and updated result types.
1452 auto converted = rewriter.create<TF::WhileRegionOp>(
1453 op.getLoc(), result_types, adaptor.getOperands(), op->getAttrs());
1454
1455 // Inline the regions from the old while into the new one, and apply
1456 // signature conversion to inlined region.
1457 for (auto it : llvm::zip(op.getRegions(), converted.getRegions())) {
1458 Region &old_region = *std::get<0>(it);
1459 Region &new_region = *std::get<1>(it);
1460
1461 Block &entry = old_region.front();
1462 // Build signature conversion for the region.
1463 TypeConverter::SignatureConversion signature_conversion(
1464 adaptor.getOperands().size());
1465 for (auto it : llvm::zip(entry.getArguments(), adaptor.getOperands())) {
1466 BlockArgument arg = std::get<0>(it);
1467 signature_conversion.addInputs(
1468 arg.getArgNumber(),
1469 VariantToUnrankedTensorType(arg.getType(), std::get<1>(it)));
1470 }
1471
1472 rewriter.inlineRegionBefore(old_region, new_region, new_region.end());
1473 rewriter.applySignatureConversion(&new_region, signature_conversion);
1474 }
1475
1476 rewriter.replaceOp(op, converted.getResults());
1477 return success();
1478 }
1479 };
1480
1481 #include "tensorflow/compiler/mlir/lite/transforms/generated_lower_static_tensor_list.inc"
1482
runOnOperation()1483 void LowerStaticTensorListPass::runOnOperation() {
1484 auto *context = &getContext();
1485
1486 // TensorFlow operations that doesn't have operands and results of type
1487 // variant are legal. Here, we don't distinguish between variants encoding
1488 // TensorList or some other type as that information is not available here.
1489 // Partial legalization is used below to still allow ops with variant types
1490 // still.
1491 auto is_legal = [](Operation *op) {
1492 auto is_not_variant = [](Type ty) {
1493 return !ty.cast<ShapedType>().getElementType().isa<TF::VariantType>();
1494 };
1495 return llvm::all_of(op->getOperandTypes(), is_not_variant) &&
1496 llvm::all_of(op->getResultTypes(), is_not_variant);
1497 };
1498
1499 ConversionTarget target(*context);
1500 target.addDynamicallyLegalDialect<TF::TensorFlowDialect>(is_legal);
1501 target.addIllegalOp<TF::EmptyTensorListOp, TF::TensorListFromTensorOp,
1502 TF::TensorListGetItemOp, TF::TensorListLengthOp,
1503 TF::TensorListPushBackOp, TF::TensorListReserveOp,
1504 TF::TensorListSetItemOp, TF::TensorListStackOp,
1505 TF::TensorListResizeOp, TF::TensorListConcatV2Op>();
1506 // TODO(hinsu): Use TFLite constant op for constants.
1507 target.addLegalOp<arith::ConstantOp>();
1508 target.addLegalOp<func::FuncOp>();
1509 target.addDynamicallyLegalOp<func::ReturnOp>(is_legal);
1510 target.addDynamicallyLegalOp<TF::YieldOp>(is_legal);
1511 target.addLegalOp<TFL::CustomOp>();
1512 // Register fused LSTM/RNN ops as legal.
1513 target.addLegalOp<TFL::LSTMOp>();
1514 target.addLegalOp<TFL::UnidirectionalSequenceLSTMOp>();
1515 target.addLegalOp<TFL::UnidirectionalSequenceRNNOp>();
1516 target.addLegalOp<TFL::BidirectionalSequenceLSTMOp>();
1517
1518 RewritePatternSet patterns(&getContext());
1519 populateWithGenerated(patterns);
1520 patterns.add<ConvertConst, ConvertIdentity, ConvertTensorListGetItem,
1521 ConvertTensorListLength, ConvertTensorListPushBack,
1522 ConvertTensorListStack, ConvertTensorListResize, ConvertWhile,
1523 ConvertWhileRegion, ConvertIf, ConvertReturn, ConvertYield>(
1524 context);
1525 patterns.add<ConvertTensorListSetItem>(context,
1526 this->enable_dynamic_update_slice_);
1527 patterns.add<ConvertEmptyTensorList, ConvertTensorListConcatV2,
1528 ConvertTensorListReserve>(context,
1529 this->allow_tensorlist_pass_through_,
1530 this->default_to_single_batch_);
1531 ModuleOp module = getOperation();
1532 if (!this->allow_tensorlist_pass_through_) {
1533 if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
1534 module.emitError(
1535 "Lowering tensor list ops is failed. Please consider using Select TF "
1536 "ops and disabling `_experimental_lower_tensor_list_ops` flag in the "
1537 "TFLite converter object. For example, "
1538 "converter.target_spec.supported_ops = "
1539 "[tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]\\n "
1540 "converter._experimental_lower_tensor_list_ops = False");
1541 signalPassFailure();
1542 }
1543 } else {
1544 // If `allow_tensorlist_pass_through` is set to true, if legalization fails
1545 // we should not leak the diagnostic info outside this pass. Hence we use
1546 // a `StatusScopedDiagnosticHandler` here to capture diagnostics generated
1547 // within this pass.
1548 StatusScopedDiagnosticHandler handler(context);
1549 if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
1550 auto _ = handler.ConsumeStatus();
1551 }
1552 }
1553 }
1554
1555 } // namespace
1556
1557 /// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
1558 /// pass.
CreateLowerStaticTensorListPass(bool allow_tensorlist_pass_through,bool default_to_single_batch,bool enable_dynamic_update_slice)1559 std::unique_ptr<OperationPass<ModuleOp>> TFL::CreateLowerStaticTensorListPass(
1560 bool allow_tensorlist_pass_through, bool default_to_single_batch,
1561 bool enable_dynamic_update_slice) {
1562 return std::make_unique<LowerStaticTensorListPass>(
1563 allow_tensorlist_pass_through, default_to_single_batch,
1564 enable_dynamic_update_slice);
1565 }
1566
1567 std::unique_ptr<OperationPass<ModuleOp>>
CreateLowerStaticTensorListPass()1568 TFL::CreateLowerStaticTensorListPass() {
1569 return std::make_unique<LowerStaticTensorListPass>();
1570 }
1571
1572 } // namespace mlir
1573