xref: /aosp_15_r20/external/tensorflow/tensorflow/core/runtime_fallback/util/attr_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/runtime_fallback/util/attr_util.h"
16 
17 #include <cstdlib>
18 
19 #include "absl/strings/numbers.h"
20 #include "absl/strings/str_split.h"
21 #include "absl/strings/string_view.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/op_def.pb.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor.pb.h"
26 #include "tensorflow/core/framework/tensor_shape.pb.h"
27 #include "tensorflow/core/framework/types.pb.h"
28 #include "tensorflow/core/platform/errors.h"
29 #include "tensorflow/core/platform/status.h"
30 #include "tensorflow/core/tfrt/utils/tensor_util.h"
31 #include "tfrt/core_runtime/op_attrs.h"  // from @tf_runtime
32 #include "tfrt/host_context/attribute_utils.h"  // from @tf_runtime
33 #include "tfrt/support/error_util.h"  // from @tf_runtime
34 #include "tfrt/support/forward_decls.h"  // from @tf_runtime
35 #include "tfrt/support/logging.h"  // from @tf_runtime
36 #include "tfrt/tensor/dense_host_tensor.h"  // from @tf_runtime
37 #include "tfrt/tensor/tensor_serialize_utils.h"  // from @tf_runtime
38 
39 namespace tensorflow {
40 namespace tfd {
41 namespace {
42 
43 using ::tensorflow::protobuf::RepeatedFieldBackInserter;
44 using ::tfrt::AggregateAttr;
45 using ::tfrt::BEFAttributeType;
46 using ::tfrt::DenseAttr;
47 using ::tfrt::DenseHostTensor;
48 using ::tfrt::HostContext;
49 using ::tfrt::OpAttrsRawEntry;
50 using ::tfrt::OpAttrsRef;
51 using ::tfrt::OpAttrType;
52 using ::tfrt::string_view;
53 
DecodeDenseAttrToTfTensor(const DenseAttr & dense_attr,HostContext * host)54 llvm::Expected<tensorflow::Tensor> DecodeDenseAttrToTfTensor(
55     const DenseAttr& dense_attr, HostContext* host) {
56   llvm::Expected<DenseHostTensor> dht =
57       tfrt::DeserializeDenseHostTensorFromDenseAttr(dense_attr, host);
58   if (!dht) {
59     return tfrt::MakeStringError(
60         "Cannot create DenseHostTensor in DecodeDenseAttrToTensorInterface: ",
61         dht.takeError());
62   }
63 
64   return tfrt::TFRTTensorToTFTensor(*dht, host);
65 }
66 
FillAttrValueMapUsingArray(const OpAttrsRawEntry & entry,AttrValue & attr_tmp,const OpAttrsRef & attrs)67 llvm::Error FillAttrValueMapUsingArray(const OpAttrsRawEntry& entry,
68                                        AttrValue& attr_tmp,
69                                        const OpAttrsRef& attrs) {
70   attr_tmp.mutable_list()->Clear();
71   if (entry.element_count == 0) {
72     if (entry.type == OpAttrType::CHAR) {
73       // Empty string.
74       attr_tmp.set_s("");
75     }
76     // Empty array of other types.
77     return llvm::Error::success();
78   }
79   switch (entry.type) {
80     case OpAttrType::CHAR: {
81       string_view attr_value = attrs.GetStringAsserting(entry.name);
82       attr_tmp.set_s(attr_value.data(), attr_value.size());
83       return llvm::Error::success();
84     }
85 
86     case OpAttrType::FUNC: {
87       string_view attr_value = attrs.GetFuncNameAsserting(entry.name);
88       attr_tmp.mutable_func()->set_name(attr_value.data(), attr_value.size());
89       return llvm::Error::success();
90     }
91     case OpAttrType::I64: {
92       llvm::ArrayRef<int64_t> int_array =
93           attrs.GetArrayAsserting<int64_t>(entry.name);
94       auto* mutable_i = attr_tmp.mutable_list()->mutable_i();
95       std::copy(int_array.begin(), int_array.end(),
96                 RepeatedFieldBackInserter(mutable_i));
97       return llvm::Error::success();
98     }
99     case OpAttrType::F32: {
100       llvm::ArrayRef<float> float_array =
101           attrs.GetArrayAsserting<float>(entry.name);
102       auto* mutable_f = attr_tmp.mutable_list()->mutable_f();
103       std::copy(float_array.begin(), float_array.end(),
104                 RepeatedFieldBackInserter(mutable_f));
105       return llvm::Error::success();
106     }
107     case OpAttrType::BOOL: {
108       llvm::ArrayRef<bool> bool_array =
109           attrs.GetArrayAsserting<bool>(entry.name);
110       auto mutable_b = attr_tmp.mutable_list()->mutable_b();
111       std::copy(bool_array.begin(), bool_array.end(),
112                 RepeatedFieldBackInserter(mutable_b));
113       return llvm::Error::success();
114     }
115     case OpAttrType::DTYPE: {
116       const auto& op_attr = attrs.GetRawAsserting(entry.name);
117       assert(op_attr.IsArray());
118 
119       // DTypes in BEF attributes are tfrt::DType enums. So we need
120       // to convert then to tensorflow data types first.
121       auto bef_dtypes =
122           llvm::makeArrayRef(static_cast<const tfrt::DType*>(op_attr.GetData()),
123                              op_attr.element_count);
124 
125       llvm::SmallVector<tensorflow::DataType, 4> tf_dtypes;
126       tf_dtypes.reserve(bef_dtypes.size());
127       for (auto bef_dtype : bef_dtypes) {
128         tf_dtypes.push_back(ConvertBefAttrTypeToTfDataType(bef_dtype));
129       }
130       auto* mutable_type = attr_tmp.mutable_list()->mutable_type();
131       std::copy(tf_dtypes.begin(), tf_dtypes.end(),
132                 RepeatedFieldBackInserter(mutable_type));
133       return llvm::Error::success();
134     }
135     default:
136       return tfrt::MakeStringError("unsupported array attribute type");
137   }
138 }
139 
FillAttrValueMapUsingAggregate(const OpAttrsRawEntry & entry,AttrValue & attr_tmp,const OpAttrsRef & attrs)140 llvm::Error FillAttrValueMapUsingAggregate(const OpAttrsRawEntry& entry,
141                                            AttrValue& attr_tmp,
142                                            const OpAttrsRef& attrs) {
143   AggregateAttr list_attr = attrs.GetAsserting<AggregateAttr>(entry.name);
144   int num_values = list_attr.GetNumElements();
145   if (num_values == 0) {
146     // Create an empty list.
147     attr_tmp.mutable_list();
148     return llvm::Error::success();
149   }
150   // It is guaranteed that items in one list attribute have the same
151   // type, though their sizes can be different. In particular,
152   // list(TensorShape) and list(Tensor) attribute types have to be
153   // encoded as AggregateAttr.
154   auto attr_base = list_attr.GetAttribute(0);
155   auto* mutable_list = attr_tmp.mutable_list();
156   mutable_list->Clear();
157   if (IsDataTypeAttribute(attr_base.type()) &&
158       GetDataType(attr_base.type()) == tfrt::DType::String) {
159     // Handle list(string).
160     auto* mutable_s = mutable_list->mutable_s();
161     mutable_s->Reserve(num_values);
162     for (int i = 0; i < num_values; ++i) {
163       auto string_attr = list_attr.GetAttributeOfType<tfrt::StringAttr>(i);
164       mutable_list->add_s(string_attr.GetValue().data(),
165                           string_attr.GetValue().size());
166     }
167   } else if (attr_base.type() == BEFAttributeType::kFunc) {
168     // Handle list(Function).
169     auto* mutable_f = mutable_list->mutable_func();
170     mutable_f->Reserve(num_values);
171     for (int i = 0; i < num_values; ++i) {
172       auto func_attr = list_attr.GetAttributeOfType<tfrt::FuncAttr>(i);
173       auto mutable_func = mutable_list->add_func();
174       mutable_func->set_name(func_attr.GetFunctionName().str());
175     }
176   } else if (attr_base.type() == BEFAttributeType::kShape) {
177     // Handle list(TensorShape).
178     auto* mutable_list = attr_tmp.mutable_list();
179     auto* mutable_shape = mutable_list->mutable_shape();
180     mutable_shape->Reserve(num_values);
181     for (int i = 0; i < num_values; ++i) {
182       auto shape_attr = list_attr.GetAttributeOfType<tfrt::ShapeAttr>(i);
183       auto* added_shape = mutable_list->add_shape();
184       if (shape_attr.HasRank()) {
185         int rank = shape_attr.GetRank();
186         auto shape = shape_attr.GetShape();
187         added_shape->mutable_dim()->Reserve(rank);
188         for (int d = 0; d < rank; ++d) {
189           added_shape->add_dim()->set_size(shape[d]);
190         }
191       } else {
192         added_shape->set_unknown_rank(true);
193       }
194     }
195   } else {
196     return tfrt::MakeStringError("unsupported list attribute type");
197   }
198   return llvm::Error::success();
199 }
200 
FillAttrValueMapUsingScalar(const OpAttrsRawEntry & entry,AttrValue & attr_tmp,HostContext * host,const OpAttrsRef & attrs)201 llvm::Error FillAttrValueMapUsingScalar(const OpAttrsRawEntry& entry,
202                                         AttrValue& attr_tmp, HostContext* host,
203                                         const OpAttrsRef& attrs) {
204   switch (entry.type) {
205     case OpAttrType::I64: {
206       int64_t attr_value = attrs.GetAsserting<int64_t>(entry.name);
207       attr_tmp.set_i(attr_value);
208       return llvm::Error::success();
209     }
210     case OpAttrType::F32: {
211       float attr_value = attrs.GetAsserting<float>(entry.name);
212       attr_tmp.set_f(attr_value);
213       return llvm::Error::success();
214     }
215     case OpAttrType::BOOL: {
216       bool attr_value = attrs.GetAsserting<bool>(entry.name);
217       attr_tmp.set_b(attr_value);
218       return llvm::Error::success();
219     }
220     case OpAttrType::DTYPE: {
221       OpAttrType op_attr_type = attrs.GetAsserting<OpAttrType>(entry.name);
222       DataType tf_dtype = ConvertToTfDataType(op_attr_type);
223       attr_tmp.set_type(tf_dtype);
224       return llvm::Error::success();
225     }
226     case OpAttrType::SHAPE: {
227       auto shape_attr = attrs.GetAsserting<tfrt::ShapeAttr>(entry.name);
228       auto* mutable_shape = attr_tmp.mutable_shape();
229       if (shape_attr.HasRank()) {
230         int rank = shape_attr.GetRank();
231         auto shape = shape_attr.GetShape();
232         mutable_shape->mutable_dim()->Reserve(rank);
233         for (int d = 0; d < rank; ++d) {
234           mutable_shape->add_dim()->set_size(shape[d]);
235         }
236       } else {
237         mutable_shape->set_unknown_rank(true);
238       }
239       return llvm::Error::success();
240     }
241     case OpAttrType::DENSE: {
242       auto dense_attr = attrs.GetAsserting<tfrt::DenseAttr>(entry.name);
243       llvm::Expected<tensorflow::Tensor> tf_tensor =
244           DecodeDenseAttrToTfTensor(dense_attr, host);
245       if (!tf_tensor) return tf_tensor.takeError();
246       auto* mutable_tensor = attr_tmp.mutable_tensor();
247       if (tf_tensor->NumElements() > 1) {
248         tf_tensor->AsProtoTensorContent(mutable_tensor);
249       } else {
250         tf_tensor->AsProtoField(mutable_tensor);
251       }
252       return llvm::Error::success();
253     }
254     case OpAttrType::AGGREGATE: {
255       return FillAttrValueMapUsingAggregate(entry, attr_tmp, attrs);
256     }
257     default:
258       LOG(ERROR) << "failure case";
259       return tfrt::MakeStringError("unsupported scalar attribute type");
260   }
261 }
262 
263 }  // namespace
264 
ParseTfDataType(absl::string_view dtype,DataType * data_type)265 Status ParseTfDataType(absl::string_view dtype, DataType* data_type) {
266   if (dtype == "DT_INT8") {
267     *data_type = DataType::DT_INT8;
268     return OkStatus();
269   } else if (dtype == "DT_INT32") {
270     *data_type = DataType::DT_INT32;
271     return OkStatus();
272   } else if (dtype == "DT_INT64") {
273     *data_type = DataType::DT_INT64;
274     return OkStatus();
275   } else if (dtype == "DT_HALF") {
276     *data_type = DataType::DT_HALF;
277     return OkStatus();
278   } else if (dtype == "DT_FLOAT") {
279     *data_type = DataType::DT_FLOAT;
280     return OkStatus();
281   } else if (dtype == "DT_DOUBLE") {
282     *data_type = DataType::DT_DOUBLE;
283     return OkStatus();
284   } else {
285     return errors::InvalidArgument("Unsupported dtype, ", std::string(dtype),
286                                    " in ParseTfDataType.");
287   }
288 }
289 
ConvertToTfDataType(tfrt::OpAttrType op_attr_type)290 DataType ConvertToTfDataType(tfrt::OpAttrType op_attr_type) {
291   switch (op_attr_type) {
292 #define OP_ATTR_TYPE(TFRT_ENUM, DT_ENUM) \
293   case tfrt::OpAttrType::TFRT_ENUM:      \
294     return DataType::DT_ENUM;
295 #include "tensorflow/core/runtime_fallback/util/attr_type.def"  // NOLINT
296     default:
297       TFRT_DLOG(ERROR) << "unsupported dtype" << static_cast<int>(op_attr_type)
298                        << " in TFRT fallback kernel.";
299       abort();
300   }
301 }
302 
ConvertFromTfDataType(DataType data_type)303 tfrt::OpAttrType ConvertFromTfDataType(DataType data_type) {
304   switch (data_type) {
305 #define OP_ATTR_TYPE(TFRT_ENUM, DT_ENUM) \
306   case DataType::DT_ENUM:                \
307     return tfrt::OpAttrType::TFRT_ENUM;
308 #include "tensorflow/core/runtime_fallback/util/attr_type.def"  // NOLINT
309     default:
310       TFRT_DLOG(ERROR) << "unsupported dtype " << static_cast<int>(data_type)
311                        << "in TFRT fallback kernel.";
312       abort();
313   }
314 }
315 
ConvertBefAttrTypeToTfDataType(tfrt::DType attr_type)316 DataType ConvertBefAttrTypeToTfDataType(tfrt::DType attr_type) {
317   switch (attr_type) {
318     case tfrt::DType::I1:
319       return DataType::DT_BOOL;
320     case tfrt::DType::I8:
321       return DataType::DT_INT8;
322     case tfrt::DType::I16:
323       return DataType::DT_INT16;
324     case tfrt::DType::I32:
325       return DataType::DT_INT32;
326     case tfrt::DType::I64:
327       return DataType::DT_INT64;
328     case tfrt::DType::UI8:
329       return DataType::DT_UINT8;
330     case tfrt::DType::UI16:
331       return DataType::DT_UINT16;
332     case tfrt::DType::UI32:
333       return DataType::DT_UINT32;
334     case tfrt::DType::UI64:
335       return DataType::DT_UINT64;
336     case tfrt::DType::F16:
337       return DataType::DT_HALF;
338     case tfrt::DType::BF16:
339       return DataType::DT_BFLOAT16;
340     case tfrt::DType::F32:
341       return DataType::DT_FLOAT;
342     case tfrt::DType::F64:
343       return DataType::DT_DOUBLE;
344     case tfrt::DType::Complex64:
345       return DataType::DT_COMPLEX64;
346     case tfrt::DType::Complex128:
347       return DataType::DT_COMPLEX128;
348     case tfrt::DType::String:
349       return DataType::DT_STRING;
350     case tfrt::DType::Resource:
351       return DataType::DT_RESOURCE;
352     case tfrt::DType::Variant:
353       return DataType::DT_VARIANT;
354     case tfrt::DType::QUI8:
355       return DataType::DT_QUINT8;
356     case tfrt::DType::QUI16:
357       return DataType::DT_QUINT16;
358     case tfrt::DType::QI8:
359       return DataType::DT_QINT8;
360     case tfrt::DType::QI16:
361       return DataType::DT_QINT16;
362     case tfrt::DType::QI32:
363       return DataType::DT_QINT32;
364     default:
365       TFRT_DLOG(ERROR) << "unsupported tfrt::DType"
366                        << static_cast<int>(attr_type)
367                        << " in TFRT fallback kernel.";
368       abort();
369   }
370 }
371 
ConvertTfDataTypeToBefAttrType(DataType data_type)372 tfrt::DType ConvertTfDataTypeToBefAttrType(DataType data_type) {
373   switch (data_type) {
374     case DataType::DT_UINT8:
375       return tfrt::DType::UI8;
376     case DataType::DT_UINT16:
377       return tfrt::DType::UI16;
378     case DataType::DT_UINT32:
379       return tfrt::DType::UI32;
380     case DataType::DT_UINT64:
381       return tfrt::DType::UI64;
382     case DataType::DT_BOOL:
383       return tfrt::DType::I1;
384     case DataType::DT_INT8:
385       return tfrt::DType::I8;
386     case DataType::DT_INT16:
387       return tfrt::DType::I16;
388     case DataType::DT_INT32:
389       return tfrt::DType::I32;
390     case DataType::DT_INT64:
391       return tfrt::DType::I64;
392     case DataType::DT_HALF:
393       return tfrt::DType::F16;
394     case DataType::DT_BFLOAT16:
395       return tfrt::DType::BF16;
396     case DataType::DT_FLOAT:
397       return tfrt::DType::F32;
398     case DataType::DT_DOUBLE:
399       return tfrt::DType::F64;
400     case DataType::DT_COMPLEX64:
401       return tfrt::DType::Complex64;
402     case DataType::DT_COMPLEX128:
403       return tfrt::DType::Complex128;
404     case DataType::DT_STRING:
405       return tfrt::DType::String;
406     case DataType::DT_RESOURCE:
407       return tfrt::DType::Resource;
408     case DataType::DT_VARIANT:
409       return tfrt::DType::Variant;
410     case DataType::DT_QUINT8:
411       return tfrt::DType::QUI8;
412     case DataType::DT_QUINT16:
413       return tfrt::DType::QUI16;
414     case DataType::DT_QINT8:
415       return tfrt::DType::QI8;
416     case DataType::DT_QINT16:
417       return tfrt::DType::QI16;
418     case DataType::DT_QINT32:
419       return tfrt::DType::QI32;
420     default:
421       TFRT_DLOG(ERROR) << "unsupported DataType " << static_cast<int>(data_type)
422                        << " in TFRT fallback kernel.";
423       abort();
424   }
425 }
426 
ParseBoolAttrValue(absl::string_view attr_value,bool * bool_val)427 Status ParseBoolAttrValue(absl::string_view attr_value, bool* bool_val) {
428   if (attr_value == "false") {
429     *bool_val = false;
430     return OkStatus();
431   } else if (attr_value == "true") {
432     *bool_val = true;
433     return OkStatus();
434   } else {
435     return errors::InvalidArgument("Could not parse bool from \"", attr_value,
436                                    "\"");
437   }
438 }
439 
ParseIntAttrValue(absl::string_view attr_value,int64_t * int_val)440 Status ParseIntAttrValue(absl::string_view attr_value, int64_t* int_val) {
441   bool success = absl::SimpleAtoi(attr_value, int_val);
442   if (!success) {
443     return errors::InvalidArgument("Could not parse int from \"", attr_value,
444                                    "\"");
445   }
446   return OkStatus();
447 }
448 
ParseTensorAttrValue(absl::string_view attr_value,tensorflow::Tensor * tensor)449 Status ParseTensorAttrValue(absl::string_view attr_value,
450                             tensorflow::Tensor* tensor) {
451   if (std::is_base_of<tensorflow::protobuf::Message,
452                       tensorflow::TensorProto>()) {
453     tensorflow::TensorProto tensor_proto;
454     // We use reinterpret_cast here to make sure ParseFromString call
455     // below compiles if TensorProto is not a subclass of Message.
456     // At run time, we should never get to this point if TensorProto
457     // is not a subclass of message due to if-condition above.
458     auto* message = reinterpret_cast<protobuf::Message*>(&tensor_proto);
459     if (protobuf::TextFormat::ParseFromString(
460             static_cast<std::string>(attr_value), message) &&
461         tensor->FromProto(tensor_proto)) {
462       return OkStatus();
463     } else {
464       return errors::InvalidArgument("Could not parse tensor value from \"",
465                                      attr_value, "\"");
466     }
467   } else {
468     // TextFormat does not work with portable proto implementations.
469     return errors::InvalidArgument(
470         "Tensor attributes are not supported on mobile.");
471   }
472 }
473 
ParseTensorShapeAttrValue(absl::string_view attr_value,std::vector<int64_t> * shape_val)474 Status ParseTensorShapeAttrValue(absl::string_view attr_value,
475                                  std::vector<int64_t>* shape_val) {
476   if (attr_value.size() < 2 || attr_value[0] != '[' ||
477       attr_value[attr_value.size() - 1] != ']') {
478     return errors::InvalidArgument(
479         "Tensor shape attribute must be a string of the form [1,2...], instead "
480         "got \"",
481         attr_value, "\"");
482   }
483   absl::string_view attr_value_trunc =
484       attr_value.substr(1, attr_value.size() - 2);
485   // `container` is an absl::strings_internal::Splitter, which is a
486   // lazy-splitting iterable. So we cannot get its size to reserve `dims`.
487   auto container = absl::StrSplit(attr_value_trunc, ',');
488   for (auto it = container.begin(); it != container.end(); ++it) {
489     int64_t int_val;
490     if (!ParseIntAttrValue(*it, &int_val).ok()) {
491       return errors::InvalidArgument("Failed to parse an integer value from ",
492                                      *it, " while parsing shape.");
493     }
494     shape_val->push_back(int_val);
495   }
496   return OkStatus();
497 }
498 
IsUnusedAttribute(absl::string_view attr_name)499 bool IsUnusedAttribute(absl::string_view attr_name) {
500   // These are extra attributes added by TF MLIR dialect, and not needed by
501   // current TF runtime.
502   //
503   // TODO(chky): Consider removing this attribute in tf-to-tfrt
504   // lowering.
505   return absl::StrContains(attr_name, "result_segment_sizes") ||
506          absl::StrContains(attr_name, "operand_segment_sizes") ||
507          absl::EndsWith(attr_name, "_tf_data_function");
508 }
509 
FillAttrValueMap(const tfrt::OpAttrsRef & attrs,tfrt::HostContext * host,tensorflow::AttrValueMap * attr_value_map)510 llvm::Error FillAttrValueMap(const tfrt::OpAttrsRef& attrs,
511                              tfrt::HostContext* host,
512                              tensorflow::AttrValueMap* attr_value_map) {
513   AttrValue attr_tmp;
514   llvm::Error error = llvm::Error::success();
515   attrs.IterateEntries([&error, attr_value_map, &attr_tmp, host,
516                         &attrs](const OpAttrsRawEntry& entry) {
517     // TFE does not expect a device attribute.
518     assert(strcmp(entry.name, "device") != 0);
519     if (IsUnusedAttribute(entry.name)) {
520       return;
521     } else if (entry.IsArray()) {
522       error = FillAttrValueMapUsingArray(entry, attr_tmp, attrs);
523     } else {
524       error = FillAttrValueMapUsingScalar(entry, attr_tmp, host, attrs);
525     }
526     if (error) return;
527     attr_value_map->insert(AttrValueMap::value_type(entry.name, attr_tmp));
528   });
529   return error;
530 }
531 
532 namespace {
533 
CreateTfTensorFromDenseAttr(tfrt::DenseAttr attr)534 tensorflow::Tensor CreateTfTensorFromDenseAttr(tfrt::DenseAttr attr) {
535   tensorflow::TensorShape shape(absl::InlinedVector<int64_t, 4>(
536       attr.shape().begin(), attr.shape().end()));
537   tensorflow::DataType dtype = ConvertBefAttrTypeToTfDataType(attr.dtype());
538 
539   tensorflow::Tensor tensor(dtype, shape);
540 
541   std::memcpy(tensor.data(), attr.GetElements(), tensor.TotalBytes());
542 
543   return tensor;
544 }
545 
SetUpScalarAttr(tfrt::TypedAttrBase bef_attr,tensorflow::AttrValue * tf_attr)546 Status SetUpScalarAttr(tfrt::TypedAttrBase bef_attr,
547                        tensorflow::AttrValue* tf_attr) {
548   if (auto shape_attr = bef_attr.dyn_cast<tfrt::ShapeAttr>()) {
549     if (shape_attr.HasRank()) {
550       tensorflow::PartialTensorShape tf_shape(shape_attr.GetShape());
551       tf_shape.AsProto(tf_attr->mutable_shape());
552     } else {
553       tensorflow::PartialTensorShape unranked_shape;
554       unranked_shape.AsProto(tf_attr->mutable_shape());
555     }
556   } else if (auto dense_attr = bef_attr.dyn_cast<tfrt::DenseAttr>()) {
557     auto tf_tensor = CreateTfTensorFromDenseAttr(dense_attr);
558     tf_tensor.AsProtoTensorContent(tf_attr->mutable_tensor());
559   } else if (auto type_attr = bef_attr.dyn_cast<tfrt::TypeAttr>()) {
560     tf_attr->set_type(ConvertBefAttrTypeToTfDataType(type_attr.GetValue()));
561   } else if (auto i1_attr = bef_attr.dyn_cast<tfrt::I1Attr>()) {
562     tf_attr->set_b(i1_attr.GetValue());
563   } else if (auto f32_attr = bef_attr.dyn_cast<tfrt::F32Attr>()) {
564     tf_attr->set_f(f32_attr.GetValue());
565   } else if (auto i64_attr = bef_attr.dyn_cast<tfrt::I64Attr>()) {
566     tf_attr->set_i(i64_attr.GetValue());
567   } else if (auto string_attr = bef_attr.dyn_cast<tfrt::StringAttr>()) {
568     tf_attr->set_s(string_attr.GetValue().data(),
569                    string_attr.GetValue().size());
570   } else {
571     return tensorflow::errors::Internal("Failed to set up attribute.");
572   }
573 
574   return OkStatus();
575 }
576 
SetUpScalarFunctionAttr(tfrt::StringAttr func_attr,tensorflow::AttrValue & tf_attr)577 Status SetUpScalarFunctionAttr(tfrt::StringAttr func_attr,
578                                tensorflow::AttrValue& tf_attr) {
579   tfrt::string_view func_name = func_attr.GetValue();
580   tf_attr.mutable_func()->set_name(func_name.data(), func_name.size());
581   return OkStatus();
582 }
583 
AddShapeToAttrList(tfrt::ShapeAttr shape,tensorflow::AttrValue::ListValue * list)584 void AddShapeToAttrList(tfrt::ShapeAttr shape,
585                         tensorflow::AttrValue::ListValue* list) {
586   if (shape.HasRank()) {
587     tensorflow::PartialTensorShape tf_shape(shape.GetShape());
588     tf_shape.AsProto(list->add_shape());
589     return;
590   }
591 
592   tensorflow::PartialTensorShape unranked_shape;
593   unranked_shape.AsProto(list->add_shape());
594 }
AddTensorToAttrList(tfrt::DenseAttr dense_attr,tensorflow::AttrValue::ListValue * list)595 void AddTensorToAttrList(tfrt::DenseAttr dense_attr,
596                          tensorflow::AttrValue::ListValue* list) {
597   auto tf_tensor = CreateTfTensorFromDenseAttr(dense_attr);
598   tf_tensor.AsProtoTensorContent(list->add_tensor());
599 }
600 
SetUpListAttr(tfrt::AggregateAttr aggregate_attr,tensorflow::AttrValue * tf_attr)601 Status SetUpListAttr(tfrt::AggregateAttr aggregate_attr,
602                      tensorflow::AttrValue* tf_attr) {
603   auto* list = tf_attr->mutable_list();
604   for (int i = 0; i < aggregate_attr.GetNumElements(); ++i) {
605     auto base = aggregate_attr.GetAttribute(i);
606     if (auto shape_attr = base.dyn_cast<tfrt::ShapeAttr>()) {
607       AddShapeToAttrList(shape_attr, list);
608     } else if (auto dense_attr = base.dyn_cast<tfrt::DenseAttr>()) {
609       AddTensorToAttrList(dense_attr, list);
610     } else if (auto string_attr = base.dyn_cast<tfrt::StringAttr>()) {
611       list->add_s(string_attr.GetValue().data(), string_attr.GetValue().size());
612     } else {
613       return tensorflow::errors::Internal("Failed to set up list attr.");
614     }
615   }
616   return OkStatus();
617 }
618 
SetUpListAttr(tfrt::ArrayAttr array_attr,tensorflow::AttrValue * tf_attr)619 Status SetUpListAttr(tfrt::ArrayAttr array_attr,
620                      tensorflow::AttrValue* tf_attr) {
621   auto* list = tf_attr->mutable_list();
622 
623   // Handle an empty array case.
624   if (array_attr.GetNumElements() == 0) {
625     return OkStatus();
626   }
627 
628   tfrt::BEFAttributeType element_type = array_attr.GetElementType();
629   if (tfrt::IsDataTypeAttribute(element_type)) {
630     tfrt::DType dtype = GetDataType(element_type);
631     switch (dtype) {
632       case tfrt::DType::I1: {
633         for (auto value : array_attr.GetValue<bool>()) {
634           list->add_b(value);
635         }
636         return OkStatus();
637       }
638       case tfrt::DType::I64: {
639         for (auto value : array_attr.GetValue<int64_t>()) {
640           list->add_i(value);
641         }
642         return OkStatus();
643       }
644       case tfrt::DType::F32: {
645         for (auto value : array_attr.GetValue<float>()) {
646           list->add_f(value);
647         }
648         return OkStatus();
649       }
650       default:
651         return tensorflow::errors::Internal(
652             StrCat("Failed to set up list attr: unsupported dtype: ",
653                    tfrt::DType(dtype)));
654     }
655   } else if (element_type == tfrt::BEFAttributeType::kType) {
656     for (auto value : array_attr.GetValue<tfrt::DType>()) {
657       list->add_type(ConvertBefAttrTypeToTfDataType(value));
658     }
659     return OkStatus();
660   }
661 
662   return tensorflow::errors::Internal("Failed to set up list attr.");
663 }
664 
665 }  // namespace
666 
SetUpAttrValueMap(tfrt::AggregateAttr op_attr_array,tfrt::AggregateAttr op_func_attr_array,tensorflow::AttrValueMap * attr_value_map)667 Status SetUpAttrValueMap(tfrt::AggregateAttr op_attr_array,
668                          tfrt::AggregateAttr op_func_attr_array,
669                          tensorflow::AttrValueMap* attr_value_map) {
670   auto obtain_name_attr_pair =
671       [](tfrt::AggregateAttr attr_array,
672          int i) -> std::pair<std::string, tfrt::TypedAttrBase> {
673     auto pair = attr_array.GetAttributeOfType<tfrt::AggregateAttr>(i);
674     assert(pair.GetNumElements() == 2);
675     return {pair.GetAttributeOfType<tfrt::StringAttr>(0).GetValue().str(),
676             pair.GetAttribute(1)};
677   };
678 
679   for (size_t i = 0, e = op_attr_array.GetNumElements(); i != e; ++i) {
680     auto name_attr_pair = obtain_name_attr_pair(op_attr_array, i);
681     if (IsUnusedAttribute(name_attr_pair.first)) continue;
682 
683     AttrValue& tf_attr = (*attr_value_map)[name_attr_pair.first];
684     tfrt::TypedAttrBase attr_value = name_attr_pair.second;
685     if (auto aggregate_attr = attr_value.dyn_cast<tfrt::AggregateAttr>()) {
686       TF_RETURN_IF_ERROR(SetUpListAttr(aggregate_attr, &tf_attr));
687     } else if (auto array_attr = attr_value.dyn_cast<tfrt::ArrayAttr>()) {
688       TF_RETURN_IF_ERROR(SetUpListAttr(array_attr, &tf_attr));
689     } else {
690       TF_RETURN_IF_ERROR(SetUpScalarAttr(attr_value, &tf_attr));
691     }
692   }
693 
694   for (size_t i = 0, e = op_func_attr_array.GetNumElements(); i != e; ++i) {
695     auto name_attr_pair = obtain_name_attr_pair(op_func_attr_array, i);
696     if (IsUnusedAttribute(name_attr_pair.first)) continue;
697 
698     AttrValue& tf_attr = (*attr_value_map)[name_attr_pair.first];
699     auto attr_value = name_attr_pair.second.dyn_cast<tfrt::StringAttr>();
700     TF_RETURN_IF_ERROR(SetUpScalarFunctionAttr(attr_value, tf_attr));
701   }
702 
703   return OkStatus();
704 }
705 
706 }  // namespace tfd
707 }  // namespace tensorflow
708