xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Transform pass for LSTMs.
17 
18 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_HELPER
19 #define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_HELPER
20 
21 #include <algorithm>
22 #include <cmath>
23 #include <memory>
24 #include <string>
25 #include <vector>
26 
27 #include "absl/container/flat_hash_set.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/Support/Casting.h"
30 #include "llvm/Support/MathExtras.h"
31 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
32 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
33 #include "mlir/IR/Attributes.h"  // from @llvm-project
34 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
35 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
36 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
37 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
38 #include "mlir/IR/Value.h"  // from @llvm-project
39 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
40 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
41 #include "tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h"
42 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
43 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
44 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
45 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
46 #include "tensorflow/core/framework/types.pb.h"
47 #include "tensorflow/lite/schema/schema_generated.h"
48 #include "tensorflow/lite/tools/optimize/operator_property.h"
49 
50 //===----------------------------------------------------------------------===//
51 // The prepare-quantize Pass for LSTM.
52 //
53 namespace mlir {
54 namespace TFL {
55 
56 constexpr double power_of_two_scale = 32768.0;
57 
58 // Same with the ordering of //tensorflow/compiler/mlir/lite/ir/tfl_ops.td
59 constexpr const char* intermediate_attributes[] = {
60     "input_to_input_intermediate", "input_to_forget_intermediate",
61     "input_to_cell_intermediate", "input_to_output_intermediate",
62     "effective_hidden_scale_intermediate"};
63 
64 // Calculates the minimum power of two that is not less than the value.
PowerOfTwoBound(double value)65 inline double PowerOfTwoBound(double value) {
66   return std::pow(2, std::ceil(std::log2(value)));
67 }
68 
69 // Returns the element type of LSTM's intermediate tensor designated by the
70 // index.
71 template <typename LstmOp>
GetIntermediateElementType(LstmOp op,int tensor_index)72 inline QuantizedType GetIntermediateElementType(LstmOp op, int tensor_index) {
73   if (tensor_index < 0 || tensor_index > 4) return nullptr;
74   TypeAttr attr = op->template getAttrOfType<TypeAttr>(
75       intermediate_attributes[tensor_index]);
76   if (!attr) {
77     return nullptr;
78   }
79   return QuantizedType::getQuantizedElementType(attr.getValue());
80 }
81 
82 namespace operator_property = ::tflite::optimize::operator_property;
83 using Q = quantfork::QuantizeCastOp;
84 using DQ = quantfork::DequantizeCastOp;
85 
86 template <typename LstmOp>
GetLstmProperty(LstmOp op,operator_property::OpVariant * lstm_variant,operator_property::OperatorProperty * op_property)87 LogicalResult GetLstmProperty(
88     LstmOp op, operator_property::OpVariant* lstm_variant,
89     operator_property::OperatorProperty* op_property) {
90   if (llvm::isa<TFL::LSTMOp>(op.getOperation())) {
91     lstm_variant->op_code = tflite::BuiltinOperator_LSTM;
92   } else if (llvm::isa<TFL::UnidirectionalSequenceLSTMOp>(op.getOperation())) {
93     lstm_variant->op_code =
94         tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM;
95   } else {
96     op.emitError("ConvertLstmStatsToQDQs pass only supports LSTMs.");
97     return failure();
98   }
99   lstm_variant->use_projection =
100       !op.projection_weights().getType().template isa<NoneType>();
101   lstm_variant->use_peephole =
102       !op.cell_to_output_weights().getType().template isa<NoneType>();
103   lstm_variant->use_layer_norm =
104       !op.forget_layer_norm_coefficients().getType().template isa<NoneType>();
105 
106   *op_property = operator_property::GetOperatorProperty(*lstm_variant);
107 
108   // TODO(b/176258587) move this to operator_property.cc if this is needed in
109   // other components, too.
110   bool use_cifg =
111       op.input_to_input_weights().getType().template isa<NoneType>();
112   if (use_cifg) {
113     const absl::flat_hash_set<int> cifg_non_inputs = {1, 5, 9, 12, 20};
114     const int cifg_non_intermediate = 0;
115     op_property->inputs.erase(
116         std::remove_if(
117             op_property->inputs.begin(), op_property->inputs.end(),
118             [&](std::pair<int, operator_property::TensorProperty> input) {
119               return cifg_non_inputs.find(input.first) != cifg_non_inputs.end();
120             }),
121         op_property->inputs.end());
122     op_property->intermediates.erase(
123         std::remove_if(op_property->intermediates.begin(),
124                        op_property->intermediates.end(),
125                        [&](std::pair<int, operator_property::TensorProperty>
126                                intermediate) {
127                          return intermediate.first == cifg_non_intermediate;
128                        }),
129         op_property->intermediates.end());
130   }
131   return success();
132 }
133 
134 template <typename SourceOp>
135 class PrepareLstmOutputScale : public OpRewritePattern<SourceOp> {
136  public:
PrepareLstmOutputScale(MLIRContext * context)137   explicit PrepareLstmOutputScale(MLIRContext* context)
138       : OpRewritePattern<SourceOp>(context) {}
matchAndRewrite(SourceOp op,PatternRewriter & rewriter)139   LogicalResult matchAndRewrite(SourceOp op,
140                                 PatternRewriter& rewriter) const override {
141     operator_property::OpVariant lstm_variant;
142     operator_property::OperatorProperty lstm_property;
143 
144     if (failed(GetLstmProperty(op, &lstm_variant, &lstm_property))) {
145       return failure();
146     }
147     if (lstm_property.restrict_scale.size() != 1) {
148       op.emitError() << "The LSTM's operator property expects exactly one "
149                      << "restrict scale requirement. Got "
150                      << lstm_property.restrict_scale.size()
151                      << " restrict scale requirements.";
152       return failure();
153     }
154 
155     // Use same scale for input and output specified in restrict_scale.
156     const std::vector<int>& tensors = lstm_property.restrict_scale[0];
157     if (tensors.size() != 2) {
158       op.emitError(
159           "Unexpected restricted_scale from operator property."
160           " Should only have a pair of indices.");
161       return failure();
162     }
163     return processRestrictScale(op, tensors[0], tensors[1], rewriter);
164   }
165 
166  private:
167   // For LSTM's recurrent input activation and output, they are quantized with
168   // the collective range of both tensors, because theoretically the input
169   // activation value for the very first inference is not reflected in the
170   // output and the input activation is not captured.
processRestrictScale(SourceOp op,int input_index,int output_index,PatternRewriter & rewriter)171   LogicalResult processRestrictScale(SourceOp op, int input_index,
172                                      int output_index,
173                                      PatternRewriter& rewriter) const {
174     assert(output_index == 0);
175     if (!op.getResult().hasOneUse()) {
176       op.emitError()
177           << "output " << output_index
178           << " should have only one use, which should be quant.stats.";
179       return failure();
180     }
181 
182     llvm::SmallVector<quantfork::StatisticsOp, 2> stats_ops = {
183         llvm::dyn_cast_or_null<quantfork::StatisticsOp>(
184             op.getOperand(input_index).getDefiningOp()),
185         llvm::dyn_cast_or_null<quantfork::StatisticsOp>(
186             *op.getResult().getUsers().begin()),
187     };
188 
189     if (!stats_ops[0] || !stats_ops[1]) {
190       return failure();  // Already converted to Q-DQ pair.
191     }
192 
193     llvm::SmallVector<llvm::APFloat, 4> min_max_values;
194 
195     for (auto& stats_op : stats_ops) {
196       auto values = stats_op.getLayerStats()
197                         .dyn_cast<DenseFPElementsAttr>()
198                         .getValues<llvm::APFloat>();
199       min_max_values.insert(min_max_values.end(), values.begin(), values.end());
200     }
201 
202     // min and max values of two stats are already the same.
203     if (min_max_values[0] == min_max_values[2] &&
204         min_max_values[1] == min_max_values[3]) {
205       return failure();
206     }
207 
208     mlir::ElementsAttr layer_stats = mlir::DenseFPElementsAttr::get(
209         mlir::RankedTensorType::get({2}, rewriter.getF32Type()),
210         {llvm::minimum(min_max_values[0], min_max_values[2]),
211          llvm::maximum(min_max_values[1], min_max_values[3])});
212     mlir::ElementsAttr axis_stats;
213     mlir::IntegerAttr axis;
214     for (auto& stats_op : stats_ops) {
215       rewriter.setInsertionPointAfter(stats_op);
216       rewriter.replaceOpWithNewOp<quantfork::StatisticsOp>(
217           stats_op, stats_op.getArg(), layer_stats, axis_stats, axis);
218     }
219     return success();
220   }
221 };
222 
223 template <typename SourceOp>
224 class ConvertOpStatsToQDQs : public OpRewritePattern<SourceOp> {
225  public:
226   explicit ConvertOpStatsToQDQs(MLIRContext* context,
227                                 const quant::QuantizationSpecs& quant_specs,
228                                 PatternBenefit benefit = 1)
229       : OpRewritePattern<SourceOp>(context, benefit),
230         quant_specs_(quant_specs) {}
231 
232  protected:
233   quant::QuantizationSpecs quant_specs_;
234 
processInputs(SourceOp op,const operator_property::OpVariant & op_variant,const operator_property::OperatorProperty & op_property,PatternRewriter & rewriter)235   LogicalResult processInputs(
236       SourceOp op, const operator_property::OpVariant& op_variant,
237       const operator_property::OperatorProperty& op_property,
238       PatternRewriter& rewriter) const {
239     for (auto& enumerated_inputs : op_property.inputs) {
240       int index = enumerated_inputs.first;
241       auto& tensor_property = enumerated_inputs.second;
242 
243       Value input = op.getOperand(index);
244 
245       if (input.getDefiningOp() == nullptr) continue;
246 
247       // TODO(b/172517537): make this work with non-PTQ case.
248       if (llvm::isa<func::ConstantOp, arith::ConstantOp, TFL::ConstOp>(
249               input.getDefiningOp())) {
250         // Tensors with derived scale are biases, and handled in propagation.
251         if (tensor_property.use_derived_scale) continue;
252         // For weights, use quantization scale inferred from the values.
253         if (failed(processConstantOp(op, input.getDefiningOp(), index,
254                                      tensor_property, rewriter))) {
255           return failure();
256         }
257       } else {
258         if (auto stats_op = llvm::dyn_cast<quantfork::StatisticsOp>(
259                 input.getDefiningOp())) {
260           if (failed(replaceStatsOp(op, stats_op, index, tensor_property,
261                                     rewriter))) {
262             return failure();
263           }
264         } else if (!llvm::isa<DQ>(input.getDefiningOp()) &&
265                    !llvm::isa<SameScalesOpInterface, FixedOutputRangeInterface>(
266                        input.getDefiningOp())) {
267           // Continue if StatisticsOp is already converted to Q-DQ pair, or
268           // stats op is not immediately available to the input because either
269           // it's connected to ops with same scale requirements or it has
270           // fixed output range.
271           // TODO(b/172517537): make this work with non-PTQ case.
272           return failure();
273         }
274       }
275     }
276     return success();
277   }
278 
processConstantOp(SourceOp op,Operation * const_op,int input_index,const operator_property::TensorProperty & tensor_property,PatternRewriter & rewriter)279   LogicalResult processConstantOp(
280       SourceOp op, Operation* const_op, int input_index,
281       const operator_property::TensorProperty& tensor_property,
282       PatternRewriter& rewriter) const {
283     // Non-float tensors are neither weights nor require quantization.
284     auto type = const_op->getResult(0).getType().dyn_cast<ShapedType>();
285     if (!type || !type.getElementType().isa<FloatType>()) return success();
286 
287     DenseFPElementsAttr attr;
288     if (!matchPattern(const_op->getResult(0), m_Constant(&attr))) {
289       const_op->emitError("Not a constant op.");
290       return failure();
291     }
292 
293     UniformQuantizedType quant_type = nullptr;
294     // When the number of bits is 10 (instead of 16), quantize the tensor to
295     // [-512, 512], instead of [-32767, 32767].
296     // For now this behavior is specific for SVDF, where 6 bits are reserved for
297     // the reduce operation after element-wise multiplication between state and
298     // time weights.
299     if (tensor_property.number_of_bits == 10) {
300       SmallVector<double, 4> mins(1, std::numeric_limits<double>::max());
301       SmallVector<double, 4> maxs(1, std::numeric_limits<double>::min());
302       // Computes the effective min/max values of the attribute values.
303       quant::ExtractMinMaxFromAttr(attr, /*dim_size=*/1, /*slice_size=*/1,
304                                    /*symmetric=*/true, mins, maxs);
305       double scale = maxs[0] / -llvm::minIntN(tensor_property.number_of_bits);
306       quant_type = UniformQuantizedType::getChecked(
307           const_op->getLoc(), quant::QuantizationFlags::Signed,
308           rewriter.getIntegerType(16), attr.getType().getElementType(), scale,
309           /*zeroPoint=*/0, llvm::minIntN(10), -llvm::minIntN(10));
310     } else {
311       quant_type =
312           quant::GetUniformQuantizedTypeForWeight(
313               attr, /*symmetric=*/true,
314               /*num_bits=*/tensor_property.number_of_bits, /*is_signed=*/true,
315               /*narrow_range=*/true, quant_specs_.legacy_float_scale)
316               .template dyn_cast<quant::UniformQuantizedType>();
317     }
318     if (!quant_type) {
319       const_op->emitError("Failed to get quantized type");
320       return failure();
321     }
322 
323     // TODO(b/172517537): duplicate the constant when the bias is shared.
324     Type expressed_type = const_op->getResult(0).getType();
325     Type cast_type = quant_type.castFromExpressedType(expressed_type);
326     rewriter.setInsertionPointAfter(const_op);
327     auto q = rewriter.create<Q>(const_op->getLoc(), cast_type,
328                                 const_op->getResult(0));
329     auto dq = rewriter.create<DQ>(const_op->getLoc(), expressed_type, q);
330     op.setOperand(input_index, dq.getResult());
331     return success();
332   }
333 
replaceStatsOp(SourceOp op,quantfork::StatisticsOp stats_op,int input_index,const operator_property::TensorProperty & tensor_property,PatternRewriter & rewriter)334   LogicalResult replaceStatsOp(
335       SourceOp op, quantfork::StatisticsOp stats_op, int input_index,
336       const operator_property::TensorProperty& tensor_property,
337       PatternRewriter& rewriter) const {
338     if (tensor_property.state_tensor && !stats_op.getResult().hasOneUse()) {
339       // TODO(b/172517537): check if other tensors should go through this
340       // check too.
341       op.emitError() << "Input tensor [" << input_index
342                      << "] is a state tensor, but has more than one use.";
343       return failure();
344     }
345     auto stats = stats_op.getLayerStats().dyn_cast<DenseFPElementsAttr>();
346     if (!stats || stats.getNumElements() != 2) {
347       stats_op.emitError("Stats should have 2 values.");
348       return failure();
349     }
350     quant::QuantizedType quant_type;
351     double min = FloatAttr::getValueAsDouble(stats.getValues<APFloat>()[0]);
352     double max = FloatAttr::getValueAsDouble(stats.getValues<APFloat>()[1]);
353     // Make sure the range includes zero.
354     min = std::min(min, 0.0);
355     max = std::max(max, 0.0);
356     Type expressed = getElementTypeOrSelf(stats_op.getType());
357 
358     if (tensor_property.extend_to_power_of_two) {
359       if (tensor_property.number_of_bits != 16) {
360         op.emitError(
361             "extended power of 2 scale is only supported for 16-bit"
362             " quantization.");
363         return failure();
364       }
365 
366       double bound = PowerOfTwoBound(std::max(std::abs(min), std::abs(max)));
367       // Set flags to 1 for signed type.
368       quant_type = UniformQuantizedType::getChecked(
369           op.getLoc(), quant::QuantizationFlags::Signed,
370           rewriter.getIntegerType(tensor_property.number_of_bits), expressed,
371           /*scale=*/bound / -llvm::minIntN(tensor_property.number_of_bits),
372           /*zeroPoint=*/0, llvm::minIntN(tensor_property.number_of_bits),
373           llvm::maxIntN(tensor_property.number_of_bits));
374     } else {
375       // int16 uses range [-32767, 32767]
376       if (tensor_property.number_of_bits == 16) {
377         max = std::max(std::abs(min), std::abs(max));
378         min = -max;
379         quant_type = quantfork::fakeQuantAttrsToType(
380             op.getLoc(), tensor_property.number_of_bits, min, max,
381             /*narrowRange=*/true, expressed,
382             /*isSigned=*/true);
383       } else {
384         quant_type = quantfork::fakeQuantAttrsToType(
385             op.getLoc(), tensor_property.number_of_bits, min, max,
386             /*narrowRange=*/false, expressed,
387             /*isSigned=*/true);
388       }
389       if (quant_specs_.legacy_float_scale) {
390         quant_type = quant::DownCastScale(quant_type, min, max, op.getLoc());
391       }
392     }
393     rewriter.setInsertionPointAfter(stats_op);
394     Type result_type = quant_type.castFromExpressedType(stats_op.getType());
395     auto q =
396         rewriter.create<Q>(stats_op.getLoc(), result_type, stats_op.getArg());
397     rewriter.replaceOpWithNewOp<DQ>(stats_op, stats_op.getType(), q);
398     return success();
399   }
400 };
401 
402 // Quantize LSTM according to its quantization recipe.
403 template <typename SourceOp>
404 class ConvertLstmStatsToQDQs : public ConvertOpStatsToQDQs<SourceOp> {
405  public:
ConvertLstmStatsToQDQs(MLIRContext * context,const quant::QuantizationSpecs & quant_specs)406   ConvertLstmStatsToQDQs(MLIRContext* context,
407                          const quant::QuantizationSpecs& quant_specs)
408 
409       : ConvertOpStatsToQDQs<SourceOp>(context, quant_specs) {}
matchAndRewrite(SourceOp op,PatternRewriter & rewriter)410   LogicalResult matchAndRewrite(SourceOp op,
411                                 PatternRewriter& rewriter) const override {
412     operator_property::OpVariant lstm_variant;
413     operator_property::OperatorProperty lstm_property;
414     if (failed(GetLstmProperty(op, &lstm_variant, &lstm_property))) {
415       return failure();
416     }
417 
418     if (failed(processIntermediates(op, lstm_variant, lstm_property)) ||
419         failed(ConvertOpStatsToQDQs<SourceOp>::processInputs(
420             op, lstm_variant, lstm_property, rewriter))) {
421       return failure();
422     }
423 
424     return success();
425   }
426 
427  private:
processIntermediates(SourceOp op,const operator_property::OpVariant & lstm_variant,const operator_property::OperatorProperty & lstm_property)428   LogicalResult processIntermediates(
429       SourceOp op, const operator_property::OpVariant& lstm_variant,
430       const operator_property::OperatorProperty& lstm_property) const {
431     for (auto& enumerated_intermediates : lstm_property.intermediates) {
432       int index = enumerated_intermediates.first;
433       auto& tensor_property = enumerated_intermediates.second;
434       // intermediate tensors 0, 1, 2, 3 are only used with layer normalization.
435       if (!lstm_variant.use_layer_norm && index != 4) {
436         continue;
437       }
438 
439       TypeAttr attr =
440           op->template getAttrOfType<TypeAttr>(intermediate_attributes[index]);
441       auto quant_type = GetIntermediateElementType<SourceOp>(op, index);
442       if (!quant_type) {
443         // intermediate tensor 4 is optional, unless the LSTM uses projection.
444         if (index == 4 && !lstm_variant.use_projection) {
445           return success();
446         }
447         op.emitError() << intermediate_attributes[index]
448                        << " is not quantized.";
449         return failure();
450       }
451       auto calibrated_type =
452           quant_type.template dyn_cast<quant::CalibratedQuantizedType>();
453       if (!calibrated_type) {
454         int num_storage_bits = quant_type.getStorageTypeIntegralWidth();
455         if (tensor_property.number_of_bits != num_storage_bits) {
456           op.emitError() << intermediate_attributes[index]
457                          << " is expected to be quantized with "
458                          << tensor_property.number_of_bits << " bits, but got "
459                          << num_storage_bits << " bits instead.";
460           return failure();
461         }
462         continue;  // skip if it is already quantized.
463       }
464       quant::UniformQuantizedType qtype;
465       if (tensor_property.number_of_bits == 8) {
466         qtype = quantfork::fakeQuantAttrsToType(
467             op.getLoc(), tensor_property.number_of_bits,
468             calibrated_type.getMin(), calibrated_type.getMax(),
469             /*narrowRange=*/false, calibrated_type.getExpressedType(),
470             /*isSigned=*/this->quant_specs_.IsSignedInferenceType());
471         if (this->quant_specs_.legacy_float_scale) {
472           qtype = quant::DownCastScale(qtype, calibrated_type.getMin(),
473                                        calibrated_type.getMax(), op.getLoc())
474                       .template cast<UniformQuantizedType>();
475         }
476       } else if (tensor_property.number_of_bits == 16) {
477         double max = std::max(std::abs(calibrated_type.getMin()),
478                               std::abs(calibrated_type.getMax()));
479         qtype = quantfork::fakeQuantAttrsToType(
480             op.getLoc(), tensor_property.number_of_bits, -max, max,
481             /*narrowRange=*/true, calibrated_type.getExpressedType(),
482             /*isSigned=*/true);
483       } else {
484         op.emitError() << "Unsupported quantization bits: "
485                        << tensor_property.number_of_bits;
486         return failure();
487       }
488       op->setAttr(intermediate_attributes[index],
489                   TypeAttr::get(qtype.castFromExpressedType(
490                       qtype.castToExpressedType(attr.getValue()))));
491     }
492     return success();
493   }
494 };
495 
496 // Returns a function that returns the quantized type of a bias input.
497 // The scale of bias is a multiplication of given scale and scales from the
498 // quantization type of other operands.
GetUniformQuantizedTypeForBiasWithScale(double scale)499 inline quant::AccumulatorScaleFunc GetUniformQuantizedTypeForBiasWithScale(
500     double scale) {
501   return [=](const std::vector<quant::QuantParams>& quant_params,
502              bool legacy_float_scale) -> quant::QuantParams {
503     if (auto qtype = quant::GetUniformQuantizedTypeForBias(quant_params,
504                                                            legacy_float_scale)
505                          .dyn_cast_or_null<UniformQuantizedType>()) {
506       return quant::UniformQuantizedType::get(
507           qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(),
508           qtype.getScale() * scale, qtype.getZeroPoint(),
509           qtype.getStorageTypeMin(), qtype.getStorageTypeMax());
510     }
511     return {};
512   };
513 }
514 
515 // Returns quantization spec for LSTMs based on their operator properties.
516 template <typename LstmOp>
GetLstmOpQuantSpec(LstmOp op)517 std::unique_ptr<quant::OpQuantSpec> GetLstmOpQuantSpec(LstmOp op) {
518   operator_property::OpVariant lstm_variant;
519   operator_property::OperatorProperty lstm_property;
520   if (failed(GetLstmProperty(op, &lstm_variant, &lstm_property))) {
521     return nullptr;
522   }
523 
524   auto spec = std::make_unique<quant::OpQuantSpec>();
525 
526   for (const auto& enumerated_inputs : lstm_property.inputs) {
527     int index = enumerated_inputs.first;
528     auto& tensor_property = enumerated_inputs.second;
529     if (tensor_property.use_derived_scale) {
530       double scale = 1.0;
531       for (int tensor_index :
532            tensor_property.derived_scale.intermediate_tensors) {
533         auto quant_type = GetIntermediateElementType<LstmOp>(op, tensor_index);
534         if (!quant_type ||
535             !quant_type.template isa<quant::UniformQuantizedType>()) {
536           op->emitError() << "While processing derived scale, intermediate "
537                           << intermediate_attributes[tensor_index]
538                           << " is not quantized.";
539           return nullptr;
540         }
541         scale *= quant_type.template dyn_cast<quant::UniformQuantizedType>()
542                      .getScale();
543       }
544       for (float factor : tensor_property.derived_scale.factors) {
545         scale *= factor;
546       }
547       spec->biases_params.emplace(
548           index,
549           std::make_pair(tensor_property.derived_scale.input_tensors,
550                          GetUniformQuantizedTypeForBiasWithScale(scale)));
551     }
552   }
553   return spec;
554 }
555 
556 class ConvertSvdfStatsToQDQs : public ConvertOpStatsToQDQs<TFL::SVDFOp> {
557  public:
ConvertSvdfStatsToQDQs(MLIRContext * context,const quant::QuantizationSpecs & quant_specs_param)558   explicit ConvertSvdfStatsToQDQs(
559       MLIRContext* context, const quant::QuantizationSpecs& quant_specs_param)
560       : ConvertOpStatsToQDQs<TFL::SVDFOp>(context, quant_specs_param) {}
matchAndRewrite(TFL::SVDFOp op,PatternRewriter & rewriter)561   LogicalResult matchAndRewrite(TFL::SVDFOp op,
562                                 PatternRewriter& rewriter) const override {
563     operator_property::OpVariant op_variant;
564     op_variant.op_code = tflite::BuiltinOperator_SVDF;
565     auto op_property = operator_property::GetOperatorProperty(op_variant);
566     return ConvertOpStatsToQDQs<TFL::SVDFOp>::processInputs(
567         op, op_variant, op_property, rewriter);
568   }
569 };
570 
571 }  // namespace TFL
572 }  // namespace mlir
573 
574 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_HELPER
575