xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <iterator>
21 #include <limits>
22 #include <memory>
23 #include <numeric>
24 #include <string>
25 
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/Support/Casting.h"
29 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
30 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
31 #include "mlir/IR/Attributes.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
34 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
35 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
36 #include "mlir/Support/LLVM.h"  // from @llvm-project
37 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
38 #include "tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h"
39 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
40 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.h"
41 #include "tensorflow/compiler/mlir/lite/quantization/ir/UniformSupport.h"
42 #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
43 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
44 #include "tensorflow/lite/tools/optimize/quantization_utils.h"
45 
46 namespace mlir {
47 
48 // This includes the interface class definition. It couldn't be in a namespace
49 // because the table gen doesn't emit the namespace when it is used.
50 #include "tensorflow/compiler/mlir/lite/quantization/quantization_interface.cc.inc"
51 
52 namespace quant {
53 
54 namespace {
55 constexpr double kSmallestHalfRange = kNearZeroTolerance / 2;
56 using QType = quant::QuantizedType;
57 
58 // This method expands the range to be larger than or equal to 1.0e-6, if it is
59 // very small (< 1.0e-6). This is to prevent very large quantized value by this
60 // range.
ExpandVerySmallRange(ArrayRef<double> mins,ArrayRef<double> maxs,SmallVectorImpl<double> * effective_mins,SmallVectorImpl<double> * effective_maxs)61 void ExpandVerySmallRange(ArrayRef<double> mins, ArrayRef<double> maxs,
62                           SmallVectorImpl<double>* effective_mins,
63                           SmallVectorImpl<double>* effective_maxs) {
64   for (auto arg : llvm::zip(mins, maxs)) {
65     double min = std::get<0>(arg);
66     double max = std::get<1>(arg);
67     // The range is wide, then use the same min/max.
68     if ((max - min) > kNearZeroTolerance) {
69       effective_mins->push_back(min);
70       effective_maxs->push_back(max);
71       continue;
72     }
73 
74     // The range is small. Expands the range to stride 0.0 and also at least
75     // 1.0e-6.
76     effective_mins->push_back(std::min(min, -kSmallestHalfRange));
77     effective_maxs->push_back(std::max(max, kSmallestHalfRange));
78   }
79 }
80 
81 // Set the min / max, scale and zero_points from the fake quant num_bits
82 // attribute from QAT.
ResetMinMaxFromNumBits(QuantizedType type,int num_bits,bool narrow_range,bool is_signed)83 QuantizedType ResetMinMaxFromNumBits(QuantizedType type, int num_bits,
84                                      bool narrow_range, bool is_signed) {
85   if (num_bits >= 8) {
86     return type;
87   }
88   int64_t qmin = QType::getDefaultMinimumForInteger(is_signed, num_bits);
89   int64_t qmax = QType::getDefaultMaximumForInteger(is_signed, num_bits);
90   if (narrow_range) {
91     qmin += 1;
92   }
93   const int64_t storage_type_min = type.getStorageTypeMin();
94   const int64_t storage_type_max = type.getStorageTypeMax();
95   const double rate =
96       static_cast<double>(storage_type_max - storage_type_min) / (qmax - qmin);
97   const auto& recalculate_scale = [&](double scale) -> double {
98     return scale * rate;
99   };
100   const auto& recalculate_zero_point = [&](int64_t zero_point) -> int64_t {
101     return qmax - std::round((storage_type_max - zero_point) / rate);
102   };
103   if (auto q_type = type.dyn_cast<UniformQuantizedType>()) {
104     const double scale = recalculate_scale(q_type.getScale());
105     const double zero_point = recalculate_zero_point(q_type.getZeroPoint());
106     return UniformQuantizedType::get(q_type.getFlags(), q_type.getStorageType(),
107                                      q_type.getExpressedType(), scale,
108                                      zero_point, qmin, qmax);
109   } else if (auto q_type =
110                  type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
111     const int size = q_type.getScales().size();
112     SmallVector<double, 4> scales(size);
113     SmallVector<int64_t, 4> zero_points(size);
114     for (int i = 0; i < size; ++i) {
115       scales[i] = recalculate_scale(q_type.getScales()[i]);
116       zero_points[i] = recalculate_zero_point(q_type.getZeroPoints()[i]);
117     }
118     return quant::UniformQuantizedPerAxisType::get(
119         q_type.getFlags(), q_type.getStorageType(), q_type.getExpressedType(),
120         scales, zero_points, q_type.getQuantizedDimension(), qmin, qmax);
121   } else {
122     llvm_unreachable("Unsupported QuantizedType in ResetMinMaxFromNumBits");
123   }
124   return type;
125 }
126 
127 // Repeats the content of `data` multiple times to resize to `target_size`.
128 // Note that this only broadcast across one dimension.
129 template <typename T>
BroadcastVector(int target_size,SmallVectorImpl<T> & data)130 bool BroadcastVector(int target_size, SmallVectorImpl<T>& data) {
131   int size = data.size();
132   if (size != target_size) {
133     if (target_size % size != 0) return true;
134     data.reserve(target_size);
135     for (int i = 1, e = target_size / size; i != e; ++i) {
136       data.insert(data.end(), data.begin(), data.begin() + size);
137     }
138   }
139   return false;
140 }
141 
142 // Changes the axis of the input per-channel quantized type to match the
143 // dimension of the target type. Returns nullptr if it fails.
ResetAxisAndBroadcast(ArrayRef<int64_t> shape,quant::UniformQuantizedPerAxisType qtype,Type target,int quant_dim)144 quant::UniformQuantizedPerAxisType ResetAxisAndBroadcast(
145     ArrayRef<int64_t> shape, quant::UniformQuantizedPerAxisType qtype,
146     Type target, int quant_dim) {
147   auto shaped = target.dyn_cast<RankedTensorType>();
148   if (!shaped) return {};
149   ArrayRef<int64_t> new_shape = shaped.getShape();
150 
151   SmallVector<double, 4> scales(qtype.getScales().begin(),
152                                 qtype.getScales().end());
153   SmallVector<int64_t, 4> zero_points(qtype.getZeroPoints().begin(),
154                                       qtype.getZeroPoints().end());
155 
156   if (new_shape.size() == shape.size()) {  // same rank
157     // Broadcast the scales and zero points to match the target size, which is
158     // usually the axis-th dimension of the target type. Currently, it covers
159     // two cases:
160     // - for Transpose, the data layout is changed so the `dim[axis]` still
161     // equals to the `scales_size`. The broadcast skips;
162     // - for Reshape, the data layout isn't changed but the innermost dimension
163     // is expand to cover the last two original dimensions. Thus we just need to
164     // be repeated the `scales` dim[2] times to covers the new dim length.
165     //
166     // TODO(b/141709944): after the fix, the `scales` can be for dim[2], thus we
167     // have to repeat each elements in the `scales` locally dim[3] times.
168     if (BroadcastVector<double>(shaped.getDimSize(quant_dim), scales) ||
169         BroadcastVector<int64_t>(shaped.getDimSize(quant_dim), zero_points)) {
170       return {};
171     }
172   } else if ((new_shape.size() == shape.size() + 1) && new_shape.front() == 1) {
173     // Handle the [A, B, C] -> [1, A, B, C] reshape case.
174     if (!(std::equal(shape.begin(), shape.end(), new_shape.begin() + 1) &&
175           quant_dim == new_shape.size() - 1)) {
176       return {};
177     }
178   } else {
179     return {};
180   }
181 
182   return quant::UniformQuantizedPerAxisType::get(
183       qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(),
184       scales, zero_points, quant_dim, qtype.getStorageTypeMin(),
185       qtype.getStorageTypeMax());
186 }
187 
188 }  // namespace
189 
IsOpNotQuantizable(Operation * op)190 bool IsOpNotQuantizable(Operation* op) {
191   // If it is terminator or not quantizable or any ops form the mlir quant
192   // ops dialect, we shouldn't rewrite.
193   bool attr_enforced_quantizable =
194       op->hasAttrOfType<StringAttr>(kQuantTraitAttrName) &&
195       op->getAttrOfType<StringAttr>(kQuantTraitAttrName).getValue().str() ==
196           QuantTraitValues[QuantizationTrait::FullyQuantizable];
197 
198   // Constant ops do not have QuantizableResult attribute but they can deal with
199   // quantized tensors.
200   if (llvm::isa<func::ConstantOp, arith::ConstantOp, quantfork::StatisticsOp>(
201           op))
202     return false;
203 
204   bool prop_enforced_quantizable =
205       op->hasTrait<OpTrait::quant::QuantizableResult>();
206 
207   return op->hasTrait<OpTrait::IsTerminator>() ||
208          llvm::isa<quantfork::QuantizeCastOp, quantfork::DequantizeCastOp>(
209              op) ||
210          (!attr_enforced_quantizable && !prop_enforced_quantizable);
211 }
212 
213 // Returns the quantized type for the
214 // input_type/min/max/storag_type_width/narrow_range.
215 // This is entry point to the Quant dialect and used for both quantizing
216 // activations and weights.
GetQuantizedType(Builder builder,Type input_type,ArrayRef<double> min,ArrayRef<double> max,int quant_dim,int storage_type_width,bool narrow_range,bool is_signed,bool legacy_float_scale,bool use_fake_quant_num_bits)217 Type GetQuantizedType(Builder builder, Type input_type, ArrayRef<double> min,
218                       ArrayRef<double> max, int quant_dim,
219                       int storage_type_width, bool narrow_range, bool is_signed,
220                       bool legacy_float_scale, bool use_fake_quant_num_bits) {
221   auto converter =
222       quantfork::ExpressedToQuantizedConverter::forInputType(input_type);
223 
224   // Expand the range to prevent extremely small scales and large quantized
225   // integers which can cause overflow. This leads to scale
226   // 7.843137254901961e-9 with 8 bits.
227   SmallVector<double, 4> effective_mins, effective_maxs;
228   ExpandVerySmallRange(min, max, &effective_mins, &effective_maxs);
229 
230   quant::QuantizedType quantizedEleType;
231   if (min.size() == 1 && max.size() == 1 && quant_dim == -1) {
232     quantizedEleType = quantfork::fakeQuantAttrsToType(
233         builder.getUnknownLoc(), storage_type_width, effective_mins[0],
234         effective_maxs[0], narrow_range, converter.expressedType, is_signed);
235     if (legacy_float_scale) {
236       quantizedEleType =
237           DownCastScale(quantizedEleType, effective_mins[0], effective_maxs[0],
238                         builder.getUnknownLoc());
239     }
240   } else if (min.size() == max.size()) {
241     auto shape = input_type.dyn_cast<ShapedType>();
242     if (!shape || shape.getRank() <= quant_dim ||
243         static_cast<int64_t>(min.size()) != shape.getDimSize(quant_dim)) {
244       return {};
245     }
246     // The quantization dim is set to the last dimension.
247     quantizedEleType = quantfork::fakeQuantAttrsToType(
248         builder.getUnknownLoc(), storage_type_width, quant_dim, effective_mins,
249         effective_maxs, narrow_range, converter.expressedType, is_signed);
250     if (legacy_float_scale) {
251       quantizedEleType = DownCastScale(quantizedEleType, effective_mins,
252                                        effective_maxs, builder.getUnknownLoc());
253     }
254   }
255   if (!quantizedEleType) return {};
256   // Use fake quant configured bit-widths (only supported for
257   // 1 < num_bits < 8 bits) instead of using 8bit defaults.
258   if (use_fake_quant_num_bits && (storage_type_width > 1) &&
259       (storage_type_width < 8) &&
260       (quantizedEleType.getStorageTypeMax() >
261        QType::getDefaultMinimumForInteger(is_signed, storage_type_width))) {
262     auto resetEleType = ResetMinMaxFromNumBits(
263         quantizedEleType, storage_type_width, narrow_range, is_signed);
264     return converter.convert(resetEleType);
265   }
266   return converter.convert(quantizedEleType);
267 }
268 
269 // TODO(fengliuai): promote this utility method to mlir QuantOps.
RescaleQuantizedType(Type input,Attribute factor)270 TypeAttr RescaleQuantizedType(Type input, Attribute factor) {
271   auto factor_values = factor.dyn_cast_or_null<DenseFPElementsAttr>();
272   if (!factor_values) return {};
273   auto ele_type = quant::QuantizedType::getQuantizedElementType(input);
274   if (!ele_type) return {};
275   if (auto qtype = ele_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
276     ArrayRef<double> scales = qtype.getScales();
277     // Broadcasting hasn't been implemented yet.
278     if (static_cast<int64_t>(scales.size()) != factor_values.getNumElements())
279       return {};
280     SmallVector<double, 4> new_scales;
281     new_scales.reserve(scales.size());
282     auto scales_iter = scales.begin();
283     for (const auto& f : factor_values) {
284       new_scales.push_back(*(scales_iter++) *
285                            std::fabs(FloatAttr::getValueAsDouble(f)));
286     }
287     // We are assuming symmetric quantization.
288     auto new_ele_type = quant::UniformQuantizedPerAxisType::get(
289         qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(),
290         new_scales, qtype.getZeroPoints(), qtype.getQuantizedDimension(),
291         qtype.getStorageTypeMin(), qtype.getStorageTypeMax());
292     if (auto new_type = new_ele_type.castFromExpressedType(
293             quant::QuantizedType::castToExpressedType(input))) {
294       return TypeAttr::get(new_type);
295     }
296   }
297   // Currently, we only support per-axis quantized type.
298   return {};
299 }
300 
GetQuantizedTypeAttr(Builder builder,Type input_type,Attribute min,Attribute max,int quant_dim,IntegerAttr num_bits,BoolAttr narrow_range,bool is_signed,bool legacy_float_scale,bool use_fake_quant_num_bits)301 TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min,
302                               Attribute max, int quant_dim,
303                               IntegerAttr num_bits, BoolAttr narrow_range,
304                               bool is_signed, bool legacy_float_scale,
305                               bool use_fake_quant_num_bits) {
306   SmallVector<double, 4> min_value, max_value;
307   auto mins = min.dyn_cast<DenseFPElementsAttr>();
308   auto maxs = max.dyn_cast<DenseFPElementsAttr>();
309   if (mins && maxs) {
310     min_value.reserve(mins.getNumElements());
311     max_value.reserve(maxs.getNumElements());
312     for (auto it = mins.begin(), e = mins.end(); it != e; ++it) {
313       min_value.push_back(FloatAttr::getValueAsDouble(*it));
314     }
315     for (auto it = maxs.begin(), e = maxs.end(); it != e; ++it) {
316       max_value.push_back(FloatAttr::getValueAsDouble(*it));
317     }
318   } else {
319     auto fmin = min.dyn_cast<FloatAttr>();
320     auto fmax = max.dyn_cast<FloatAttr>();
321     if (fmin && fmax) {
322       min_value.push_back(fmin.getValueAsDouble());
323       max_value.push_back(fmax.getValueAsDouble());
324     } else {
325       return {};
326     }
327   }
328   Type final_type =
329       GetQuantizedType(builder, input_type, min_value, max_value, quant_dim,
330                        num_bits.getInt(), narrow_range.getValue(), is_signed,
331                        legacy_float_scale, use_fake_quant_num_bits);
332   if (!final_type) return {};
333   return TypeAttr::get(final_type);
334 }
335 
CastQuantizedTypeAttrFromExpressedType(Builder builder,TypeAttr source,Type target,int axis)336 TypeAttr CastQuantizedTypeAttrFromExpressedType(Builder builder,
337                                                 TypeAttr source, Type target,
338                                                 int axis) {
339   auto source_type = source.getValue().dyn_cast_or_null<ShapedType>();
340   if (!source_type) return {};
341   auto src_ele_type = source_type.getElementType();
342   auto qtype = src_ele_type.dyn_cast<quant::QuantizedType>();
343 
344   // Reset the quantization dimensions if it is per-axis.
345   if (auto per_axis =
346           qtype.dyn_cast_or_null<quant::UniformQuantizedPerAxisType>()) {
347     // For the pass-through ops, we don't know which the dimension will be the
348     // new quantization dimension. Only if the new quantization dimension can
349     // be inferred, it is safe to reset the per-axis quantized type.
350     if (axis == -1) return {};
351     qtype =
352         ResetAxisAndBroadcast(source_type.getShape(), per_axis, target, axis);
353   }
354   if (!qtype) return {};
355   Type final_type = qtype.castFromExpressedType(target);
356   if (!final_type) return {};
357   return TypeAttr::get(final_type);
358 }
359 
ExtractMinMaxFromAttr(DenseFPElementsAttr values,int dim_size,int slice_size,bool symmetric,SmallVectorImpl<double> & mins,SmallVectorImpl<double> & maxs)360 void ExtractMinMaxFromAttr(DenseFPElementsAttr values, int dim_size,
361                            int slice_size, bool symmetric,
362                            SmallVectorImpl<double>& mins,
363                            SmallVectorImpl<double>& maxs) {
364   // If all the element values are same we don't need to scan the content.
365   if (values.isSplat()) {
366     double single_value =
367         FloatAttr::getValueAsDouble(values.getSplatValue<llvm::APFloat>());
368 
369     // When the single value isn't 0.0, we expand it to a range to include
370     // this single value and 0.0. This will give us a scale and zero point
371     // works for both this value and 0.0.
372     if (single_value < 0.0) {
373       mins[0] = single_value;
374       maxs[0] = symmetric ? -single_value : 0.0;
375     } else if (single_value > 0.0) {
376       mins[0] = symmetric ? -single_value : 0.0;
377       maxs[0] = single_value;
378     } else {
379       mins[0] = maxs[0] = single_value;
380     }
381     for (int i = 1; i < dim_size; ++i) {
382       mins[i] = mins[0];
383       maxs[i] = maxs[0];
384     }
385   } else {
386     int64_t flatten_index = 0;
387     for (auto it = values.begin(), e = values.end(); it != e;
388          ++it, ++flatten_index) {
389       double ele_value = FloatAttr::getValueAsDouble(*it);
390       int slice_index = flatten_index / slice_size;
391       int channel_index = slice_index % dim_size;
392       mins[channel_index] = std::min(mins[channel_index], ele_value);
393       maxs[channel_index] = std::max(maxs[channel_index], ele_value);
394     }
395     // Expand range to include 0.
396     for (int i = 0; i < dim_size; ++i) {
397       maxs[i] = std::max(maxs[i], 0.0);
398       mins[i] = std::min(mins[i], 0.0);
399     }
400     if (symmetric) {
401       for (int i = 0; i < dim_size; ++i) {
402         maxs[i] = std::max(std::abs(mins[i]), std::abs(maxs[i]));
403         mins[i] = -maxs[i];
404       }
405     }
406   }
407 }
408 
GetUniformQuantizedTypeForWeight(ElementsAttr attr,bool symmetric,unsigned num_bits,bool is_signed,bool narrow_range,bool legacy_float_scale,bool use_fake_quant_num_bits)409 Type GetUniformQuantizedTypeForWeight(ElementsAttr attr, bool symmetric,
410                                       unsigned num_bits, bool is_signed,
411                                       bool narrow_range,
412                                       bool legacy_float_scale,
413                                       bool use_fake_quant_num_bits) {
414   Builder builder(attr.getContext());
415   // `symmetric` can only be used when it is `signed` and `narrow_range`.
416   if (symmetric && (!is_signed || !narrow_range)) return {};
417 
418   SmallVector<double, 4> mins(1, std::numeric_limits<double>::max());
419   SmallVector<double, 4> maxs(1, std::numeric_limits<double>::min());
420   auto fp = attr.dyn_cast<DenseFPElementsAttr>();
421   if (!fp) return {};
422 
423   // Computes the effective min/max values of the attribute values.
424   ExtractMinMaxFromAttr(fp, /*dim_size=*/1, /*slice_size=*/1, symmetric, mins,
425                         maxs);
426 
427   auto type =
428       GetQuantizedType(builder, attr.getType(), mins[0], maxs[0],
429                        /*quant_dim=*/-1, num_bits, narrow_range, is_signed,
430                        legacy_float_scale, use_fake_quant_num_bits);
431   if (auto ele_type = type.dyn_cast_or_null<TensorType>())
432     return ele_type.getElementType();
433 
434   return {};
435 }
436 
GetUniformQuantizedPerAxisTypeForWeight(ElementsAttr attr,int quant_dim,bool symmetric,unsigned num_bits,bool is_signed,bool narrow_range,bool legacy_float_scale,bool use_fake_quant_num_bits)437 Type GetUniformQuantizedPerAxisTypeForWeight(ElementsAttr attr, int quant_dim,
438                                              bool symmetric, unsigned num_bits,
439                                              bool is_signed, bool narrow_range,
440                                              bool legacy_float_scale,
441                                              bool use_fake_quant_num_bits) {
442   Builder builder(attr.getContext());
443   auto shape = attr.getType().cast<ShapedType>().getShape();
444   if (static_cast<int>(shape.size()) <= quant_dim) return {};
445   // `symmetric` can only be used when it is `signed` and `narrow_range`.
446   if (symmetric && (!is_signed || !narrow_range)) return {};
447 
448   int dim_size = shape[quant_dim];
449   int slice_size = std::accumulate(std::next(shape.begin(), quant_dim + 1),
450                                    shape.end(), 1, std::multiplies<int64_t>());
451   SmallVector<double, 4> mins(dim_size, std::numeric_limits<double>::max());
452   SmallVector<double, 4> maxs(dim_size, std::numeric_limits<double>::min());
453   auto fp = attr.dyn_cast<DenseFPElementsAttr>();
454   if (!fp) return {};
455 
456   // Computes the effective min/max values of the attribute values.
457   ExtractMinMaxFromAttr(fp, dim_size, slice_size, symmetric, mins, maxs);
458 
459   auto type = GetQuantizedType(builder, attr.getType(), mins, maxs, quant_dim,
460                                num_bits, narrow_range, is_signed,
461                                legacy_float_scale, use_fake_quant_num_bits);
462   if (auto ele_type = type.dyn_cast_or_null<TensorType>())
463     return ele_type.getElementType();
464 
465   return {};
466 }
467 
GetUniformQuantizedTypeForBias(const std::vector<quant::QuantizedType> & op_types,bool legacy_float_scale)468 quant::QuantizedType GetUniformQuantizedTypeForBias(
469     const std::vector<quant::QuantizedType>& op_types,
470     bool legacy_float_scale) {
471   if (op_types.empty()) return {};
472 
473   size_t axis_size = 1;
474   int32_t quant_dim = -1;
475   Type expressed_type;
476   // Requires all the op types are valid UniformQuantizedTypes or
477   // UniformQuantizedPerAxisTypes and also have same expressed type. For all
478   // the UniformQuantizedPerAxisTypes, the quantization dimension index and
479   // dimension sizes are same.
480   for (auto op_type : op_types) {
481     if (!op_type) return {};
482     if (expressed_type && expressed_type != op_type.getExpressedType()) {
483       return {};
484     }
485     expressed_type = op_type.getExpressedType();
486 
487     if (auto type = op_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
488       if ((axis_size != 1 && axis_size != type.getScales().size())) return {};
489       if (quant_dim != -1 && quant_dim != type.getQuantizedDimension())
490         return {};
491       axis_size = type.getScales().size();
492       quant_dim = type.getQuantizedDimension();
493     } else if (!op_type.isa<quant::UniformQuantizedType>()) {
494       return {};
495     }
496   }
497 
498   // The scale from the UniformQuantizedTypes is broadcasted if there are
499   // UniformQuantizedPerAxisTypes.
500   llvm::SmallVector<double, 4> scales(axis_size, 1.0);
501   for (auto op_type : op_types) {
502     if (auto type = op_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
503       for (const auto& index_scale : llvm::enumerate(type.getScales())) {
504         scales[index_scale.index()] *= index_scale.value();
505       }
506     } else if (auto type = op_type.dyn_cast<quant::UniformQuantizedType>()) {
507       for (int index = 0, e = axis_size; index != e; ++index) {
508         scales[index] *= type.getScale();
509       }
510     }
511   }
512   if (legacy_float_scale) {
513     for (int i = 0; i < scales.size(); ++i) {
514       scales[i] = static_cast<float>(scales[i]);
515     }
516   }
517 
518   // Builds the result quantized type, which has signed 32 bits storage type.
519   Builder builder(expressed_type.getContext());
520   IntegerType storage_type = builder.getIntegerType(32);
521   int64_t storage_type_min =
522       quant::QuantizedType::getDefaultMinimumForInteger(/*isSigned=*/true, 32);
523   int64_t storage_type_max =
524       quant::QuantizedType::getDefaultMaximumForInteger(/*isSigned=*/true, 32);
525   if (axis_size == 1) {
526     return quant::UniformQuantizedType::getChecked(
527         builder.getUnknownLoc(),
528         /*flags=*/true, storage_type, expressed_type, scales[0],
529         /*zeroPoint=*/0, storage_type_min, storage_type_max);
530   } else {
531     llvm::SmallVector<int64_t, 4> zero_points(axis_size, 0);
532     // Assume the bias is a 1-D tensor, and set the quantization dim to the last
533     // dimension, which is 0. If the bias rank is larger than 1, this returned
534     // quantized type couldn't be used to quantize the bias.
535     return quant::UniformQuantizedPerAxisType::getChecked(
536         builder.getUnknownLoc(),
537         /*flags=*/true, storage_type, expressed_type, scales, zero_points,
538         /*quantizedDimension=*/0, storage_type_min, storage_type_max);
539   }
540 }
541 
QuantizeLegacy(Attribute real_value,Type tensor_type)542 ElementsAttr QuantizeLegacy(Attribute real_value, Type tensor_type) {
543   if (!real_value.isa<DenseFPElementsAttr>() ||
544       !quant::QuantizedType::getQuantizedElementType(tensor_type)) {
545     return {};
546   }
547   auto real_values_attr = real_value.cast<DenseFPElementsAttr>();
548   auto q_type = quant::QuantizedType::getQuantizedElementType(tensor_type);
549   std::vector<float> real_values;
550   llvm::SmallVector<APInt, 8> quantized_attr;
551   real_values.reserve(real_values_attr.getNumElements());
552   quantized_attr.reserve(real_values_attr.getNumElements());
553   std::transform(real_values_attr.begin(), real_values_attr.end(),
554                  std::back_inserter(real_values), [&](APFloat value) -> float {
555                    return value.convertToFloat();
556                  });
557   ShapedType new_dense_type =
558       q_type.castExpressedToStorageType(real_values_attr.getType())
559           .dyn_cast_or_null<ShapedType>();
560   int width = q_type.getStorageType().dyn_cast<mlir::IntegerType>().getWidth();
561 
562   if (width == 8 && q_type.getStorageTypeMax() == 127 &&
563       q_type.getStorageTypeMin() == -127) {
564     std::vector<int8_t> quantized_values(real_values_attr.getNumElements());
565     if (auto uniform_type = q_type.dyn_cast<UniformQuantizedType>()) {
566       float min, max, scale;
567       tflite::tensor_utils::SymmetricQuantizeFloats(
568           real_values.data(), real_values.size(), quantized_values.data(), &min,
569           &max, &scale);
570       // The scale has been adjusted, so the adjusted scale should be respected.
571       if (std::abs(scale - uniform_type.getScale()) > 1e-3) {
572         return Quantize(real_value, tensor_type);
573       }
574     } else if (auto uniform_type =
575                    q_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
576       std::vector<float> scales_inv;
577       std::vector<int32_t> dimension;
578       dimension.insert(dimension.end(), new_dense_type.getShape().begin(),
579                        new_dense_type.getShape().end());
580       std::transform(uniform_type.getScales().begin(),
581                      uniform_type.getScales().end(),
582                      std::back_inserter(scales_inv),
583                      [](float scale) { return 1.0 / scale; });
584 
585       tflite::optimize::utils::SymmetricPerChannelQuantizeValues(
586           real_values.data(), scales_inv, dimension,
587           uniform_type.getQuantizedDimension(), &quantized_values);
588     } else {
589       return {};
590     }
591     std::transform(quantized_values.begin(), quantized_values.end(),
592                    std::back_inserter(quantized_attr),
593                    [&](int8_t value) -> APInt {
594                      return APInt(8, value, /*isSigned=*/true);
595                    });
596     return DenseElementsAttr::get(new_dense_type, quantized_attr);
597   } else if (width == 8) {
598     // This can be a state tensor, or an actual constant tensor with
599     // asymmetric range. For a state tensor, assigining correct quantization
600     // parameters is sufficient, and for constants with asymmetric range it's
601     // not correctly quantized by legacy quantizer so call the new Quantize.
602     return Quantize(real_value, tensor_type);
603   } else if (width == 16) {
604     if (auto uniform_type = q_type.dyn_cast<UniformQuantizedType>()) {
605       auto quantized_values =
606           tflite::optimize::utils::SymmetricQuantizeFloatsToInt16(
607               real_values.data(), real_values.size(), uniform_type.getScale());
608       std::transform(quantized_values.begin(), quantized_values.end(),
609                      std::back_inserter(quantized_attr),
610                      [&](int16_t value) -> APInt {
611                        return APInt(16, value, /*isSigned=*/true);
612                      });
613       return DenseElementsAttr::get(new_dense_type, quantized_attr);
614     }
615   } else if (width == 32) {
616     std::vector<float> scales;
617     if (auto uniform_type = q_type.dyn_cast<UniformQuantizedType>()) {
618       scales.push_back(uniform_type.getScale());
619     } else if (auto uniform_type =
620                    q_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
621       scales.insert(scales.end(), uniform_type.getScales().begin(),
622                     uniform_type.getScales().end());
623     } else {
624       return {};
625     }
626     auto quantized_bias =
627         tflite::optimize::utils::SymmetricBiasQuantize<std::int32_t>(
628             real_values.data(), real_values.size(), scales);
629     std::transform(quantized_bias.begin(), quantized_bias.end(),
630                    std::back_inserter(quantized_attr),
631                    [&](int32_t value) -> APInt {
632                      return APInt(32, value, /*isSigned=*/true);
633                    });
634     return DenseElementsAttr::get(new_dense_type, quantized_attr);
635   }
636   return {};
637 }
638 
Quantize(Attribute real_value,Type tensor_type)639 ElementsAttr Quantize(Attribute real_value, Type tensor_type) {
640   if (auto q_type =
641           quant::QuantizedType::getQuantizedElementType(tensor_type)) {
642     Type converted_type;
643     return quantfork::quantizeAttr(real_value, q_type, converted_type)
644         .dyn_cast<ElementsAttr>();
645   }
646   return {};
647 }
648 
DownCastScale(QuantizedType type,double min,double max,Location loc)649 quant::QuantizedType DownCastScale(QuantizedType type, double min, double max,
650                                    Location loc) {
651   SmallVector<double, 1> mins = {min};
652   SmallVector<double, 1> maxs = {max};
653   return DownCastScale(type, mins, maxs, loc);
654 }
655 
DownCastScale(QuantizedType type,const SmallVectorImpl<double> & mins,const SmallVectorImpl<double> & maxs,Location loc)656 quant::QuantizedType DownCastScale(QuantizedType type,
657                                    const SmallVectorImpl<double>& mins,
658                                    const SmallVectorImpl<double>& maxs,
659                                    Location loc) {
660   // The given type can be null. For example, there can be an invalid scale and
661   // so on.
662   if (!type) return type;
663   SmallVector<double, 4> scales(mins.size());
664   SmallVector<int64_t, 4> zero_points(mins.size());
665   if (auto q_type = type.dyn_cast<UniformQuantizedType>()) {
666     zero_points.push_back(q_type.getZeroPoint());
667   } else if (auto q_type =
668                  type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
669     zero_points = {q_type.getZeroPoints().begin(),
670                    q_type.getZeroPoints().end()};
671   }
672   for (int i = 0; i < mins.size(); ++i) {
673     scales[i] = (static_cast<float>(maxs[i]) - static_cast<float>(mins[i])) /
674                 (type.getStorageTypeMax() - type.getStorageTypeMin());
675     if (type.getStorageTypeMax() != -type.getStorageTypeMin()) {
676       // Only applies for asymmetric quantized range with original scale.
677       float zero_point_from_min =
678           type.getStorageTypeMin() - mins[i] / scales[i];
679       if (zero_point_from_min < type.getStorageTypeMin()) {
680         zero_points[i] = static_cast<int64_t>(type.getStorageTypeMin());
681       } else if (zero_point_from_min > type.getStorageTypeMax()) {
682         zero_points[i] = static_cast<int64_t>(type.getStorageTypeMax());
683       } else {
684         zero_points[i] = static_cast<int64_t>(std::round(zero_point_from_min));
685       }
686     }
687   }
688   if (auto q_type = type.dyn_cast<UniformQuantizedType>()) {
689     return UniformQuantizedType::get(q_type.getFlags(), q_type.getStorageType(),
690                                      q_type.getExpressedType(), scales[0],
691                                      zero_points[0], q_type.getStorageTypeMin(),
692                                      q_type.getStorageTypeMax());
693   } else if (auto q_type =
694                  type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
695     return quant::UniformQuantizedPerAxisType::get(
696         q_type.getFlags(), q_type.getStorageType(), q_type.getExpressedType(),
697         scales, zero_points, q_type.getQuantizedDimension(),
698         q_type.getStorageTypeMin(), q_type.getStorageTypeMax());
699   }
700   return type;
701 }
702 
703 // A heuristic to determine whether the scales needs to be from operands or
704 // from results for the ops with the `SameOperandsAndResultsScale` property.
705 // The current implementation is based on the number of operands.
PreferResultScale(Operation * op)706 static bool PreferResultScale(Operation* op) {
707   int float_operands = 0;
708   for (auto operand : op->getOperands()) {
709     if (auto operand_type = operand.getType().dyn_cast<ShapedType>()) {
710       if (operand_type.getElementType().isa<FloatType>()) {
711         if (++float_operands > 1) return true;
712       }
713     }
714   }
715   return false;
716 }
717 
GetDefaultQuantScaleSpec(Operation * op)718 std::unique_ptr<OpQuantScaleSpec> GetDefaultQuantScaleSpec(Operation* op) {
719   auto spec = std::make_unique<OpQuantScaleSpec>();
720   if (llvm::isa<SameScalesOpInterface>(op)) {
721     spec->has_same_scale_requirement = true;
722     spec->required_same_scale_func = [op](bool sign, int bit_width) {
723       return llvm::cast<SameScalesOpInterface>(op)
724           .RequiredSameOperandsAndResultsScale(sign, bit_width);
725     };
726     spec->required_same_quantized_axes_func = [op]() {
727       return llvm::cast<SameScalesOpInterface>(op).RequiredSameQuantizedAxes();
728     };
729   }
730   if (llvm::isa<FixedOutputRangeInterface>(op)) {
731     spec->has_fixed_output_range = true;
732     spec->fixed_output_range_func = [op](bool sign, int bit_width) {
733       return llvm::cast<FixedOutputRangeInterface>(op).GetFixedOutputRange(
734           sign, bit_width);
735     };
736   }
737   return spec;
738 }
739 
740 // The stats op of some of the ops can be redundant. The current implementation
741 // only considers the ops with restricted output params.
IsStatsRedundant(Operation * op,OpQuantSpecGetter op_quant_spec_getter,OpQuantScaleSpecGetter op_quant_scale_spec_getter)742 static bool IsStatsRedundant(
743     Operation* op, OpQuantSpecGetter op_quant_spec_getter,
744     OpQuantScaleSpecGetter op_quant_scale_spec_getter) {
745   // If it has FixedOutputRangeInterface, no need to manually create spec.
746   return llvm::isa<FixedOutputRangeInterface>(op) ||
747          op_quant_scale_spec_getter(op)->has_fixed_output_range;
748 }
749 
IsSameScaleOp(Operation * op,OpQuantScaleSpecGetter op_quant_scale_spec_getter)750 static bool IsSameScaleOp(Operation* op,
751                           OpQuantScaleSpecGetter op_quant_scale_spec_getter) {
752   // If it has SameScalesOpInterface, no need to manually create spec.
753   return llvm::dyn_cast<SameScalesOpInterface>(op) ||
754          op_quant_scale_spec_getter(op)->has_same_scale_requirement;
755 }
756 
RemoveRedundantStatsOps(mlir::func::FuncOp func,OpQuantSpecGetter op_quant_spec_getter,OpQuantScaleSpecGetter op_quant_scale_spec_getter)757 bool RemoveRedundantStatsOps(
758     mlir::func::FuncOp func, OpQuantSpecGetter op_quant_spec_getter,
759     OpQuantScaleSpecGetter op_quant_scale_spec_getter) {
760   llvm::SmallVector<quantfork::StatisticsOp, 16> all_stats_ops;
761   llvm::DenseSet<Operation*> redundant_stats_ops;
762 
763   // Step 0: remove the quantfork::StatisticsOp which are used by the
764   // quant.qcast op in case it overrides the information from training FakeQuant
765   // ops.
766   func.walk([&](quantfork::QuantizeCastOp q) {
767     auto input_op = q.getArg().getDefiningOp();
768     if (auto stats =
769             llvm::dyn_cast_or_null<quantfork::StatisticsOp>(input_op)) {
770       q.setOperand(stats.getArg());
771       if (stats.use_empty()) stats.erase();
772     }
773   });
774 
775   // Step 1: forward pass: propagate any value scales which are not produces
776   // by `SameOperandsAndResultsScale`. Additionally, remove the value scales
777   // which are produced by the ops with the `FixedOutputRangeInterface`.
778   // Note that we don't propagate across the multiple-operands
779   // `SameOperandsAndResultsScale` ops like `concatenation`.
780   func.walk([&](quantfork::StatisticsOp stats_op) {
781     all_stats_ops.push_back(stats_op);
782   });
783 
784   while (!all_stats_ops.empty()) {
785     quantfork::StatisticsOp stats_op = all_stats_ops.back();
786     all_stats_ops.pop_back();
787 
788     if (auto def = stats_op.getArg().getDefiningOp()) {
789       if (IsStatsRedundant(def, op_quant_spec_getter,
790                            op_quant_scale_spec_getter)) {
791         redundant_stats_ops.insert(stats_op);
792       }
793     }
794 
795     for (auto user : stats_op.getResult().getUsers()) {
796       // We don't propagate this parameter down if it has multiple operands.
797       // We want to use the result parameter scales instead.
798       if (!IsSameScaleOp(user, op_quant_scale_spec_getter) ||
799           PreferResultScale(user)) {
800         continue;
801       }
802       for (Value res : user->getResults()) {
803         if (!res.hasOneUse()) {
804           continue;
805         }
806         if (auto next_stats = llvm::dyn_cast<quantfork::StatisticsOp>(
807                 *res.getUsers().begin())) {
808           // quantization parameters can be propagated to next_stats
809           redundant_stats_ops.insert(next_stats);
810           // add next_stats to the work list so propagation can continue.
811           all_stats_ops.push_back(next_stats);
812         }
813       }
814     }
815   }
816 
817   // Step 2: backward pass: For the ops skiped in the forward pass, propagate
818   // its results scale backwards as far as possible.
819   func.walk([&](quantfork::StatisticsOp stats_op) {
820     if (redundant_stats_ops.find(stats_op) == redundant_stats_ops.end()) {
821       all_stats_ops.push_back(stats_op);
822     }
823   });
824 
825   while (!all_stats_ops.empty()) {
826     quantfork::StatisticsOp stats_op = all_stats_ops.back();
827     all_stats_ops.pop_back();
828 
829     if (auto def = stats_op.getArg().getDefiningOp()) {
830       if (!IsSameScaleOp(def, op_quant_scale_spec_getter)) {
831         continue;
832       }
833       for (auto input : def->getOperands()) {
834         if (auto next_stats = llvm::dyn_cast_or_null<quantfork::StatisticsOp>(
835                 input.getDefiningOp())) {
836           redundant_stats_ops.insert(next_stats);
837           all_stats_ops.push_back(next_stats);
838         }
839       }
840     }
841   }
842 
843   // Step3: Remove all the redundant stats ops
844   for (auto it : redundant_stats_ops) {
845     if (!llvm::isa<quantfork::StatisticsOp>(it)) return true;
846     auto stats_op = llvm::cast<quantfork::StatisticsOp>(it);
847     stats_op.getResult().replaceAllUsesWith(stats_op.getArg());
848     stats_op.erase();
849   }
850 
851   // Returns false if the steps finish without errors.
852   return false;
853 }
854 
VerifySameScales(Operation * op)855 LogicalResult VerifySameScales(Operation* op) {
856   auto same_scale_op = llvm::cast<SameScalesOpInterface>(op);
857 
858   llvm::SmallVector<QuantizedType, 4> collected_quant_params;
859   for (auto input : op->getOperands()) {
860     auto quant_params = QuantizedType::getQuantizedElementType(input.getType());
861     // Skip non-quantizable operands.
862     if (quant_params) {
863       collected_quant_params.push_back(quant_params);
864     }
865   }
866 
867   for (auto output : op->getResults()) {
868     auto quant_params =
869         QuantizedType::getQuantizedElementType(output.getType());
870     // Skip non-quantizable results.
871     if (quant_params) {
872       collected_quant_params.push_back(quant_params);
873     }
874   }
875 
876   if (collected_quant_params.size() <= 1) return success();
877   const auto& expected_params = collected_quant_params[0];
878   for (int i = 1; i < collected_quant_params.size(); i++) {
879     const auto& compared_params = collected_quant_params[i];
880     // For some ops (such as Transpose or Squeeze), the quantized axis might not
881     // be the same, this function only verifies the scale and zero point in
882     // that case. The quantized axis should be verified in their own verifier
883     // method.
884     if (!same_scale_op.RequiredSameQuantizedAxes()) {
885       auto expected_per_axis_qtype =
886           expected_params.dyn_cast<quant::UniformQuantizedPerAxisType>();
887       auto compared_per_axis_qtype =
888           compared_params.dyn_cast<quant::UniformQuantizedPerAxisType>();
889       if (expected_per_axis_qtype && compared_per_axis_qtype &&
890           llvm::equal(expected_per_axis_qtype.getScales(),
891                       compared_per_axis_qtype.getScales()) &&
892           llvm::equal(expected_per_axis_qtype.getZeroPoints(),
893                       compared_per_axis_qtype.getZeroPoints()) &&
894           expected_params.getStorageType() ==
895               compared_params.getStorageType() &&
896           expected_params.getExpressedType() ==
897               compared_params.getExpressedType()) {
898         continue;
899       }
900     }
901     // Same quantization parameters are always ok.
902     if (expected_params == compared_params) continue;
903     // If the quantization parameters are not the same, as long as it has the
904     // same storage type and the op interface doesn't require same scale
905     // constraint for this storage type, it is still ok.
906     if ((expected_params.isSigned() == compared_params.isSigned() &&
907          expected_params.getStorageTypeIntegralWidth() ==
908              compared_params.getStorageTypeIntegralWidth()) &&
909         !same_scale_op.RequiredSameOperandsAndResultsScale(
910             expected_params.isSigned(),
911             expected_params.getStorageTypeIntegralWidth()))
912       continue;
913 
914     std::string err_msg =
915         "quantization parameters violate the same scale constraint: ";
916     llvm::raw_string_ostream os(err_msg);
917     expected_params.print(os);
918     os << " vs. ";
919     compared_params.print(os);
920     os.flush();
921     return op->emitOpError(err_msg);
922   }
923   return success();
924 }
925 
GetFixedOutputRange(bool is_signed,int bit_width,Type tensor_type,double scale,int64_t zero_point,int64_t storage_min,int64_t storage_max)926 quant::UniformQuantizedType GetFixedOutputRange(bool is_signed, int bit_width,
927                                                 Type tensor_type, double scale,
928                                                 int64_t zero_point,
929                                                 int64_t storage_min,
930                                                 int64_t storage_max) {
931   auto result_type = tensor_type.cast<ShapedType>();
932   if (!result_type.getElementType().isa<FloatType>()) return {};
933   Builder builder(result_type.getContext());
934 
935   // Only support 8-bits
936   if (bit_width != 8) return {};
937   IntegerType storage_type = builder.getIntegerType(bit_width);
938   if (!is_signed) {
939     zero_point += 128;
940     storage_min += 128;
941     storage_max += 128;
942   }
943   return quant::UniformQuantizedType::getChecked(
944       builder.getUnknownLoc(), is_signed, storage_type,
945       result_type.getElementType(), scale, zero_point, storage_min,
946       storage_max);
947 }
948 
ConvertSignedQuantizedToUnsigned(Type signed_tensor_type,Location loc)949 Type ConvertSignedQuantizedToUnsigned(Type signed_tensor_type, Location loc) {
950   auto qtype = QType::getQuantizedElementType(signed_tensor_type);
951   if (!qtype || !qtype.isSigned()) return {};
952 
953   int num_bits = qtype.getStorageTypeIntegralWidth();
954   // This is a negative value, and will be applied on zero points and fixed
955   // point ranges.
956   int64_t offset =
957       QType::getDefaultMinimumForInteger(/*isSigned=*/true, num_bits) -
958       QType::getDefaultMinimumForInteger(/*isSigned=*/false, num_bits);
959 
960   auto flags = !quant::QuantizationFlags::Signed;
961   QType new_qtype;
962   if (auto uqtype = qtype.dyn_cast<quant::UniformQuantizedType>()) {
963     new_qtype = quant::UniformQuantizedType::getChecked(
964         loc, flags, qtype.getStorageType(), qtype.getExpressedType(),
965         uqtype.getScale(), uqtype.getZeroPoint() - offset,
966         uqtype.getStorageTypeMin() - offset,
967         uqtype.getStorageTypeMax() - offset);
968   } else if (auto aqtype =
969                  qtype.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
970     auto zero_points = aqtype.getZeroPoints();
971     llvm::SmallVector<int64_t, 4> new_zero_points(zero_points.begin(),
972                                                   zero_points.end());
973     for (int i = 0, e = new_zero_points.size(); i != e; ++i) {
974       new_zero_points[i] -= offset;
975     }
976     new_qtype = quant::UniformQuantizedPerAxisType::getChecked(
977         loc, flags, qtype.getStorageType(), qtype.getExpressedType(),
978         aqtype.getScales(), new_zero_points, aqtype.getQuantizedDimension(),
979         aqtype.getStorageTypeMin() - offset,
980         aqtype.getStorageTypeMax() - offset);
981   }
982   return new_qtype.castFromExpressedType(
983       QType::castToExpressedType(signed_tensor_type));
984 }
985 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const986 LogicalResult RemoveDebugAttrPattern::matchAndRewrite(
987     Operation* op, PatternRewriter& rewriter) const {
988   // removeAttr will return nullptr if the attribute did not exist. Thus we can
989   // return success(result) to indicate if this op has changed.
990   return success(/*isSuccess=*/
991                  op->removeAttr(kDebugModeOpQuantAttrName) ||
992                  op->removeAttr(kDebugModeOpFloatAttrName));
993 }
994 
995 }  // namespace quant
996 }  // namespace mlir
997