xref: /aosp_15_r20/external/tensorflow/tensorflow/core/runtime_fallback/kernel/tfrt_op_kernel.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/runtime_fallback/kernel/tfrt_op_kernel.h"
16 
17 #include "absl/strings/str_split.h"
18 #include "absl/strings/strip.h"
19 #include "llvm/Support/raw_ostream.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/framework/types.h"
23 #include "tensorflow/core/platform/status.h"
24 #include "tensorflow/core/runtime_fallback/kernel/attr_util.h"
25 #include "tensorflow/core/tfrt/utils/error_util.h"
26 #include "tfrt/host_context/async_value.h"  // from @tf_runtime
27 #include "tfrt/host_context/kernel_frame.h"  // from @tf_runtime
28 
29 namespace tensorflow {
30 
31 //////////////////////////////////////////////////////////////////////
32 // OpKernel interface.
33 //////////////////////////////////////////////////////////////////////
TFRTOpKernelConstruction(const tfrt::OpAttrsRef & attributes)34 TFRTOpKernelConstruction::TFRTOpKernelConstruction(
35     const tfrt::OpAttrsRef& attributes)
36     : attributes_(std::move(attributes)) {}
37 
MissingAttributeError(StringPiece attr_name)38 Status MissingAttributeError(StringPiece attr_name) {
39   return errors::InvalidArgument("Missing attribute: ", attr_name);
40 }
41 
42 template <>
GetAttr(StringPiece attr_name,std::string * value) const43 Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name,
44                                          std::string* value) const {
45   tfrt::string_view view;
46   bool success = attributes_.GetString(
47       llvm::StringRef(attr_name.data(), attr_name.size()), &view);
48   if (!success) {
49     return MissingAttributeError(attr_name);
50   }
51   *value = view.str();
52   return OkStatus();
53 }
54 
55 template <>
GetAttr(StringPiece attr_name,DataType * value) const56 Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name,
57                                          DataType* value) const {
58   tfrt::OpAttrType attrtype;
59   bool success = attributes_.Get<tfrt::OpAttrType>(
60       llvm::StringRef(attr_name.data(), attr_name.size()), &attrtype);
61   if (!success) {
62     return MissingAttributeError(attr_name);
63   }
64   *value = tfd::ConvertToTfDataType(attrtype);
65   return OkStatus();
66 }
67 
68 template <>
GetAttr(StringPiece attr_name,Padding * value) const69 Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name,
70                                          Padding* value) const {
71   std::string padding_str;
72   TF_RETURN_IF_ERROR(GetAttr<std::string>(attr_name, &padding_str));
73   return GetPaddingFromString(padding_str, value);
74 }
75 
76 template <>
GetAttr(StringPiece attr_name,std::vector<int32> * value) const77 Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name,
78                                          std::vector<int32>* value) const {
79   llvm::ArrayRef<int32> arrayref;
80   bool success = attributes_.GetArray<int32>(
81       llvm::StringRef(attr_name.data(), attr_name.size()), &arrayref);
82   if (!success) {
83     return MissingAttributeError(attr_name);
84   }
85   *value = arrayref;
86   return OkStatus();
87 }
88 
CtxFailure(const Status & s)89 void TFRTOpKernelConstruction::CtxFailure(const Status& s) {
90   error_ = tfrt::MakeStatusString(s);
91 }
92 
CtxFailureWithWarning(const Status & s)93 void TFRTOpKernelConstruction::CtxFailureWithWarning(const Status& s) {
94   CtxFailure(s);
95 }
96 
97 namespace {
FillFailureMessage(const char * file,int line,const Status & s)98 std::string FillFailureMessage(const char* file, int line, const Status& s) {
99   std::string error;
100   llvm::raw_string_ostream sstr(error);
101   sstr << "OP_REQUIRES failed at " << file << ":" << line << " : "
102        << tfrt::MakeStatusString(s);
103   sstr.str();
104   return error;
105 }
106 }  // namespace
107 
CtxFailure(const char * file,int line,const Status & s)108 void TFRTOpKernelConstruction::CtxFailure(const char* file, int line,
109                                           const Status& s) {
110   error_ = FillFailureMessage(file, line, s);
111 }
112 
CtxFailureWithWarning(const char * file,int line,const Status & s)113 void TFRTOpKernelConstruction::CtxFailureWithWarning(const char* file, int line,
114                                                      const Status& s) {
115   CtxFailure(file, line, s);
116 }
117 
error()118 const llvm::Optional<std::string>& TFRTOpKernelConstruction::error() {
119   return error_;
120 }
121 
TFRTOpKernelContext(llvm::ArrayRef<tfrt::RCReference<tfrt::AsyncValue>> inputs,int num_outputs,const TFRTOpMeta * op_meta,tfrt::HostContext * host)122 TFRTOpKernelContext::TFRTOpKernelContext(
123     llvm::ArrayRef<tfrt::RCReference<tfrt::AsyncValue>> inputs, int num_outputs,
124     const TFRTOpMeta* op_meta, tfrt::HostContext* host)
125     : inputs_(inputs),
126       op_meta_(op_meta),
127       outputs_(num_outputs),
128       eigen_host_context_(host) {}
129 
output(int index)130 const Tensor& TFRTOpKernelContext::output(int index) { return outputs_[index]; }
131 
error()132 const llvm::Optional<std::string>& TFRTOpKernelContext::error() {
133   return error_;
134 }
135 
ValidateInputsAreSameShape(TFRTOpKernel * op)136 bool TFRTOpKernelContext::ValidateInputsAreSameShape(TFRTOpKernel* op) {
137   // TODO(lauj) Check shapes.
138   return true;
139 }
140 
input(int index)141 const Tensor& TFRTOpKernelContext::input(int index) {
142   return inputs_[index]->get<Tensor>();
143 }
144 
num_inputs() const145 int TFRTOpKernelContext::num_inputs() const { return inputs_.size(); }
146 
num_outputs() const147 int TFRTOpKernelContext::num_outputs() const { return outputs_.size(); }
148 
set_output(int index,const Tensor & tensor)149 void TFRTOpKernelContext::set_output(int index, const Tensor& tensor) {
150   outputs_[index] = tensor;
151 }
152 
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp)153 Status TFRTOpKernelContext::allocate_temp(DataType type,
154                                           const TensorShape& shape,
155                                           Tensor* out_temp) {
156   *out_temp = Tensor(type, shape);
157   return OkStatus();
158 }
159 
allocate_output(int index,const TensorShape & shape,Tensor ** tensor)160 Status TFRTOpKernelContext::allocate_output(int index, const TensorShape& shape,
161                                             Tensor** tensor) {
162   // Fetch output DataType from the op's TFRTOpMeta.
163   DataType output_type = op_meta_->output_type(index);
164   outputs_[index] = Tensor(output_type, shape);
165   *tensor = &outputs_[index];
166   return OkStatus();
167 }
168 
expected_output_dtype(int i) const169 DataType TFRTOpKernelContext::expected_output_dtype(int i) const {
170   return op_meta_->output_type(i);
171 }
172 
CtxFailure(const Status & s)173 void TFRTOpKernelContext::CtxFailure(const Status& s) {
174   error_ = s.error_message();
175 }
CtxFailureWithWarning(const Status & s)176 void TFRTOpKernelContext::CtxFailureWithWarning(const Status& s) {
177   CtxFailure(s);
178 }
CtxFailure(const char * file,int line,const Status & s)179 void TFRTOpKernelContext::CtxFailure(const char* file, int line,
180                                      const Status& s) {
181   error_ = FillFailureMessage(file, line, s);
182 }
CtxFailureWithWarning(const char * file,int line,const Status & s)183 void TFRTOpKernelContext::CtxFailureWithWarning(const char* file, int line,
184                                                 const Status& s) {
185   CtxFailure(file, line, s);
186 }
187 
188 template <>
eigen_device() const189 const Eigen::ThreadPoolDevice& TFRTOpKernelContext::eigen_device() const {
190   return eigen_host_context_.Device();
191 }
192 
193 //////////////////////////////////////////////////////////////////////
194 // Forwarding op metadata.
195 //////////////////////////////////////////////////////////////////////
TFRTOpMeta(std::vector<DataType> output_types)196 TFRTOpMeta::TFRTOpMeta(std::vector<DataType> output_types)
197     : output_types_(std::move(output_types)) {}
198 
output_type(int index) const199 DataType TFRTOpMeta::output_type(int index) const {
200   return output_types_[index];
201 }
202 
TFRTOpMetaBuilder(StringPiece op_name)203 TFRTOpMetaBuilder::TFRTOpMetaBuilder(StringPiece op_name) : op_name_(op_name) {}
204 
205 namespace {
206 
ParseInputOutputSpec(StringPiece spec)207 DataType ParseInputOutputSpec(StringPiece spec) {
208   std::vector<absl::string_view> name_type =
209       absl::StrSplit(spec, absl::MaxSplits(':', 2));
210   DataType data_type;
211   bool success =
212       DataTypeFromString(absl::StripAsciiWhitespace(name_type[1]), &data_type);
213   assert(success && "Failed to parse DataType");
214   (void)success;
215   return data_type;
216 }
217 
218 }  // anonymous namespace
219 
Output(StringPiece output_spec)220 TFRTOpMetaBuilder& TFRTOpMetaBuilder::Output(StringPiece output_spec) {
221   output_types_.push_back(ParseInputOutputSpec(output_spec));
222   return *this;
223 }
224 
Input(StringPiece input_spec)225 TFRTOpMetaBuilder& TFRTOpMetaBuilder::Input(StringPiece input_spec) {
226   return *this;
227 }
228 
Attr(StringPiece attr_spec)229 TFRTOpMetaBuilder& TFRTOpMetaBuilder::Attr(StringPiece attr_spec) {
230   return *this;
231 }
232 
op_name() const233 const string& TFRTOpMetaBuilder::op_name() const { return op_name_; }
234 
BuildMeta() const235 TFRTOpMeta TFRTOpMetaBuilder::BuildMeta() const {
236   return TFRTOpMeta(output_types_);
237 }
238 
TFRTOpMetaMap()239 TFRTOpMetaMap::TFRTOpMetaMap() {}
240 
RegisterOpMeta(const TFRTOpMetaBuilder & op_builder)241 void TFRTOpMetaMap::RegisterOpMeta(const TFRTOpMetaBuilder& op_builder) {
242   auto insert_result = op_metas_.insert(
243       std::make_pair(op_builder.op_name(), op_builder.BuildMeta()));
244   assert(insert_result.second && "Multiple registrations for the same op_name");
245   (void)insert_result;
246 }
247 
GetOpMeta(StringPiece op_name) const248 const TFRTOpMeta* TFRTOpMetaMap::GetOpMeta(StringPiece op_name) const {
249   auto it = op_metas_.find(llvm::StringRef(op_name.data(), op_name.size()));
250   if (it == op_metas_.end()) return nullptr;
251 
252   return &it->second;
253 }
254 
TFRTOpRegisterer(const TFRTOpMetaBuilder & op_builder)255 TFRTOpRegisterer::TFRTOpRegisterer(const TFRTOpMetaBuilder& op_builder) {
256   tfrt_forwarding_op_meta_map->RegisterOpMeta(op_builder);
257 }
258 
259 llvm::ManagedStatic<TFRTOpMetaMap> tfrt_forwarding_op_meta_map;
260 
261 llvm::ManagedStatic<TFRTOpKernelFactories> tfrt_forwarding_kernel_factories;
262 
263 //////////////////////////////////////////////////////////////////////
264 // Forwarding kernel registration.
265 //////////////////////////////////////////////////////////////////////
266 
TFRTOpKernelFactories()267 TFRTOpKernelFactories::TFRTOpKernelFactories() {}
268 
RegisterFactory(StringPiece kernel_class_name,TFRTOpKernelReg kernel_info)269 void TFRTOpKernelFactories::RegisterFactory(StringPiece kernel_class_name,
270                                             TFRTOpKernelReg kernel_info) {
271   factories_[std::string(kernel_class_name)].push_back(kernel_info);
272 }
273 
274 // Returns true if kernel attributes match given type constraints.
ValidKernelAttr(StringPiece kernel_class_name,TFRTOpKernelConstruction * construction,const llvm::StringMap<DataType> & constraints)275 Status ValidKernelAttr(StringPiece kernel_class_name,
276                        TFRTOpKernelConstruction* construction,
277                        const llvm::StringMap<DataType>& constraints) {
278   for (const auto& constraint : constraints) {
279     auto attr_name = std::string(constraint.first());
280     DataType type;
281     Status s = construction->GetAttr(attr_name, &type);
282     if (!s.ok()) {
283       return errors::InvalidArgument(
284           "Kernel ", kernel_class_name,
285           " has constraint for unset tfdtype attribute ", attr_name, ".");
286     }
287     if (type != constraint.second) {
288       return errors::InvalidArgument(
289           "Kernel ", kernel_class_name, " with type constraint ", attr_name,
290           ": ", DataTypeString(constraint.second),
291           " does not match attribute type ", DataTypeString(type), ".");
292     }
293   }
294   return OkStatus();
295 }
296 
CreateKernel(StringPiece kernel_class_name,TFRTOpKernelConstruction * op_kernel_construction) const297 std::unique_ptr<TFRTOpKernel> TFRTOpKernelFactories::CreateKernel(
298     StringPiece kernel_class_name,
299     TFRTOpKernelConstruction* op_kernel_construction) const {
300   auto it = factories_.find(std::string(kernel_class_name));
301   if (it == factories_.end()) {
302     // Could not find kernel in the registry
303     op_kernel_construction->CtxFailure(errors::NotFound(
304         "Could not find kernel ", kernel_class_name, " in the registry."));
305     return std::unique_ptr<TFRTOpKernel>(nullptr);
306   }
307   Status status;
308   for (const auto& kernel_info : it->second) {
309     Status s = ValidKernelAttr(kernel_class_name, op_kernel_construction,
310                                kernel_info.type_constraints);
311     if (s.ok()) {
312       return kernel_info.callback(op_kernel_construction);
313     }
314     status.Update(s);
315   }
316   // No valid kernel found
317   op_kernel_construction->CtxFailure(status);
318   return std::unique_ptr<TFRTOpKernel>(nullptr);
319 }
320 
321 }  // namespace tensorflow
322