1 /* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_H_ 17 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_H_ 18 19 #include "absl/container/flat_hash_map.h" 20 #include "absl/container/flat_hash_set.h" 21 #include "tensorflow/lite/builtin_ops.h" 22 #include "tensorflow/lite/c/common.h" 23 #include "tensorflow/lite/core/api/op_resolver.h" 24 #include "tensorflow/lite/delegates/gpu/common/model.h" 25 #include "tensorflow/lite/delegates/gpu/common/shape.h" 26 #include "tensorflow/lite/delegates/gpu/common/status.h" 27 #include "tensorflow/lite/delegates/gpu/common/tensor.h" 28 #include "tensorflow/lite/model.h" 29 30 namespace tflite { 31 namespace gpu { 32 33 // Validates which operations are supported and returns array of operations to 34 // replace with GPU kernels. The caller must free the pointer on TfLiteIntArray. 35 // 'max_delegated_partitions' limits the maximum number of partitions to 36 // delegate as a graph could possibly have multiple partitions (each partition 37 // consists of a subset of ops) to be replaced. 38 // 'excluded_ops', if not null, specifies a set of ops that should not be 39 // replaced with GPU kernels. 40 TfLiteIntArray* GetOpsToReplace( 41 TfLiteContext* context, bool allow_quant_ops = false, 42 int max_delegated_partitions = 1, 43 const absl::flat_hash_set<TfLiteBuiltinOperator>* excluded_ops = nullptr); 44 45 // Extracts TFLite delegate execution plan from the input TFLite context and 46 // converts it into generic graph format. 47 // 48 // If model is quantized, quant_conversion_map maps the dequantized tensor 49 // (floating-point) to the original tensor (fixed-point) & vice-versa. 50 // NOTE: Not all of these new tensors will have any data and need memory 51 // allocated for them. We need to do that only for the overall GPU graph inputs 52 // & outputs. This should be done by the delegate, by setting the appropriate 53 // TfLiteNode->temporaries. 54 absl::Status BuildModel( 55 TfLiteContext* context, const TfLiteDelegateParams* delegate_params, 56 GraphFloat32* graph, 57 absl::flat_hash_map<int, int>* quant_conversion_map = nullptr); 58 59 // Same as BuildModel, but enforces user-provided input/output indices instead 60 // of using delegate_params->inputs and delegate_params->outputs for 61 // inputs/outputs preallocating. 62 absl::Status BuildModelEnforceIO( 63 TfLiteContext* context, const TfLiteDelegateParams* delegate_params, 64 const std::vector<int>& input_ids, const std::vector<int>& output_ids, 65 GraphFloat32* graph, 66 absl::flat_hash_map<int, int>* quant_conversion_map = nullptr); 67 68 // Same as above but also apply all transformations on the final graph. 69 // Prefer using this method instead of BuildModel. 70 // 71 // If model is quantized, quant_conversion_map maps the dequantized tensor 72 // (floating-point) to the original TFLite tensor (fixed-point) & vice-versa. 73 // NOTE: Not all of these new tensors will have any data and need memory 74 // allocated for them. We need to do that only for the overall GPU graph inputs 75 // & outputs. This should be done by the delegate, by setting the appropriate 76 // TfLiteNode->temporaries. 77 absl::Status BuildFinalModel( 78 TfLiteContext* context, const TfLiteDelegateParams* delegate_params, 79 GraphFloat32* graph, 80 absl::flat_hash_map<int, int>* quant_conversion_map = nullptr); 81 82 // Convenience wrapper that builds a GraphFloat32 from the provided 83 // FlatBufferModel. 84 absl::Status BuildFromFlatBuffer(const FlatBufferModel& flatbuffer, 85 const OpResolver& op_resolver, 86 GraphFloat32* graph, 87 bool allow_quant_ops = false); 88 89 // Module-internal converter, exposed for unit testing purpose only. 90 absl::Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor, 91 TensorRef<BHWC>* tensor_ref); 92 93 } // namespace gpu 94 } // namespace tflite 95 96 #endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_H_ 97