xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass prepares for legalization to the TFLite dialect by
17 // converting Tensorlist operations in TensorFlow dialect into operations that
18 // can be legalized to TensorFlow Lite dialect with simple replacements.  The
19 // newly created operations are in the TensorFlow dialect if the operation can
20 // be represented using a TensorFlow op. Otherwise, TensorFlow Lite dialect op
21 // is used.
22 
23 #include <climits>
24 #include <cstdint>
25 #include <utility>
26 
27 #include "absl/container/inlined_vector.h"
28 #include "llvm/ADT/ArrayRef.h"
29 #include "llvm/ADT/DenseMap.h"
30 #include "llvm/ADT/None.h"
31 #include "llvm/ADT/Optional.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/SmallSet.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/ADT/StringSwitch.h"
36 #include "llvm/Support/Casting.h"
37 #include "llvm/Support/Debug.h"
38 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"  // from @llvm-project
39 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
40 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
41 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
42 #include "mlir/IR/Attributes.h"  // from @llvm-project
43 #include "mlir/IR/Block.h"  // from @llvm-project
44 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
45 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
46 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
47 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
48 #include "mlir/IR/Matchers.h"  // from @llvm-project
49 #include "mlir/IR/Operation.h"  // from @llvm-project
50 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
51 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
52 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
53 #include "mlir/IR/TypeRange.h"  // from @llvm-project
54 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
55 #include "mlir/IR/Types.h"  // from @llvm-project
56 #include "mlir/IR/UseDefLists.h"  // from @llvm-project
57 #include "mlir/IR/Value.h"  // from @llvm-project
58 #include "mlir/Pass/Pass.h"  // from @llvm-project
59 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
60 #include "mlir/Support/LLVM.h"  // from @llvm-project
61 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
62 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
63 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
64 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
65 #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
66 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
67 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
68 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
69 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
70 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
71 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
72 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
73 #include "tensorflow/core/framework/tensor.h"
74 #include "tensorflow/core/framework/types.pb.h"
75 #include "tensorflow/core/kernels/tensor_list.h"
76 
77 #define DEBUG_TYPE "tf-tfl-legalization"
78 
79 //===----------------------------------------------------------------------===//
80 // The actual LowerStaticTensorList Pass.
81 //
82 namespace mlir {
83 namespace {
84 #define GEN_PASS_CLASSES
85 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
86 
87 /// Lower TensorList ops in functions for subsequent legalization.
88 struct LowerStaticTensorListPass
89     : public LowerStaticTensorListPassBase<LowerStaticTensorListPass> {
90   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerStaticTensorListPass)
91 
92   LowerStaticTensorListPass() = default;
LowerStaticTensorListPassmlir::__anonaaa99d6c0111::LowerStaticTensorListPass93   LowerStaticTensorListPass(const LowerStaticTensorListPass &) {}
LowerStaticTensorListPassmlir::__anonaaa99d6c0111::LowerStaticTensorListPass94   explicit LowerStaticTensorListPass(bool allow_tensorlist_pass_through,
95                                      bool default_to_single_batch,
96                                      bool enable_dynamic_update_slice) {
97     this->allow_tensorlist_pass_through_ = allow_tensorlist_pass_through;
98     this->default_to_single_batch_ = default_to_single_batch;
99     this->enable_dynamic_update_slice_ = enable_dynamic_update_slice;
100   }
101 
102   void runOnOperation() override;
103 };
104 
CreateI32SplatConst(Location loc,PatternRewriter * rewriter,ArrayRef<int64_t> shape,int32_t val)105 Value CreateI32SplatConst(Location loc, PatternRewriter *rewriter,
106                           ArrayRef<int64_t> shape, int32_t val) {
107   RankedTensorType type =
108       RankedTensorType::get(shape, rewriter->getIntegerType(32));
109   DenseElementsAttr attr =
110       DenseElementsAttr::get(type, rewriter->getI32IntegerAttr(val));
111   return rewriter->create<arith::ConstantOp>(loc, type, attr);
112 }
113 
CreateI64SplatConst(Location loc,PatternRewriter * rewriter,ArrayRef<int64_t> shape,int64_t val)114 Value CreateI64SplatConst(Location loc, PatternRewriter *rewriter,
115                           ArrayRef<int64_t> shape, int64_t val) {
116   RankedTensorType type =
117       RankedTensorType::get(shape, rewriter->getIntegerType(64));
118   DenseElementsAttr attr =
119       DenseElementsAttr::get(type, rewriter->getI64IntegerAttr(val));
120   return rewriter->create<arith::ConstantOp>(loc, type, attr);
121 }
122 
CreateI32SplatTensor(Location loc,PatternRewriter * rewriter,Value shape_tensor,int32_t val)123 Value CreateI32SplatTensor(Location loc, PatternRewriter *rewriter,
124                            Value shape_tensor, int32_t val) {
125   Value scalar_val = CreateI32SplatConst(loc, rewriter, {}, val);
126   return rewriter->create<TF::FillOp>(
127       loc, RankedTensorType::get({-1}, rewriter->getIntegerType(32)),
128       shape_tensor, scalar_val);
129 }
130 
131 // Returns a new type by prepending the specified dimension to the shape of
132 // the given type if it is a ranked type.
PrependLeadingDimIfRanked(int64_t dim,Type type,PatternRewriter * rewriter)133 Type PrependLeadingDimIfRanked(int64_t dim, Type type,
134                                PatternRewriter *rewriter) {
135   Type dtype = getElementTypeOrSelf(type);
136   if (RankedTensorType ty = type.dyn_cast<RankedTensorType>()) {
137     llvm::SmallVector<int64_t, 4> shape = {dim};
138     shape.append(ty.getShape().begin(), ty.getShape().end());
139     return RankedTensorType::get(shape, dtype);
140   }
141   return type;
142 }
143 
GetTensorTypeForTensorList(Type element_type,TF::VariantType handle_dtype,PatternRewriter * rewriter)144 Type GetTensorTypeForTensorList(Type element_type, TF::VariantType handle_dtype,
145                                 PatternRewriter *rewriter) {
146   // If the variant type in the output handle has item shape available, use it
147   // to derive the output shape by setting unknown leading dimension.
148   // Otherwise, result type will be of unranked type.
149   if (handle_dtype.getSubtypes().empty()) {
150     return UnrankedTensorType::get(element_type);
151   }
152   return PrependLeadingDimIfRanked(-1, handle_dtype.getSubtypes()[0], rewriter);
153 }
154 
155 // Gets the index of tensorlist arguments which size might get changed by the
156 // function.
GetResizedTensorListIndexes(func::FuncOp func,const llvm::SmallSet<int,4> & tensor_list_args)157 llvm::SmallSet<int, 4> GetResizedTensorListIndexes(
158     func::FuncOp func, const llvm::SmallSet<int, 4> &tensor_list_args) {
159   // `indexes` stores the argument index of tensorlists which size may get
160   // updated in the function.
161   llvm::SmallSet<int, 4> indexes;
162   for (BlockArgument &arg : func.getArguments()) {
163     if (tensor_list_args.contains(arg.getArgNumber())) {
164       for (const mlir::OpOperand &use : arg.getUses()) {
165         mlir::Operation *op = use.getOwner();
166         // Currently we only check if the tensorlist argument is consumed by
167         // `TensorListPushBack` or `TensorListResize`, since those are the only
168         // length-mutating ops supported in this pass.
169         if (llvm::isa<TF::TensorListPushBackOp>(op) ||
170             llvm::isa<TF::TensorListResizeOp>(op)) {
171           indexes.insert(arg.getArgNumber());
172         }
173       }
174     }
175   }
176   return indexes;
177 }
178 
179 // Creates a slice of the tensorlist `input_list`, starting from
180 // [start_index, 0, ...0], with size [size, -1, ...-1].
181 //
182 // Requires that `start_index` and `size` are scalar tensors and
183 // `item_position_shape` is a 1-D tensor with only one element equal to the rank
184 // of an item in the tensorlist.
CreateSliceOpForTensorList(Location loc,Value input_list,Value start_index,Value size,Value item_rank,Type result_type,PatternRewriter * rewriter)185 TF::SliceOp CreateSliceOpForTensorList(Location loc, Value input_list,
186                                        Value start_index, Value size,
187                                        Value item_rank, Type result_type,
188                                        PatternRewriter *rewriter) {
189   // Create the start position of slice. This is done by concatenating
190   // `start_index` and `partial_start_position` together.
191   IntegerType shape_dtype = rewriter->getIntegerType(32);
192   RankedTensorType position_type = RankedTensorType::get({-1}, shape_dtype);
193   Value partial_start_position =
194       CreateI32SplatTensor(loc, rewriter, item_rank, 0);
195   Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
196   RankedTensorType vector_type = RankedTensorType::get({1}, shape_dtype);
197   auto expanded_start_index = rewriter->create<TF::ExpandDimsOp>(
198       loc, vector_type, start_index, scalar_zero);
199   auto start_position = rewriter->create<TF::ConcatOp>(
200       loc, position_type, scalar_zero,
201       ArrayRef<Value>({expanded_start_index, partial_start_position}));
202 
203   // Create the slice size tensor. This is done by concatenating `size` and
204   // `partial_size`.
205   auto size_leading_dim =
206       rewriter->create<TF::ExpandDimsOp>(loc, vector_type, size, scalar_zero);
207   Value partial_size = CreateI32SplatTensor(loc, rewriter, item_rank, -1);
208   auto slice_size = rewriter->create<TF::ConcatOp>(
209       loc, position_type, scalar_zero,
210       ArrayRef<Value>({size_leading_dim, partial_size}));
211 
212   return rewriter->create<TF::SliceOp>(loc, result_type, input_list,
213                                        start_position, slice_size);
214 }
215 
216 template <typename OpT>
217 class TensorListOpConverterBase : public OpConversionPattern<OpT> {
218  public:
TensorListOpConverterBase(MLIRContext * context,bool allow_tensorlist_pass_through,bool default_to_single_batch)219   explicit TensorListOpConverterBase<OpT>(MLIRContext *context,
220                                           bool allow_tensorlist_pass_through,
221                                           bool default_to_single_batch)
222       : OpConversionPattern<OpT>::OpConversionPattern(context),
223         allow_tensorlist_pass_through_(allow_tensorlist_pass_through),
224         default_to_single_batch_(default_to_single_batch) {}
225 
226  protected:
227   // This flag will control the behavior of error emitting during rewrite:
228   // 1) If it's true, then patterns will only emit errors during debug or
229   // tracing mode. 2) If it's false, then patterns will emit standard errors
230   // when there is a rewrite failure.
231   bool allow_tensorlist_pass_through_;
232 
233   // This flag will control the behavior of setting the batch size one when the
234   // given batch size is None in order to force to proceed the tensor list op
235   // lowerings.
236   bool default_to_single_batch_;
237 };
238 
239 // Converts tf.Const containing variant of type TensorList to a tensor of
240 // primitive element types. Each of the individual tensor in the list is
241 // converted to an ElementsAttr and then those are packed together using
242 // tf.Pack op.
243 struct ConvertConst : public OpConversionPattern<TF::ConstOp> {
244   using OpConversionPattern::OpConversionPattern;
245 
matchAndRewritemlir::__anonaaa99d6c0111::ConvertConst246   LogicalResult matchAndRewrite(
247       TF::ConstOp op, OpAdaptor adaptor,
248       ConversionPatternRewriter &rewriter) const override {
249     // Verify that the tensor proto contains tensor of type variant and scalar
250     // shape. The variant type should hold a TensorList.
251     auto proto_attr = op.value().dyn_cast<TF::TensorProtoAttr>();
252     if (!proto_attr) return failure();
253     tensorflow::Tensor tensor;
254     if (!tensorflow::ConvertToTensor(proto_attr, &tensor).ok())
255       return failure();
256     if (tensor.dtype() != tensorflow::DT_VARIANT) return failure();
257     if (!tensorflow::TensorShapeUtils::IsScalar(tensor.shape()))
258       return failure();
259 
260     const tensorflow::TensorList *list =
261         tensor.scalar<tensorflow::Variant>()().get<tensorflow::TensorList>();
262     if (!list) return failure();
263 
264     // Verify output type is variant and contains exactly one ranked subtypes.
265     auto variant_ty =
266         getElementTypeOrSelf(op.getType()).dyn_cast<TF::VariantType>();
267     if (!variant_ty) return failure();
268     ArrayRef<TensorType> subtypes = variant_ty.getSubtypes();
269     if (subtypes.size() != 1) return failure();
270     RankedTensorType list_element_ty =
271         subtypes.front().dyn_cast<RankedTensorType>();
272     if (!list_element_ty) return failure();
273 
274     // Extract tensor elements for the TensorList and construct result type
275     // based on the number of elements and element shape.
276     const std::vector<tensorflow::Tensor> &tensors = list->tensors();
277     llvm::SmallVector<int64_t, 4> result_shape = {
278         static_cast<int64_t>(tensors.size())};
279     result_shape.append(list_element_ty.getShape().begin(),
280                         list_element_ty.getShape().end());
281     auto result_ty =
282         RankedTensorType::get(result_shape, list_element_ty.getElementType());
283 
284     // If the list is empty, directly create the final result instead of
285     // creating the tf.Pack op. tf.Pack op requires at least one operand.
286     if (tensors.empty()) {
287       tensorflow::Tensor tensor(list->element_dtype,
288                                 tensorflow::TensorShape(result_shape));
289       auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter);
290       if (!attr_or.ok()) return failure();
291       rewriter.replaceOpWithNewOp<TF::ConstOp>(op, attr_or.ValueOrDie());
292       return success();
293     }
294 
295     // Extract individual tensor list element and combine them using the tf.Pack
296     // op.
297     Location loc = op.getLoc();
298     llvm::SmallVector<Value, 4> values;
299     values.reserve(tensors.size());
300     for (const tensorflow::Tensor &tensor : tensors) {
301       auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter);
302       if (!attr_or.ok()) return failure();
303 
304       auto value = rewriter.create<TF::ConstOp>(loc, attr_or.ValueOrDie());
305       values.push_back(value);
306     }
307     rewriter.replaceOpWithNewOp<TF::PackOp>(
308         op, result_ty, values, /*axis=*/rewriter.getI64IntegerAttr(0));
309     return success();
310   }
311 };
312 
313 struct ConvertTensorListSetItem
314     : public OpConversionPattern<TF::TensorListSetItemOp> {
ConvertTensorListSetItemmlir::__anonaaa99d6c0111::ConvertTensorListSetItem315   explicit ConvertTensorListSetItem(MLIRContext *context,
316                                     bool enable_dynamic_update_slice = false)
317       : OpConversionPattern<TF::TensorListSetItemOp>(context),
318         enable_dynamic_update_slice(enable_dynamic_update_slice) {}
319 
matchAndRewritemlir::__anonaaa99d6c0111::ConvertTensorListSetItem320   LogicalResult matchAndRewrite(
321       TF::TensorListSetItemOp op, OpAdaptor adaptor,
322       ConversionPatternRewriter &rewriter) const override {
323     if (enable_dynamic_update_slice) {
324       return matchAndRewriteImplWithDynamicUpdateSlice(op, adaptor, rewriter);
325     } else {
326       return matchAndRewriteImplWithSliceAndConcat(op, adaptor, rewriter);
327     }
328   }
329 
330   // This function rewrites the original op into a series of slice and concat op
331   // to produce the same result. It first slices the first `$index` rows. Then
332   // expands the dimension of the `$item`, followed by another slice of the
333   // remaining rows starting from `$index` + 1. Lastly it concatenates the
334   // three parts together.
335   // On a high level, it's doing something like:
336   // def : Pat<(TF_TensorListSetItemOp $input, $index, $item),
337   //      (Concat
338   //        concat_dim = 0,
339   //        (Slice $input, [0, 0, ...], (Concat (ExpandDims $index, expand_dim =
340   //        0), [-1, -1, ...])), (ExpandDims $item, expand_dim = 0), (Slice
341   //        $input, [$index + 1, 0, 0, ...], [-1, -1, ...]))>;
matchAndRewriteImplWithSliceAndConcatmlir::__anonaaa99d6c0111::ConvertTensorListSetItem342   LogicalResult matchAndRewriteImplWithSliceAndConcat(
343       TF::TensorListSetItemOp op, OpAdaptor adaptor,
344       ConversionPatternRewriter &rewriter) const {
345     Location loc = op.getLoc();
346     Value input = adaptor.getOperands()[0];
347     Value index = adaptor.getOperands()[1];
348     Value item = adaptor.getOperands()[2];
349 
350     IntegerType shape_dtype = rewriter.getIntegerType(32);
351     auto item_rank = rewriter.create<TF::RankOp>(
352         loc, RankedTensorType::get({}, shape_dtype), item);
353     Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
354 
355     // Calculate `index` + 1, which is used to generate the start position for
356     // the second slice op.
357     auto suffix_start =
358         rewriter.create<TF::AddOp>(loc, index.getType(), index,
359                                    CreateI32SplatConst(loc, &rewriter, {}, 1));
360 
361     auto item_position_shape = rewriter.create<TF::ExpandDimsOp>(
362         loc, RankedTensorType::get({1}, shape_dtype), item_rank, scalar_zero);
363     // Create two slice ops.
364     Type element_type = input.getType().cast<TensorType>().getElementType();
365     UnrankedTensorType unranked_tensor = UnrankedTensorType::get(element_type);
366     Value scalar_minus_one = CreateI32SplatConst(loc, &rewriter, {}, -1);
367     TF::SliceOp slice1 =
368         CreateSliceOpForTensorList(loc, /*input_list=*/input,
369                                    /*start_index=*/scalar_zero,
370                                    /*size=*/index,
371                                    /*item_rank=*/item_position_shape,
372                                    /*result_type=*/unranked_tensor, &rewriter);
373     TF::SliceOp slice2 =
374         CreateSliceOpForTensorList(loc, /*input_list=*/input,
375                                    /*start_index=*/suffix_start,
376                                    /*size=*/scalar_minus_one,
377                                    /*item_rank=*/item_position_shape,
378                                    /*result_type=*/unranked_tensor, &rewriter);
379 
380     // Expand the dimension of item so that it will have the same rank with
381     // input.
382     auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
383         op.getLoc(), unranked_tensor, item, scalar_zero);
384 
385     // Concatenate three parts together to generate the final result.
386     rewriter.replaceOpWithNewOp<TF::ConcatOp>(
387         op, input.getType(), scalar_zero,
388         ArrayRef<Value>({slice1, expanded_item, slice2}));
389     return success();
390   }
391 
392   // This function rewrites the original op into a XLA DynamicUpdateSlice op.
393   // |item| is expanded to have the same dimension as input_handle and
394   // |index| is expanded to [index, 0, 0, ...] as the indices to input_handle.
395   // On a high level, it's doing something like:
396   // def : Pat<(TensorListSetItem($input_handle, $index, $item)),
397   //           (XlaDynamicUpdateSlice($input_handle, ExpandDims($item, 0),
398   //              Concat(ExpandDims($index, 0), [0, 0, 0, ...])))>
matchAndRewriteImplWithDynamicUpdateSlicemlir::__anonaaa99d6c0111::ConvertTensorListSetItem399   LogicalResult matchAndRewriteImplWithDynamicUpdateSlice(
400       TF::TensorListSetItemOp op, OpAdaptor adaptor,
401       ConversionPatternRewriter &rewriter) const {
402     Location loc = op.getLoc();
403     Value input = adaptor.getOperands()[0];
404     Value index = adaptor.getOperands()[1];
405     Value item = adaptor.getOperands()[2];
406 
407     IntegerType shape_dtype = rewriter.getIntegerType(32);
408     auto item_rank = rewriter.create<TF::RankOp>(
409         loc, RankedTensorType::get({}, shape_dtype), item);
410     Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
411 
412     // Concat(ExpandDims(index, 0), [0, 0, 0, ...])
413     RankedTensorType position_type = RankedTensorType::get({-1}, shape_dtype);
414     auto item_position_shape = rewriter.create<TF::ExpandDimsOp>(
415         loc, RankedTensorType::get({1}, shape_dtype), item_rank, scalar_zero);
416     Value partial_index =
417         CreateI32SplatTensor(loc, &rewriter, item_position_shape, 0);
418     RankedTensorType vector_type = RankedTensorType::get({1}, shape_dtype);
419     auto expanded_index =
420         rewriter.create<TF::ExpandDimsOp>(loc, vector_type, index, scalar_zero);
421     auto start_position = rewriter.create<TF::ConcatOp>(
422         loc, position_type, scalar_zero,
423         ArrayRef<Value>({expanded_index, partial_index}));
424 
425     // Expand the dimension of item so that it will have the same rank with
426     // input.
427     // ExpandDims(item, 0)
428     Type element_type = input.getType().cast<TensorType>().getElementType();
429     UnrankedTensorType unranked_tensor = UnrankedTensorType::get(element_type);
430     auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
431         op.getLoc(), unranked_tensor, item, scalar_zero);
432 
433     // Update the element with XlaDynamicUpdateSliceOp.
434     rewriter.replaceOpWithNewOp<TF::XlaDynamicUpdateSliceOp>(
435         op, input.getType(), input, expanded_item, start_position);
436     return success();
437   }
438 
439   bool enable_dynamic_update_slice;
440 };
441 
442 // Rewrites op of the template type initializing a TensorList with a list of ops
443 // to generate an equivalent raw tensor. Derived classes are required to
444 // override GetNumElements method.
445 template <typename OpT>
446 struct ConvertTensorListInitOp : public TensorListOpConverterBase<OpT> {
447   using TensorListOpConverterBase<OpT>::TensorListOpConverterBase;
448   using TensorListOpConverterBase<OpT>::allow_tensorlist_pass_through_;
449   using TensorListOpConverterBase<OpT>::default_to_single_batch_;
450 
451   // Create and return a 1-d tensor with exactly one element equal to the number
452   // of list elements to initialize the output tensor list with.
453   virtual Value GetNumElements(OpT op, ValueRange operands,
454                                PatternRewriter *rewriter) const = 0;
455 
456   // Rewrites the original op into `tf.fill`. The result tensor shape is
457   // [num_element, element_shape]. All the values in the result tensor will be
458   // initialized to 0.
matchAndRewritemlir::__anonaaa99d6c0111::ConvertTensorListInitOp459   LogicalResult matchAndRewrite(
460       OpT op, typename OpT::Adaptor adaptor,
461       ConversionPatternRewriter &rewriter) const override {
462     Type dtype = op.element_dtype();
463     if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() ||
464           dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) ||
465           dtype.isInteger(32) || dtype.isInteger(64))) {
466       const char *error_info =
467           "requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
468           "integer or 16-bit/32-bit/64-bit float type during TF Lite "
469           "transformation pass";
470       return allow_tensorlist_pass_through_
471                  ? rewriter.notifyMatchFailure(op, error_info)
472                  : op.emitOpError(error_info);
473     }
474 
475     Value element_shape = adaptor.getOperands()[0];
476     Type shape_dtype = getElementTypeOrSelf(element_shape.getType());
477     // If the `element_shape` is a scalar, we try to acquire its shape by
478     // looking at the first `TensorListSetItemOp` writing to this tensor list.
479     // Here we assume that the element_shape won't be changed before calling
480     // the first `TensorListSetItemOp`.
481     if (auto shaped_type = element_shape.getType().dyn_cast<ShapedType>()) {
482       if (shaped_type.hasRank() && shaped_type.getRank() == 0) {
483         bool element_shape_acquired = false;
484         auto uses = op.getResult().getUses();
485         for (auto &use : llvm::make_early_inc_range(uses)) {
486           if (TF::TensorListSetItemOp set_op =
487                   llvm::dyn_cast<TF::TensorListSetItemOp>(use.getOwner())) {
488             element_shape = rewriter.create<TF::ShapeOp>(
489                 op.getLoc(), RankedTensorType::get({-1}, shape_dtype),
490                 set_op.item());
491             element_shape_acquired = true;
492           } else if (TF::WhileOp while_op =
493                          llvm::dyn_cast<TF::WhileOp>(use.getOwner())) {
494             // Tensorlist is passed into a while loop, check inside the body
495             // function.
496             auto inside_uses = while_op.body_function()
497                                    .getArgument(use.getOperandNumber())
498                                    .getUses();
499             for (auto &inside_use : llvm::make_early_inc_range(inside_uses)) {
500               if (TF::TensorListSetItemOp set_op =
501                       llvm::dyn_cast<TF::TensorListSetItemOp>(
502                           inside_use.getOwner())) {
503                 if (auto shaped_type =
504                         set_op.item().getType().dyn_cast<ShapedType>()) {
505                   if (shaped_type.hasStaticShape()) {
506                     RankedTensorType type = RankedTensorType::get(
507                         {shaped_type.getRank()}, rewriter.getIntegerType(32));
508                     SmallVector<Attribute, 4> shape_attr;
509                     for (int64_t dim : shaped_type.getShape()) {
510                       shape_attr.push_back(rewriter.getI32IntegerAttr(dim));
511                     }
512                     DenseElementsAttr attr =
513                         DenseElementsAttr::get(type, shape_attr);
514                     element_shape = rewriter.create<arith::ConstantOp>(
515                         op.getLoc(), type, attr);
516                     element_shape_acquired = true;
517                     break;
518                   }
519                 }
520               }
521             }
522           }
523           if (element_shape_acquired) break;
524         }
525         if (!element_shape_acquired) {
526           const char *error_info =
527               "requires element_shape to be 1D tensor during TF Lite "
528               "transformation pass";
529           return allow_tensorlist_pass_through_
530                      ? rewriter.notifyMatchFailure(op, error_info)
531                      : op.emitOpError(error_info);
532         }
533       }
534     }
535 
536     DenseIntElementsAttr dense_elem_attr;
537     if (matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
538       // Note: It's technically unsafe to rewrite
539       //     TensorListReserve(num_element, element_shape)
540       // to
541       //     Fill(Concat(num_element, element_shape), 0)
542       // because element_shape may contain -1 to represent unknown dimension.
543       //
544       // In real world use cases (e.g. Keras RNN), `element_shape` is usually
545       // a constant, and the first dimension of `element_shape` is usually
546       // batch dimension. Currently TFLiteConverter always rewrite unknown
547       // batch dimension to 1, therefore we also rewrite unknown dimension in
548       // `element_shape` to 1 here.
549       //
550       // This workaround enables converting Keras RNN without specifying batch
551       // dimension. This isn't guaranteed to work, but it doesn't break any
552       // non-broken cases either (since it's already broken if `element_shape`
553       // contains -1).
554       // TODO(b/142096690): Support dynamic element shape and remove the
555       // workaround.
556       SmallVector<int32_t, 4> new_element_shape_values;
557 
558       auto int_values = dense_elem_attr.getValues<APInt>();
559       for (auto it = int_values.begin(); it != int_values.end(); ++it) {
560         auto dim_value = (*it).getSExtValue();
561         if (it == int_values.begin() && dim_value == -1) {
562           if (!default_to_single_batch_) {
563             const char *error_info =
564                 "requires element_shape to be static during TF Lite "
565                 "transformation pass";
566             return allow_tensorlist_pass_through_
567                        ? rewriter.notifyMatchFailure(op, error_info)
568                        : op.emitOpError(error_info);
569           }
570           dim_value = 1;
571         }
572         new_element_shape_values.push_back(dim_value);
573       }
574 
575       auto attr = DenseIntElementsAttr::get(
576           element_shape.getType().cast<ShapedType>(), new_element_shape_values);
577       auto new_element_shape = rewriter.create<arith::ConstantOp>(
578           op.getLoc(), element_shape.getType(), attr);
579       element_shape = new_element_shape;
580     }
581 
582     int64_t result_rank = -1;  // -1 means unknown result rank.
583     Type element_dtype = op.element_dtype();
584     Type result_type = UnrankedTensorType::get(element_dtype);
585     Value leading_dim = GetNumElements(op, adaptor.getOperands(), &rewriter);
586     if (auto element_type =
587             op.element_type().template dyn_cast<RankedTensorType>()) {
588       result_rank = element_type.getRank() + 1;
589       int64_t leading_dim_v = -1;
590       ElementsAttr element_attr;
591       if (matchPattern(leading_dim, m_Constant(&element_attr))) {
592         leading_dim_v = element_attr.getValues<APInt>()[0].getSExtValue();
593       }
594       SmallVector<int64_t, 4> result_shape = {leading_dim_v};
595       ArrayRef<int64_t> shape = element_type.getShape();
596       result_shape.append(shape.begin(), shape.end());
597       result_type = RankedTensorType::get(result_shape, element_dtype);
598     }
599 
600     // Create a 1-D RankedTensorType for result's shape. Number of elements in
601     // it is equal to the rank of the result, if known. Otherwise, the number of
602     // elements are unknown and represented with -1. In both cases, we can
603     // specify dimension using rank of the result.
604     Type shape_type = RankedTensorType::get({result_rank}, shape_dtype);
605 
606     Location loc = op.getLoc();
607     // Add number of elements as the prefix to the element shape to get shape of
608     // the output tensor.
609     Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
610     auto list_shape = rewriter.create<TF::ConcatOp>(
611         loc, shape_type, scalar_zero,
612         ArrayRef<Value>({leading_dim, element_shape}));
613 
614     // Create a zero-initialized constant tensor that has the same type
615     // as specified by element_dtype.
616     RankedTensorType zero_type = RankedTensorType::get({}, element_dtype);
617     Attribute zero_attr = rewriter.getZeroAttr(zero_type);
618     auto zero = rewriter.create<arith::ConstantOp>(loc, zero_type, zero_attr);
619 
620     rewriter.replaceOpWithNewOp<TF::FillOp>(op, result_type, list_shape, zero);
621     return success();
622   }
623 };
624 
625 struct ConvertTensorListReserve
626     : public ConvertTensorListInitOp<TF::TensorListReserveOp> {
ConvertTensorListReservemlir::__anonaaa99d6c0111::ConvertTensorListReserve627   explicit ConvertTensorListReserve(MLIRContext *context,
628                                     bool allow_tensorlist_pass_through,
629                                     bool default_to_single_batch)
630       : ConvertTensorListInitOp(context, allow_tensorlist_pass_through,
631                                 default_to_single_batch) {}
632 
GetNumElementsmlir::__anonaaa99d6c0111::ConvertTensorListReserve633   Value GetNumElements(TF::TensorListReserveOp op, ValueRange operands,
634                        PatternRewriter *rewriter) const override {
635     Value scalar_zero = CreateI32SplatConst(op.getLoc(), rewriter, {}, 0);
636     Type shape_dtype = getElementTypeOrSelf(op.element_shape().getType());
637     Value num_elements = operands[1];
638     IntegerAttr attr;
639     if (matchPattern(num_elements, m_Constant(&attr))) {
640       return CreateI32SplatConst(op.getLoc(), rewriter, {1}, attr.getInt());
641     }
642     if (auto const_op = num_elements.getDefiningOp<TF::ConstOp>()) {
643       return CreateI32SplatConst(op->getLoc(), rewriter, {1},
644                                  (*const_op.value()
645                                        .cast<DenseElementsAttr>()
646                                        .getValues<APInt>()
647                                        .begin())
648                                      .getSExtValue());
649     }
650     return rewriter->create<TF::ExpandDimsOp>(
651         op.getLoc(), RankedTensorType::get({1}, shape_dtype), num_elements,
652         scalar_zero);
653   }
654 };
655 
656 // Note that we ignore the second operand `max_num_elements` as we don't have
657 // any restrictions on the number of elements we can support. So this may
658 // have a different behavior compared to TensorFlow in case of errors.
659 struct ConvertEmptyTensorList
660     : public ConvertTensorListInitOp<TF::EmptyTensorListOp> {
ConvertEmptyTensorListmlir::__anonaaa99d6c0111::ConvertEmptyTensorList661   explicit ConvertEmptyTensorList(MLIRContext *context,
662                                   bool allow_tensorlist_pass_through,
663                                   bool default_to_single_batch)
664       : ConvertTensorListInitOp(context, allow_tensorlist_pass_through,
665                                 default_to_single_batch) {}
666 
GetNumElementsmlir::__anonaaa99d6c0111::ConvertEmptyTensorList667   Value GetNumElements(TF::EmptyTensorListOp op, ValueRange operands,
668                        PatternRewriter *rewriter) const override {
669     return CreateI32SplatConst(op.getLoc(), rewriter, {1}, 0);
670   }
671 };
672 
673 struct ConvertTensorListPushBack
674     : public OpConversionPattern<TF::TensorListPushBackOp> {
675   using OpConversionPattern::OpConversionPattern;
676 
matchAndRewritemlir::__anonaaa99d6c0111::ConvertTensorListPushBack677   LogicalResult matchAndRewrite(
678       TF::TensorListPushBackOp op, OpAdaptor adaptor,
679       ConversionPatternRewriter &rewriter) const override {
680     Value input_handle = adaptor.getOperands()[0];
681     Value item = adaptor.getOperands()[1];
682 
683     // Expand the shape of the item so that it will have rank same as the input
684     // tensor and it is compatible for the Concat Op.
685     Type expanded_item_type =
686         PrependLeadingDimIfRanked(1, item.getType(), &rewriter);
687     Location loc = op.getLoc();
688     Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
689     auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
690         loc, expanded_item_type, item, scalar_zero);
691 
692     Type elem_type = getElementTypeOrSelf(item);
693     auto handle_dtype = getElementTypeOrSelf(op.output_handle().getType())
694                             .cast<TF::VariantType>();
695     Type result_type =
696         GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
697 
698     // Concatenate tensor stored in the input handle with the expanded item to
699     // get a tensor equivalent to the TensorList generated by this op.
700     rewriter.replaceOpWithNewOp<TF::ConcatOp>(
701         op, result_type, scalar_zero,
702         ArrayRef<Value>({input_handle, expanded_item}));
703     return success();
704   }
705 };
706 
707 // Rewrites `TensorListResize` op into a functional `If` op and several basic
708 // TF ops to match the op semantics of Tensorflow. Basically, it does:
709 // 1) If the requested size is smaller or equal than the input tensorlist's
710 // size, rewrite it to a Slice op so that only the first 'size' rows are
711 // returned. 2) If the requested size is larger than the input tensorlist's
712 // size. We need to create an additional tensorlist with 'size - input_size'
713 // elements, and append it to the end of the input tensorlist.
714 struct ConvertTensorListResize
715     : public OpConversionPattern<TF::TensorListResizeOp> {
716   using OpConversionPattern::OpConversionPattern;
717 
matchAndRewritemlir::__anonaaa99d6c0111::ConvertTensorListResize718   LogicalResult matchAndRewrite(
719       TF::TensorListResizeOp op, OpAdaptor adaptor,
720       ConversionPatternRewriter &rewriter) const override {
721     Value input_handle = adaptor.getOperands()[0];
722     Value size = adaptor.getOperands()[1];
723 
724     Location loc = op.getLoc();
725     Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
726 
727     // Compute the input tensorlist's length and store it in `input_size`.
728     IntegerType shape_dtype = rewriter.getIntegerType(32);
729     auto input_size = rewriter.create<TF::TensorListLengthOp>(
730         loc, RankedTensorType::get({}, shape_dtype), op.getOperand(0));
731 
732     // Infer result type of this op based on TF's shape inference result.
733     Type elem_type = getElementTypeOrSelf(input_handle);
734     auto handle_dtype = getElementTypeOrSelf(op.output_handle().getType())
735                             .cast<TF::VariantType>();
736     Type result_type =
737         GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
738 
739     // Compute the difference of `size` and `input_size`, and store it in
740     // `size_diff`, which is then consumed by `if_cond`.
741     auto size_diff = rewriter.create<TF::SubOp>(
742         loc, RankedTensorType::get({}, shape_dtype), size, input_size);
743     auto if_cond = rewriter.create<TF::GreaterOp>(
744         loc, RankedTensorType::get({}, rewriter.getI1Type()), size_diff,
745         scalar_zero);
746 
747     // Build the argument/result types for if branch function.
748     auto input_shape = rewriter.create<TF::ShapeOp>(
749         loc, RankedTensorType::get({-1}, shape_dtype), input_handle);
750 
751     Type branch_args_type[] = {input_handle.getType(), input_shape.getType(),
752                                size_diff.getType(), size.getType()};
753     Type branch_result_type[] = {result_type};
754     auto func_type = FunctionType::get(rewriter.getContext(), branch_args_type,
755                                        branch_result_type);
756 
757     // Create functions in a higher scope before restoring the insertion point.
758     // Additionally, create the SymbolTable before further modifying the module.
759     auto original_point = rewriter.saveInsertionPoint();
760     rewriter.setInsertionPointAfter(op->getParentOfType<func::FuncOp>());
761     SymbolTable manager(op->getParentOfType<ModuleOp>());
762 
763     // Constructs `then_branch`, which is executed when `if_cond` evaluates to
764     // true.
765     auto then_branch_op =
766         rewriter.create<func::FuncOp>(loc, "cond_true", func_type);
767     CreateCondTrueBranch(op, shape_dtype, result_type, then_branch_op,
768                          &rewriter);
769     then_branch_op.setVisibility(func::FuncOp::Visibility::Private);
770 
771     // Constructs `else_branch`, which is executed when `if_cond` evaluates to
772     // false.
773     auto else_branch_op =
774         rewriter.create<func::FuncOp>(loc, "cond_false", func_type);
775     CreateCondFalseBranch(loc, shape_dtype, result_type, else_branch_op,
776                           &rewriter);
777     else_branch_op.setVisibility(func::FuncOp::Visibility::Private);
778 
779     // Inserts the two blocks' names into the symbol table held by the module.
780     // Using SymbolTable will ensure that the inserted symbol names are
781     // unique.
782     manager.insert(then_branch_op);
783     manager.insert(else_branch_op);
784 
785     rewriter.restoreInsertionPoint(original_point);
786     rewriter.replaceOpWithNewOp<TF::IfOp>(
787         op, result_type, if_cond,
788         /*input=*/
789         ArrayRef<Value>({input_handle, input_shape, size_diff, size}),
790         /*then_branch=*/
791         mlir::SymbolRefAttr::get(then_branch_op),
792         /*else_branch=*/
793         mlir::SymbolRefAttr::get(else_branch_op),
794         /*is_stateless=*/rewriter.getBoolAttr(true));
795     return success();
796   }
797 
798  private:
799   // When the input tensorlist's size is smaller than the requested size,
800   // then branch is executed.
801   // Create a new tensorlist of size 'size - input_size' and concat it
802   // with the input tensorlist.
CreateCondTrueBranchmlir::__anonaaa99d6c0111::ConvertTensorListResize803   void CreateCondTrueBranch(TF::TensorListResizeOp resize_op, Type shape_dtype,
804                             Type result_type, func::FuncOp branch_func,
805                             ConversionPatternRewriter *rewriter) const {
806     auto guard = OpBuilder::InsertionGuard(*rewriter);
807     auto inputs = branch_func.getFunctionType().getInputs();
808     Block *block = rewriter->createBlock(
809         &branch_func.getBody(), branch_func.begin(), inputs,
810         SmallVector<Location>(inputs.size(), branch_func.getLoc()));
811 
812     auto input_shape = block->getArgument(1);
813     auto size_diff = block->getArgument(2);
814     auto input = block->getArgument(0);
815 
816     Location loc = resize_op.getLoc();
817     // Get the element shape by slicing from index 1 in the input shape.
818     Value slice_size = CreateI32SplatConst(loc, rewriter, {1}, -1);
819     Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
820     Value slice_start = CreateI32SplatConst(loc, rewriter, {1}, 1);
821     auto elem_shape = rewriter->create<TF::SliceOp>(
822         loc, RankedTensorType::get({-1}, shape_dtype), input_shape, slice_start,
823         slice_size);
824     auto extended_part = rewriter->create<TF::TensorListReserveOp>(
825         loc, resize_op.output_handle().getType(), elem_shape, size_diff);
826     // `ConcatOp` expects non-variant-typed input. Insert a
827     // `TensorListStackOp` here to convert type from variant to non-variant.
828     // Note that we are using the same `result_type` for both the
829     // `TensorListStackOp` and `ConcatOp`, since the first dimension of the
830     // shape specified by `result_type` is -1.
831     auto stacked_extended_part = rewriter->create<TF::TensorListStackOp>(
832         loc, result_type, extended_part,
833         /*element_shape=*/CreateI32SplatConst(loc, rewriter, {}, -1),
834         /*num_elements=*/rewriter->getI32IntegerAttr(-1));
835     auto concat_op = rewriter->create<TF::ConcatOp>(
836         loc, result_type, scalar_zero,
837         ArrayRef<Value>({input, stacked_extended_part}));
838     rewriter->create<func::ReturnOp>(loc, ArrayRef<Value>({concat_op}));
839   }
840 
CreateCondFalseBranchmlir::__anonaaa99d6c0111::ConvertTensorListResize841   void CreateCondFalseBranch(Location loc, Type shape_dtype, Type result_type,
842                              func::FuncOp branch_func,
843                              ConversionPatternRewriter *rewriter) const {
844     // When the input tensorlist's size is larger or equal than the requested
845     // size, the else branch is executed.
846     // Slice the first 'size' rows from the input tensorlist.
847     auto guard = OpBuilder::InsertionGuard(*rewriter);
848     auto inputs = branch_func.getFunctionType().getInputs();
849     Block *block = rewriter->createBlock(
850         &branch_func.getBody(), branch_func.begin(), inputs,
851         SmallVector<Location>(inputs.size(), branch_func.getLoc()));
852 
853     Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
854     Value vector_one = CreateI32SplatConst(loc, rewriter, {1}, 1);
855     auto input = block->getArgument(0);
856     auto size = block->getArgument(3);
857 
858     // Subtract `input_rank` by 1 to get the item's rank, which is used as
859     // `partial_position_shape`.
860     auto input_rank = rewriter->create<TF::RankOp>(
861         loc, RankedTensorType::get({}, shape_dtype), input);
862     auto partial_position_shape = rewriter->create<TF::SubOp>(
863         loc, RankedTensorType::get({1}, shape_dtype), input_rank, vector_one);
864     auto slice_op =
865         CreateSliceOpForTensorList(loc, /*input_list=*/input,
866                                    /*start_index=*/scalar_zero, /*size=*/size,
867                                    /*item_rank=*/partial_position_shape,
868                                    /*result_type=*/result_type, rewriter);
869     rewriter->create<func::ReturnOp>(loc, ArrayRef<Value>({slice_op}));
870   }
871 };
872 
873 struct ConvertTensorListGetItem
874     : public OpConversionPattern<TF::TensorListGetItemOp> {
875   using OpConversionPattern::OpConversionPattern;
876 
matchAndRewritemlir::__anonaaa99d6c0111::ConvertTensorListGetItem877   LogicalResult matchAndRewrite(
878       TF::TensorListGetItemOp op, OpAdaptor adaptor,
879       ConversionPatternRewriter &rewriter) const override {
880     Value input = adaptor.getOperands()[0];
881     Value index = adaptor.getOperands()[1];
882     rewriter.replaceOpWithNewOp<TF::GatherOp>(op, op.getType(), input, index,
883                                               rewriter.getBoolAttr(true));
884     return success();
885   }
886 };
887 
888 struct ConvertTensorListLength
889     : public OpConversionPattern<TF::TensorListLengthOp> {
890   using OpConversionPattern::OpConversionPattern;
891 
matchAndRewritemlir::__anonaaa99d6c0111::ConvertTensorListLength892   LogicalResult matchAndRewrite(
893       TF::TensorListLengthOp op, OpAdaptor adaptor,
894       ConversionPatternRewriter &rewriter) const override {
895     Location loc = op.getLoc();
896     Value input_handle = adaptor.getOperands()[0];
897 
898     BoolAttr true_attr = rewriter.getBoolAttr(true);
899     auto shape = rewriter.create<TF::ShapeOp>(loc, input_handle,
900                                               /*use_32bit=*/true_attr);
901     rewriter.replaceOpWithNewOp<TF::GatherOp>(
902         op, op.getType(), shape, CreateI32SplatConst(loc, &rewriter, {}, 0),
903         /*validate_indices=*/true_attr);
904     return success();
905   }
906 };
907 
908 struct ConvertTensorListStack
909     : public OpConversionPattern<TF::TensorListStackOp> {
910   using OpConversionPattern::OpConversionPattern;
911 
matchAndRewritemlir::__anonaaa99d6c0111::ConvertTensorListStack912   LogicalResult matchAndRewrite(
913       TF::TensorListStackOp op, OpAdaptor adaptor,
914       ConversionPatternRewriter &rewriter) const override {
915     Location loc = op.getLoc();
916     Value input = adaptor.getOperands()[0];
917     Value element_shape = adaptor.getOperands()[1];
918 
919     // If the `element_shape` is a known constant (which is defined when calling
920     // `tensor_list_stack`) and also valid (not scalar), we rewrite this op to a
921     // trivial Reshape op (that doesn't actually change the input's shape) and
922     // also populate the shape info to the op result. The shape of the
923     // tensorlist is inferred from `num_elements` and `element_shape`.
924     auto ranked_type = element_shape.getType().dyn_cast<RankedTensorType>();
925     DenseIntElementsAttr dense_elem_attr;
926     if ((ranked_type && ranked_type.getRank() == 0) ||
927         !matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
928       // If no constant is spotted, just forward the operand.
929       rewriter.replaceOp(op, {input});
930       return success();
931     }
932 
933     RankedTensorType shape_type =
934         RankedTensorType::get({-1}, rewriter.getIntegerType(32));
935     auto new_shape = rewriter.create<TF::ShapeOp>(loc, shape_type, input);
936     SmallVector<int64_t, 8> output_shape(/*Size=*/1, op.num_elements());
937     for (const auto &dim : dense_elem_attr.getValues<APInt>())
938       output_shape.push_back(dim.getSExtValue());
939     RankedTensorType result_type =
940         RankedTensorType::get(output_shape, getElementTypeOrSelf(input));
941     rewriter.replaceOpWithNewOp<TF::ReshapeOp>(op, result_type, input,
942                                                new_shape);
943     return success();
944   }
945 };
946 
947 // Converts `TensorListConcatV2` into Unpack and Concat. First we unpack
948 // the input tensorlist along the first dimension, which results in N (where N
949 // is the first dim's size) tensors (each with shape [element_shape]). Then
950 // we concatenate all those tensors along the first dimension.
951 // The pattern will be rejected if either `element_shape` is not constant, or
952 // the first dimension of `input` is not known.
953 struct ConvertTensorListConcatV2
954     : public TensorListOpConverterBase<TF::TensorListConcatV2Op> {
955   using TensorListOpConverterBase<
956       TF::TensorListConcatV2Op>::TensorListOpConverterBase;
957   using TensorListOpConverterBase<
958       TF::TensorListConcatV2Op>::allow_tensorlist_pass_through_;
959 
matchAndRewritemlir::__anonaaa99d6c0111::ConvertTensorListConcatV2960   LogicalResult matchAndRewrite(
961       TF::TensorListConcatV2Op op, OpAdaptor adaptor,
962       ConversionPatternRewriter &rewriter) const override {
963     Location loc = op.getLoc();
964     Value input = adaptor.getOperands()[0];
965     Value element_shape = adaptor.getOperands()[1];
966 
967     // Only match when `element_shape` is a constant.
968     DenseIntElementsAttr dense_elem_attr;
969     if (!matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
970       const char *error_info = "requires element_shape to be a constant";
971       return allow_tensorlist_pass_through_
972                  ? rewriter.notifyMatchFailure(op, error_info)
973                  : op.emitOpError(error_info);
974     }
975     llvm::SmallVector<int64_t, 4> output_shape;
976     for (const auto &dim : dense_elem_attr.getValues<APInt>()) {
977       output_shape.push_back(dim.getSExtValue());
978     }
979 
980     // First unpack the input tensor along the first dimension.
981     Type input_element_type = getElementTypeOrSelf(input);
982     int64_t num_unpacked = 0;
983     if (auto type = input.getType().dyn_cast<RankedTensorType>()) {
984       if (type.getDimSize(0) > 0) {
985         num_unpacked = type.getDimSize(0);
986       } else {
987         const char *error_info =
988             "requires the first dimension of input tensor to have > 0 "
989             "dimension";
990         return allow_tensorlist_pass_through_
991                    ? rewriter.notifyMatchFailure(op, error_info)
992                    : op.emitOpError(error_info);
993       }
994     }
995     llvm::SmallVector<Type, 1> unpack_output_type;
996     unpack_output_type.insert(
997         unpack_output_type.begin(), num_unpacked,
998         RankedTensorType::get(output_shape, input_element_type));
999     auto unpack = rewriter.create<TF::UnpackOp>(loc, unpack_output_type, input,
1000                                                 /*axis=*/0);
1001 
1002     // Concatenate the unpacked tensors along the first dimension.
1003     // Since we're concatenating along first dimension, change its dim size to
1004     // -1.
1005     output_shape[0] = -1;
1006     Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
1007     auto concat = rewriter.create<TF::ConcatOp>(
1008         loc, RankedTensorType::get(output_shape, input_element_type),
1009         scalar_zero, unpack->getResults());
1010     // `lengths` is only useful for computing gradient. For now we just return
1011     // a placeholder tensor.
1012     rewriter.replaceOp(
1013         op, {concat.getResult(), CreateI64SplatConst(loc, &rewriter, {0}, 0)});
1014     return success();
1015   }
1016 };
1017 
1018 struct ConvertIdentity : public OpConversionPattern<TF::IdentityOp> {
1019   using OpConversionPattern::OpConversionPattern;
1020 
matchAndRewritemlir::__anonaaa99d6c0111::ConvertIdentity1021   LogicalResult matchAndRewrite(
1022       TF::IdentityOp op, OpAdaptor adaptor,
1023       ConversionPatternRewriter &rewriter) const override {
1024     Value input = adaptor.getOperands()[0];
1025     rewriter.replaceOpWithNewOp<TF::IdentityOp>(
1026         op, input.getType(), adaptor.getOperands(), op->getAttrs());
1027     return success();
1028   }
1029 };
1030 
1031 struct ConvertReturn : public OpConversionPattern<func::ReturnOp> {
1032   using OpConversionPattern::OpConversionPattern;
1033 
matchAndRewritemlir::__anonaaa99d6c0111::ConvertReturn1034   LogicalResult matchAndRewrite(
1035       func::ReturnOp op, OpAdaptor adaptor,
1036       ConversionPatternRewriter &rewriter) const override {
1037     rewriter.updateRootInPlace(op,
1038                                [&] { op->setOperands(adaptor.getOperands()); });
1039     return success();
1040   }
1041 };
1042 
1043 struct ConvertYield : public OpConversionPattern<TF::YieldOp> {
1044   using OpConversionPattern::OpConversionPattern;
1045 
matchAndRewritemlir::__anonaaa99d6c0111::ConvertYield1046   LogicalResult matchAndRewrite(
1047       TF::YieldOp op, OpAdaptor adaptor,
1048       ConversionPatternRewriter &rewriter) const override {
1049     rewriter.updateRootInPlace(op,
1050                                [&] { op->setOperands(adaptor.getOperands()); });
1051     return success();
1052   }
1053 };
1054 
1055 // Returns an unranked tensor type with an element of the same type as `value`
1056 // if `type` is a tensor of variant. Otherwise, returns `type` unmodified.
VariantToUnrankedTensorType(Type type,Value value)1057 Type VariantToUnrankedTensorType(Type type, Value value) {
1058   TF::VariantType variant_ty =
1059       getElementTypeOrSelf(type).dyn_cast<TF::VariantType>();
1060   if (!variant_ty) {
1061     return type;
1062   }
1063   if (!variant_ty.getSubtypes().empty()) {
1064     // Short-circut if the variant type has subtype info.
1065     return UnrankedTensorType::get(
1066         variant_ty.getSubtypes()[0].getElementType());
1067   }
1068   Type value_type = value.getType();
1069   Type element_type;
1070   variant_ty = value_type.dyn_cast<TF::VariantType>();
1071   if (variant_ty && !variant_ty.getSubtypes().empty()) {
1072     element_type = variant_ty.getSubtypes()[0].getElementType();
1073   } else {
1074     element_type = getElementTypeOrSelf(value_type);
1075   }
1076   return UnrankedTensorType::get(element_type);
1077 }
1078 
1079 // Returns true if we can deduce the type is tensorlist.
IsTensorListType(Type type,llvm::Optional<Value> value)1080 bool IsTensorListType(Type type, llvm::Optional<Value> value) {
1081   TF::VariantType variant_ty =
1082       getElementTypeOrSelf(type).dyn_cast<TF::VariantType>();
1083   if (!variant_ty) {
1084     return false;
1085   }
1086   // Check there is only one subtype contained in the variant type. Note that
1087   // when `subtypes.size() == 1` does not always mean the type is actually
1088   // a tensorlist. We probably need some form of data flow analysis.
1089   if (variant_ty.getSubtypes().size() == 1) {
1090     return true;
1091   }
1092   // If subtype info is not available, check if the value is used by any of
1093   // the following TensorList operations.
1094   if (!value.has_value()) {
1095     return false;
1096   }
1097   for (const mlir::OpOperand &use : value.getValue().getUses()) {
1098     mlir::Operation *op = use.getOwner();
1099     if (llvm::isa<TF::TensorListGetItemOp>(op) ||
1100         llvm::isa<TF::TensorListLengthOp>(op) ||
1101         llvm::isa<TF::TensorListPushBackOp>(op) ||
1102         llvm::isa<TF::TensorListReserveOp>(op) ||
1103         llvm::isa<TF::TensorListSetItemOp>(op) ||
1104         llvm::isa<TF::TensorListStackOp>(op) ||
1105         llvm::isa<TF::TensorListResizeOp>(op)) {
1106       return true;
1107     }
1108   }
1109   return false;
1110 }
1111 
1112 // Returns a set of integers that correspond to the tensorlist arguments in
1113 // the function.
GetTensorListArgumentsIndex(func::FuncOp func)1114 llvm::SmallSet<int, 4> GetTensorListArgumentsIndex(func::FuncOp func) {
1115   llvm::SmallSet<int, 4> set;
1116   for (const auto &arg_and_idx : llvm::enumerate(func.getArguments())) {
1117     if (IsTensorListType(arg_and_idx.value().getType(), arg_and_idx.value())) {
1118       set.insert(arg_and_idx.index());
1119     }
1120   }
1121   return set;
1122 }
1123 
1124 // Returns a set of integers that correspond to the tensorlist results in the
1125 // function.
GetTensorListResultsIndex(func::FuncOp func)1126 llvm::SmallSet<int, 4> GetTensorListResultsIndex(func::FuncOp func) {
1127   llvm::SmallSet<int, 4> set;
1128 
1129   for (const auto &result_and_idx :
1130        llvm::enumerate(func.getFunctionType().getResults())) {
1131     if (IsTensorListType(result_and_idx.value(), llvm::None)) {
1132       set.insert(result_and_idx.index());
1133     }
1134   }
1135   return set;
1136 }
1137 
1138 // Updates the tensorlist types based on the input index. If the tensorlist's
1139 // size isn't changed(which is indicated by `resized_tensor_list_index`), then
1140 // we will use the original operand's type, otherwise update it with the
1141 // unranked tensor type.
1142 template <typename R>
UpdateTensorListTypes(const llvm::SmallSet<int,4> & tensor_list_index,const llvm::SmallSet<int,4> & resized_tensor_list_index,ArrayRef<Type> types,R && range,ValueRange operands,llvm::SmallVectorImpl<Type> * updated_types)1143 void UpdateTensorListTypes(
1144     const llvm::SmallSet<int, 4> &tensor_list_index,
1145     const llvm::SmallSet<int, 4> &resized_tensor_list_index,
1146     ArrayRef<Type> types, R &&range, ValueRange operands,
1147     llvm::SmallVectorImpl<Type> *updated_types) {
1148   int i = 0;
1149   for (const auto it : llvm::zip(types, range, operands)) {
1150     if (tensor_list_index.count(i)) {
1151       // Only change the tensorlist's type to unranked tensor if it has been
1152       // resized.
1153       if (resized_tensor_list_index.count(i)) {
1154         updated_types->push_back(
1155             VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
1156       } else {
1157         updated_types->push_back(std::get<2>(it).getType());
1158       }
1159     } else {
1160       updated_types->push_back(std::get<0>(it));
1161     }
1162     ++i;
1163   }
1164 }
1165 
1166 // Updates the tensorlist types to unranked tensor types based on the input
1167 // index.
1168 template <typename R>
ChangeVariantToUnrankedTensorType(const llvm::SmallSet<int,4> & tensor_list_index,ArrayRef<Type> types,R && range,llvm::SmallVectorImpl<Type> * updated_types)1169 void ChangeVariantToUnrankedTensorType(
1170     const llvm::SmallSet<int, 4> &tensor_list_index, ArrayRef<Type> types,
1171     R &&range, llvm::SmallVectorImpl<Type> *updated_types) {
1172   int i = 0;
1173   for (const auto it : llvm::zip(types, range)) {
1174     if (tensor_list_index.count(i)) {
1175       updated_types->push_back(
1176           VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
1177     } else {
1178       updated_types->push_back(std::get<0>(it));
1179     }
1180     ++i;
1181   }
1182 }
1183 
1184 // Updates the specified function's type and region signature.
UpdateFunctionAndRegionType(ConversionPatternRewriter & rewriter,func::FuncOp func,llvm::ArrayRef<Type> updated_argument_types,llvm::ArrayRef<Type> updated_result_types)1185 void UpdateFunctionAndRegionType(ConversionPatternRewriter &rewriter,
1186                                  func::FuncOp func,
1187                                  llvm::ArrayRef<Type> updated_argument_types,
1188                                  llvm::ArrayRef<Type> updated_result_types) {
1189   // Change `func`'s argument type to `unranked_argument_types`. If its
1190   // return types contain a `DT_VARIANT`, change it to the unranked type
1191   // derived from the corresponding argument.
1192   rewriter.updateRootInPlace(func, [&] {
1193     func.setType(FunctionType::get(func.getContext(), updated_argument_types,
1194                                    updated_result_types));
1195   });
1196   Region &entry = func.getRegion();
1197   TypeConverter::SignatureConversion signature_conversion(
1198       entry.getNumArguments());
1199   for (const BlockArgument &arg : entry.getArguments()) {
1200     signature_conversion.addInputs(arg.getArgNumber(),
1201                                    updated_argument_types[arg.getArgNumber()]);
1202   }
1203   rewriter.applySignatureConversion(&entry, signature_conversion);
1204 }
1205 
1206 // Changes the function type of `cond_func` and `body_func` for the given While
1207 // op.
UpdateFunctionTypesForWhileOp(ConversionPatternRewriter & rewriter,TF::WhileOp op,ValueRange operands,const llvm::SmallSet<int,4> & tensor_list_args,const llvm::SmallSet<int,4> & resized_tensor_lists)1208 LogicalResult UpdateFunctionTypesForWhileOp(
1209     ConversionPatternRewriter &rewriter, TF::WhileOp op, ValueRange operands,
1210     const llvm::SmallSet<int, 4> &tensor_list_args,
1211     const llvm::SmallSet<int, 4> &resized_tensor_lists) {
1212   int func_index = 0;
1213   for (func::FuncOp func : {op.cond_function(), op.body_function()}) {
1214     ++func_index;
1215     if (!func) continue;
1216 
1217     FunctionType func_type = func.getFunctionType();
1218     int num_inputs = func_type.getNumInputs();
1219     int num_results = func_type.getNumResults();
1220 
1221     // For each argument type in function's arguments, change it to uranked
1222     // tensor type if it's a variant type.
1223     SmallVector<Type, 8> updated_argument_types;
1224     updated_argument_types.reserve(num_inputs);
1225     UpdateTensorListTypes<mlir::OperandRange>(
1226         tensor_list_args, resized_tensor_lists, func_type.getInputs(),
1227         op.getOperands(), operands, &updated_argument_types);
1228 
1229     // Change all DT_VARIANT result types in function results to unranked tensor
1230     // type with element type derived from the corresponding input operand. This
1231     // is correct because while body's inputs and results have the same type.
1232     SmallVector<Type, 8> updated_result_types;
1233     updated_result_types.reserve(num_results);
1234     if (func_index == 1) {
1235       // We only update the result types for the body function.
1236       for (Type ty : func_type.getResults()) {
1237         updated_result_types.push_back(ty);
1238       }
1239     } else {
1240       UpdateTensorListTypes<mlir::OperandRange>(
1241           tensor_list_args, resized_tensor_lists, func_type.getResults(),
1242           op.getOperands(), operands, &updated_result_types);
1243     }
1244 
1245     UpdateFunctionAndRegionType(rewriter, func, updated_argument_types,
1246                                 updated_result_types);
1247   }
1248   return success();
1249 }
1250 
1251 // Changes the function type of `then_function` and `else_function` for the
1252 // given If op.
UpdateFunctionTypesForIfOp(ConversionPatternRewriter & rewriter,TF::IfOp op,ValueRange operands,const llvm::SmallSet<int,4> & tensor_list_args,const llvm::SmallSet<int,4> & resized_tensor_lists,llvm::ArrayRef<Type> updated_result_types)1253 LogicalResult UpdateFunctionTypesForIfOp(
1254     ConversionPatternRewriter &rewriter, TF::IfOp op, ValueRange operands,
1255     const llvm::SmallSet<int, 4> &tensor_list_args,
1256     const llvm::SmallSet<int, 4> &resized_tensor_lists,
1257     llvm::ArrayRef<Type> updated_result_types) {
1258   for (func::FuncOp func : {op.else_function(), op.then_function()}) {
1259     if (!func) continue;
1260 
1261     FunctionType func_type = func.getFunctionType();
1262     int num_inputs = func_type.getNumInputs();
1263 
1264     // Update the argument types of the function. If it's a tensorlist and
1265     // is not resized inside the function, we will use the corresponding
1266     // operand's type, otherwise change its type to unranked tensor type.
1267     SmallVector<Type, 8> updated_argument_types;
1268     updated_argument_types.reserve(num_inputs);
1269     UpdateTensorListTypes<mlir::OperandRange>(
1270         tensor_list_args, resized_tensor_lists, func_type.getInputs(),
1271         op.getOperands().drop_front(), operands.drop_front(),
1272         &updated_argument_types);
1273 
1274     UpdateFunctionAndRegionType(rewriter, func, updated_argument_types,
1275                                 updated_result_types);
1276   }
1277   return success();
1278 }
1279 
1280 // Returns a `llvm::DenseMap` which maps from the index of tensorlist in the
1281 // result, to the index of the same tensorlist in the arguments. For `If` op's
1282 // branch functions, the results and arguments are not usually matched 1-1. This
1283 // will let us konw which tensorlist result maps to which tensorlist in the
1284 // arguments. Once we know this info it will help us decide the types of the
1285 // result tensorlist based on the operand's of the `If` op.
MapTensorListResultToArgument(func::FuncOp func)1286 llvm::DenseMap<int, int> MapTensorListResultToArgument(func::FuncOp func) {
1287   // `map_fn` will trace upwards along the use-def chain of the ssa value. It
1288   // starts from the last ssa value (returned by the function), and check its
1289   // parent op iteratively. If the root ssa value appears in the function's
1290   // argument list, it will return the index of the corresponding argument,
1291   // otherwise it will return -1.
1292   auto map_fn = [](Value value) -> int {
1293     Value parent = value;
1294     while (true) {
1295       if (auto identity = parent.getDefiningOp<TF::IdentityOp>()) {
1296         parent = identity.input();
1297       } else if (auto set_item =
1298                      parent.getDefiningOp<TF::TensorListSetItemOp>()) {
1299         parent = set_item.input_handle();
1300       } else {
1301         break;
1302       }
1303     }
1304     if (auto block_arg = parent.dyn_cast<mlir::BlockArgument>()) {
1305       return block_arg.getArgNumber();
1306     }
1307     // Returns -1 if we don't find which this result maps to.
1308     return -1;
1309   };
1310 
1311   llvm::SmallVector<Value, 4> returns;
1312   for (auto res : func.getBody().back().getTerminator()->getOperands()) {
1313     returns.push_back(res);
1314   }
1315   llvm::DenseMap<int, int> result;
1316   for (const auto &result_and_idx : llvm::enumerate(returns)) {
1317     if (IsTensorListType(result_and_idx.value().getType(),
1318                          result_and_idx.value())) {
1319       int arg_idx = map_fn(result_and_idx.value());
1320       if (arg_idx != -1) {
1321         result.insert({result_and_idx.index(), arg_idx});
1322       }
1323     }
1324   }
1325   return result;
1326 }
1327 
1328 // Updates the tensorlist result types for the `If` Op. If the tensorlist result
1329 // maps to a specific argument (indicated by `tensor_list_map`), and also that
1330 // tensorlist argument's shape isn't changed (indicated by
1331 // `resized_tensor_list_index`), we will update this tensorlist's result type to
1332 // the corresponding operand's type. In all other cases we change the
1333 // tensorlist's type to unranked tensor type.
1334 template <typename R>
UpdateTensorListResultTypesForIf(const llvm::SmallSet<int,4> & tensor_list_index,const llvm::SmallSet<int,4> & resized_tensor_list_index,const llvm::DenseMap<int,int> & tensor_list_map,ArrayRef<Type> types,R && range,ValueRange operands,llvm::SmallVectorImpl<Type> * updated_types)1335 void UpdateTensorListResultTypesForIf(
1336     const llvm::SmallSet<int, 4> &tensor_list_index,
1337     const llvm::SmallSet<int, 4> &resized_tensor_list_index,
1338     const llvm::DenseMap<int, int> &tensor_list_map, ArrayRef<Type> types,
1339     R &&range, ValueRange operands,
1340     llvm::SmallVectorImpl<Type> *updated_types) {
1341   int i = 0;
1342   for (const auto it : llvm::zip(types, range)) {
1343     if (!tensor_list_index.count(i)) {
1344       updated_types->push_back(std::get<0>(it));
1345       ++i;
1346       continue;
1347     }
1348     auto iter = tensor_list_map.find(i);
1349     if (iter != tensor_list_map.end()) {
1350       int arg_idx = iter->second;
1351       if (!resized_tensor_list_index.count(arg_idx)) {
1352         // If the mapped tensorlist argument's size isn't changed, we will
1353         // use the corresponding `operand` type.
1354         updated_types->push_back(operands[arg_idx].getType());
1355         ++i;
1356         continue;
1357       }
1358     }
1359     updated_types->push_back(
1360         VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
1361     ++i;
1362   }
1363 }
1364 
1365 struct ConvertIf : public OpConversionPattern<TF::IfOp> {
1366   using OpConversionPattern::OpConversionPattern;
1367 
matchAndRewritemlir::__anonaaa99d6c0111::ConvertIf1368   LogicalResult matchAndRewrite(
1369       TF::IfOp op, OpAdaptor adaptor,
1370       ConversionPatternRewriter &rewriter) const override {
1371     // Find all Tensor List arugments.
1372     auto tensor_list_args = GetTensorListArgumentsIndex(op.else_function());
1373     auto tensor_list_results = GetTensorListResultsIndex(op.else_function());
1374     auto tensor_list_map = MapTensorListResultToArgument(op.else_function());
1375     llvm::SmallSet<int, 4> resized_tensor_lists =
1376         GetResizedTensorListIndexes(op.else_function(), tensor_list_args);
1377 
1378     llvm::SmallVector<Type, 8> result_types;
1379     result_types.reserve(op.getNumResults());
1380     llvm::SmallVector<Type, 4> op_result_types;
1381     for (Type ty : op.getResultTypes()) {
1382       op_result_types.push_back(ty);
1383     }
1384 
1385     UpdateTensorListResultTypesForIf<mlir::ResultRange>(
1386         tensor_list_results, resized_tensor_lists, tensor_list_map,
1387         op_result_types, op->getResults(), adaptor.getOperands().drop_front(),
1388         &result_types);
1389 
1390     // Create a new if op with new operands and updated result types.
1391     auto converted = rewriter.create<TF::IfOp>(
1392         op.getLoc(), result_types, adaptor.getOperands(), op->getAttrs());
1393     converted->removeAttr("T");
1394     (void)UpdateFunctionTypesForIfOp(rewriter, converted, adaptor.getOperands(),
1395                                      tensor_list_args, resized_tensor_lists,
1396                                      result_types);
1397     rewriter.replaceOp(op, converted.getResults());
1398     return success();
1399   }
1400 };
1401 
1402 struct ConvertWhile : public OpConversionPattern<TF::WhileOp> {
1403   using OpConversionPattern::OpConversionPattern;
1404 
matchAndRewritemlir::__anonaaa99d6c0111::ConvertWhile1405   LogicalResult matchAndRewrite(
1406       TF::WhileOp op, OpAdaptor adaptor,
1407       ConversionPatternRewriter &rewriter) const override {
1408     // Find all Tensor List arugments.
1409     auto tensor_list_args = GetTensorListArgumentsIndex(op.body_function());
1410 
1411     llvm::SmallVector<Type, 8> result_types;
1412     result_types.reserve(op.getNumOperands());
1413     // Change all DT_VARIANT result types to unranked tensor type.
1414     llvm::SmallVector<Type, 4> op_result_types;
1415     for (Type ty : op.getResultTypes()) {
1416       op_result_types.push_back(ty);
1417     }
1418 
1419     llvm::SmallSet<int, 4> resized_tensor_lists =
1420         GetResizedTensorListIndexes(op.body_function(), tensor_list_args);
1421     UpdateTensorListTypes<mlir::OperandRange>(
1422         tensor_list_args, resized_tensor_lists, op_result_types,
1423         op.getOperands(), adaptor.getOperands(), &result_types);
1424 
1425     // Create a new while op with new operands and updated result types.
1426     auto converted = rewriter.create<TF::WhileOp>(
1427         op.getLoc(), result_types, adaptor.getOperands(), op->getAttrs());
1428     converted->removeAttr("T");
1429     (void)UpdateFunctionTypesForWhileOp(rewriter, converted,
1430                                         adaptor.getOperands(), tensor_list_args,
1431                                         resized_tensor_lists);
1432 
1433     rewriter.replaceOp(op, converted.getResults());
1434     return success();
1435   }
1436 };
1437 
1438 struct ConvertWhileRegion : public OpConversionPattern<TF::WhileRegionOp> {
1439   using OpConversionPattern::OpConversionPattern;
1440 
matchAndRewritemlir::__anonaaa99d6c0111::ConvertWhileRegion1441   LogicalResult matchAndRewrite(
1442       TF::WhileRegionOp op, OpAdaptor adaptor,
1443       ConversionPatternRewriter &rewriter) const override {
1444     llvm::SmallVector<Type, 8> result_types;
1445     result_types.reserve(op.getNumOperands());
1446     // Change all DT_VARIANT result types to unranked tensor type.
1447     for (auto it : llvm::zip(op.getResultTypes(), adaptor.getOperands()))
1448       result_types.push_back(
1449           VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
1450 
1451     // Create a new while op with new operands and updated result types.
1452     auto converted = rewriter.create<TF::WhileRegionOp>(
1453         op.getLoc(), result_types, adaptor.getOperands(), op->getAttrs());
1454 
1455     // Inline the regions from the old while into the new one, and apply
1456     // signature conversion to inlined region.
1457     for (auto it : llvm::zip(op.getRegions(), converted.getRegions())) {
1458       Region &old_region = *std::get<0>(it);
1459       Region &new_region = *std::get<1>(it);
1460 
1461       Block &entry = old_region.front();
1462       // Build signature conversion for the region.
1463       TypeConverter::SignatureConversion signature_conversion(
1464           adaptor.getOperands().size());
1465       for (auto it : llvm::zip(entry.getArguments(), adaptor.getOperands())) {
1466         BlockArgument arg = std::get<0>(it);
1467         signature_conversion.addInputs(
1468             arg.getArgNumber(),
1469             VariantToUnrankedTensorType(arg.getType(), std::get<1>(it)));
1470       }
1471 
1472       rewriter.inlineRegionBefore(old_region, new_region, new_region.end());
1473       rewriter.applySignatureConversion(&new_region, signature_conversion);
1474     }
1475 
1476     rewriter.replaceOp(op, converted.getResults());
1477     return success();
1478   }
1479 };
1480 
1481 #include "tensorflow/compiler/mlir/lite/transforms/generated_lower_static_tensor_list.inc"
1482 
runOnOperation()1483 void LowerStaticTensorListPass::runOnOperation() {
1484   auto *context = &getContext();
1485 
1486   // TensorFlow operations that doesn't have operands and results of type
1487   // variant are legal. Here, we don't distinguish between variants encoding
1488   // TensorList or some other type as that information is not available here.
1489   // Partial legalization is used below to still allow ops with variant types
1490   // still.
1491   auto is_legal = [](Operation *op) {
1492     auto is_not_variant = [](Type ty) {
1493       return !ty.cast<ShapedType>().getElementType().isa<TF::VariantType>();
1494     };
1495     return llvm::all_of(op->getOperandTypes(), is_not_variant) &&
1496            llvm::all_of(op->getResultTypes(), is_not_variant);
1497   };
1498 
1499   ConversionTarget target(*context);
1500   target.addDynamicallyLegalDialect<TF::TensorFlowDialect>(is_legal);
1501   target.addIllegalOp<TF::EmptyTensorListOp, TF::TensorListFromTensorOp,
1502                       TF::TensorListGetItemOp, TF::TensorListLengthOp,
1503                       TF::TensorListPushBackOp, TF::TensorListReserveOp,
1504                       TF::TensorListSetItemOp, TF::TensorListStackOp,
1505                       TF::TensorListResizeOp, TF::TensorListConcatV2Op>();
1506   // TODO(hinsu): Use TFLite constant op for constants.
1507   target.addLegalOp<arith::ConstantOp>();
1508   target.addLegalOp<func::FuncOp>();
1509   target.addDynamicallyLegalOp<func::ReturnOp>(is_legal);
1510   target.addDynamicallyLegalOp<TF::YieldOp>(is_legal);
1511   target.addLegalOp<TFL::CustomOp>();
1512   // Register fused LSTM/RNN ops as legal.
1513   target.addLegalOp<TFL::LSTMOp>();
1514   target.addLegalOp<TFL::UnidirectionalSequenceLSTMOp>();
1515   target.addLegalOp<TFL::UnidirectionalSequenceRNNOp>();
1516   target.addLegalOp<TFL::BidirectionalSequenceLSTMOp>();
1517 
1518   RewritePatternSet patterns(&getContext());
1519   populateWithGenerated(patterns);
1520   patterns.add<ConvertConst, ConvertIdentity, ConvertTensorListGetItem,
1521                ConvertTensorListLength, ConvertTensorListPushBack,
1522                ConvertTensorListStack, ConvertTensorListResize, ConvertWhile,
1523                ConvertWhileRegion, ConvertIf, ConvertReturn, ConvertYield>(
1524       context);
1525   patterns.add<ConvertTensorListSetItem>(context,
1526                                          this->enable_dynamic_update_slice_);
1527   patterns.add<ConvertEmptyTensorList, ConvertTensorListConcatV2,
1528                ConvertTensorListReserve>(context,
1529                                          this->allow_tensorlist_pass_through_,
1530                                          this->default_to_single_batch_);
1531   ModuleOp module = getOperation();
1532   if (!this->allow_tensorlist_pass_through_) {
1533     if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
1534       module.emitError(
1535           "Lowering tensor list ops is failed. Please consider using Select TF "
1536           "ops and disabling `_experimental_lower_tensor_list_ops` flag in the "
1537           "TFLite converter object. For example, "
1538           "converter.target_spec.supported_ops = "
1539           "[tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]\\n "
1540           "converter._experimental_lower_tensor_list_ops = False");
1541       signalPassFailure();
1542     }
1543   } else {
1544     // If `allow_tensorlist_pass_through` is set to true, if legalization fails
1545     // we should not leak the diagnostic info outside this pass. Hence we use
1546     // a `StatusScopedDiagnosticHandler` here to capture diagnostics generated
1547     // within this pass.
1548     StatusScopedDiagnosticHandler handler(context);
1549     if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
1550       auto _ = handler.ConsumeStatus();
1551     }
1552   }
1553 }
1554 
1555 }  // namespace
1556 
1557 /// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
1558 /// pass.
CreateLowerStaticTensorListPass(bool allow_tensorlist_pass_through,bool default_to_single_batch,bool enable_dynamic_update_slice)1559 std::unique_ptr<OperationPass<ModuleOp>> TFL::CreateLowerStaticTensorListPass(
1560     bool allow_tensorlist_pass_through, bool default_to_single_batch,
1561     bool enable_dynamic_update_slice) {
1562   return std::make_unique<LowerStaticTensorListPass>(
1563       allow_tensorlist_pass_through, default_to_single_batch,
1564       enable_dynamic_update_slice);
1565 }
1566 
1567 std::unique_ptr<OperationPass<ModuleOp>>
CreateLowerStaticTensorListPass()1568 TFL::CreateLowerStaticTensorListPass() {
1569   return std::make_unique<LowerStaticTensorListPass>();
1570 }
1571 
1572 }  // namespace mlir
1573