xref: /aosp_15_r20/external/tensorflow/tensorflow/core/ir/importexport/convert_attributes.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/ir/importexport/convert_attributes.h"
17 
18 #include <string>
19 
20 #include "llvm/ADT/StringSet.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "mlir/IR/Builders.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
25 #include "mlir/IR/Location.h"  // from @llvm-project
26 #include "mlir/Support/DebugStringHelper.h"  // from @llvm-project
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/core/framework/attr_value.pb.h"
29 #include "tensorflow/core/framework/full_type.pb.h"
30 #include "tensorflow/core/framework/op_def.pb.h"
31 #include "tensorflow/core/ir/dialect.h"
32 #include "tensorflow/core/ir/importexport/convert_tensor.h"
33 #include "tensorflow/core/ir/importexport/convert_types.h"
34 #include "tensorflow/core/ir/importexport/mangling.h"
35 #include "tensorflow/core/ir/types/dialect.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/statusor.h"
38 
39 using tensorflow::AttrValue;
40 using tensorflow::AttrValueMap;
41 using tensorflow::DataType;
42 using tensorflow::NodeDef;
43 using tensorflow::Status;
44 using tensorflow::StatusOr;
45 using tensorflow::TensorProto;
46 using tensorflow::TensorShapeProto;
47 using tensorflow::errors::InvalidArgument;
48 using tensorflow::errors::Unimplemented;
49 
50 namespace mlir {
51 namespace tfg {
52 
53 namespace {
54 // Converts a location to the debug information for the node def.
ConvertLocation(Location inst_loc,NodeDef::ExperimentalDebugInfo * debug_info)55 Status ConvertLocation(Location inst_loc,
56                        NodeDef::ExperimentalDebugInfo* debug_info) {
57   if (auto call_site = inst_loc.dyn_cast<CallSiteLoc>()) {
58     if (auto name_loc = call_site.getCallee().dyn_cast<NameLoc>()) {
59       debug_info->add_original_node_names(name_loc.getName().data());
60     }
61   } else if (auto fused = inst_loc.dyn_cast<FusedLoc>()) {
62     auto locations = fused.getLocations();
63     if (locations.size() <= 1)
64       return InvalidArgument("Expected experimental debug info.");
65     // skip the first one, which is the name of the node_def.
66     for (int i = 0, end = locations.size() - 1; i < end; ++i) {
67       TF_RETURN_IF_ERROR(ConvertLocation(locations[i], debug_info));
68     }
69   }
70   return ::tensorflow::OkStatus();
71 }
72 
ConvertAttribute(BoolAttr attr,AttrValue * value)73 Status ConvertAttribute(BoolAttr attr, AttrValue* value) {
74   value->set_b(attr.getValue());
75   return ::tensorflow::OkStatus();
76 }
77 
ConvertAttribute(IntegerAttr attr,AttrValue * value)78 Status ConvertAttribute(IntegerAttr attr, AttrValue* value) {
79   value->set_i(attr.getInt());
80   return ::tensorflow::OkStatus();
81 }
82 
ConvertAttribute(FloatAttr attr,AttrValue * value)83 Status ConvertAttribute(FloatAttr attr, AttrValue* value) {
84   value->set_f(attr.getValueAsDouble());
85   return ::tensorflow::OkStatus();
86 }
87 
ConvertAttribute(ElementsAttr attr,AttrValue * value)88 Status ConvertAttribute(ElementsAttr attr, AttrValue* value) {
89   return ConvertToTensorProto(attr, value->mutable_tensor());
90 }
91 
ConvertAttribute(PlaceholderAttr attr,AttrValue * value)92 Status ConvertAttribute(PlaceholderAttr attr, AttrValue* value) {
93   value->set_placeholder(attr.getValue().str());
94   return ::tensorflow::OkStatus();
95 }
96 
ConvertAttribute(ShapeAttr attr,AttrValue * value)97 Status ConvertAttribute(ShapeAttr attr, AttrValue* value) {
98   SetTensorShapeProto(attr, value->mutable_shape());
99   return ::tensorflow::OkStatus();
100 }
101 
ConvertAttribute(FlatSymbolRefAttr attr,AttrValue * value)102 Status ConvertAttribute(FlatSymbolRefAttr attr, AttrValue* value) {
103   value->mutable_func()->set_name(attr.getValue().str());
104   return ::tensorflow::OkStatus();
105 }
106 
ConvertAttribute(FuncAttr attr,bool remove_ref_type,AttrValue * value)107 Status ConvertAttribute(FuncAttr attr, bool remove_ref_type, AttrValue* value) {
108   TF_RETURN_IF_ERROR(
109       ConvertAttribute(attr.getName().cast<FlatSymbolRefAttr>(), value));
110   TF_RETURN_IF_ERROR(ConvertAttributes(attr.getAttrs().getValue(),
111                                        /*attrs_to_ignore=*/{}, remove_ref_type,
112                                        value->mutable_func()->mutable_attr()));
113   return ::tensorflow::OkStatus();
114 }
115 
ConvertAttribute(StringAttr attr,AttrValue * value)116 Status ConvertAttribute(StringAttr attr, AttrValue* value) {
117   value->set_s(attr.str());
118   return ::tensorflow::OkStatus();
119 }
120 
ConvertAttribute(Type type,bool remove_ref_type,AttrValue * value)121 Status ConvertAttribute(Type type, bool remove_ref_type, AttrValue* value) {
122   DataType dtype;
123   TF_RETURN_IF_ERROR(ConvertToDataType(type, &dtype));
124   if (tensorflow::IsRefType(dtype)) dtype = tensorflow::RemoveRefType(dtype);
125   value->set_type(dtype);
126   return ::tensorflow::OkStatus();
127 }
128 
ConvertAttribute(const TypeAttr & type,bool remove_ref_type,AttrValue * value)129 Status ConvertAttribute(const TypeAttr& type, bool remove_ref_type,
130                         AttrValue* value) {
131   return ConvertAttribute(type.getValue(), remove_ref_type, value);
132 }
133 
ConvertAttribute(const UnitAttr & attr,AttrValue * value)134 Status ConvertAttribute(const UnitAttr& attr, AttrValue* value) {
135   value->clear_value();
136   return ::tensorflow::OkStatus();
137 }
138 
ConvertAttribute(const ArrayAttr & attr,bool remove_ref_type,AttrValue * value)139 Status ConvertAttribute(const ArrayAttr& attr, bool remove_ref_type,
140                         AttrValue* value) {
141   auto* list = value->mutable_list();
142   for (Attribute a : attr.getValue()) {
143     if (auto attr = a.dyn_cast<BoolAttr>()) {
144       list->add_b(attr.getValue());
145     } else if (auto attr = a.dyn_cast<IntegerAttr>()) {
146       list->add_i(attr.getInt());
147     } else if (auto attr = a.dyn_cast<FloatAttr>()) {
148       list->add_f(attr.getValueAsDouble());
149     } else if (auto attr = a.dyn_cast<StringAttr>()) {
150       AttrValue nested_value;
151       TF_RETURN_IF_ERROR(ConvertAttribute(attr, &nested_value));
152       switch (nested_value.value_case()) {
153         case AttrValue::kS:
154           list->add_s(nested_value.s());
155           break;
156         case AttrValue::kType:
157           list->add_type(nested_value.type());
158           break;
159         case AttrValue::kShape:
160           *list->add_shape() = nested_value.shape();
161           break;
162         default:
163           return Unimplemented("Unhandled nested attribute!");
164       }
165     } else if (auto attr = a.dyn_cast<ElementsAttr>()) {
166       TensorProto tensor;
167       TF_RETURN_IF_ERROR(ConvertToTensorProto(attr, &tensor));
168       *list->add_tensor() = tensor;
169     } else if (auto attr = a.dyn_cast<FlatSymbolRefAttr>()) {
170       AttrValue attr_val;
171       TF_RETURN_IF_ERROR(ConvertAttribute(attr, &attr_val));
172       *list->add_func() = attr_val.func();
173     } else if (auto attr = a.dyn_cast<FuncAttr>()) {
174       AttrValue attr_val;
175       TF_RETURN_IF_ERROR(ConvertAttribute(attr, remove_ref_type, &attr_val));
176       *list->add_func() = attr_val.func();
177     } else if (auto attr = a.dyn_cast<TypeAttr>()) {
178       AttrValue attr_val;
179       // For type attributes, we only propagate the element type.
180       Type elt_type = attr.getValue();
181       if (auto shaped_type = elt_type.dyn_cast<ShapedType>()) {
182         elt_type = shaped_type.getElementType();
183       }
184       TF_RETURN_IF_ERROR(
185           ConvertAttribute(elt_type, remove_ref_type, &attr_val));
186       list->add_type(attr_val.type());
187     } else if (auto attr = a.dyn_cast<ShapeAttr>()) {
188       AttrValue attr_val;
189       TF_RETURN_IF_ERROR(ConvertAttribute(attr, &attr_val));
190       *list->add_shape() = attr_val.shape();
191     } else {
192       return Unimplemented("Unhandled MLIR attribute in export to graph:",
193                            debugString(a));
194     }
195   }
196   return ::tensorflow::OkStatus();
197 }
198 }  // namespace
199 
ConvertAttribute(Attribute attr)200 StatusOr<AttrValue> ConvertAttribute(Attribute attr) {
201   AttrValue value;
202   if (auto symbol_ref = attr.dyn_cast<SymbolRefAttr>()) {
203     TF_RETURN_IF_ERROR(
204         ConvertAttribute(symbol_ref.cast<FlatSymbolRefAttr>(), &value));
205     return value;
206   }
207   if (auto func_attr = attr.dyn_cast<FuncAttr>()) {
208     TF_RETURN_IF_ERROR(
209         ConvertAttribute(func_attr, /*remove_ref_type=*/false, &value));
210     return value;
211   }
212   if (attr.isa<AffineMapAttr>())
213     return Unimplemented("AffineMap attribute unimplemented");
214   TF_RETURN_IF_ERROR(
215       llvm::TypeSwitch<Attribute, Status>(attr)
216           .Case<BoolAttr, IntegerAttr, FloatAttr, StringAttr, ElementsAttr,
217                 UnitAttr, ShapeAttr, PlaceholderAttr>([&](auto derived_attr) {
218             return ConvertAttribute(derived_attr, &value);
219           })
220           .Case<ArrayAttr, TypeAttr>([&](auto derived_attr) {
221             return ConvertAttribute(derived_attr,
222                                     /*remove_ref_type=*/false, &value);
223           })
224           .Default([&](Attribute attr) {
225             return Unimplemented("Unhandled attribute kind for attribute: ",
226                                  debugString(attr));
227           }));
228   return value;
229 }
230 
ConvertAttributes(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> attrs_to_ignore,bool remove_ref_type,AttrValueMap * values)231 Status ConvertAttributes(ArrayRef<NamedAttribute> attrs,
232                          ArrayRef<StringRef> attrs_to_ignore,
233                          bool remove_ref_type, AttrValueMap* values) {
234   StringSet<> ignored_attrs;
235   ignored_attrs.insert(attrs_to_ignore.begin(), attrs_to_ignore.end());
236   AttrValueMap func_call_attrs;
237   for (const NamedAttribute& named_attr : attrs) {
238     std::string name_str = named_attr.getName().str();
239     auto attr = named_attr.getValue();
240     absl::string_view name = name_str;
241     if (ignored_attrs.contains(name_str)) {
242       // The name, device spec of a TF op or function are not stored as
243       // AttrValue inside NodeDef, but we model them using attribute inside
244       // MLIR. So we need to ignore them when going back to AttrValue here.
245       continue;
246     }
247     if (mangling_util::IsMangledAttributeName(name)) {
248       // In MLIR, attributes for functions requires dialect prefix. We need to
249       // remove TF dialect prefix before converting to AttrValue.
250       name = mangling_util::DemangleAttributeName(name);
251     }
252     TF_ASSIGN_OR_RETURN(AttrValue value, ConvertAttribute(attr));
253     if (attr.isa<SymbolRefAttr>()) {
254       func_call_attrs[std::string(name)] = value;
255       continue;
256     }
257     if (attr.isa<FuncAttr>()) {
258       func_call_attrs[std::string(name)] = value;
259       continue;
260     }
261     // According to the NodeDef proto definition, an attribute name from the
262     // input TensorFlow GraphDef shouldn't contain '.'. If it does appear in
263     // the attribute from MLIR, it is treated as an attribute from function
264     // calls.
265     std::vector<std::string> name_tokens =
266         absl::StrSplit(name, '.', absl::SkipEmpty());
267     TF_RET_CHECK(name_tokens.size() <= 2);
268     auto it = func_call_attrs.find(name_tokens[0]);
269     if (it == func_call_attrs.end())
270       (*values)[std::string(name)] = value;
271     else
272       (*it->second.mutable_func()->mutable_attr())[name_tokens[1]] = value;
273   }
274   for (const auto& it : func_call_attrs) {
275     (*values)[it.first] = it.second;
276   }
277   return ::tensorflow::OkStatus();
278 }
279 
SetShapeAttribute(absl::string_view name,ShapedType shaped_type,AttrValueMap * values)280 Status SetShapeAttribute(absl::string_view name, ShapedType shaped_type,
281                          AttrValueMap* values) {
282   AttrValue value;
283   SetTensorShapeProto(shaped_type, value.mutable_shape());
284 
285   auto result = values->insert({std::string(name), value});
286   if (!result.second) {
287     // This should be extremely rare as it means we are adding the same
288     // attribute multiple times/have some redundancy in representing this
289     // attribute.
290     TensorShapeProto actual_shape = result.first->second.shape();
291     // Just check via string output as we shouldn't get here and if we do they
292     // should be trivially the same, else fail.
293     std::string new_shape_string = value.shape().ShortDebugString();
294     if (actual_shape.ShortDebugString() != new_shape_string) {
295       return InvalidArgument("Expected ", new_shape_string, " '", name,
296                              "' attribute but found ",
297                              actual_shape.ShortDebugString());
298     }
299   }
300   return ::tensorflow::OkStatus();
301 }
302 
303 // Converts non func AttrValue proto into an MLIR attribute. Func attribute is
304 // exclused in this function because the function might be renamed when the
305 // function definition is imported.
ConvertNonFuncAttributeValue(const AttrValue & value,Builder & builder)306 StatusOr<Attribute> ConvertNonFuncAttributeValue(const AttrValue& value,
307                                                  Builder& builder) {
308   switch (value.value_case()) {
309     case AttrValue::kI:
310       return builder.getI64IntegerAttr(value.i());
311     case AttrValue::kS:
312       return builder.getStringAttr(value.s());
313     case AttrValue::kF:
314       return builder.getFloatAttr(builder.getF32Type(), value.f());
315     case AttrValue::kB:
316       return builder.getBoolAttr(value.b());
317     case AttrValue::kType: {
318       Type type;
319       TF_RETURN_IF_ERROR(ConvertDataType(value.type(), builder, &type));
320       return TypeAttr::get(type);
321     }
322     case AttrValue::kShape:
323       return ConvertTensorShapeProto(value.shape(), builder.getContext());
324     case AttrValue::kTensor:
325       return ConvertTensorProto(value.tensor(), builder);
326     case AttrValue::kList: {
327       absl::InlinedVector<Attribute, 8> attrs;
328       for (const auto& item : value.list().i())
329         attrs.push_back(builder.getI64IntegerAttr(item));
330       for (const auto& item : value.list().s())
331         attrs.push_back(builder.getStringAttr(item));
332       for (const auto& item : value.list().f())
333         attrs.push_back(builder.getFloatAttr(builder.getF32Type(), item));
334       for (const auto& item : value.list().b())
335         attrs.push_back(builder.getBoolAttr(item));
336       for (const auto& item : value.list().type()) {
337         Type type;
338         TF_RETURN_IF_ERROR(ConvertDataType(DataType(item), builder, &type));
339         attrs.push_back(TypeAttr::get(type));
340       }
341       for (const auto& item : value.list().shape()) {
342         TF_ASSIGN_OR_RETURN(
343             auto attr, ConvertTensorShapeProto(item, builder.getContext()));
344         attrs.push_back(attr);
345       }
346       for (const auto& item : value.list().tensor()) {
347         TF_ASSIGN_OR_RETURN(auto attr, ConvertTensorProto(item, builder));
348         attrs.push_back(attr);
349       }
350       for (const auto& func_attr : value.list().func()) {
351         NamedAttrList subattrs;
352         for (const auto& subattr : func_attr.attr()) {
353           TF_ASSIGN_OR_RETURN(auto attr,
354                               ConvertAttributeValue(subattr.second, builder));
355           if (subattr.first.empty())
356             return InvalidArgument("empty func_attr name");
357           subattrs.push_back(builder.getNamedAttr(subattr.first, attr));
358         }
359         attrs.push_back(FuncAttr::get(builder.getContext(), func_attr.name(),
360                                       builder.getDictionaryAttr(subattrs)));
361       }
362       return builder.getArrayAttr(
363           llvm::makeArrayRef(attrs.begin(), attrs.end()));
364     }
365     case AttrValue::VALUE_NOT_SET:
366       return builder.getUnitAttr();
367     case AttrValue::kPlaceholder:
368       return PlaceholderAttr::get(builder.getContext(), value.placeholder());
369     default:
370       return tensorflow::errors::Unimplemented(
371           absl::StrCat("Attribute ", value.DebugString()));
372   }
373 }
374 
ConvertAttributeValue(const AttrValue & value,Builder & builder)375 StatusOr<Attribute> ConvertAttributeValue(const AttrValue& value,
376                                           Builder& builder) {
377   switch (value.value_case()) {
378     case AttrValue::kFunc: {
379       NamedAttrList attrs;
380       for (const auto& func_attr : value.func().attr()) {
381         if (func_attr.first.empty()) return InvalidArgument("empty attr name");
382         TF_ASSIGN_OR_RETURN(auto attr,
383                             ConvertAttributeValue(func_attr.second, builder));
384         attrs.push_back(builder.getNamedAttr(func_attr.first, attr));
385       }
386       auto func_attrs = builder.getDictionaryAttr(attrs);
387       return FuncAttr::get(builder.getContext(), value.func().name(),
388                            func_attrs);
389     }
390     default:
391       return ConvertNonFuncAttributeValue(value, builder);
392   }
393 }
394 
ConvertAttribute(const tensorflow::FullTypeDef & full_type,Builder & builder)395 StatusOr<tf_type::FullTypeAttr> ConvertAttribute(
396     const tensorflow::FullTypeDef& full_type, Builder& builder) {
397   using FullTypeAttr = ::mlir::tf_type::FullTypeAttr;
398 
399   SmallVector<FullTypeAttr> args;
400   for (const tensorflow::FullTypeDef& it : full_type.args()) {
401     TF_ASSIGN_OR_RETURN(FullTypeAttr arg, ConvertAttribute(it, builder));
402     args.push_back(arg);
403   }
404 
405   Attribute attr;
406   switch (full_type.attr_case()) {
407     case tensorflow::FullTypeDef::AttrCase::kS:
408       attr = builder.getStringAttr(full_type.s());
409       break;
410     case tensorflow::FullTypeDef::AttrCase::kI:
411       attr = builder.getI64IntegerAttr(full_type.i());
412       break;
413     case tensorflow::FullTypeDef::ATTR_NOT_SET:
414       break;
415     default:
416       return InvalidArgument("Unsupported attr kind in FullType");
417   }
418 
419   return FullTypeAttr::get(builder.getContext(), full_type.type_id(), args,
420                            attr);
421 }
422 
ConvertAttribute(tf_type::FullTypeAttr full_type)423 StatusOr<tensorflow::FullTypeDef> ConvertAttribute(
424     tf_type::FullTypeAttr full_type) {
425   using FullTypeDef = tensorflow::FullTypeDef;
426 
427   FullTypeDef ret;
428   for (tf_type::FullTypeAttr it : full_type.getArgs()) {
429     TF_ASSIGN_OR_RETURN(*ret.add_args(), ConvertAttribute(it));
430   }
431 
432   if (full_type.getAttr()) {
433     bool converted = llvm::TypeSwitch<Attribute, bool>(full_type.getAttr())
434                          .Case<StringAttr>([&](StringAttr sattr) {
435                            ret.set_s(sattr.str());
436                            return true;
437                          })
438                          .Case<IntegerAttr>([&](IntegerAttr iattr) {
439                            ret.set_i(iattr.getInt());
440                            return true;
441                          })
442                          .Default([&](Attribute attr) { return false; });
443     if (!converted)
444       return InvalidArgument("Unsupported attr kind in FullType:",
445                              mlir::debugString(full_type.getAttr()));
446   }
447 
448   ret.set_type_id(static_cast<tensorflow::FullTypeId>(full_type.getTypeId()));
449 
450   return ret;
451 }
452 
ConvertHandleData(Builder builder,const tensorflow::protobuf::RepeatedPtrField<tensorflow::ResourceHandleProto_DtypeAndShape> & handle_data)453 StatusOr<ArrayAttr> ConvertHandleData(
454     Builder builder,
455     const tensorflow::protobuf::RepeatedPtrField<
456         tensorflow::ResourceHandleProto_DtypeAndShape>& handle_data) {
457   SmallVector<Attribute> dtype_and_shape;
458   for (const auto& handle : handle_data) {
459     if (handle.dtype() == tensorflow::DT_INVALID)
460       return InvalidArgument("Invalid dtype for handle_data");
461     Type dtype;
462     TF_RETURN_IF_ERROR(ConvertDataType(handle.dtype(), builder, &dtype));
463     TF_ASSIGN_OR_RETURN(
464         ShapeAttr shape,
465         ConvertTensorShapeProto(handle.shape(), builder.getContext()));
466     TensorType handle_type;
467     if (shape.hasRank()) {
468       handle_type = RankedTensorType::get(shape.getShape(), dtype);
469     } else {
470       handle_type = UnrankedTensorType::get(dtype);
471     }
472     dtype_and_shape.push_back(TypeAttr::get(handle_type));
473   }
474   return builder.getArrayAttr(dtype_and_shape);
475 }
476 
ConvertHandleData(ArrayAttr handle_data_arr,tensorflow::OpDef::ArgDef * arg)477 Status ConvertHandleData(ArrayAttr handle_data_arr,
478                          tensorflow::OpDef::ArgDef* arg) {
479   if (!handle_data_arr) return {};
480   for (auto handle_data_attr : handle_data_arr.getAsRange<TypeAttr>()) {
481     TensorType handle_type = handle_data_attr.getValue().dyn_cast<TensorType>();
482     if (!handle_type) {
483       return InvalidArgument("Expected an array of tensor types, but got ",
484                              debugString(handle_data_arr));
485     }
486     auto* handle_data = arg->add_handle_data();
487     if (handle_type.hasRank()) {
488       ConvertToTensorShapeProto(handle_type.getShape(),
489                                 handle_data->mutable_shape());
490     } else {
491       handle_data->mutable_shape()->set_unknown_rank(true);
492     }
493     DataType dtype;
494     TF_RETURN_IF_ERROR(ConvertToDataType(handle_type.getElementType(), &dtype));
495     handle_data->set_dtype(dtype);
496   }
497   return {};
498 }
499 
500 }  // namespace tfg
501 }  // namespace mlir
502