1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/runtime_fallback/kernel/tfrt_op_kernel.h"
16
17 #include "absl/strings/str_split.h"
18 #include "absl/strings/strip.h"
19 #include "llvm/Support/raw_ostream.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/framework/types.h"
23 #include "tensorflow/core/platform/status.h"
24 #include "tensorflow/core/runtime_fallback/kernel/attr_util.h"
25 #include "tensorflow/core/tfrt/utils/error_util.h"
26 #include "tfrt/host_context/async_value.h" // from @tf_runtime
27 #include "tfrt/host_context/kernel_frame.h" // from @tf_runtime
28
29 namespace tensorflow {
30
31 //////////////////////////////////////////////////////////////////////
32 // OpKernel interface.
33 //////////////////////////////////////////////////////////////////////
TFRTOpKernelConstruction(const tfrt::OpAttrsRef & attributes)34 TFRTOpKernelConstruction::TFRTOpKernelConstruction(
35 const tfrt::OpAttrsRef& attributes)
36 : attributes_(std::move(attributes)) {}
37
MissingAttributeError(StringPiece attr_name)38 Status MissingAttributeError(StringPiece attr_name) {
39 return errors::InvalidArgument("Missing attribute: ", attr_name);
40 }
41
42 template <>
GetAttr(StringPiece attr_name,std::string * value) const43 Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name,
44 std::string* value) const {
45 tfrt::string_view view;
46 bool success = attributes_.GetString(
47 llvm::StringRef(attr_name.data(), attr_name.size()), &view);
48 if (!success) {
49 return MissingAttributeError(attr_name);
50 }
51 *value = view.str();
52 return OkStatus();
53 }
54
55 template <>
GetAttr(StringPiece attr_name,DataType * value) const56 Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name,
57 DataType* value) const {
58 tfrt::OpAttrType attrtype;
59 bool success = attributes_.Get<tfrt::OpAttrType>(
60 llvm::StringRef(attr_name.data(), attr_name.size()), &attrtype);
61 if (!success) {
62 return MissingAttributeError(attr_name);
63 }
64 *value = tfd::ConvertToTfDataType(attrtype);
65 return OkStatus();
66 }
67
68 template <>
GetAttr(StringPiece attr_name,Padding * value) const69 Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name,
70 Padding* value) const {
71 std::string padding_str;
72 TF_RETURN_IF_ERROR(GetAttr<std::string>(attr_name, &padding_str));
73 return GetPaddingFromString(padding_str, value);
74 }
75
76 template <>
GetAttr(StringPiece attr_name,std::vector<int32> * value) const77 Status TFRTOpKernelConstruction::GetAttr(StringPiece attr_name,
78 std::vector<int32>* value) const {
79 llvm::ArrayRef<int32> arrayref;
80 bool success = attributes_.GetArray<int32>(
81 llvm::StringRef(attr_name.data(), attr_name.size()), &arrayref);
82 if (!success) {
83 return MissingAttributeError(attr_name);
84 }
85 *value = arrayref;
86 return OkStatus();
87 }
88
CtxFailure(const Status & s)89 void TFRTOpKernelConstruction::CtxFailure(const Status& s) {
90 error_ = tfrt::MakeStatusString(s);
91 }
92
CtxFailureWithWarning(const Status & s)93 void TFRTOpKernelConstruction::CtxFailureWithWarning(const Status& s) {
94 CtxFailure(s);
95 }
96
97 namespace {
FillFailureMessage(const char * file,int line,const Status & s)98 std::string FillFailureMessage(const char* file, int line, const Status& s) {
99 std::string error;
100 llvm::raw_string_ostream sstr(error);
101 sstr << "OP_REQUIRES failed at " << file << ":" << line << " : "
102 << tfrt::MakeStatusString(s);
103 sstr.str();
104 return error;
105 }
106 } // namespace
107
CtxFailure(const char * file,int line,const Status & s)108 void TFRTOpKernelConstruction::CtxFailure(const char* file, int line,
109 const Status& s) {
110 error_ = FillFailureMessage(file, line, s);
111 }
112
CtxFailureWithWarning(const char * file,int line,const Status & s)113 void TFRTOpKernelConstruction::CtxFailureWithWarning(const char* file, int line,
114 const Status& s) {
115 CtxFailure(file, line, s);
116 }
117
error()118 const llvm::Optional<std::string>& TFRTOpKernelConstruction::error() {
119 return error_;
120 }
121
TFRTOpKernelContext(llvm::ArrayRef<tfrt::RCReference<tfrt::AsyncValue>> inputs,int num_outputs,const TFRTOpMeta * op_meta,tfrt::HostContext * host)122 TFRTOpKernelContext::TFRTOpKernelContext(
123 llvm::ArrayRef<tfrt::RCReference<tfrt::AsyncValue>> inputs, int num_outputs,
124 const TFRTOpMeta* op_meta, tfrt::HostContext* host)
125 : inputs_(inputs),
126 op_meta_(op_meta),
127 outputs_(num_outputs),
128 eigen_host_context_(host) {}
129
output(int index)130 const Tensor& TFRTOpKernelContext::output(int index) { return outputs_[index]; }
131
error()132 const llvm::Optional<std::string>& TFRTOpKernelContext::error() {
133 return error_;
134 }
135
ValidateInputsAreSameShape(TFRTOpKernel * op)136 bool TFRTOpKernelContext::ValidateInputsAreSameShape(TFRTOpKernel* op) {
137 // TODO(lauj) Check shapes.
138 return true;
139 }
140
input(int index)141 const Tensor& TFRTOpKernelContext::input(int index) {
142 return inputs_[index]->get<Tensor>();
143 }
144
num_inputs() const145 int TFRTOpKernelContext::num_inputs() const { return inputs_.size(); }
146
num_outputs() const147 int TFRTOpKernelContext::num_outputs() const { return outputs_.size(); }
148
set_output(int index,const Tensor & tensor)149 void TFRTOpKernelContext::set_output(int index, const Tensor& tensor) {
150 outputs_[index] = tensor;
151 }
152
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp)153 Status TFRTOpKernelContext::allocate_temp(DataType type,
154 const TensorShape& shape,
155 Tensor* out_temp) {
156 *out_temp = Tensor(type, shape);
157 return OkStatus();
158 }
159
allocate_output(int index,const TensorShape & shape,Tensor ** tensor)160 Status TFRTOpKernelContext::allocate_output(int index, const TensorShape& shape,
161 Tensor** tensor) {
162 // Fetch output DataType from the op's TFRTOpMeta.
163 DataType output_type = op_meta_->output_type(index);
164 outputs_[index] = Tensor(output_type, shape);
165 *tensor = &outputs_[index];
166 return OkStatus();
167 }
168
expected_output_dtype(int i) const169 DataType TFRTOpKernelContext::expected_output_dtype(int i) const {
170 return op_meta_->output_type(i);
171 }
172
CtxFailure(const Status & s)173 void TFRTOpKernelContext::CtxFailure(const Status& s) {
174 error_ = s.error_message();
175 }
CtxFailureWithWarning(const Status & s)176 void TFRTOpKernelContext::CtxFailureWithWarning(const Status& s) {
177 CtxFailure(s);
178 }
CtxFailure(const char * file,int line,const Status & s)179 void TFRTOpKernelContext::CtxFailure(const char* file, int line,
180 const Status& s) {
181 error_ = FillFailureMessage(file, line, s);
182 }
CtxFailureWithWarning(const char * file,int line,const Status & s)183 void TFRTOpKernelContext::CtxFailureWithWarning(const char* file, int line,
184 const Status& s) {
185 CtxFailure(file, line, s);
186 }
187
188 template <>
eigen_device() const189 const Eigen::ThreadPoolDevice& TFRTOpKernelContext::eigen_device() const {
190 return eigen_host_context_.Device();
191 }
192
193 //////////////////////////////////////////////////////////////////////
194 // Forwarding op metadata.
195 //////////////////////////////////////////////////////////////////////
TFRTOpMeta(std::vector<DataType> output_types)196 TFRTOpMeta::TFRTOpMeta(std::vector<DataType> output_types)
197 : output_types_(std::move(output_types)) {}
198
output_type(int index) const199 DataType TFRTOpMeta::output_type(int index) const {
200 return output_types_[index];
201 }
202
TFRTOpMetaBuilder(StringPiece op_name)203 TFRTOpMetaBuilder::TFRTOpMetaBuilder(StringPiece op_name) : op_name_(op_name) {}
204
205 namespace {
206
ParseInputOutputSpec(StringPiece spec)207 DataType ParseInputOutputSpec(StringPiece spec) {
208 std::vector<absl::string_view> name_type =
209 absl::StrSplit(spec, absl::MaxSplits(':', 2));
210 DataType data_type;
211 bool success =
212 DataTypeFromString(absl::StripAsciiWhitespace(name_type[1]), &data_type);
213 assert(success && "Failed to parse DataType");
214 (void)success;
215 return data_type;
216 }
217
218 } // anonymous namespace
219
Output(StringPiece output_spec)220 TFRTOpMetaBuilder& TFRTOpMetaBuilder::Output(StringPiece output_spec) {
221 output_types_.push_back(ParseInputOutputSpec(output_spec));
222 return *this;
223 }
224
Input(StringPiece input_spec)225 TFRTOpMetaBuilder& TFRTOpMetaBuilder::Input(StringPiece input_spec) {
226 return *this;
227 }
228
Attr(StringPiece attr_spec)229 TFRTOpMetaBuilder& TFRTOpMetaBuilder::Attr(StringPiece attr_spec) {
230 return *this;
231 }
232
op_name() const233 const string& TFRTOpMetaBuilder::op_name() const { return op_name_; }
234
BuildMeta() const235 TFRTOpMeta TFRTOpMetaBuilder::BuildMeta() const {
236 return TFRTOpMeta(output_types_);
237 }
238
TFRTOpMetaMap()239 TFRTOpMetaMap::TFRTOpMetaMap() {}
240
RegisterOpMeta(const TFRTOpMetaBuilder & op_builder)241 void TFRTOpMetaMap::RegisterOpMeta(const TFRTOpMetaBuilder& op_builder) {
242 auto insert_result = op_metas_.insert(
243 std::make_pair(op_builder.op_name(), op_builder.BuildMeta()));
244 assert(insert_result.second && "Multiple registrations for the same op_name");
245 (void)insert_result;
246 }
247
GetOpMeta(StringPiece op_name) const248 const TFRTOpMeta* TFRTOpMetaMap::GetOpMeta(StringPiece op_name) const {
249 auto it = op_metas_.find(llvm::StringRef(op_name.data(), op_name.size()));
250 if (it == op_metas_.end()) return nullptr;
251
252 return &it->second;
253 }
254
TFRTOpRegisterer(const TFRTOpMetaBuilder & op_builder)255 TFRTOpRegisterer::TFRTOpRegisterer(const TFRTOpMetaBuilder& op_builder) {
256 tfrt_forwarding_op_meta_map->RegisterOpMeta(op_builder);
257 }
258
259 llvm::ManagedStatic<TFRTOpMetaMap> tfrt_forwarding_op_meta_map;
260
261 llvm::ManagedStatic<TFRTOpKernelFactories> tfrt_forwarding_kernel_factories;
262
263 //////////////////////////////////////////////////////////////////////
264 // Forwarding kernel registration.
265 //////////////////////////////////////////////////////////////////////
266
TFRTOpKernelFactories()267 TFRTOpKernelFactories::TFRTOpKernelFactories() {}
268
RegisterFactory(StringPiece kernel_class_name,TFRTOpKernelReg kernel_info)269 void TFRTOpKernelFactories::RegisterFactory(StringPiece kernel_class_name,
270 TFRTOpKernelReg kernel_info) {
271 factories_[std::string(kernel_class_name)].push_back(kernel_info);
272 }
273
274 // Returns true if kernel attributes match given type constraints.
ValidKernelAttr(StringPiece kernel_class_name,TFRTOpKernelConstruction * construction,const llvm::StringMap<DataType> & constraints)275 Status ValidKernelAttr(StringPiece kernel_class_name,
276 TFRTOpKernelConstruction* construction,
277 const llvm::StringMap<DataType>& constraints) {
278 for (const auto& constraint : constraints) {
279 auto attr_name = std::string(constraint.first());
280 DataType type;
281 Status s = construction->GetAttr(attr_name, &type);
282 if (!s.ok()) {
283 return errors::InvalidArgument(
284 "Kernel ", kernel_class_name,
285 " has constraint for unset tfdtype attribute ", attr_name, ".");
286 }
287 if (type != constraint.second) {
288 return errors::InvalidArgument(
289 "Kernel ", kernel_class_name, " with type constraint ", attr_name,
290 ": ", DataTypeString(constraint.second),
291 " does not match attribute type ", DataTypeString(type), ".");
292 }
293 }
294 return OkStatus();
295 }
296
CreateKernel(StringPiece kernel_class_name,TFRTOpKernelConstruction * op_kernel_construction) const297 std::unique_ptr<TFRTOpKernel> TFRTOpKernelFactories::CreateKernel(
298 StringPiece kernel_class_name,
299 TFRTOpKernelConstruction* op_kernel_construction) const {
300 auto it = factories_.find(std::string(kernel_class_name));
301 if (it == factories_.end()) {
302 // Could not find kernel in the registry
303 op_kernel_construction->CtxFailure(errors::NotFound(
304 "Could not find kernel ", kernel_class_name, " in the registry."));
305 return std::unique_ptr<TFRTOpKernel>(nullptr);
306 }
307 Status status;
308 for (const auto& kernel_info : it->second) {
309 Status s = ValidKernelAttr(kernel_class_name, op_kernel_construction,
310 kernel_info.type_constraints);
311 if (s.ok()) {
312 return kernel_info.callback(op_kernel_construction);
313 }
314 status.Update(s);
315 }
316 // No valid kernel found
317 op_kernel_construction->CtxFailure(status);
318 return std::unique_ptr<TFRTOpKernel>(nullptr);
319 }
320
321 } // namespace tensorflow
322