xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_
16 #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_
17 
18 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
19 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
20 
21 namespace mlir {
22 namespace TFL {
23 
24 // Converts all the tfl.quantize/tfl.dequantize ops to the ops in the mlir.quant
25 // dialect ones in the function.
26 void ConvertTFLQuantOpsToMlirQuantOps(func::FuncOp func);
27 
28 // Converts all the mlir.quant dialect ops to the tfl.quantize/tfl.dequantize
29 // ops in the function.
30 void ConvertMlirQuantOpsToTFLQuantOps(func::FuncOp func);
31 
32 // A helper class to convert target function to another representation using
33 // `ConvertForward` function during construction and convert target function
34 // back to the original representation using `ConvertBackward` function during
35 // deconstruction.
36 template <void (*ConvertForward)(func::FuncOp),
37           void (*ConvertBackward)(func::FuncOp)>
38 class ScopedOpsConverter {
39  public:
ScopedOpsConverter(func::FuncOp func)40   explicit ScopedOpsConverter(func::FuncOp func) : func_(func) {
41     ConvertForward(func_);
42   }
43 
44   ScopedOpsConverter(const ScopedOpsConverter&) = delete;
45   ScopedOpsConverter operator=(const ScopedOpsConverter&) = delete;
46   ScopedOpsConverter(const ScopedOpsConverter&&) = delete;
47   ScopedOpsConverter operator=(const ScopedOpsConverter&&) = delete;
48 
~ScopedOpsConverter()49   ~ScopedOpsConverter() { ConvertBackward(func_); }
50 
51  private:
52   func::FuncOp func_;
53 };
54 
55 using ScopedTFLQuantOpsToMlirQuantOpsConverter =
56     ScopedOpsConverter<ConvertTFLQuantOpsToMlirQuantOps,
57                        ConvertMlirQuantOpsToTFLQuantOps>;
58 using ScopedMlirQuantOpsToTFLQuantOpsConverter =
59     ScopedOpsConverter<ConvertMlirQuantOpsToTFLQuantOps,
60                        ConvertTFLQuantOpsToMlirQuantOps>;
61 }  // namespace TFL
62 }  // namespace mlir
63 
64 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_
65