1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Compatibility layer for calling directly into a TensorFlow kernel via TFRT,
17 // bypassing the existing TensorFlow runtime. This file defines:
18 //
19 // TFRTOpKernel
20 // TFRTOpKernelConstruction
21 // TFRTOpKernelContext
22 //
23 // Note that these are standalone objects that do not share a base class with
24 // TF's corresponding OpKernel, OpKernelConstruction, and OpKernelContext types.
25 // There is no common base class to avoid virtual call overhead. Kernels that
26 // support these fallback types must be templated: see
27 // core/kernels/aggregate_ops.h for an example.
28
29 #ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_TFRT_OP_KERNEL_H_
30 #define TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_TFRT_OP_KERNEL_H_
31
32 #include <string>
33
34 #include "llvm/ADT/Optional.h"
35 #include "llvm/ADT/StringMap.h"
36 #include "llvm/ADT/StringRef.h"
37 #include "llvm/Support/ManagedStatic.h"
38 #include "tensorflow/core/framework/types.h"
39 #include "tensorflow/core/platform/errors.h"
40 #include "tensorflow/core/platform/status.h"
41 #include "tensorflow/core/platform/stringpiece.h"
42 #include "tensorflow/core/runtime_fallback/kernel/attr_util.h"
43 #include "tensorflow/core/runtime_fallback/util/attr_util.h"
44 #include "tfrt/common/compat/eigen/thread_pool_device.h" // from @tf_runtime
45 #include "tfrt/core_runtime/op_attrs.h" // from @tf_runtime
46
47 namespace tfrt {
48 class AsyncKernelFrame;
49 } // namespace tfrt
50
51 namespace tensorflow {
52
53 class Status;
54 class TFRTOpKernel;
55 class TFRTOpMeta;
56 class Tensor;
57 class TensorShape;
58
59 //////////////////////////////////////////////////////////////////////
60 // OpKernel interface.
61 //////////////////////////////////////////////////////////////////////
62 class TFRTOpKernelConstruction {
63 public:
64 explicit TFRTOpKernelConstruction(const tfrt::OpAttrsRef& attributes);
65
66 template <class T>
67 Status GetAttr(StringPiece attr_name, T* value) const;
68
69 void CtxFailure(const Status& s);
70 void CtxFailureWithWarning(const Status& s);
71 void CtxFailure(const char* file, int line, const Status& s);
72 void CtxFailureWithWarning(const char* file, int line, const Status& s);
73
MatchSignature(const DataTypeSlice expected_inputs,const DataTypeSlice expected_outputs)74 Status MatchSignature(const DataTypeSlice expected_inputs,
75 const DataTypeSlice expected_outputs) {
76 // TODO(annarev): Move MatchSignatureHelper out of op_kernel.h
77 // and call it here.
78 return OkStatus();
79 }
80
81 const llvm::Optional<std::string>& error();
82
83 private:
84 const tfrt::OpAttrsRef& attributes_;
85 // If an error occurs, the error message is stored here.
86 llvm::Optional<std::string> error_;
87 };
88
89 template <>
90 Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name,
91 std::string* value) const;
92
93 template <>
94 Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name,
95 DataType* value) const;
96
97 template <>
98 Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name,
99 Padding* value) const;
100
101 template <>
102 Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name,
103 std::vector<int32>* value) const;
104
105 Status MissingAttributeError(StringPiece attr_name);
106
107 template <class T>
GetAttr(StringPiece attr_name,T * value)108 Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name,
109 T* value) const {
110 bool success = attributes_.Get<T>(
111 llvm::StringRef(attr_name.data(), attr_name.size()), value);
112 if (!success) {
113 return MissingAttributeError(attr_name);
114 }
115 return OkStatus();
116 }
117
118 // An implementation of OpKernelContext that fetches inputs from a
119 // tfrt::AsyncKernelFrame. Outputs and errors are stored internally.
120 class TFRTOpKernelContext {
121 public:
122 explicit TFRTOpKernelContext(
123 llvm::ArrayRef<tfrt::RCReference<tfrt::AsyncValue>> inputs,
124 int num_outputs, const TFRTOpMeta* op_meta, tfrt::HostContext* host);
125 const Tensor& output(int index);
126 const llvm::Optional<std::string>& error();
127
128 // OpKernelContext interface implementation.
129 bool ValidateInputsAreSameShape(TFRTOpKernel* op);
130 const Tensor& input(int index);
131 int num_inputs() const;
132 void set_output(int index, const Tensor& tensor);
133 int num_outputs() const;
forward_input_to_output_with_shape(int input_index,int output_index,const TensorShape & output_shape,Tensor ** output)134 bool forward_input_to_output_with_shape(int input_index, int output_index,
135 const TensorShape& output_shape,
136 Tensor** output) {
137 return false;
138 }
139 Status allocate_temp(DataType type, const TensorShape& shape,
140 Tensor* out_temp);
141 Status allocate_output(int index, const TensorShape& shape, Tensor** tensor);
142 DataType expected_output_dtype(int i) const;
143
144 template <typename EigenDeviceType>
145 const EigenDeviceType& eigen_device() const;
146
147 void CtxFailure(const Status& s);
148 void CtxFailureWithWarning(const Status& s);
149 void CtxFailure(const char* file, int line, const Status& s);
150 void CtxFailureWithWarning(const char* file, int line, const Status& s);
151
152 private:
153 llvm::ArrayRef<tfrt::RCReference<tfrt::AsyncValue>> inputs_;
154 const TFRTOpMeta* op_meta_;
155
156 // The kernel's outputs are kept here. We can't directly store outputs in the
157 // AsyncKernelFrame because we must temporarily store allocate_output's Tensor
158 // somewhere until the Tensor is initialized. If we stored the allocated
159 // Tensor directly in the AsyncKernelFrame, the frame's output becomes
160 // available and downstream kernels may use the allocated (but uninitialized)
161 // Tensor.
162 std::vector<Tensor> outputs_;
163
164 // If an error occurs, the error message is stored here.
165 llvm::Optional<std::string> error_;
166
167 tfrt::compat::EigenHostContext eigen_host_context_;
168 };
169
170 class TFRTOpKernel {
171 public:
TFRTOpKernel(TFRTOpKernelConstruction * context)172 explicit TFRTOpKernel(TFRTOpKernelConstruction* context) {}
~TFRTOpKernel()173 virtual ~TFRTOpKernel() {}
174 virtual void Compute(TFRTOpKernelContext* context) = 0;
175 };
176
CheckNotInComputeAsync(TFRTOpKernelConstruction *,const char *)177 inline void CheckNotInComputeAsync(TFRTOpKernelConstruction*, const char*) {}
CheckNotInComputeAsync(TFRTOpKernelContext *,const char *)178 inline void CheckNotInComputeAsync(TFRTOpKernelContext*, const char*) {}
179
180 //////////////////////////////////////////////////////////////////////
181 // Forwarding op metadata.
182 //////////////////////////////////////////////////////////////////////
183
184 // Op metadata. For now TFRTOpMeta only stores the op's output types.
185 class TFRTOpMeta {
186 public:
187 explicit TFRTOpMeta(std::vector<DataType> output_types);
188 DataType output_type(int index) const;
189
190 private:
191 std::vector<DataType> output_types_;
192 };
193
194 // Construct a TFRTOpMeta from .Input(), .Output(), and .Attr()
195 // specifications. This supports the same syntax as TF's REGISTER_OP macro, but
196 // this implementation only supports a subset of the full language.
197 //
198 // Currently, this only supports single-tensor outputs with fixed type.
199 // TODO(lauj) Support attribute outputs and compound attribute types as used by
200 // AddN.
201 class TFRTOpMetaBuilder {
202 public:
203 explicit TFRTOpMetaBuilder(StringPiece op_name);
204 TFRTOpMetaBuilder& Output(StringPiece output_spec);
205 TFRTOpMetaBuilder& Input(StringPiece input_spec);
206 TFRTOpMetaBuilder& Attr(StringPiece attr_spec);
207
208 const string& op_name() const;
209 TFRTOpMeta BuildMeta() const;
210
211 private:
212 string op_name_;
213 std::vector<DataType> output_types_;
214 };
215
216 // Map from op name to TFRTOpMeta.
217 class TFRTOpMetaMap {
218 public:
219 TFRTOpMetaMap();
220 void RegisterOpMeta(const TFRTOpMetaBuilder& op_builder);
221
222 // Returns nullptr if there is no metadata for op_name.
223 const TFRTOpMeta* GetOpMeta(StringPiece op_name) const;
224
225 private:
226 llvm::StringMap<TFRTOpMeta> op_metas_;
227 };
228
229 extern llvm::ManagedStatic<TFRTOpMetaMap> tfrt_forwarding_op_meta_map;
230
231 // Implementation detail for REGISTER_KERNEL_FALLBACK_OP. This helps with
232 // evaluating the .Input()/.Output()/.Attr() clauses in the REGISTER_OP syntax
233 // before calling BuildMeta().
234 class TFRTOpRegisterer {
235 public:
236 TFRTOpRegisterer( // NOLINT(google-explicit-constructor)
237 const TFRTOpMetaBuilder& op_builder);
238 };
239
240 #define REGISTER_KERNEL_FALLBACK_OP(name) \
241 REGISTER_KERNEL_FALLBACK_OP_UNIQ_HELPER(__COUNTER__, name)
242
243 #define REGISTER_KERNEL_FALLBACK_OP_UNIQ_HELPER(ctr, name) \
244 REGISTER_KERNEL_FALLBACK_OP_UNIQ(ctr, name)
245
246 #define REGISTER_KERNEL_FALLBACK_OP_UNIQ(ctr, name) \
247 static TFRTOpRegisterer global_tfrt_forwarding_op_meta_builder_##ctr##_ = \
248 TFRTOpMetaBuilder(name)
249
250 //////////////////////////////////////////////////////////////////////
251 // Forwarding kernel registration.
252 //////////////////////////////////////////////////////////////////////
253
254 // Represents Kernel Fallback kernel registration information.
255 struct TFRTOpKernelReg {
256 using CallbackT =
257 std::unique_ptr<TFRTOpKernel> (*)(TFRTOpKernelConstruction*);
258
TFRTOpKernelRegTFRTOpKernelReg259 explicit TFRTOpKernelReg(CallbackT callback) : callback(callback) {}
260
261 // Callback that creates a kernel.
262 CallbackT callback;
263 // Map from attribute names to type it must match.
264 // For e.g. foo: DT_FLOAT indicates that foo attribute
265 // must be a tfdtype attribute with type float.
266 llvm::StringMap<DataType> type_constraints;
267 };
268
269 class TFRTOpKernelFactories {
270 public:
271 TFRTOpKernelFactories();
272 void RegisterFactory(StringPiece kernel_class_name,
273 TFRTOpKernelReg kernel_info);
274
275 // Creates a kernel with the given name and passes op_kernel_construction
276 // to kernel constructor.
277 // Returns the constructed kernel on success.
278 // In case of failure, returns a nullptr. Kernel creation can fail in one
279 // of the following cases:
280 // 1. Kernel with the given name is not found.
281 // 2. Attributes in op_kernel_construction don't match type constraints
282 // for any of the kernels with this name.
283 // Note that we consider a constraint to be "not matched" if attribute
284 // it applies to is not in op_kernel_construction.
285 std::unique_ptr<TFRTOpKernel> CreateKernel(
286 StringPiece kernel_class_name,
287 TFRTOpKernelConstruction* op_kernel_construction) const;
288
289 private:
290 llvm::StringMap<std::vector<TFRTOpKernelReg>> factories_;
291 };
292
293 // TODO(lauj) Should we move these kernel registrations to tfrt::KernelRegistry?
294 extern llvm::ManagedStatic<TFRTOpKernelFactories>
295 tfrt_forwarding_kernel_factories;
296
297 #define REGISTER_KERNEL_FALLBACK_KERNEL(name, ...) \
298 REGISTER_KERNEL_FALLBACK_KERNEL_UNIQ_HELPER(__COUNTER__, name, __VA_ARGS__)
299
300 #define REGISTER_KERNEL_FALLBACK_KERNEL_UNIQ_HELPER(ctr, name, ...) \
301 REGISTER_KERNEL_FALLBACK_KERNEL_UNIQ(ctr, name, __VA_ARGS__)
302
303 #define REGISTER_KERNEL_FALLBACK_KERNEL_UNIQ(ctr, name, ...) \
304 static bool global_tfrt_forwarding_kernel_##ctr##_registered_ = []() { \
305 ::tensorflow::tfrt_forwarding_kernel_factories->RegisterFactory( \
306 name, TFRTOpKernelReg([](TFRTOpKernelConstruction* construction) \
307 -> std::unique_ptr<TFRTOpKernel> { \
308 return std::make_unique<__VA_ARGS__>(construction); \
309 })); \
310 return true; \
311 }();
312
313 } // namespace tensorflow
314
315 #endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_TFRT_OP_KERNEL_H_
316