1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/runtime_fallback/util/attr_util.h"
16
17 #include <cstdlib>
18
19 #include "absl/strings/numbers.h"
20 #include "absl/strings/str_split.h"
21 #include "absl/strings/string_view.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/op_def.pb.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor.pb.h"
26 #include "tensorflow/core/framework/tensor_shape.pb.h"
27 #include "tensorflow/core/framework/types.pb.h"
28 #include "tensorflow/core/platform/errors.h"
29 #include "tensorflow/core/platform/status.h"
30 #include "tensorflow/core/tfrt/utils/tensor_util.h"
31 #include "tfrt/core_runtime/op_attrs.h" // from @tf_runtime
32 #include "tfrt/host_context/attribute_utils.h" // from @tf_runtime
33 #include "tfrt/support/error_util.h" // from @tf_runtime
34 #include "tfrt/support/forward_decls.h" // from @tf_runtime
35 #include "tfrt/support/logging.h" // from @tf_runtime
36 #include "tfrt/tensor/dense_host_tensor.h" // from @tf_runtime
37 #include "tfrt/tensor/tensor_serialize_utils.h" // from @tf_runtime
38
39 namespace tensorflow {
40 namespace tfd {
41 namespace {
42
43 using ::tensorflow::protobuf::RepeatedFieldBackInserter;
44 using ::tfrt::AggregateAttr;
45 using ::tfrt::BEFAttributeType;
46 using ::tfrt::DenseAttr;
47 using ::tfrt::DenseHostTensor;
48 using ::tfrt::HostContext;
49 using ::tfrt::OpAttrsRawEntry;
50 using ::tfrt::OpAttrsRef;
51 using ::tfrt::OpAttrType;
52 using ::tfrt::string_view;
53
DecodeDenseAttrToTfTensor(const DenseAttr & dense_attr,HostContext * host)54 llvm::Expected<tensorflow::Tensor> DecodeDenseAttrToTfTensor(
55 const DenseAttr& dense_attr, HostContext* host) {
56 llvm::Expected<DenseHostTensor> dht =
57 tfrt::DeserializeDenseHostTensorFromDenseAttr(dense_attr, host);
58 if (!dht) {
59 return tfrt::MakeStringError(
60 "Cannot create DenseHostTensor in DecodeDenseAttrToTensorInterface: ",
61 dht.takeError());
62 }
63
64 return tfrt::TFRTTensorToTFTensor(*dht, host);
65 }
66
FillAttrValueMapUsingArray(const OpAttrsRawEntry & entry,AttrValue & attr_tmp,const OpAttrsRef & attrs)67 llvm::Error FillAttrValueMapUsingArray(const OpAttrsRawEntry& entry,
68 AttrValue& attr_tmp,
69 const OpAttrsRef& attrs) {
70 attr_tmp.mutable_list()->Clear();
71 if (entry.element_count == 0) {
72 if (entry.type == OpAttrType::CHAR) {
73 // Empty string.
74 attr_tmp.set_s("");
75 }
76 // Empty array of other types.
77 return llvm::Error::success();
78 }
79 switch (entry.type) {
80 case OpAttrType::CHAR: {
81 string_view attr_value = attrs.GetStringAsserting(entry.name);
82 attr_tmp.set_s(attr_value.data(), attr_value.size());
83 return llvm::Error::success();
84 }
85
86 case OpAttrType::FUNC: {
87 string_view attr_value = attrs.GetFuncNameAsserting(entry.name);
88 attr_tmp.mutable_func()->set_name(attr_value.data(), attr_value.size());
89 return llvm::Error::success();
90 }
91 case OpAttrType::I64: {
92 llvm::ArrayRef<int64_t> int_array =
93 attrs.GetArrayAsserting<int64_t>(entry.name);
94 auto* mutable_i = attr_tmp.mutable_list()->mutable_i();
95 std::copy(int_array.begin(), int_array.end(),
96 RepeatedFieldBackInserter(mutable_i));
97 return llvm::Error::success();
98 }
99 case OpAttrType::F32: {
100 llvm::ArrayRef<float> float_array =
101 attrs.GetArrayAsserting<float>(entry.name);
102 auto* mutable_f = attr_tmp.mutable_list()->mutable_f();
103 std::copy(float_array.begin(), float_array.end(),
104 RepeatedFieldBackInserter(mutable_f));
105 return llvm::Error::success();
106 }
107 case OpAttrType::BOOL: {
108 llvm::ArrayRef<bool> bool_array =
109 attrs.GetArrayAsserting<bool>(entry.name);
110 auto mutable_b = attr_tmp.mutable_list()->mutable_b();
111 std::copy(bool_array.begin(), bool_array.end(),
112 RepeatedFieldBackInserter(mutable_b));
113 return llvm::Error::success();
114 }
115 case OpAttrType::DTYPE: {
116 const auto& op_attr = attrs.GetRawAsserting(entry.name);
117 assert(op_attr.IsArray());
118
119 // DTypes in BEF attributes are tfrt::DType enums. So we need
120 // to convert then to tensorflow data types first.
121 auto bef_dtypes =
122 llvm::makeArrayRef(static_cast<const tfrt::DType*>(op_attr.GetData()),
123 op_attr.element_count);
124
125 llvm::SmallVector<tensorflow::DataType, 4> tf_dtypes;
126 tf_dtypes.reserve(bef_dtypes.size());
127 for (auto bef_dtype : bef_dtypes) {
128 tf_dtypes.push_back(ConvertBefAttrTypeToTfDataType(bef_dtype));
129 }
130 auto* mutable_type = attr_tmp.mutable_list()->mutable_type();
131 std::copy(tf_dtypes.begin(), tf_dtypes.end(),
132 RepeatedFieldBackInserter(mutable_type));
133 return llvm::Error::success();
134 }
135 default:
136 return tfrt::MakeStringError("unsupported array attribute type");
137 }
138 }
139
FillAttrValueMapUsingAggregate(const OpAttrsRawEntry & entry,AttrValue & attr_tmp,const OpAttrsRef & attrs)140 llvm::Error FillAttrValueMapUsingAggregate(const OpAttrsRawEntry& entry,
141 AttrValue& attr_tmp,
142 const OpAttrsRef& attrs) {
143 AggregateAttr list_attr = attrs.GetAsserting<AggregateAttr>(entry.name);
144 int num_values = list_attr.GetNumElements();
145 if (num_values == 0) {
146 // Create an empty list.
147 attr_tmp.mutable_list();
148 return llvm::Error::success();
149 }
150 // It is guaranteed that items in one list attribute have the same
151 // type, though their sizes can be different. In particular,
152 // list(TensorShape) and list(Tensor) attribute types have to be
153 // encoded as AggregateAttr.
154 auto attr_base = list_attr.GetAttribute(0);
155 auto* mutable_list = attr_tmp.mutable_list();
156 mutable_list->Clear();
157 if (IsDataTypeAttribute(attr_base.type()) &&
158 GetDataType(attr_base.type()) == tfrt::DType::String) {
159 // Handle list(string).
160 auto* mutable_s = mutable_list->mutable_s();
161 mutable_s->Reserve(num_values);
162 for (int i = 0; i < num_values; ++i) {
163 auto string_attr = list_attr.GetAttributeOfType<tfrt::StringAttr>(i);
164 mutable_list->add_s(string_attr.GetValue().data(),
165 string_attr.GetValue().size());
166 }
167 } else if (attr_base.type() == BEFAttributeType::kFunc) {
168 // Handle list(Function).
169 auto* mutable_f = mutable_list->mutable_func();
170 mutable_f->Reserve(num_values);
171 for (int i = 0; i < num_values; ++i) {
172 auto func_attr = list_attr.GetAttributeOfType<tfrt::FuncAttr>(i);
173 auto mutable_func = mutable_list->add_func();
174 mutable_func->set_name(func_attr.GetFunctionName().str());
175 }
176 } else if (attr_base.type() == BEFAttributeType::kShape) {
177 // Handle list(TensorShape).
178 auto* mutable_list = attr_tmp.mutable_list();
179 auto* mutable_shape = mutable_list->mutable_shape();
180 mutable_shape->Reserve(num_values);
181 for (int i = 0; i < num_values; ++i) {
182 auto shape_attr = list_attr.GetAttributeOfType<tfrt::ShapeAttr>(i);
183 auto* added_shape = mutable_list->add_shape();
184 if (shape_attr.HasRank()) {
185 int rank = shape_attr.GetRank();
186 auto shape = shape_attr.GetShape();
187 added_shape->mutable_dim()->Reserve(rank);
188 for (int d = 0; d < rank; ++d) {
189 added_shape->add_dim()->set_size(shape[d]);
190 }
191 } else {
192 added_shape->set_unknown_rank(true);
193 }
194 }
195 } else {
196 return tfrt::MakeStringError("unsupported list attribute type");
197 }
198 return llvm::Error::success();
199 }
200
FillAttrValueMapUsingScalar(const OpAttrsRawEntry & entry,AttrValue & attr_tmp,HostContext * host,const OpAttrsRef & attrs)201 llvm::Error FillAttrValueMapUsingScalar(const OpAttrsRawEntry& entry,
202 AttrValue& attr_tmp, HostContext* host,
203 const OpAttrsRef& attrs) {
204 switch (entry.type) {
205 case OpAttrType::I64: {
206 int64_t attr_value = attrs.GetAsserting<int64_t>(entry.name);
207 attr_tmp.set_i(attr_value);
208 return llvm::Error::success();
209 }
210 case OpAttrType::F32: {
211 float attr_value = attrs.GetAsserting<float>(entry.name);
212 attr_tmp.set_f(attr_value);
213 return llvm::Error::success();
214 }
215 case OpAttrType::BOOL: {
216 bool attr_value = attrs.GetAsserting<bool>(entry.name);
217 attr_tmp.set_b(attr_value);
218 return llvm::Error::success();
219 }
220 case OpAttrType::DTYPE: {
221 OpAttrType op_attr_type = attrs.GetAsserting<OpAttrType>(entry.name);
222 DataType tf_dtype = ConvertToTfDataType(op_attr_type);
223 attr_tmp.set_type(tf_dtype);
224 return llvm::Error::success();
225 }
226 case OpAttrType::SHAPE: {
227 auto shape_attr = attrs.GetAsserting<tfrt::ShapeAttr>(entry.name);
228 auto* mutable_shape = attr_tmp.mutable_shape();
229 if (shape_attr.HasRank()) {
230 int rank = shape_attr.GetRank();
231 auto shape = shape_attr.GetShape();
232 mutable_shape->mutable_dim()->Reserve(rank);
233 for (int d = 0; d < rank; ++d) {
234 mutable_shape->add_dim()->set_size(shape[d]);
235 }
236 } else {
237 mutable_shape->set_unknown_rank(true);
238 }
239 return llvm::Error::success();
240 }
241 case OpAttrType::DENSE: {
242 auto dense_attr = attrs.GetAsserting<tfrt::DenseAttr>(entry.name);
243 llvm::Expected<tensorflow::Tensor> tf_tensor =
244 DecodeDenseAttrToTfTensor(dense_attr, host);
245 if (!tf_tensor) return tf_tensor.takeError();
246 auto* mutable_tensor = attr_tmp.mutable_tensor();
247 if (tf_tensor->NumElements() > 1) {
248 tf_tensor->AsProtoTensorContent(mutable_tensor);
249 } else {
250 tf_tensor->AsProtoField(mutable_tensor);
251 }
252 return llvm::Error::success();
253 }
254 case OpAttrType::AGGREGATE: {
255 return FillAttrValueMapUsingAggregate(entry, attr_tmp, attrs);
256 }
257 default:
258 LOG(ERROR) << "failure case";
259 return tfrt::MakeStringError("unsupported scalar attribute type");
260 }
261 }
262
263 } // namespace
264
ParseTfDataType(absl::string_view dtype,DataType * data_type)265 Status ParseTfDataType(absl::string_view dtype, DataType* data_type) {
266 if (dtype == "DT_INT8") {
267 *data_type = DataType::DT_INT8;
268 return OkStatus();
269 } else if (dtype == "DT_INT32") {
270 *data_type = DataType::DT_INT32;
271 return OkStatus();
272 } else if (dtype == "DT_INT64") {
273 *data_type = DataType::DT_INT64;
274 return OkStatus();
275 } else if (dtype == "DT_HALF") {
276 *data_type = DataType::DT_HALF;
277 return OkStatus();
278 } else if (dtype == "DT_FLOAT") {
279 *data_type = DataType::DT_FLOAT;
280 return OkStatus();
281 } else if (dtype == "DT_DOUBLE") {
282 *data_type = DataType::DT_DOUBLE;
283 return OkStatus();
284 } else {
285 return errors::InvalidArgument("Unsupported dtype, ", std::string(dtype),
286 " in ParseTfDataType.");
287 }
288 }
289
ConvertToTfDataType(tfrt::OpAttrType op_attr_type)290 DataType ConvertToTfDataType(tfrt::OpAttrType op_attr_type) {
291 switch (op_attr_type) {
292 #define OP_ATTR_TYPE(TFRT_ENUM, DT_ENUM) \
293 case tfrt::OpAttrType::TFRT_ENUM: \
294 return DataType::DT_ENUM;
295 #include "tensorflow/core/runtime_fallback/util/attr_type.def" // NOLINT
296 default:
297 TFRT_DLOG(ERROR) << "unsupported dtype" << static_cast<int>(op_attr_type)
298 << " in TFRT fallback kernel.";
299 abort();
300 }
301 }
302
ConvertFromTfDataType(DataType data_type)303 tfrt::OpAttrType ConvertFromTfDataType(DataType data_type) {
304 switch (data_type) {
305 #define OP_ATTR_TYPE(TFRT_ENUM, DT_ENUM) \
306 case DataType::DT_ENUM: \
307 return tfrt::OpAttrType::TFRT_ENUM;
308 #include "tensorflow/core/runtime_fallback/util/attr_type.def" // NOLINT
309 default:
310 TFRT_DLOG(ERROR) << "unsupported dtype " << static_cast<int>(data_type)
311 << "in TFRT fallback kernel.";
312 abort();
313 }
314 }
315
ConvertBefAttrTypeToTfDataType(tfrt::DType attr_type)316 DataType ConvertBefAttrTypeToTfDataType(tfrt::DType attr_type) {
317 switch (attr_type) {
318 case tfrt::DType::I1:
319 return DataType::DT_BOOL;
320 case tfrt::DType::I8:
321 return DataType::DT_INT8;
322 case tfrt::DType::I16:
323 return DataType::DT_INT16;
324 case tfrt::DType::I32:
325 return DataType::DT_INT32;
326 case tfrt::DType::I64:
327 return DataType::DT_INT64;
328 case tfrt::DType::UI8:
329 return DataType::DT_UINT8;
330 case tfrt::DType::UI16:
331 return DataType::DT_UINT16;
332 case tfrt::DType::UI32:
333 return DataType::DT_UINT32;
334 case tfrt::DType::UI64:
335 return DataType::DT_UINT64;
336 case tfrt::DType::F16:
337 return DataType::DT_HALF;
338 case tfrt::DType::BF16:
339 return DataType::DT_BFLOAT16;
340 case tfrt::DType::F32:
341 return DataType::DT_FLOAT;
342 case tfrt::DType::F64:
343 return DataType::DT_DOUBLE;
344 case tfrt::DType::Complex64:
345 return DataType::DT_COMPLEX64;
346 case tfrt::DType::Complex128:
347 return DataType::DT_COMPLEX128;
348 case tfrt::DType::String:
349 return DataType::DT_STRING;
350 case tfrt::DType::Resource:
351 return DataType::DT_RESOURCE;
352 case tfrt::DType::Variant:
353 return DataType::DT_VARIANT;
354 case tfrt::DType::QUI8:
355 return DataType::DT_QUINT8;
356 case tfrt::DType::QUI16:
357 return DataType::DT_QUINT16;
358 case tfrt::DType::QI8:
359 return DataType::DT_QINT8;
360 case tfrt::DType::QI16:
361 return DataType::DT_QINT16;
362 case tfrt::DType::QI32:
363 return DataType::DT_QINT32;
364 default:
365 TFRT_DLOG(ERROR) << "unsupported tfrt::DType"
366 << static_cast<int>(attr_type)
367 << " in TFRT fallback kernel.";
368 abort();
369 }
370 }
371
ConvertTfDataTypeToBefAttrType(DataType data_type)372 tfrt::DType ConvertTfDataTypeToBefAttrType(DataType data_type) {
373 switch (data_type) {
374 case DataType::DT_UINT8:
375 return tfrt::DType::UI8;
376 case DataType::DT_UINT16:
377 return tfrt::DType::UI16;
378 case DataType::DT_UINT32:
379 return tfrt::DType::UI32;
380 case DataType::DT_UINT64:
381 return tfrt::DType::UI64;
382 case DataType::DT_BOOL:
383 return tfrt::DType::I1;
384 case DataType::DT_INT8:
385 return tfrt::DType::I8;
386 case DataType::DT_INT16:
387 return tfrt::DType::I16;
388 case DataType::DT_INT32:
389 return tfrt::DType::I32;
390 case DataType::DT_INT64:
391 return tfrt::DType::I64;
392 case DataType::DT_HALF:
393 return tfrt::DType::F16;
394 case DataType::DT_BFLOAT16:
395 return tfrt::DType::BF16;
396 case DataType::DT_FLOAT:
397 return tfrt::DType::F32;
398 case DataType::DT_DOUBLE:
399 return tfrt::DType::F64;
400 case DataType::DT_COMPLEX64:
401 return tfrt::DType::Complex64;
402 case DataType::DT_COMPLEX128:
403 return tfrt::DType::Complex128;
404 case DataType::DT_STRING:
405 return tfrt::DType::String;
406 case DataType::DT_RESOURCE:
407 return tfrt::DType::Resource;
408 case DataType::DT_VARIANT:
409 return tfrt::DType::Variant;
410 case DataType::DT_QUINT8:
411 return tfrt::DType::QUI8;
412 case DataType::DT_QUINT16:
413 return tfrt::DType::QUI16;
414 case DataType::DT_QINT8:
415 return tfrt::DType::QI8;
416 case DataType::DT_QINT16:
417 return tfrt::DType::QI16;
418 case DataType::DT_QINT32:
419 return tfrt::DType::QI32;
420 default:
421 TFRT_DLOG(ERROR) << "unsupported DataType " << static_cast<int>(data_type)
422 << " in TFRT fallback kernel.";
423 abort();
424 }
425 }
426
ParseBoolAttrValue(absl::string_view attr_value,bool * bool_val)427 Status ParseBoolAttrValue(absl::string_view attr_value, bool* bool_val) {
428 if (attr_value == "false") {
429 *bool_val = false;
430 return OkStatus();
431 } else if (attr_value == "true") {
432 *bool_val = true;
433 return OkStatus();
434 } else {
435 return errors::InvalidArgument("Could not parse bool from \"", attr_value,
436 "\"");
437 }
438 }
439
ParseIntAttrValue(absl::string_view attr_value,int64_t * int_val)440 Status ParseIntAttrValue(absl::string_view attr_value, int64_t* int_val) {
441 bool success = absl::SimpleAtoi(attr_value, int_val);
442 if (!success) {
443 return errors::InvalidArgument("Could not parse int from \"", attr_value,
444 "\"");
445 }
446 return OkStatus();
447 }
448
ParseTensorAttrValue(absl::string_view attr_value,tensorflow::Tensor * tensor)449 Status ParseTensorAttrValue(absl::string_view attr_value,
450 tensorflow::Tensor* tensor) {
451 if (std::is_base_of<tensorflow::protobuf::Message,
452 tensorflow::TensorProto>()) {
453 tensorflow::TensorProto tensor_proto;
454 // We use reinterpret_cast here to make sure ParseFromString call
455 // below compiles if TensorProto is not a subclass of Message.
456 // At run time, we should never get to this point if TensorProto
457 // is not a subclass of message due to if-condition above.
458 auto* message = reinterpret_cast<protobuf::Message*>(&tensor_proto);
459 if (protobuf::TextFormat::ParseFromString(
460 static_cast<std::string>(attr_value), message) &&
461 tensor->FromProto(tensor_proto)) {
462 return OkStatus();
463 } else {
464 return errors::InvalidArgument("Could not parse tensor value from \"",
465 attr_value, "\"");
466 }
467 } else {
468 // TextFormat does not work with portable proto implementations.
469 return errors::InvalidArgument(
470 "Tensor attributes are not supported on mobile.");
471 }
472 }
473
ParseTensorShapeAttrValue(absl::string_view attr_value,std::vector<int64_t> * shape_val)474 Status ParseTensorShapeAttrValue(absl::string_view attr_value,
475 std::vector<int64_t>* shape_val) {
476 if (attr_value.size() < 2 || attr_value[0] != '[' ||
477 attr_value[attr_value.size() - 1] != ']') {
478 return errors::InvalidArgument(
479 "Tensor shape attribute must be a string of the form [1,2...], instead "
480 "got \"",
481 attr_value, "\"");
482 }
483 absl::string_view attr_value_trunc =
484 attr_value.substr(1, attr_value.size() - 2);
485 // `container` is an absl::strings_internal::Splitter, which is a
486 // lazy-splitting iterable. So we cannot get its size to reserve `dims`.
487 auto container = absl::StrSplit(attr_value_trunc, ',');
488 for (auto it = container.begin(); it != container.end(); ++it) {
489 int64_t int_val;
490 if (!ParseIntAttrValue(*it, &int_val).ok()) {
491 return errors::InvalidArgument("Failed to parse an integer value from ",
492 *it, " while parsing shape.");
493 }
494 shape_val->push_back(int_val);
495 }
496 return OkStatus();
497 }
498
IsUnusedAttribute(absl::string_view attr_name)499 bool IsUnusedAttribute(absl::string_view attr_name) {
500 // These are extra attributes added by TF MLIR dialect, and not needed by
501 // current TF runtime.
502 //
503 // TODO(chky): Consider removing this attribute in tf-to-tfrt
504 // lowering.
505 return absl::StrContains(attr_name, "result_segment_sizes") ||
506 absl::StrContains(attr_name, "operand_segment_sizes") ||
507 absl::EndsWith(attr_name, "_tf_data_function");
508 }
509
FillAttrValueMap(const tfrt::OpAttrsRef & attrs,tfrt::HostContext * host,tensorflow::AttrValueMap * attr_value_map)510 llvm::Error FillAttrValueMap(const tfrt::OpAttrsRef& attrs,
511 tfrt::HostContext* host,
512 tensorflow::AttrValueMap* attr_value_map) {
513 AttrValue attr_tmp;
514 llvm::Error error = llvm::Error::success();
515 attrs.IterateEntries([&error, attr_value_map, &attr_tmp, host,
516 &attrs](const OpAttrsRawEntry& entry) {
517 // TFE does not expect a device attribute.
518 assert(strcmp(entry.name, "device") != 0);
519 if (IsUnusedAttribute(entry.name)) {
520 return;
521 } else if (entry.IsArray()) {
522 error = FillAttrValueMapUsingArray(entry, attr_tmp, attrs);
523 } else {
524 error = FillAttrValueMapUsingScalar(entry, attr_tmp, host, attrs);
525 }
526 if (error) return;
527 attr_value_map->insert(AttrValueMap::value_type(entry.name, attr_tmp));
528 });
529 return error;
530 }
531
532 namespace {
533
CreateTfTensorFromDenseAttr(tfrt::DenseAttr attr)534 tensorflow::Tensor CreateTfTensorFromDenseAttr(tfrt::DenseAttr attr) {
535 tensorflow::TensorShape shape(absl::InlinedVector<int64_t, 4>(
536 attr.shape().begin(), attr.shape().end()));
537 tensorflow::DataType dtype = ConvertBefAttrTypeToTfDataType(attr.dtype());
538
539 tensorflow::Tensor tensor(dtype, shape);
540
541 std::memcpy(tensor.data(), attr.GetElements(), tensor.TotalBytes());
542
543 return tensor;
544 }
545
SetUpScalarAttr(tfrt::TypedAttrBase bef_attr,tensorflow::AttrValue * tf_attr)546 Status SetUpScalarAttr(tfrt::TypedAttrBase bef_attr,
547 tensorflow::AttrValue* tf_attr) {
548 if (auto shape_attr = bef_attr.dyn_cast<tfrt::ShapeAttr>()) {
549 if (shape_attr.HasRank()) {
550 tensorflow::PartialTensorShape tf_shape(shape_attr.GetShape());
551 tf_shape.AsProto(tf_attr->mutable_shape());
552 } else {
553 tensorflow::PartialTensorShape unranked_shape;
554 unranked_shape.AsProto(tf_attr->mutable_shape());
555 }
556 } else if (auto dense_attr = bef_attr.dyn_cast<tfrt::DenseAttr>()) {
557 auto tf_tensor = CreateTfTensorFromDenseAttr(dense_attr);
558 tf_tensor.AsProtoTensorContent(tf_attr->mutable_tensor());
559 } else if (auto type_attr = bef_attr.dyn_cast<tfrt::TypeAttr>()) {
560 tf_attr->set_type(ConvertBefAttrTypeToTfDataType(type_attr.GetValue()));
561 } else if (auto i1_attr = bef_attr.dyn_cast<tfrt::I1Attr>()) {
562 tf_attr->set_b(i1_attr.GetValue());
563 } else if (auto f32_attr = bef_attr.dyn_cast<tfrt::F32Attr>()) {
564 tf_attr->set_f(f32_attr.GetValue());
565 } else if (auto i64_attr = bef_attr.dyn_cast<tfrt::I64Attr>()) {
566 tf_attr->set_i(i64_attr.GetValue());
567 } else if (auto string_attr = bef_attr.dyn_cast<tfrt::StringAttr>()) {
568 tf_attr->set_s(string_attr.GetValue().data(),
569 string_attr.GetValue().size());
570 } else {
571 return tensorflow::errors::Internal("Failed to set up attribute.");
572 }
573
574 return OkStatus();
575 }
576
SetUpScalarFunctionAttr(tfrt::StringAttr func_attr,tensorflow::AttrValue & tf_attr)577 Status SetUpScalarFunctionAttr(tfrt::StringAttr func_attr,
578 tensorflow::AttrValue& tf_attr) {
579 tfrt::string_view func_name = func_attr.GetValue();
580 tf_attr.mutable_func()->set_name(func_name.data(), func_name.size());
581 return OkStatus();
582 }
583
AddShapeToAttrList(tfrt::ShapeAttr shape,tensorflow::AttrValue::ListValue * list)584 void AddShapeToAttrList(tfrt::ShapeAttr shape,
585 tensorflow::AttrValue::ListValue* list) {
586 if (shape.HasRank()) {
587 tensorflow::PartialTensorShape tf_shape(shape.GetShape());
588 tf_shape.AsProto(list->add_shape());
589 return;
590 }
591
592 tensorflow::PartialTensorShape unranked_shape;
593 unranked_shape.AsProto(list->add_shape());
594 }
AddTensorToAttrList(tfrt::DenseAttr dense_attr,tensorflow::AttrValue::ListValue * list)595 void AddTensorToAttrList(tfrt::DenseAttr dense_attr,
596 tensorflow::AttrValue::ListValue* list) {
597 auto tf_tensor = CreateTfTensorFromDenseAttr(dense_attr);
598 tf_tensor.AsProtoTensorContent(list->add_tensor());
599 }
600
SetUpListAttr(tfrt::AggregateAttr aggregate_attr,tensorflow::AttrValue * tf_attr)601 Status SetUpListAttr(tfrt::AggregateAttr aggregate_attr,
602 tensorflow::AttrValue* tf_attr) {
603 auto* list = tf_attr->mutable_list();
604 for (int i = 0; i < aggregate_attr.GetNumElements(); ++i) {
605 auto base = aggregate_attr.GetAttribute(i);
606 if (auto shape_attr = base.dyn_cast<tfrt::ShapeAttr>()) {
607 AddShapeToAttrList(shape_attr, list);
608 } else if (auto dense_attr = base.dyn_cast<tfrt::DenseAttr>()) {
609 AddTensorToAttrList(dense_attr, list);
610 } else if (auto string_attr = base.dyn_cast<tfrt::StringAttr>()) {
611 list->add_s(string_attr.GetValue().data(), string_attr.GetValue().size());
612 } else {
613 return tensorflow::errors::Internal("Failed to set up list attr.");
614 }
615 }
616 return OkStatus();
617 }
618
SetUpListAttr(tfrt::ArrayAttr array_attr,tensorflow::AttrValue * tf_attr)619 Status SetUpListAttr(tfrt::ArrayAttr array_attr,
620 tensorflow::AttrValue* tf_attr) {
621 auto* list = tf_attr->mutable_list();
622
623 // Handle an empty array case.
624 if (array_attr.GetNumElements() == 0) {
625 return OkStatus();
626 }
627
628 tfrt::BEFAttributeType element_type = array_attr.GetElementType();
629 if (tfrt::IsDataTypeAttribute(element_type)) {
630 tfrt::DType dtype = GetDataType(element_type);
631 switch (dtype) {
632 case tfrt::DType::I1: {
633 for (auto value : array_attr.GetValue<bool>()) {
634 list->add_b(value);
635 }
636 return OkStatus();
637 }
638 case tfrt::DType::I64: {
639 for (auto value : array_attr.GetValue<int64_t>()) {
640 list->add_i(value);
641 }
642 return OkStatus();
643 }
644 case tfrt::DType::F32: {
645 for (auto value : array_attr.GetValue<float>()) {
646 list->add_f(value);
647 }
648 return OkStatus();
649 }
650 default:
651 return tensorflow::errors::Internal(
652 StrCat("Failed to set up list attr: unsupported dtype: ",
653 tfrt::DType(dtype)));
654 }
655 } else if (element_type == tfrt::BEFAttributeType::kType) {
656 for (auto value : array_attr.GetValue<tfrt::DType>()) {
657 list->add_type(ConvertBefAttrTypeToTfDataType(value));
658 }
659 return OkStatus();
660 }
661
662 return tensorflow::errors::Internal("Failed to set up list attr.");
663 }
664
665 } // namespace
666
SetUpAttrValueMap(tfrt::AggregateAttr op_attr_array,tfrt::AggregateAttr op_func_attr_array,tensorflow::AttrValueMap * attr_value_map)667 Status SetUpAttrValueMap(tfrt::AggregateAttr op_attr_array,
668 tfrt::AggregateAttr op_func_attr_array,
669 tensorflow::AttrValueMap* attr_value_map) {
670 auto obtain_name_attr_pair =
671 [](tfrt::AggregateAttr attr_array,
672 int i) -> std::pair<std::string, tfrt::TypedAttrBase> {
673 auto pair = attr_array.GetAttributeOfType<tfrt::AggregateAttr>(i);
674 assert(pair.GetNumElements() == 2);
675 return {pair.GetAttributeOfType<tfrt::StringAttr>(0).GetValue().str(),
676 pair.GetAttribute(1)};
677 };
678
679 for (size_t i = 0, e = op_attr_array.GetNumElements(); i != e; ++i) {
680 auto name_attr_pair = obtain_name_attr_pair(op_attr_array, i);
681 if (IsUnusedAttribute(name_attr_pair.first)) continue;
682
683 AttrValue& tf_attr = (*attr_value_map)[name_attr_pair.first];
684 tfrt::TypedAttrBase attr_value = name_attr_pair.second;
685 if (auto aggregate_attr = attr_value.dyn_cast<tfrt::AggregateAttr>()) {
686 TF_RETURN_IF_ERROR(SetUpListAttr(aggregate_attr, &tf_attr));
687 } else if (auto array_attr = attr_value.dyn_cast<tfrt::ArrayAttr>()) {
688 TF_RETURN_IF_ERROR(SetUpListAttr(array_attr, &tf_attr));
689 } else {
690 TF_RETURN_IF_ERROR(SetUpScalarAttr(attr_value, &tf_attr));
691 }
692 }
693
694 for (size_t i = 0, e = op_func_attr_array.GetNumElements(); i != e; ++i) {
695 auto name_attr_pair = obtain_name_attr_pair(op_func_attr_array, i);
696 if (IsUnusedAttribute(name_attr_pair.first)) continue;
697
698 AttrValue& tf_attr = (*attr_value_map)[name_attr_pair.first];
699 auto attr_value = name_attr_pair.second.dyn_cast<tfrt::StringAttr>();
700 TF_RETURN_IF_ERROR(SetUpScalarFunctionAttr(attr_value, tf_attr));
701 }
702
703 return OkStatus();
704 }
705
706 } // namespace tfd
707 } // namespace tensorflow
708