1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Transform pass for LSTMs.
17
18 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_HELPER
19 #define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_HELPER
20
21 #include <algorithm>
22 #include <cmath>
23 #include <memory>
24 #include <string>
25 #include <vector>
26
27 #include "absl/container/flat_hash_set.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/Support/Casting.h"
30 #include "llvm/Support/MathExtras.h"
31 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
32 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
33 #include "mlir/IR/Attributes.h" // from @llvm-project
34 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
35 #include "mlir/IR/OpDefinition.h" // from @llvm-project
36 #include "mlir/IR/PatternMatch.h" // from @llvm-project
37 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
38 #include "mlir/IR/Value.h" // from @llvm-project
39 #include "mlir/Support/LogicalResult.h" // from @llvm-project
40 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
41 #include "tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h"
42 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
43 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
44 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
45 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
46 #include "tensorflow/core/framework/types.pb.h"
47 #include "tensorflow/lite/schema/schema_generated.h"
48 #include "tensorflow/lite/tools/optimize/operator_property.h"
49
50 //===----------------------------------------------------------------------===//
51 // The prepare-quantize Pass for LSTM.
52 //
53 namespace mlir {
54 namespace TFL {
55
56 constexpr double power_of_two_scale = 32768.0;
57
58 // Same with the ordering of //tensorflow/compiler/mlir/lite/ir/tfl_ops.td
59 constexpr const char* intermediate_attributes[] = {
60 "input_to_input_intermediate", "input_to_forget_intermediate",
61 "input_to_cell_intermediate", "input_to_output_intermediate",
62 "effective_hidden_scale_intermediate"};
63
64 // Calculates the minimum power of two that is not less than the value.
PowerOfTwoBound(double value)65 inline double PowerOfTwoBound(double value) {
66 return std::pow(2, std::ceil(std::log2(value)));
67 }
68
69 // Returns the element type of LSTM's intermediate tensor designated by the
70 // index.
71 template <typename LstmOp>
GetIntermediateElementType(LstmOp op,int tensor_index)72 inline QuantizedType GetIntermediateElementType(LstmOp op, int tensor_index) {
73 if (tensor_index < 0 || tensor_index > 4) return nullptr;
74 TypeAttr attr = op->template getAttrOfType<TypeAttr>(
75 intermediate_attributes[tensor_index]);
76 if (!attr) {
77 return nullptr;
78 }
79 return QuantizedType::getQuantizedElementType(attr.getValue());
80 }
81
82 namespace operator_property = ::tflite::optimize::operator_property;
83 using Q = quantfork::QuantizeCastOp;
84 using DQ = quantfork::DequantizeCastOp;
85
86 template <typename LstmOp>
GetLstmProperty(LstmOp op,operator_property::OpVariant * lstm_variant,operator_property::OperatorProperty * op_property)87 LogicalResult GetLstmProperty(
88 LstmOp op, operator_property::OpVariant* lstm_variant,
89 operator_property::OperatorProperty* op_property) {
90 if (llvm::isa<TFL::LSTMOp>(op.getOperation())) {
91 lstm_variant->op_code = tflite::BuiltinOperator_LSTM;
92 } else if (llvm::isa<TFL::UnidirectionalSequenceLSTMOp>(op.getOperation())) {
93 lstm_variant->op_code =
94 tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM;
95 } else {
96 op.emitError("ConvertLstmStatsToQDQs pass only supports LSTMs.");
97 return failure();
98 }
99 lstm_variant->use_projection =
100 !op.projection_weights().getType().template isa<NoneType>();
101 lstm_variant->use_peephole =
102 !op.cell_to_output_weights().getType().template isa<NoneType>();
103 lstm_variant->use_layer_norm =
104 !op.forget_layer_norm_coefficients().getType().template isa<NoneType>();
105
106 *op_property = operator_property::GetOperatorProperty(*lstm_variant);
107
108 // TODO(b/176258587) move this to operator_property.cc if this is needed in
109 // other components, too.
110 bool use_cifg =
111 op.input_to_input_weights().getType().template isa<NoneType>();
112 if (use_cifg) {
113 const absl::flat_hash_set<int> cifg_non_inputs = {1, 5, 9, 12, 20};
114 const int cifg_non_intermediate = 0;
115 op_property->inputs.erase(
116 std::remove_if(
117 op_property->inputs.begin(), op_property->inputs.end(),
118 [&](std::pair<int, operator_property::TensorProperty> input) {
119 return cifg_non_inputs.find(input.first) != cifg_non_inputs.end();
120 }),
121 op_property->inputs.end());
122 op_property->intermediates.erase(
123 std::remove_if(op_property->intermediates.begin(),
124 op_property->intermediates.end(),
125 [&](std::pair<int, operator_property::TensorProperty>
126 intermediate) {
127 return intermediate.first == cifg_non_intermediate;
128 }),
129 op_property->intermediates.end());
130 }
131 return success();
132 }
133
134 template <typename SourceOp>
135 class PrepareLstmOutputScale : public OpRewritePattern<SourceOp> {
136 public:
PrepareLstmOutputScale(MLIRContext * context)137 explicit PrepareLstmOutputScale(MLIRContext* context)
138 : OpRewritePattern<SourceOp>(context) {}
matchAndRewrite(SourceOp op,PatternRewriter & rewriter)139 LogicalResult matchAndRewrite(SourceOp op,
140 PatternRewriter& rewriter) const override {
141 operator_property::OpVariant lstm_variant;
142 operator_property::OperatorProperty lstm_property;
143
144 if (failed(GetLstmProperty(op, &lstm_variant, &lstm_property))) {
145 return failure();
146 }
147 if (lstm_property.restrict_scale.size() != 1) {
148 op.emitError() << "The LSTM's operator property expects exactly one "
149 << "restrict scale requirement. Got "
150 << lstm_property.restrict_scale.size()
151 << " restrict scale requirements.";
152 return failure();
153 }
154
155 // Use same scale for input and output specified in restrict_scale.
156 const std::vector<int>& tensors = lstm_property.restrict_scale[0];
157 if (tensors.size() != 2) {
158 op.emitError(
159 "Unexpected restricted_scale from operator property."
160 " Should only have a pair of indices.");
161 return failure();
162 }
163 return processRestrictScale(op, tensors[0], tensors[1], rewriter);
164 }
165
166 private:
167 // For LSTM's recurrent input activation and output, they are quantized with
168 // the collective range of both tensors, because theoretically the input
169 // activation value for the very first inference is not reflected in the
170 // output and the input activation is not captured.
processRestrictScale(SourceOp op,int input_index,int output_index,PatternRewriter & rewriter)171 LogicalResult processRestrictScale(SourceOp op, int input_index,
172 int output_index,
173 PatternRewriter& rewriter) const {
174 assert(output_index == 0);
175 if (!op.getResult().hasOneUse()) {
176 op.emitError()
177 << "output " << output_index
178 << " should have only one use, which should be quant.stats.";
179 return failure();
180 }
181
182 llvm::SmallVector<quantfork::StatisticsOp, 2> stats_ops = {
183 llvm::dyn_cast_or_null<quantfork::StatisticsOp>(
184 op.getOperand(input_index).getDefiningOp()),
185 llvm::dyn_cast_or_null<quantfork::StatisticsOp>(
186 *op.getResult().getUsers().begin()),
187 };
188
189 if (!stats_ops[0] || !stats_ops[1]) {
190 return failure(); // Already converted to Q-DQ pair.
191 }
192
193 llvm::SmallVector<llvm::APFloat, 4> min_max_values;
194
195 for (auto& stats_op : stats_ops) {
196 auto values = stats_op.getLayerStats()
197 .dyn_cast<DenseFPElementsAttr>()
198 .getValues<llvm::APFloat>();
199 min_max_values.insert(min_max_values.end(), values.begin(), values.end());
200 }
201
202 // min and max values of two stats are already the same.
203 if (min_max_values[0] == min_max_values[2] &&
204 min_max_values[1] == min_max_values[3]) {
205 return failure();
206 }
207
208 mlir::ElementsAttr layer_stats = mlir::DenseFPElementsAttr::get(
209 mlir::RankedTensorType::get({2}, rewriter.getF32Type()),
210 {llvm::minimum(min_max_values[0], min_max_values[2]),
211 llvm::maximum(min_max_values[1], min_max_values[3])});
212 mlir::ElementsAttr axis_stats;
213 mlir::IntegerAttr axis;
214 for (auto& stats_op : stats_ops) {
215 rewriter.setInsertionPointAfter(stats_op);
216 rewriter.replaceOpWithNewOp<quantfork::StatisticsOp>(
217 stats_op, stats_op.getArg(), layer_stats, axis_stats, axis);
218 }
219 return success();
220 }
221 };
222
223 template <typename SourceOp>
224 class ConvertOpStatsToQDQs : public OpRewritePattern<SourceOp> {
225 public:
226 explicit ConvertOpStatsToQDQs(MLIRContext* context,
227 const quant::QuantizationSpecs& quant_specs,
228 PatternBenefit benefit = 1)
229 : OpRewritePattern<SourceOp>(context, benefit),
230 quant_specs_(quant_specs) {}
231
232 protected:
233 quant::QuantizationSpecs quant_specs_;
234
processInputs(SourceOp op,const operator_property::OpVariant & op_variant,const operator_property::OperatorProperty & op_property,PatternRewriter & rewriter)235 LogicalResult processInputs(
236 SourceOp op, const operator_property::OpVariant& op_variant,
237 const operator_property::OperatorProperty& op_property,
238 PatternRewriter& rewriter) const {
239 for (auto& enumerated_inputs : op_property.inputs) {
240 int index = enumerated_inputs.first;
241 auto& tensor_property = enumerated_inputs.second;
242
243 Value input = op.getOperand(index);
244
245 if (input.getDefiningOp() == nullptr) continue;
246
247 // TODO(b/172517537): make this work with non-PTQ case.
248 if (llvm::isa<func::ConstantOp, arith::ConstantOp, TFL::ConstOp>(
249 input.getDefiningOp())) {
250 // Tensors with derived scale are biases, and handled in propagation.
251 if (tensor_property.use_derived_scale) continue;
252 // For weights, use quantization scale inferred from the values.
253 if (failed(processConstantOp(op, input.getDefiningOp(), index,
254 tensor_property, rewriter))) {
255 return failure();
256 }
257 } else {
258 if (auto stats_op = llvm::dyn_cast<quantfork::StatisticsOp>(
259 input.getDefiningOp())) {
260 if (failed(replaceStatsOp(op, stats_op, index, tensor_property,
261 rewriter))) {
262 return failure();
263 }
264 } else if (!llvm::isa<DQ>(input.getDefiningOp()) &&
265 !llvm::isa<SameScalesOpInterface, FixedOutputRangeInterface>(
266 input.getDefiningOp())) {
267 // Continue if StatisticsOp is already converted to Q-DQ pair, or
268 // stats op is not immediately available to the input because either
269 // it's connected to ops with same scale requirements or it has
270 // fixed output range.
271 // TODO(b/172517537): make this work with non-PTQ case.
272 return failure();
273 }
274 }
275 }
276 return success();
277 }
278
processConstantOp(SourceOp op,Operation * const_op,int input_index,const operator_property::TensorProperty & tensor_property,PatternRewriter & rewriter)279 LogicalResult processConstantOp(
280 SourceOp op, Operation* const_op, int input_index,
281 const operator_property::TensorProperty& tensor_property,
282 PatternRewriter& rewriter) const {
283 // Non-float tensors are neither weights nor require quantization.
284 auto type = const_op->getResult(0).getType().dyn_cast<ShapedType>();
285 if (!type || !type.getElementType().isa<FloatType>()) return success();
286
287 DenseFPElementsAttr attr;
288 if (!matchPattern(const_op->getResult(0), m_Constant(&attr))) {
289 const_op->emitError("Not a constant op.");
290 return failure();
291 }
292
293 UniformQuantizedType quant_type = nullptr;
294 // When the number of bits is 10 (instead of 16), quantize the tensor to
295 // [-512, 512], instead of [-32767, 32767].
296 // For now this behavior is specific for SVDF, where 6 bits are reserved for
297 // the reduce operation after element-wise multiplication between state and
298 // time weights.
299 if (tensor_property.number_of_bits == 10) {
300 SmallVector<double, 4> mins(1, std::numeric_limits<double>::max());
301 SmallVector<double, 4> maxs(1, std::numeric_limits<double>::min());
302 // Computes the effective min/max values of the attribute values.
303 quant::ExtractMinMaxFromAttr(attr, /*dim_size=*/1, /*slice_size=*/1,
304 /*symmetric=*/true, mins, maxs);
305 double scale = maxs[0] / -llvm::minIntN(tensor_property.number_of_bits);
306 quant_type = UniformQuantizedType::getChecked(
307 const_op->getLoc(), quant::QuantizationFlags::Signed,
308 rewriter.getIntegerType(16), attr.getType().getElementType(), scale,
309 /*zeroPoint=*/0, llvm::minIntN(10), -llvm::minIntN(10));
310 } else {
311 quant_type =
312 quant::GetUniformQuantizedTypeForWeight(
313 attr, /*symmetric=*/true,
314 /*num_bits=*/tensor_property.number_of_bits, /*is_signed=*/true,
315 /*narrow_range=*/true, quant_specs_.legacy_float_scale)
316 .template dyn_cast<quant::UniformQuantizedType>();
317 }
318 if (!quant_type) {
319 const_op->emitError("Failed to get quantized type");
320 return failure();
321 }
322
323 // TODO(b/172517537): duplicate the constant when the bias is shared.
324 Type expressed_type = const_op->getResult(0).getType();
325 Type cast_type = quant_type.castFromExpressedType(expressed_type);
326 rewriter.setInsertionPointAfter(const_op);
327 auto q = rewriter.create<Q>(const_op->getLoc(), cast_type,
328 const_op->getResult(0));
329 auto dq = rewriter.create<DQ>(const_op->getLoc(), expressed_type, q);
330 op.setOperand(input_index, dq.getResult());
331 return success();
332 }
333
replaceStatsOp(SourceOp op,quantfork::StatisticsOp stats_op,int input_index,const operator_property::TensorProperty & tensor_property,PatternRewriter & rewriter)334 LogicalResult replaceStatsOp(
335 SourceOp op, quantfork::StatisticsOp stats_op, int input_index,
336 const operator_property::TensorProperty& tensor_property,
337 PatternRewriter& rewriter) const {
338 if (tensor_property.state_tensor && !stats_op.getResult().hasOneUse()) {
339 // TODO(b/172517537): check if other tensors should go through this
340 // check too.
341 op.emitError() << "Input tensor [" << input_index
342 << "] is a state tensor, but has more than one use.";
343 return failure();
344 }
345 auto stats = stats_op.getLayerStats().dyn_cast<DenseFPElementsAttr>();
346 if (!stats || stats.getNumElements() != 2) {
347 stats_op.emitError("Stats should have 2 values.");
348 return failure();
349 }
350 quant::QuantizedType quant_type;
351 double min = FloatAttr::getValueAsDouble(stats.getValues<APFloat>()[0]);
352 double max = FloatAttr::getValueAsDouble(stats.getValues<APFloat>()[1]);
353 // Make sure the range includes zero.
354 min = std::min(min, 0.0);
355 max = std::max(max, 0.0);
356 Type expressed = getElementTypeOrSelf(stats_op.getType());
357
358 if (tensor_property.extend_to_power_of_two) {
359 if (tensor_property.number_of_bits != 16) {
360 op.emitError(
361 "extended power of 2 scale is only supported for 16-bit"
362 " quantization.");
363 return failure();
364 }
365
366 double bound = PowerOfTwoBound(std::max(std::abs(min), std::abs(max)));
367 // Set flags to 1 for signed type.
368 quant_type = UniformQuantizedType::getChecked(
369 op.getLoc(), quant::QuantizationFlags::Signed,
370 rewriter.getIntegerType(tensor_property.number_of_bits), expressed,
371 /*scale=*/bound / -llvm::minIntN(tensor_property.number_of_bits),
372 /*zeroPoint=*/0, llvm::minIntN(tensor_property.number_of_bits),
373 llvm::maxIntN(tensor_property.number_of_bits));
374 } else {
375 // int16 uses range [-32767, 32767]
376 if (tensor_property.number_of_bits == 16) {
377 max = std::max(std::abs(min), std::abs(max));
378 min = -max;
379 quant_type = quantfork::fakeQuantAttrsToType(
380 op.getLoc(), tensor_property.number_of_bits, min, max,
381 /*narrowRange=*/true, expressed,
382 /*isSigned=*/true);
383 } else {
384 quant_type = quantfork::fakeQuantAttrsToType(
385 op.getLoc(), tensor_property.number_of_bits, min, max,
386 /*narrowRange=*/false, expressed,
387 /*isSigned=*/true);
388 }
389 if (quant_specs_.legacy_float_scale) {
390 quant_type = quant::DownCastScale(quant_type, min, max, op.getLoc());
391 }
392 }
393 rewriter.setInsertionPointAfter(stats_op);
394 Type result_type = quant_type.castFromExpressedType(stats_op.getType());
395 auto q =
396 rewriter.create<Q>(stats_op.getLoc(), result_type, stats_op.getArg());
397 rewriter.replaceOpWithNewOp<DQ>(stats_op, stats_op.getType(), q);
398 return success();
399 }
400 };
401
402 // Quantize LSTM according to its quantization recipe.
403 template <typename SourceOp>
404 class ConvertLstmStatsToQDQs : public ConvertOpStatsToQDQs<SourceOp> {
405 public:
ConvertLstmStatsToQDQs(MLIRContext * context,const quant::QuantizationSpecs & quant_specs)406 ConvertLstmStatsToQDQs(MLIRContext* context,
407 const quant::QuantizationSpecs& quant_specs)
408
409 : ConvertOpStatsToQDQs<SourceOp>(context, quant_specs) {}
matchAndRewrite(SourceOp op,PatternRewriter & rewriter)410 LogicalResult matchAndRewrite(SourceOp op,
411 PatternRewriter& rewriter) const override {
412 operator_property::OpVariant lstm_variant;
413 operator_property::OperatorProperty lstm_property;
414 if (failed(GetLstmProperty(op, &lstm_variant, &lstm_property))) {
415 return failure();
416 }
417
418 if (failed(processIntermediates(op, lstm_variant, lstm_property)) ||
419 failed(ConvertOpStatsToQDQs<SourceOp>::processInputs(
420 op, lstm_variant, lstm_property, rewriter))) {
421 return failure();
422 }
423
424 return success();
425 }
426
427 private:
processIntermediates(SourceOp op,const operator_property::OpVariant & lstm_variant,const operator_property::OperatorProperty & lstm_property)428 LogicalResult processIntermediates(
429 SourceOp op, const operator_property::OpVariant& lstm_variant,
430 const operator_property::OperatorProperty& lstm_property) const {
431 for (auto& enumerated_intermediates : lstm_property.intermediates) {
432 int index = enumerated_intermediates.first;
433 auto& tensor_property = enumerated_intermediates.second;
434 // intermediate tensors 0, 1, 2, 3 are only used with layer normalization.
435 if (!lstm_variant.use_layer_norm && index != 4) {
436 continue;
437 }
438
439 TypeAttr attr =
440 op->template getAttrOfType<TypeAttr>(intermediate_attributes[index]);
441 auto quant_type = GetIntermediateElementType<SourceOp>(op, index);
442 if (!quant_type) {
443 // intermediate tensor 4 is optional, unless the LSTM uses projection.
444 if (index == 4 && !lstm_variant.use_projection) {
445 return success();
446 }
447 op.emitError() << intermediate_attributes[index]
448 << " is not quantized.";
449 return failure();
450 }
451 auto calibrated_type =
452 quant_type.template dyn_cast<quant::CalibratedQuantizedType>();
453 if (!calibrated_type) {
454 int num_storage_bits = quant_type.getStorageTypeIntegralWidth();
455 if (tensor_property.number_of_bits != num_storage_bits) {
456 op.emitError() << intermediate_attributes[index]
457 << " is expected to be quantized with "
458 << tensor_property.number_of_bits << " bits, but got "
459 << num_storage_bits << " bits instead.";
460 return failure();
461 }
462 continue; // skip if it is already quantized.
463 }
464 quant::UniformQuantizedType qtype;
465 if (tensor_property.number_of_bits == 8) {
466 qtype = quantfork::fakeQuantAttrsToType(
467 op.getLoc(), tensor_property.number_of_bits,
468 calibrated_type.getMin(), calibrated_type.getMax(),
469 /*narrowRange=*/false, calibrated_type.getExpressedType(),
470 /*isSigned=*/this->quant_specs_.IsSignedInferenceType());
471 if (this->quant_specs_.legacy_float_scale) {
472 qtype = quant::DownCastScale(qtype, calibrated_type.getMin(),
473 calibrated_type.getMax(), op.getLoc())
474 .template cast<UniformQuantizedType>();
475 }
476 } else if (tensor_property.number_of_bits == 16) {
477 double max = std::max(std::abs(calibrated_type.getMin()),
478 std::abs(calibrated_type.getMax()));
479 qtype = quantfork::fakeQuantAttrsToType(
480 op.getLoc(), tensor_property.number_of_bits, -max, max,
481 /*narrowRange=*/true, calibrated_type.getExpressedType(),
482 /*isSigned=*/true);
483 } else {
484 op.emitError() << "Unsupported quantization bits: "
485 << tensor_property.number_of_bits;
486 return failure();
487 }
488 op->setAttr(intermediate_attributes[index],
489 TypeAttr::get(qtype.castFromExpressedType(
490 qtype.castToExpressedType(attr.getValue()))));
491 }
492 return success();
493 }
494 };
495
496 // Returns a function that returns the quantized type of a bias input.
497 // The scale of bias is a multiplication of given scale and scales from the
498 // quantization type of other operands.
GetUniformQuantizedTypeForBiasWithScale(double scale)499 inline quant::AccumulatorScaleFunc GetUniformQuantizedTypeForBiasWithScale(
500 double scale) {
501 return [=](const std::vector<quant::QuantParams>& quant_params,
502 bool legacy_float_scale) -> quant::QuantParams {
503 if (auto qtype = quant::GetUniformQuantizedTypeForBias(quant_params,
504 legacy_float_scale)
505 .dyn_cast_or_null<UniformQuantizedType>()) {
506 return quant::UniformQuantizedType::get(
507 qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(),
508 qtype.getScale() * scale, qtype.getZeroPoint(),
509 qtype.getStorageTypeMin(), qtype.getStorageTypeMax());
510 }
511 return {};
512 };
513 }
514
515 // Returns quantization spec for LSTMs based on their operator properties.
516 template <typename LstmOp>
GetLstmOpQuantSpec(LstmOp op)517 std::unique_ptr<quant::OpQuantSpec> GetLstmOpQuantSpec(LstmOp op) {
518 operator_property::OpVariant lstm_variant;
519 operator_property::OperatorProperty lstm_property;
520 if (failed(GetLstmProperty(op, &lstm_variant, &lstm_property))) {
521 return nullptr;
522 }
523
524 auto spec = std::make_unique<quant::OpQuantSpec>();
525
526 for (const auto& enumerated_inputs : lstm_property.inputs) {
527 int index = enumerated_inputs.first;
528 auto& tensor_property = enumerated_inputs.second;
529 if (tensor_property.use_derived_scale) {
530 double scale = 1.0;
531 for (int tensor_index :
532 tensor_property.derived_scale.intermediate_tensors) {
533 auto quant_type = GetIntermediateElementType<LstmOp>(op, tensor_index);
534 if (!quant_type ||
535 !quant_type.template isa<quant::UniformQuantizedType>()) {
536 op->emitError() << "While processing derived scale, intermediate "
537 << intermediate_attributes[tensor_index]
538 << " is not quantized.";
539 return nullptr;
540 }
541 scale *= quant_type.template dyn_cast<quant::UniformQuantizedType>()
542 .getScale();
543 }
544 for (float factor : tensor_property.derived_scale.factors) {
545 scale *= factor;
546 }
547 spec->biases_params.emplace(
548 index,
549 std::make_pair(tensor_property.derived_scale.input_tensors,
550 GetUniformQuantizedTypeForBiasWithScale(scale)));
551 }
552 }
553 return spec;
554 }
555
556 class ConvertSvdfStatsToQDQs : public ConvertOpStatsToQDQs<TFL::SVDFOp> {
557 public:
ConvertSvdfStatsToQDQs(MLIRContext * context,const quant::QuantizationSpecs & quant_specs_param)558 explicit ConvertSvdfStatsToQDQs(
559 MLIRContext* context, const quant::QuantizationSpecs& quant_specs_param)
560 : ConvertOpStatsToQDQs<TFL::SVDFOp>(context, quant_specs_param) {}
matchAndRewrite(TFL::SVDFOp op,PatternRewriter & rewriter)561 LogicalResult matchAndRewrite(TFL::SVDFOp op,
562 PatternRewriter& rewriter) const override {
563 operator_property::OpVariant op_variant;
564 op_variant.op_code = tflite::BuiltinOperator_SVDF;
565 auto op_property = operator_property::GetOperatorProperty(op_variant);
566 return ConvertOpStatsToQDQs<TFL::SVDFOp>::processInputs(
567 op, op_variant, op_property, rewriter);
568 }
569 };
570
571 } // namespace TFL
572 } // namespace mlir
573
574 #endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_HELPER
575