xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/model_builder.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_H_
17 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_H_
18 
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/container/flat_hash_set.h"
21 #include "tensorflow/lite/builtin_ops.h"
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/core/api/op_resolver.h"
24 #include "tensorflow/lite/delegates/gpu/common/model.h"
25 #include "tensorflow/lite/delegates/gpu/common/shape.h"
26 #include "tensorflow/lite/delegates/gpu/common/status.h"
27 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
28 #include "tensorflow/lite/model.h"
29 
30 namespace tflite {
31 namespace gpu {
32 
33 // Validates which operations are supported and returns array of operations to
34 // replace with GPU kernels. The caller must free the pointer on TfLiteIntArray.
35 // 'max_delegated_partitions' limits the maximum number of partitions to
36 // delegate as a graph could possibly have multiple partitions (each partition
37 // consists of a subset of ops) to be replaced.
38 // 'excluded_ops', if not null, specifies a set of ops that should not be
39 // replaced with GPU kernels.
40 TfLiteIntArray* GetOpsToReplace(
41     TfLiteContext* context, bool allow_quant_ops = false,
42     int max_delegated_partitions = 1,
43     const absl::flat_hash_set<TfLiteBuiltinOperator>* excluded_ops = nullptr);
44 
45 // Extracts TFLite delegate execution plan from the input TFLite context and
46 // converts it into generic graph format.
47 //
48 // If model is quantized, quant_conversion_map maps the dequantized tensor
49 // (floating-point) to the original tensor (fixed-point) & vice-versa.
50 // NOTE: Not all of these new tensors will have any data and need memory
51 // allocated for them. We need to do that only for the overall GPU graph inputs
52 // & outputs. This should be done by the delegate, by setting the appropriate
53 // TfLiteNode->temporaries.
54 absl::Status BuildModel(
55     TfLiteContext* context, const TfLiteDelegateParams* delegate_params,
56     GraphFloat32* graph,
57     absl::flat_hash_map<int, int>* quant_conversion_map = nullptr);
58 
59 // Same as BuildModel, but enforces user-provided input/output indices instead
60 // of using delegate_params->inputs and delegate_params->outputs for
61 // inputs/outputs preallocating.
62 absl::Status BuildModelEnforceIO(
63     TfLiteContext* context, const TfLiteDelegateParams* delegate_params,
64     const std::vector<int>& input_ids, const std::vector<int>& output_ids,
65     GraphFloat32* graph,
66     absl::flat_hash_map<int, int>* quant_conversion_map = nullptr);
67 
68 // Same as above but also apply all transformations on the final graph.
69 // Prefer using this method instead of BuildModel.
70 //
71 // If model is quantized, quant_conversion_map maps the dequantized tensor
72 // (floating-point) to the original TFLite tensor (fixed-point) & vice-versa.
73 // NOTE: Not all of these new tensors will have any data and need memory
74 // allocated for them. We need to do that only for the overall GPU graph inputs
75 // & outputs. This should be done by the delegate, by setting the appropriate
76 // TfLiteNode->temporaries.
77 absl::Status BuildFinalModel(
78     TfLiteContext* context, const TfLiteDelegateParams* delegate_params,
79     GraphFloat32* graph,
80     absl::flat_hash_map<int, int>* quant_conversion_map = nullptr);
81 
82 // Convenience wrapper that builds a GraphFloat32 from the provided
83 // FlatBufferModel.
84 absl::Status BuildFromFlatBuffer(const FlatBufferModel& flatbuffer,
85                                  const OpResolver& op_resolver,
86                                  GraphFloat32* graph,
87                                  bool allow_quant_ops = false);
88 
89 // Module-internal converter, exposed for unit testing purpose only.
90 absl::Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor,
91                                             TensorRef<BHWC>* tensor_ref);
92 
93 }  // namespace gpu
94 }  // namespace tflite
95 
96 #endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_H_
97