xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/shim/op_kernel.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_SHIM_OP_KERNEL_H_
16 #define TENSORFLOW_LITE_KERNELS_SHIM_OP_KERNEL_H_
17 
18 // This file defines a shim layer on top of TF and TFLite custom op APIs.
19 // The goal is for a custom op to be written once and used for both runtimes
20 //
21 // It consists of two pieces:
22 //   * A set of *context* interfaces:
23 //     ** InvokeContext, InitContext, ShapeInferenceContext
24 //     These are passed on to the custom op implementation to read/write
25 //     tensors, etc.
26 //
27 //   * An OpKernelShim interface:
28 //     This is what a custom op needs to implement. By using that interface the
29 //     custom op can then be easily converted to both a TF op kernel and a
30 //     TFLite op kernel.
31 
32 #include <cstdint>
33 #include <memory>
34 #include <string>
35 
36 #include "absl/status/status.h"
37 #include "absl/status/statusor.h"
38 #include "absl/strings/string_view.h"
39 #include "absl/types/variant.h"
40 #include "tensorflow/lite/kernels/shim/shape.h"
41 #include "tensorflow/lite/kernels/shim/tensor_view.h"
42 
43 namespace tflite {
44 namespace shim {
45 
46 // List of the TF custom op APIs this shim library is abstracting away.
47 //
48 // This enum is used as the template parameter in various places in
49 // order to pick the correct set of types (eg. TfInvokeContext vs.
50 // TfLiteInvokeContext) in the op implementation.
51 enum class Runtime { kTf, kTfLite };
52 
53 // TensorView or error
54 using TensorViewOr = absl::StatusOr<std::unique_ptr<TensorView>>;
55 using ConstTensorViewOr = absl::StatusOr<std::unique_ptr<const TensorView>>;
56 
57 // Below are the interfaces for various "Context" objects to abstract away the
58 // TF and TFLite differences.
59 //
60 // The interfaces are static and use the CRTP pattern instead of virtual
61 // methods.
62 
63 // The attribute dictionary passed to the op
64 using AttrValue = absl::variant<bool, int64_t, float, absl::string_view>;
65 
66 // The interface for available methods during an op kernel initialization
67 template <typename SubType>
68 class InitContext {
69  public:
70   // Read the given attribute and populate the given value.
71   template <typename AttrType>
72   absl::Status GetAttr(const std::string& attr_name, AttrType* value) const;
73 
74  protected:
75   // Read a given attribute or return error
GetAttr(const std::string & attr_name)76   absl::StatusOr<AttrValue> GetAttr(const std::string& attr_name) const {
77     return static_cast<const SubType&>(*this).GetAttr(attr_name);
78   }
79 };
80 
81 // The interface for available methods during an op kernel invocation
82 template <typename SubType>
83 class InvokeContext {
84  public:
85   // Read an input tensor
GetInput(const int idx)86   ConstTensorViewOr GetInput(const int idx) const {
87     return static_cast<const SubType&>(*this).GetInput(idx);
88   }
89   // Get a mutable output tensor
GetOutput(const int idx,const Shape & shape)90   TensorViewOr GetOutput(const int idx, const Shape& shape) const {
91     return static_cast<const SubType&>(*this).GetOutput(idx, shape);
92   }
93   // Number of input tensors
NumInputs()94   int NumInputs() const {
95     return static_cast<const SubType&>(*this).NumInputs();
96   }
97   // Number of output tensors
NumOutputs()98   int NumOutputs() const {
99     return static_cast<const SubType&>(*this).NumOutputs();
100   }
101 };
102 
103 // The interface for available methods during shape inference
104 template <typename SubType>
105 class ShapeInferenceContext {
106  public:
107   // Read an input tensor shape
GetInputShape(const int idx)108   ShapeOr GetInputShape(const int idx) const {
109     return static_cast<const SubType&>(*this).GetInputShape(idx);
110   }
111   // Set an output tensor shape
SetOutputShape(const int idx,const Shape & shape)112   absl::Status SetOutputShape(const int idx, const Shape& shape) {
113     return static_cast<SubType&>(*this).SetOutputShape(idx, shape);
114   }
115   // Read an input tensor during shape inference
GetInputTensor(const int idx)116   ConstTensorViewOr GetInputTensor(const int idx) const {
117     return static_cast<const SubType&>(*this).GetInputTensor(idx);
118   }
119   // Number of input tensors
NumInputs()120   int NumInputs() const {
121     return static_cast<const SubType&>(*this).NumInputs();
122   }
123   // Number of output tensors
NumOutputs()124   int NumOutputs() const {
125     return static_cast<const SubType&>(*this).NumOutputs();
126   }
127   // Read the given attribute and populate the given value.
128   template <typename AttrType>
129   absl::Status GetAttr(const std::string& attr_name, AttrType* value) const;
130 
131  protected:
132   // Read a given attribute or return error
GetAttr(const std::string & attr_name)133   absl::StatusOr<AttrValue> GetAttr(const std::string& attr_name) const {
134     return static_cast<const SubType&>(*this).GetAttr(attr_name);
135   }
136 };
137 
138 // Maps the Runtime to the correct context types.
139 // eg. ContextTypeForRuntime<Runtime::Tf>  -->
140 //       { TfInitContext, TfInvokeContext, TfShapreInferenceContext }
141 template <Runtime Rt>
142 struct ContextTypeForRuntime {
143   // * Init
144   // * Invoke
145   // * ShapeInference
146 };
147 
148 // A Tensorflow operation interface which is then adapted to both TF and TFLite
149 // runtimes.
150 //
151 // Example usage:
152 //
153 //   template<Runtime R>
154 //   class MyOp : public OpKernelShim<MyOp, R> {
155 //
156 //     // Attributes declaration
157 //     // (syntax: https://www.tensorflow.org/guide/create_op)
158 //     static std::vector<std::string> Attrs();
159 //
160 //     // Input tensors declaration
161 //     // (syntax: https://www.tensorflow.org/guide/create_op)
162 //     static std::vector<std::string> Inputs();
163 //
164 //     // Output tensors declaration
165 //     // (syntax: https://www.tensorflow.org/guide/create_op)
166 //     static std::vector<std::string> Outputs();
167 //
168 //     // Initializes the op
169 //     absl::Status Init(InitContext* ctx);
170 //
171 //     // Runs the operation
172 //     absl::Status Invoke(InvokeContext* ctx);
173 //
174 //     // Shape inference
175 //     static absl::Status ShapeInference(ShapeInferenceContext* ctx);
176 //
177 //   };
178 //
179 // WARNING: Experimental interface, subject to change
180 template <template <Runtime> typename SubType, Runtime Rt>
181 class OpKernelShim {
182  public:
183   // Some typedefs for convenience
184   using Shape = ::tflite::shim::Shape;
185   using InitContext =
186       ::tflite::shim::InitContext<typename ContextTypeForRuntime<Rt>::Init>;
187   using InvokeContext =
188       ::tflite::shim::InvokeContext<typename ContextTypeForRuntime<Rt>::Invoke>;
189   using ShapeInferenceContext = ::tflite::shim::ShapeInferenceContext<
190       typename ContextTypeForRuntime<Rt>::ShapeInference>;
191 
192   // Needed because the pointer to this class is stored
193   virtual ~OpKernelShim() = default;
194 
195   // If the operation has any attributes they are passed here.
Init(InitContext * ctx)196   absl::Status Init(InitContext* ctx) {
197     return static_cast<SubType<Rt>&>(*this).Init(ctx);
198   }
199 
200   // The actual computations of the operation
Invoke(InvokeContext * ctx)201   absl::Status Invoke(InvokeContext* ctx) {
202     return static_cast<SubType<Rt>&>(*this).Invoke(ctx);
203   }
204 
205   // Shape inference
ShapeInference(ShapeInferenceContext * ctx)206   static absl::Status ShapeInference(ShapeInferenceContext* ctx) {
207     return SubType<Rt>::ShapeInference(ctx);
208   }
209 
210  protected:
211   OpKernelShim() = default;
212 };
213 
214 /////////////////////// Implementations
215 
216 namespace internal {
217 // Extract the given AttrType from the AttrValue variant or returns error.
218 template <typename AttrType>
GetAttr(const std::string & attr_name,const absl::StatusOr<AttrValue> attr_value_or,AttrType * value)219 absl::Status GetAttr(const std::string& attr_name,
220                      const absl::StatusOr<AttrValue> attr_value_or,
221                      AttrType* value) {
222   if (!attr_value_or.ok()) return attr_value_or.status();
223   const AttrValue& attr_value = attr_value_or.value();
224   if (!absl::holds_alternative<AttrType>(attr_value)) {
225     return absl::InternalError(
226         absl::StrCat("The attribute type does not match the provided "
227                      "type: attr_name: ",
228                      attr_name));
229   }
230   *value = absl::get<AttrType>(attr_value);
231   return absl::OkStatus();
232 }
233 }  // namespace internal
234 
235 template <typename SubType>
236 template <typename AttrType>
GetAttr(const std::string & attr_name,AttrType * value)237 absl::Status InitContext<SubType>::GetAttr(const std::string& attr_name,
238                                            AttrType* value) const {
239   const auto attr_value_or = GetAttr(attr_name);
240   return internal::GetAttr<AttrType>(attr_name, attr_value_or, value);
241 }
242 
243 template <typename SubType>
244 template <typename AttrType>
GetAttr(const std::string & attr_name,AttrType * value)245 absl::Status ShapeInferenceContext<SubType>::GetAttr(
246     const std::string& attr_name, AttrType* value) const {
247   const auto attr_value_or = GetAttr(attr_name);
248   return internal::GetAttr<AttrType>(attr_name, attr_value_or, value);
249 }
250 
251 }  // namespace shim
252 }  // namespace tflite
253 
254 #endif  // TENSORFLOW_LITE_KERNELS_SHIM_ABSTRACT_OP_H_
255