xref: /aosp_15_r20/external/tensorflow/tensorflow/core/runtime_fallback/kernel/tfrt_op_kernel.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Compatibility layer for calling directly into a TensorFlow kernel via TFRT,
17 // bypassing the existing TensorFlow runtime. This file defines:
18 //
19 //   TFRTOpKernel
20 //   TFRTOpKernelConstruction
21 //   TFRTOpKernelContext
22 //
23 // Note that these are standalone objects that do not share a base class with
24 // TF's corresponding OpKernel, OpKernelConstruction, and OpKernelContext types.
25 // There is no common base class to avoid virtual call overhead. Kernels that
26 // support these fallback types must be templated: see
27 // core/kernels/aggregate_ops.h for an example.
28 
29 #ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_TFRT_OP_KERNEL_H_
30 #define TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_TFRT_OP_KERNEL_H_
31 
32 #include <string>
33 
34 #include "llvm/ADT/Optional.h"
35 #include "llvm/ADT/StringMap.h"
36 #include "llvm/ADT/StringRef.h"
37 #include "llvm/Support/ManagedStatic.h"
38 #include "tensorflow/core/framework/types.h"
39 #include "tensorflow/core/platform/errors.h"
40 #include "tensorflow/core/platform/status.h"
41 #include "tensorflow/core/platform/stringpiece.h"
42 #include "tensorflow/core/runtime_fallback/kernel/attr_util.h"
43 #include "tensorflow/core/runtime_fallback/util/attr_util.h"
44 #include "tfrt/common/compat/eigen/thread_pool_device.h"  // from @tf_runtime
45 #include "tfrt/core_runtime/op_attrs.h"  // from @tf_runtime
46 
47 namespace tfrt {
48 class AsyncKernelFrame;
49 }  // namespace tfrt
50 
51 namespace tensorflow {
52 
53 class Status;
54 class TFRTOpKernel;
55 class TFRTOpMeta;
56 class Tensor;
57 class TensorShape;
58 
59 //////////////////////////////////////////////////////////////////////
60 // OpKernel interface.
61 //////////////////////////////////////////////////////////////////////
62 class TFRTOpKernelConstruction {
63  public:
64   explicit TFRTOpKernelConstruction(const tfrt::OpAttrsRef& attributes);
65 
66   template <class T>
67   Status GetAttr(StringPiece attr_name, T* value) const;
68 
69   void CtxFailure(const Status& s);
70   void CtxFailureWithWarning(const Status& s);
71   void CtxFailure(const char* file, int line, const Status& s);
72   void CtxFailureWithWarning(const char* file, int line, const Status& s);
73 
MatchSignature(const DataTypeSlice expected_inputs,const DataTypeSlice expected_outputs)74   Status MatchSignature(const DataTypeSlice expected_inputs,
75                         const DataTypeSlice expected_outputs) {
76     // TODO(annarev): Move MatchSignatureHelper out of op_kernel.h
77     // and call it here.
78     return OkStatus();
79   }
80 
81   const llvm::Optional<std::string>& error();
82 
83  private:
84   const tfrt::OpAttrsRef& attributes_;
85   // If an error occurs, the error message is stored here.
86   llvm::Optional<std::string> error_;
87 };
88 
89 template <>
90 Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name,
91                                          std::string* value) const;
92 
93 template <>
94 Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name,
95                                          DataType* value) const;
96 
97 template <>
98 Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name,
99                                          Padding* value) const;
100 
101 template <>
102 Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name,
103                                          std::vector<int32>* value) const;
104 
105 Status MissingAttributeError(StringPiece attr_name);
106 
107 template <class T>
GetAttr(StringPiece attr_name,T * value)108 Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name,
109                                          T* value) const {
110   bool success = attributes_.Get<T>(
111       llvm::StringRef(attr_name.data(), attr_name.size()), value);
112   if (!success) {
113     return MissingAttributeError(attr_name);
114   }
115   return OkStatus();
116 }
117 
118 // An implementation of OpKernelContext that fetches inputs from a
119 // tfrt::AsyncKernelFrame. Outputs and errors are stored internally.
120 class TFRTOpKernelContext {
121  public:
122   explicit TFRTOpKernelContext(
123       llvm::ArrayRef<tfrt::RCReference<tfrt::AsyncValue>> inputs,
124       int num_outputs, const TFRTOpMeta* op_meta, tfrt::HostContext* host);
125   const Tensor& output(int index);
126   const llvm::Optional<std::string>& error();
127 
128   // OpKernelContext interface implementation.
129   bool ValidateInputsAreSameShape(TFRTOpKernel* op);
130   const Tensor& input(int index);
131   int num_inputs() const;
132   void set_output(int index, const Tensor& tensor);
133   int num_outputs() const;
forward_input_to_output_with_shape(int input_index,int output_index,const TensorShape & output_shape,Tensor ** output)134   bool forward_input_to_output_with_shape(int input_index, int output_index,
135                                           const TensorShape& output_shape,
136                                           Tensor** output) {
137     return false;
138   }
139   Status allocate_temp(DataType type, const TensorShape& shape,
140                        Tensor* out_temp);
141   Status allocate_output(int index, const TensorShape& shape, Tensor** tensor);
142   DataType expected_output_dtype(int i) const;
143 
144   template <typename EigenDeviceType>
145   const EigenDeviceType& eigen_device() const;
146 
147   void CtxFailure(const Status& s);
148   void CtxFailureWithWarning(const Status& s);
149   void CtxFailure(const char* file, int line, const Status& s);
150   void CtxFailureWithWarning(const char* file, int line, const Status& s);
151 
152  private:
153   llvm::ArrayRef<tfrt::RCReference<tfrt::AsyncValue>> inputs_;
154   const TFRTOpMeta* op_meta_;
155 
156   // The kernel's outputs are kept here. We can't directly store outputs in the
157   // AsyncKernelFrame because we must temporarily store allocate_output's Tensor
158   // somewhere until the Tensor is initialized. If we stored the allocated
159   // Tensor directly in the AsyncKernelFrame, the frame's output becomes
160   // available and downstream kernels may use the allocated (but uninitialized)
161   // Tensor.
162   std::vector<Tensor> outputs_;
163 
164   // If an error occurs, the error message is stored here.
165   llvm::Optional<std::string> error_;
166 
167   tfrt::compat::EigenHostContext eigen_host_context_;
168 };
169 
170 class TFRTOpKernel {
171  public:
TFRTOpKernel(TFRTOpKernelConstruction * context)172   explicit TFRTOpKernel(TFRTOpKernelConstruction* context) {}
~TFRTOpKernel()173   virtual ~TFRTOpKernel() {}
174   virtual void Compute(TFRTOpKernelContext* context) = 0;
175 };
176 
CheckNotInComputeAsync(TFRTOpKernelConstruction *,const char *)177 inline void CheckNotInComputeAsync(TFRTOpKernelConstruction*, const char*) {}
CheckNotInComputeAsync(TFRTOpKernelContext *,const char *)178 inline void CheckNotInComputeAsync(TFRTOpKernelContext*, const char*) {}
179 
180 //////////////////////////////////////////////////////////////////////
181 // Forwarding op metadata.
182 //////////////////////////////////////////////////////////////////////
183 
184 // Op metadata. For now TFRTOpMeta only stores the op's output types.
185 class TFRTOpMeta {
186  public:
187   explicit TFRTOpMeta(std::vector<DataType> output_types);
188   DataType output_type(int index) const;
189 
190  private:
191   std::vector<DataType> output_types_;
192 };
193 
194 // Construct a TFRTOpMeta from .Input(), .Output(), and .Attr()
195 // specifications. This supports the same syntax as TF's REGISTER_OP macro, but
196 // this implementation only supports a subset of the full language.
197 //
198 // Currently, this only supports single-tensor outputs with fixed type.
199 // TODO(lauj) Support attribute outputs and compound attribute types as used by
200 // AddN.
201 class TFRTOpMetaBuilder {
202  public:
203   explicit TFRTOpMetaBuilder(StringPiece op_name);
204   TFRTOpMetaBuilder& Output(StringPiece output_spec);
205   TFRTOpMetaBuilder& Input(StringPiece input_spec);
206   TFRTOpMetaBuilder& Attr(StringPiece attr_spec);
207 
208   const string& op_name() const;
209   TFRTOpMeta BuildMeta() const;
210 
211  private:
212   string op_name_;
213   std::vector<DataType> output_types_;
214 };
215 
216 // Map from op name to TFRTOpMeta.
217 class TFRTOpMetaMap {
218  public:
219   TFRTOpMetaMap();
220   void RegisterOpMeta(const TFRTOpMetaBuilder& op_builder);
221 
222   // Returns nullptr if there is no metadata for op_name.
223   const TFRTOpMeta* GetOpMeta(StringPiece op_name) const;
224 
225  private:
226   llvm::StringMap<TFRTOpMeta> op_metas_;
227 };
228 
229 extern llvm::ManagedStatic<TFRTOpMetaMap> tfrt_forwarding_op_meta_map;
230 
231 // Implementation detail for REGISTER_KERNEL_FALLBACK_OP. This helps with
232 // evaluating the .Input()/.Output()/.Attr() clauses in the REGISTER_OP syntax
233 // before calling BuildMeta().
234 class TFRTOpRegisterer {
235  public:
236   TFRTOpRegisterer(  // NOLINT(google-explicit-constructor)
237       const TFRTOpMetaBuilder& op_builder);
238 };
239 
240 #define REGISTER_KERNEL_FALLBACK_OP(name) \
241   REGISTER_KERNEL_FALLBACK_OP_UNIQ_HELPER(__COUNTER__, name)
242 
243 #define REGISTER_KERNEL_FALLBACK_OP_UNIQ_HELPER(ctr, name) \
244   REGISTER_KERNEL_FALLBACK_OP_UNIQ(ctr, name)
245 
246 #define REGISTER_KERNEL_FALLBACK_OP_UNIQ(ctr, name)                         \
247   static TFRTOpRegisterer global_tfrt_forwarding_op_meta_builder_##ctr##_ = \
248       TFRTOpMetaBuilder(name)
249 
250 //////////////////////////////////////////////////////////////////////
251 // Forwarding kernel registration.
252 //////////////////////////////////////////////////////////////////////
253 
254 // Represents Kernel Fallback kernel registration information.
255 struct TFRTOpKernelReg {
256   using CallbackT =
257       std::unique_ptr<TFRTOpKernel> (*)(TFRTOpKernelConstruction*);
258 
TFRTOpKernelRegTFRTOpKernelReg259   explicit TFRTOpKernelReg(CallbackT callback) : callback(callback) {}
260 
261   // Callback that creates a kernel.
262   CallbackT callback;
263   // Map from attribute names to type it must match.
264   // For e.g. foo: DT_FLOAT indicates that foo attribute
265   // must be a tfdtype attribute with type float.
266   llvm::StringMap<DataType> type_constraints;
267 };
268 
269 class TFRTOpKernelFactories {
270  public:
271   TFRTOpKernelFactories();
272   void RegisterFactory(StringPiece kernel_class_name,
273                        TFRTOpKernelReg kernel_info);
274 
275   // Creates a kernel with the given name and passes op_kernel_construction
276   // to kernel constructor.
277   // Returns the constructed kernel on success.
278   // In case of failure, returns a nullptr. Kernel creation can fail in one
279   // of the following cases:
280   //   1. Kernel with the given name is not found.
281   //   2. Attributes in op_kernel_construction don't match type constraints
282   //      for any of the kernels with this name.
283   //      Note that we consider a constraint to be "not matched" if attribute
284   //      it applies to is not in op_kernel_construction.
285   std::unique_ptr<TFRTOpKernel> CreateKernel(
286       StringPiece kernel_class_name,
287       TFRTOpKernelConstruction* op_kernel_construction) const;
288 
289  private:
290   llvm::StringMap<std::vector<TFRTOpKernelReg>> factories_;
291 };
292 
293 // TODO(lauj) Should we move these kernel registrations to tfrt::KernelRegistry?
294 extern llvm::ManagedStatic<TFRTOpKernelFactories>
295     tfrt_forwarding_kernel_factories;
296 
297 #define REGISTER_KERNEL_FALLBACK_KERNEL(name, ...) \
298   REGISTER_KERNEL_FALLBACK_KERNEL_UNIQ_HELPER(__COUNTER__, name, __VA_ARGS__)
299 
300 #define REGISTER_KERNEL_FALLBACK_KERNEL_UNIQ_HELPER(ctr, name, ...) \
301   REGISTER_KERNEL_FALLBACK_KERNEL_UNIQ(ctr, name, __VA_ARGS__)
302 
303 #define REGISTER_KERNEL_FALLBACK_KERNEL_UNIQ(ctr, name, ...)             \
304   static bool global_tfrt_forwarding_kernel_##ctr##_registered_ = []() { \
305     ::tensorflow::tfrt_forwarding_kernel_factories->RegisterFactory(     \
306         name, TFRTOpKernelReg([](TFRTOpKernelConstruction* construction) \
307                                   -> std::unique_ptr<TFRTOpKernel> {     \
308           return std::make_unique<__VA_ARGS__>(construction);            \
309         }));                                                             \
310     return true;                                                         \
311   }();
312 
313 }  // namespace tensorflow
314 
315 #endif  // TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_TFRT_OP_KERNEL_H_
316