1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_LITE_TOOLS_DELEGATES_DELEGATE_PROVIDER_H_
17 #define TENSORFLOW_LITE_TOOLS_DELEGATES_DELEGATE_PROVIDER_H_
18
19 #include <string>
20 #include <vector>
21
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/tools/command_line_flags.h"
24 #include "tensorflow/lite/tools/logging.h"
25 #include "tensorflow/lite/tools/tool_params.h"
26
27 namespace tflite {
28 namespace tools {
29
30 // Same w/ Interpreter::TfLiteDelegatePtr to avoid pulling
31 // tensorflow/lite/interpreter.h dependency
32 using TfLiteDelegatePtr =
33 std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
34
35 class DelegateProvider {
36 public:
~DelegateProvider()37 virtual ~DelegateProvider() {}
38
39 // Create a list of command-line parsable flags based on tool params inside
40 // 'params' whose value will be set to the corresponding runtime flag value.
41 virtual std::vector<Flag> CreateFlags(ToolParams* params) const = 0;
42
43 // Log tool params. If 'verbose' is set to false, the param is going to be
44 // only logged if its value has been set, say via being parsed from
45 // commandline flags.
46 virtual void LogParams(const ToolParams& params, bool verbose) const = 0;
47
48 // Create a TfLiteDelegate based on tool params.
49 virtual TfLiteDelegatePtr CreateTfLiteDelegate(
50 const ToolParams& params) const = 0;
51
52 // Similar to the above, create a TfLiteDelegate based on tool params. If the
53 // same set of tool params could lead to creating multiple TfLite delegates,
54 // also return a relative rank of the delegate that indicates the order of the
55 // returned delegate that should be applied to the TfLite runtime.
56 virtual std::pair<TfLiteDelegatePtr, int> CreateRankedTfLiteDelegate(
57 const ToolParams& params) const = 0;
58
59 virtual std::string GetName() const = 0;
60
DefaultParams()61 const ToolParams& DefaultParams() const { return default_params_; }
62
63 protected:
64 template <typename T>
CreateFlag(const char * name,ToolParams * params,const std::string & usage)65 Flag CreateFlag(const char* name, ToolParams* params,
66 const std::string& usage) const {
67 return Flag(
68 name,
69 [params, name](const T& val, int argv_position) {
70 params->Set<T>(name, val, argv_position);
71 },
72 default_params_.Get<T>(name), usage, Flag::kOptional);
73 }
74 ToolParams default_params_;
75 };
76
77 using DelegateProviderPtr = std::unique_ptr<DelegateProvider>;
78 using DelegateProviderList = std::vector<DelegateProviderPtr>;
79
80 class DelegateProviderRegistrar {
81 public:
82 template <typename T>
83 struct Register {
RegisterRegister84 Register() {
85 auto* const instance = DelegateProviderRegistrar::GetSingleton();
86 instance->providers_.emplace_back(DelegateProviderPtr(new T()));
87 }
88 };
89
GetProviders()90 static const DelegateProviderList& GetProviders() {
91 return GetSingleton()->providers_;
92 }
93
94 private:
DelegateProviderRegistrar()95 DelegateProviderRegistrar() {}
96 DelegateProviderRegistrar(const DelegateProviderRegistrar&) = delete;
97 DelegateProviderRegistrar& operator=(const DelegateProviderRegistrar&) =
98 delete;
99
GetSingleton()100 static DelegateProviderRegistrar* GetSingleton() {
101 static auto* instance = new DelegateProviderRegistrar();
102 return instance;
103 }
104 DelegateProviderList providers_;
105 };
106
107 #define REGISTER_DELEGATE_PROVIDER_VNAME(T) gDelegateProvider_##T##_
108 #define REGISTER_DELEGATE_PROVIDER(T) \
109 static tflite::tools::DelegateProviderRegistrar::Register<T> \
110 REGISTER_DELEGATE_PROVIDER_VNAME(T);
111
112 // Creates a null delegate, useful for cases where no reasonable delegate can be
113 // created.
114 TfLiteDelegatePtr CreateNullDelegate();
115
116 // A global helper function to get all registered delegate providers.
GetRegisteredDelegateProviders()117 inline const DelegateProviderList& GetRegisteredDelegateProviders() {
118 return DelegateProviderRegistrar::GetProviders();
119 }
120
121 // A helper class to create a list of TfLite delegates based on the provided
122 // ToolParams and the global DelegateProviderRegistrar.
123 class ProvidedDelegateList {
124 public:
125 struct ProvidedDelegate {
ProvidedDelegateProvidedDelegate126 ProvidedDelegate()
127 : provider(nullptr), delegate(CreateNullDelegate()), rank(0) {}
128 const DelegateProvider* provider;
129 TfLiteDelegatePtr delegate;
130 int rank;
131 };
132
ProvidedDelegateList()133 ProvidedDelegateList() : ProvidedDelegateList(/*params*/ nullptr) {}
134
135 // 'params' is the ToolParams instance that this class will operate on,
136 // including adding all registered delegate parameters to it etc.
ProvidedDelegateList(ToolParams * params)137 explicit ProvidedDelegateList(ToolParams* params)
138 : providers_(GetRegisteredDelegateProviders()), params_(params) {}
139
providers()140 const DelegateProviderList& providers() const { return providers_; }
141
142 // Add all registered delegate params to the contained 'params_'.
143 void AddAllDelegateParams() const;
144
145 // Append command-line parsable flags to 'flags' of all registered delegate
146 // providers, and associate the flag values at runtime with the contained
147 // 'params_'.
148 void AppendCmdlineFlags(std::vector<Flag>& flags) const;
149
150 // Removes command-line parsable flag 'name' from 'flags'
151 void RemoveCmdlineFlag(std::vector<Flag>& flags,
152 const std::string& name) const;
153
154 // Return a list of TfLite delegates based on the provided 'params', and the
155 // list has been already sorted in ascending order according to the rank of
156 // the particular parameter that enables the creation of the delegate.
157 std::vector<ProvidedDelegate> CreateAllRankedDelegates(
158 const ToolParams& params) const;
159
160 // Similar to the above, the list of TfLite delegates are created based on the
161 // contained 'params_'.
CreateAllRankedDelegates()162 std::vector<ProvidedDelegate> CreateAllRankedDelegates() const {
163 return CreateAllRankedDelegates(*params_);
164 }
165
166 private:
167 const DelegateProviderList& providers_;
168
169 // Represent the set of "ToolParam"s that this helper class will operate on.
170 ToolParams* const params_; // Not own the memory.
171 };
172 } // namespace tools
173 } // namespace tflite
174
175 #endif // TENSORFLOW_LITE_TOOLS_DELEGATES_DELEGATE_PROVIDER_H_
176