1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_SHIM_OP_KERNEL_H_
16 #define TENSORFLOW_LITE_KERNELS_SHIM_OP_KERNEL_H_
17
18 // This file defines a shim layer on top of TF and TFLite custom op APIs.
19 // The goal is for a custom op to be written once and used for both runtimes
20 //
21 // It consists of two pieces:
22 // * A set of *context* interfaces:
23 // ** InvokeContext, InitContext, ShapeInferenceContext
24 // These are passed on to the custom op implementation to read/write
25 // tensors, etc.
26 //
27 // * An OpKernelShim interface:
28 // This is what a custom op needs to implement. By using that interface the
29 // custom op can then be easily converted to both a TF op kernel and a
30 // TFLite op kernel.
31
32 #include <cstdint>
33 #include <memory>
34 #include <string>
35
36 #include "absl/status/status.h"
37 #include "absl/status/statusor.h"
38 #include "absl/strings/string_view.h"
39 #include "absl/types/variant.h"
40 #include "tensorflow/lite/kernels/shim/shape.h"
41 #include "tensorflow/lite/kernels/shim/tensor_view.h"
42
43 namespace tflite {
44 namespace shim {
45
46 // List of the TF custom op APIs this shim library is abstracting away.
47 //
48 // This enum is used as the template parameter in various places in
49 // order to pick the correct set of types (eg. TfInvokeContext vs.
50 // TfLiteInvokeContext) in the op implementation.
51 enum class Runtime { kTf, kTfLite };
52
53 // TensorView or error
54 using TensorViewOr = absl::StatusOr<std::unique_ptr<TensorView>>;
55 using ConstTensorViewOr = absl::StatusOr<std::unique_ptr<const TensorView>>;
56
57 // Below are the interfaces for various "Context" objects to abstract away the
58 // TF and TFLite differences.
59 //
60 // The interfaces are static and use the CRTP pattern instead of virtual
61 // methods.
62
63 // The attribute dictionary passed to the op
64 using AttrValue = absl::variant<bool, int64_t, float, absl::string_view>;
65
66 // The interface for available methods during an op kernel initialization
67 template <typename SubType>
68 class InitContext {
69 public:
70 // Read the given attribute and populate the given value.
71 template <typename AttrType>
72 absl::Status GetAttr(const std::string& attr_name, AttrType* value) const;
73
74 protected:
75 // Read a given attribute or return error
GetAttr(const std::string & attr_name)76 absl::StatusOr<AttrValue> GetAttr(const std::string& attr_name) const {
77 return static_cast<const SubType&>(*this).GetAttr(attr_name);
78 }
79 };
80
81 // The interface for available methods during an op kernel invocation
82 template <typename SubType>
83 class InvokeContext {
84 public:
85 // Read an input tensor
GetInput(const int idx)86 ConstTensorViewOr GetInput(const int idx) const {
87 return static_cast<const SubType&>(*this).GetInput(idx);
88 }
89 // Get a mutable output tensor
GetOutput(const int idx,const Shape & shape)90 TensorViewOr GetOutput(const int idx, const Shape& shape) const {
91 return static_cast<const SubType&>(*this).GetOutput(idx, shape);
92 }
93 // Number of input tensors
NumInputs()94 int NumInputs() const {
95 return static_cast<const SubType&>(*this).NumInputs();
96 }
97 // Number of output tensors
NumOutputs()98 int NumOutputs() const {
99 return static_cast<const SubType&>(*this).NumOutputs();
100 }
101 };
102
103 // The interface for available methods during shape inference
104 template <typename SubType>
105 class ShapeInferenceContext {
106 public:
107 // Read an input tensor shape
GetInputShape(const int idx)108 ShapeOr GetInputShape(const int idx) const {
109 return static_cast<const SubType&>(*this).GetInputShape(idx);
110 }
111 // Set an output tensor shape
SetOutputShape(const int idx,const Shape & shape)112 absl::Status SetOutputShape(const int idx, const Shape& shape) {
113 return static_cast<SubType&>(*this).SetOutputShape(idx, shape);
114 }
115 // Read an input tensor during shape inference
GetInputTensor(const int idx)116 ConstTensorViewOr GetInputTensor(const int idx) const {
117 return static_cast<const SubType&>(*this).GetInputTensor(idx);
118 }
119 // Number of input tensors
NumInputs()120 int NumInputs() const {
121 return static_cast<const SubType&>(*this).NumInputs();
122 }
123 // Number of output tensors
NumOutputs()124 int NumOutputs() const {
125 return static_cast<const SubType&>(*this).NumOutputs();
126 }
127 // Read the given attribute and populate the given value.
128 template <typename AttrType>
129 absl::Status GetAttr(const std::string& attr_name, AttrType* value) const;
130
131 protected:
132 // Read a given attribute or return error
GetAttr(const std::string & attr_name)133 absl::StatusOr<AttrValue> GetAttr(const std::string& attr_name) const {
134 return static_cast<const SubType&>(*this).GetAttr(attr_name);
135 }
136 };
137
138 // Maps the Runtime to the correct context types.
139 // eg. ContextTypeForRuntime<Runtime::Tf> -->
140 // { TfInitContext, TfInvokeContext, TfShapreInferenceContext }
141 template <Runtime Rt>
142 struct ContextTypeForRuntime {
143 // * Init
144 // * Invoke
145 // * ShapeInference
146 };
147
148 // A Tensorflow operation interface which is then adapted to both TF and TFLite
149 // runtimes.
150 //
151 // Example usage:
152 //
153 // template<Runtime R>
154 // class MyOp : public OpKernelShim<MyOp, R> {
155 //
156 // // Attributes declaration
157 // // (syntax: https://www.tensorflow.org/guide/create_op)
158 // static std::vector<std::string> Attrs();
159 //
160 // // Input tensors declaration
161 // // (syntax: https://www.tensorflow.org/guide/create_op)
162 // static std::vector<std::string> Inputs();
163 //
164 // // Output tensors declaration
165 // // (syntax: https://www.tensorflow.org/guide/create_op)
166 // static std::vector<std::string> Outputs();
167 //
168 // // Initializes the op
169 // absl::Status Init(InitContext* ctx);
170 //
171 // // Runs the operation
172 // absl::Status Invoke(InvokeContext* ctx);
173 //
174 // // Shape inference
175 // static absl::Status ShapeInference(ShapeInferenceContext* ctx);
176 //
177 // };
178 //
179 // WARNING: Experimental interface, subject to change
180 template <template <Runtime> typename SubType, Runtime Rt>
181 class OpKernelShim {
182 public:
183 // Some typedefs for convenience
184 using Shape = ::tflite::shim::Shape;
185 using InitContext =
186 ::tflite::shim::InitContext<typename ContextTypeForRuntime<Rt>::Init>;
187 using InvokeContext =
188 ::tflite::shim::InvokeContext<typename ContextTypeForRuntime<Rt>::Invoke>;
189 using ShapeInferenceContext = ::tflite::shim::ShapeInferenceContext<
190 typename ContextTypeForRuntime<Rt>::ShapeInference>;
191
192 // Needed because the pointer to this class is stored
193 virtual ~OpKernelShim() = default;
194
195 // If the operation has any attributes they are passed here.
Init(InitContext * ctx)196 absl::Status Init(InitContext* ctx) {
197 return static_cast<SubType<Rt>&>(*this).Init(ctx);
198 }
199
200 // The actual computations of the operation
Invoke(InvokeContext * ctx)201 absl::Status Invoke(InvokeContext* ctx) {
202 return static_cast<SubType<Rt>&>(*this).Invoke(ctx);
203 }
204
205 // Shape inference
ShapeInference(ShapeInferenceContext * ctx)206 static absl::Status ShapeInference(ShapeInferenceContext* ctx) {
207 return SubType<Rt>::ShapeInference(ctx);
208 }
209
210 protected:
211 OpKernelShim() = default;
212 };
213
214 /////////////////////// Implementations
215
216 namespace internal {
217 // Extract the given AttrType from the AttrValue variant or returns error.
218 template <typename AttrType>
GetAttr(const std::string & attr_name,const absl::StatusOr<AttrValue> attr_value_or,AttrType * value)219 absl::Status GetAttr(const std::string& attr_name,
220 const absl::StatusOr<AttrValue> attr_value_or,
221 AttrType* value) {
222 if (!attr_value_or.ok()) return attr_value_or.status();
223 const AttrValue& attr_value = attr_value_or.value();
224 if (!absl::holds_alternative<AttrType>(attr_value)) {
225 return absl::InternalError(
226 absl::StrCat("The attribute type does not match the provided "
227 "type: attr_name: ",
228 attr_name));
229 }
230 *value = absl::get<AttrType>(attr_value);
231 return absl::OkStatus();
232 }
233 } // namespace internal
234
235 template <typename SubType>
236 template <typename AttrType>
GetAttr(const std::string & attr_name,AttrType * value)237 absl::Status InitContext<SubType>::GetAttr(const std::string& attr_name,
238 AttrType* value) const {
239 const auto attr_value_or = GetAttr(attr_name);
240 return internal::GetAttr<AttrType>(attr_name, attr_value_or, value);
241 }
242
243 template <typename SubType>
244 template <typename AttrType>
GetAttr(const std::string & attr_name,AttrType * value)245 absl::Status ShapeInferenceContext<SubType>::GetAttr(
246 const std::string& attr_name, AttrType* value) const {
247 const auto attr_value_or = GetAttr(attr_name);
248 return internal::GetAttr<AttrType>(attr_name, attr_value_or, value);
249 }
250
251 } // namespace shim
252 } // namespace tflite
253
254 #endif // TENSORFLOW_LITE_KERNELS_SHIM_ABSTRACT_OP_H_
255