xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/tools/delegates/delegate_provider.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_TOOLS_DELEGATES_DELEGATE_PROVIDER_H_
17 #define TENSORFLOW_LITE_TOOLS_DELEGATES_DELEGATE_PROVIDER_H_
18 
19 #include <string>
20 #include <vector>
21 
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/tools/command_line_flags.h"
24 #include "tensorflow/lite/tools/logging.h"
25 #include "tensorflow/lite/tools/tool_params.h"
26 
27 namespace tflite {
28 namespace tools {
29 
30 // Same w/ Interpreter::TfLiteDelegatePtr to avoid pulling
31 // tensorflow/lite/interpreter.h dependency
32 using TfLiteDelegatePtr =
33     std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
34 
35 class DelegateProvider {
36  public:
~DelegateProvider()37   virtual ~DelegateProvider() {}
38 
39   // Create a list of command-line parsable flags based on tool params inside
40   // 'params' whose value will be set to the corresponding runtime flag value.
41   virtual std::vector<Flag> CreateFlags(ToolParams* params) const = 0;
42 
43   // Log tool params. If 'verbose' is set to false, the param is going to be
44   // only logged if its value has been set, say via being parsed from
45   // commandline flags.
46   virtual void LogParams(const ToolParams& params, bool verbose) const = 0;
47 
48   // Create a TfLiteDelegate based on tool params.
49   virtual TfLiteDelegatePtr CreateTfLiteDelegate(
50       const ToolParams& params) const = 0;
51 
52   // Similar to the above, create a TfLiteDelegate based on tool params. If the
53   // same set of tool params could lead to creating multiple TfLite delegates,
54   // also return a relative rank of the delegate that indicates the order of the
55   // returned delegate that should be applied to the TfLite runtime.
56   virtual std::pair<TfLiteDelegatePtr, int> CreateRankedTfLiteDelegate(
57       const ToolParams& params) const = 0;
58 
59   virtual std::string GetName() const = 0;
60 
DefaultParams()61   const ToolParams& DefaultParams() const { return default_params_; }
62 
63  protected:
64   template <typename T>
CreateFlag(const char * name,ToolParams * params,const std::string & usage)65   Flag CreateFlag(const char* name, ToolParams* params,
66                   const std::string& usage) const {
67     return Flag(
68         name,
69         [params, name](const T& val, int argv_position) {
70           params->Set<T>(name, val, argv_position);
71         },
72         default_params_.Get<T>(name), usage, Flag::kOptional);
73   }
74   ToolParams default_params_;
75 };
76 
77 using DelegateProviderPtr = std::unique_ptr<DelegateProvider>;
78 using DelegateProviderList = std::vector<DelegateProviderPtr>;
79 
80 class DelegateProviderRegistrar {
81  public:
82   template <typename T>
83   struct Register {
RegisterRegister84     Register() {
85       auto* const instance = DelegateProviderRegistrar::GetSingleton();
86       instance->providers_.emplace_back(DelegateProviderPtr(new T()));
87     }
88   };
89 
GetProviders()90   static const DelegateProviderList& GetProviders() {
91     return GetSingleton()->providers_;
92   }
93 
94  private:
DelegateProviderRegistrar()95   DelegateProviderRegistrar() {}
96   DelegateProviderRegistrar(const DelegateProviderRegistrar&) = delete;
97   DelegateProviderRegistrar& operator=(const DelegateProviderRegistrar&) =
98       delete;
99 
GetSingleton()100   static DelegateProviderRegistrar* GetSingleton() {
101     static auto* instance = new DelegateProviderRegistrar();
102     return instance;
103   }
104   DelegateProviderList providers_;
105 };
106 
107 #define REGISTER_DELEGATE_PROVIDER_VNAME(T) gDelegateProvider_##T##_
108 #define REGISTER_DELEGATE_PROVIDER(T)                          \
109   static tflite::tools::DelegateProviderRegistrar::Register<T> \
110       REGISTER_DELEGATE_PROVIDER_VNAME(T);
111 
112 // Creates a null delegate, useful for cases where no reasonable delegate can be
113 // created.
114 TfLiteDelegatePtr CreateNullDelegate();
115 
116 // A global helper function to get all registered delegate providers.
GetRegisteredDelegateProviders()117 inline const DelegateProviderList& GetRegisteredDelegateProviders() {
118   return DelegateProviderRegistrar::GetProviders();
119 }
120 
121 // A helper class to create a list of TfLite delegates based on the provided
122 // ToolParams and the global DelegateProviderRegistrar.
123 class ProvidedDelegateList {
124  public:
125   struct ProvidedDelegate {
ProvidedDelegateProvidedDelegate126     ProvidedDelegate()
127         : provider(nullptr), delegate(CreateNullDelegate()), rank(0) {}
128     const DelegateProvider* provider;
129     TfLiteDelegatePtr delegate;
130     int rank;
131   };
132 
ProvidedDelegateList()133   ProvidedDelegateList() : ProvidedDelegateList(/*params*/ nullptr) {}
134 
135   // 'params' is the ToolParams instance that this class will operate on,
136   // including adding all registered delegate parameters to it etc.
ProvidedDelegateList(ToolParams * params)137   explicit ProvidedDelegateList(ToolParams* params)
138       : providers_(GetRegisteredDelegateProviders()), params_(params) {}
139 
providers()140   const DelegateProviderList& providers() const { return providers_; }
141 
142   // Add all registered delegate params to the contained 'params_'.
143   void AddAllDelegateParams() const;
144 
145   // Append command-line parsable flags to 'flags' of all registered delegate
146   // providers, and associate the flag values at runtime with the contained
147   // 'params_'.
148   void AppendCmdlineFlags(std::vector<Flag>& flags) const;
149 
150   // Removes command-line parsable flag 'name' from 'flags'
151   void RemoveCmdlineFlag(std::vector<Flag>& flags,
152                          const std::string& name) const;
153 
154   // Return a list of TfLite delegates based on the provided 'params', and the
155   // list has been already sorted in ascending order according to the rank of
156   // the particular parameter that enables the creation of the delegate.
157   std::vector<ProvidedDelegate> CreateAllRankedDelegates(
158       const ToolParams& params) const;
159 
160   // Similar to the above, the list of TfLite delegates are created based on the
161   // contained 'params_'.
CreateAllRankedDelegates()162   std::vector<ProvidedDelegate> CreateAllRankedDelegates() const {
163     return CreateAllRankedDelegates(*params_);
164   }
165 
166  private:
167   const DelegateProviderList& providers_;
168 
169   // Represent the set of "ToolParam"s that this helper class will operate on.
170   ToolParams* const params_;  // Not own the memory.
171 };
172 }  // namespace tools
173 }  // namespace tflite
174 
175 #endif  // TENSORFLOW_LITE_TOOLS_DELEGATES_DELEGATE_PROVIDER_H_
176