1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/ir/importexport/convert_attributes.h"
17
18 #include <string>
19
20 #include "llvm/ADT/StringSet.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "mlir/IR/Builders.h" // from @llvm-project
24 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
25 #include "mlir/IR/Location.h" // from @llvm-project
26 #include "mlir/Support/DebugStringHelper.h" // from @llvm-project
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/core/framework/attr_value.pb.h"
29 #include "tensorflow/core/framework/full_type.pb.h"
30 #include "tensorflow/core/framework/op_def.pb.h"
31 #include "tensorflow/core/ir/dialect.h"
32 #include "tensorflow/core/ir/importexport/convert_tensor.h"
33 #include "tensorflow/core/ir/importexport/convert_types.h"
34 #include "tensorflow/core/ir/importexport/mangling.h"
35 #include "tensorflow/core/ir/types/dialect.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/statusor.h"
38
39 using tensorflow::AttrValue;
40 using tensorflow::AttrValueMap;
41 using tensorflow::DataType;
42 using tensorflow::NodeDef;
43 using tensorflow::Status;
44 using tensorflow::StatusOr;
45 using tensorflow::TensorProto;
46 using tensorflow::TensorShapeProto;
47 using tensorflow::errors::InvalidArgument;
48 using tensorflow::errors::Unimplemented;
49
50 namespace mlir {
51 namespace tfg {
52
53 namespace {
54 // Converts a location to the debug information for the node def.
ConvertLocation(Location inst_loc,NodeDef::ExperimentalDebugInfo * debug_info)55 Status ConvertLocation(Location inst_loc,
56 NodeDef::ExperimentalDebugInfo* debug_info) {
57 if (auto call_site = inst_loc.dyn_cast<CallSiteLoc>()) {
58 if (auto name_loc = call_site.getCallee().dyn_cast<NameLoc>()) {
59 debug_info->add_original_node_names(name_loc.getName().data());
60 }
61 } else if (auto fused = inst_loc.dyn_cast<FusedLoc>()) {
62 auto locations = fused.getLocations();
63 if (locations.size() <= 1)
64 return InvalidArgument("Expected experimental debug info.");
65 // skip the first one, which is the name of the node_def.
66 for (int i = 0, end = locations.size() - 1; i < end; ++i) {
67 TF_RETURN_IF_ERROR(ConvertLocation(locations[i], debug_info));
68 }
69 }
70 return ::tensorflow::OkStatus();
71 }
72
ConvertAttribute(BoolAttr attr,AttrValue * value)73 Status ConvertAttribute(BoolAttr attr, AttrValue* value) {
74 value->set_b(attr.getValue());
75 return ::tensorflow::OkStatus();
76 }
77
ConvertAttribute(IntegerAttr attr,AttrValue * value)78 Status ConvertAttribute(IntegerAttr attr, AttrValue* value) {
79 value->set_i(attr.getInt());
80 return ::tensorflow::OkStatus();
81 }
82
ConvertAttribute(FloatAttr attr,AttrValue * value)83 Status ConvertAttribute(FloatAttr attr, AttrValue* value) {
84 value->set_f(attr.getValueAsDouble());
85 return ::tensorflow::OkStatus();
86 }
87
ConvertAttribute(ElementsAttr attr,AttrValue * value)88 Status ConvertAttribute(ElementsAttr attr, AttrValue* value) {
89 return ConvertToTensorProto(attr, value->mutable_tensor());
90 }
91
ConvertAttribute(PlaceholderAttr attr,AttrValue * value)92 Status ConvertAttribute(PlaceholderAttr attr, AttrValue* value) {
93 value->set_placeholder(attr.getValue().str());
94 return ::tensorflow::OkStatus();
95 }
96
ConvertAttribute(ShapeAttr attr,AttrValue * value)97 Status ConvertAttribute(ShapeAttr attr, AttrValue* value) {
98 SetTensorShapeProto(attr, value->mutable_shape());
99 return ::tensorflow::OkStatus();
100 }
101
ConvertAttribute(FlatSymbolRefAttr attr,AttrValue * value)102 Status ConvertAttribute(FlatSymbolRefAttr attr, AttrValue* value) {
103 value->mutable_func()->set_name(attr.getValue().str());
104 return ::tensorflow::OkStatus();
105 }
106
ConvertAttribute(FuncAttr attr,bool remove_ref_type,AttrValue * value)107 Status ConvertAttribute(FuncAttr attr, bool remove_ref_type, AttrValue* value) {
108 TF_RETURN_IF_ERROR(
109 ConvertAttribute(attr.getName().cast<FlatSymbolRefAttr>(), value));
110 TF_RETURN_IF_ERROR(ConvertAttributes(attr.getAttrs().getValue(),
111 /*attrs_to_ignore=*/{}, remove_ref_type,
112 value->mutable_func()->mutable_attr()));
113 return ::tensorflow::OkStatus();
114 }
115
ConvertAttribute(StringAttr attr,AttrValue * value)116 Status ConvertAttribute(StringAttr attr, AttrValue* value) {
117 value->set_s(attr.str());
118 return ::tensorflow::OkStatus();
119 }
120
ConvertAttribute(Type type,bool remove_ref_type,AttrValue * value)121 Status ConvertAttribute(Type type, bool remove_ref_type, AttrValue* value) {
122 DataType dtype;
123 TF_RETURN_IF_ERROR(ConvertToDataType(type, &dtype));
124 if (tensorflow::IsRefType(dtype)) dtype = tensorflow::RemoveRefType(dtype);
125 value->set_type(dtype);
126 return ::tensorflow::OkStatus();
127 }
128
ConvertAttribute(const TypeAttr & type,bool remove_ref_type,AttrValue * value)129 Status ConvertAttribute(const TypeAttr& type, bool remove_ref_type,
130 AttrValue* value) {
131 return ConvertAttribute(type.getValue(), remove_ref_type, value);
132 }
133
ConvertAttribute(const UnitAttr & attr,AttrValue * value)134 Status ConvertAttribute(const UnitAttr& attr, AttrValue* value) {
135 value->clear_value();
136 return ::tensorflow::OkStatus();
137 }
138
ConvertAttribute(const ArrayAttr & attr,bool remove_ref_type,AttrValue * value)139 Status ConvertAttribute(const ArrayAttr& attr, bool remove_ref_type,
140 AttrValue* value) {
141 auto* list = value->mutable_list();
142 for (Attribute a : attr.getValue()) {
143 if (auto attr = a.dyn_cast<BoolAttr>()) {
144 list->add_b(attr.getValue());
145 } else if (auto attr = a.dyn_cast<IntegerAttr>()) {
146 list->add_i(attr.getInt());
147 } else if (auto attr = a.dyn_cast<FloatAttr>()) {
148 list->add_f(attr.getValueAsDouble());
149 } else if (auto attr = a.dyn_cast<StringAttr>()) {
150 AttrValue nested_value;
151 TF_RETURN_IF_ERROR(ConvertAttribute(attr, &nested_value));
152 switch (nested_value.value_case()) {
153 case AttrValue::kS:
154 list->add_s(nested_value.s());
155 break;
156 case AttrValue::kType:
157 list->add_type(nested_value.type());
158 break;
159 case AttrValue::kShape:
160 *list->add_shape() = nested_value.shape();
161 break;
162 default:
163 return Unimplemented("Unhandled nested attribute!");
164 }
165 } else if (auto attr = a.dyn_cast<ElementsAttr>()) {
166 TensorProto tensor;
167 TF_RETURN_IF_ERROR(ConvertToTensorProto(attr, &tensor));
168 *list->add_tensor() = tensor;
169 } else if (auto attr = a.dyn_cast<FlatSymbolRefAttr>()) {
170 AttrValue attr_val;
171 TF_RETURN_IF_ERROR(ConvertAttribute(attr, &attr_val));
172 *list->add_func() = attr_val.func();
173 } else if (auto attr = a.dyn_cast<FuncAttr>()) {
174 AttrValue attr_val;
175 TF_RETURN_IF_ERROR(ConvertAttribute(attr, remove_ref_type, &attr_val));
176 *list->add_func() = attr_val.func();
177 } else if (auto attr = a.dyn_cast<TypeAttr>()) {
178 AttrValue attr_val;
179 // For type attributes, we only propagate the element type.
180 Type elt_type = attr.getValue();
181 if (auto shaped_type = elt_type.dyn_cast<ShapedType>()) {
182 elt_type = shaped_type.getElementType();
183 }
184 TF_RETURN_IF_ERROR(
185 ConvertAttribute(elt_type, remove_ref_type, &attr_val));
186 list->add_type(attr_val.type());
187 } else if (auto attr = a.dyn_cast<ShapeAttr>()) {
188 AttrValue attr_val;
189 TF_RETURN_IF_ERROR(ConvertAttribute(attr, &attr_val));
190 *list->add_shape() = attr_val.shape();
191 } else {
192 return Unimplemented("Unhandled MLIR attribute in export to graph:",
193 debugString(a));
194 }
195 }
196 return ::tensorflow::OkStatus();
197 }
198 } // namespace
199
ConvertAttribute(Attribute attr)200 StatusOr<AttrValue> ConvertAttribute(Attribute attr) {
201 AttrValue value;
202 if (auto symbol_ref = attr.dyn_cast<SymbolRefAttr>()) {
203 TF_RETURN_IF_ERROR(
204 ConvertAttribute(symbol_ref.cast<FlatSymbolRefAttr>(), &value));
205 return value;
206 }
207 if (auto func_attr = attr.dyn_cast<FuncAttr>()) {
208 TF_RETURN_IF_ERROR(
209 ConvertAttribute(func_attr, /*remove_ref_type=*/false, &value));
210 return value;
211 }
212 if (attr.isa<AffineMapAttr>())
213 return Unimplemented("AffineMap attribute unimplemented");
214 TF_RETURN_IF_ERROR(
215 llvm::TypeSwitch<Attribute, Status>(attr)
216 .Case<BoolAttr, IntegerAttr, FloatAttr, StringAttr, ElementsAttr,
217 UnitAttr, ShapeAttr, PlaceholderAttr>([&](auto derived_attr) {
218 return ConvertAttribute(derived_attr, &value);
219 })
220 .Case<ArrayAttr, TypeAttr>([&](auto derived_attr) {
221 return ConvertAttribute(derived_attr,
222 /*remove_ref_type=*/false, &value);
223 })
224 .Default([&](Attribute attr) {
225 return Unimplemented("Unhandled attribute kind for attribute: ",
226 debugString(attr));
227 }));
228 return value;
229 }
230
ConvertAttributes(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> attrs_to_ignore,bool remove_ref_type,AttrValueMap * values)231 Status ConvertAttributes(ArrayRef<NamedAttribute> attrs,
232 ArrayRef<StringRef> attrs_to_ignore,
233 bool remove_ref_type, AttrValueMap* values) {
234 StringSet<> ignored_attrs;
235 ignored_attrs.insert(attrs_to_ignore.begin(), attrs_to_ignore.end());
236 AttrValueMap func_call_attrs;
237 for (const NamedAttribute& named_attr : attrs) {
238 std::string name_str = named_attr.getName().str();
239 auto attr = named_attr.getValue();
240 absl::string_view name = name_str;
241 if (ignored_attrs.contains(name_str)) {
242 // The name, device spec of a TF op or function are not stored as
243 // AttrValue inside NodeDef, but we model them using attribute inside
244 // MLIR. So we need to ignore them when going back to AttrValue here.
245 continue;
246 }
247 if (mangling_util::IsMangledAttributeName(name)) {
248 // In MLIR, attributes for functions requires dialect prefix. We need to
249 // remove TF dialect prefix before converting to AttrValue.
250 name = mangling_util::DemangleAttributeName(name);
251 }
252 TF_ASSIGN_OR_RETURN(AttrValue value, ConvertAttribute(attr));
253 if (attr.isa<SymbolRefAttr>()) {
254 func_call_attrs[std::string(name)] = value;
255 continue;
256 }
257 if (attr.isa<FuncAttr>()) {
258 func_call_attrs[std::string(name)] = value;
259 continue;
260 }
261 // According to the NodeDef proto definition, an attribute name from the
262 // input TensorFlow GraphDef shouldn't contain '.'. If it does appear in
263 // the attribute from MLIR, it is treated as an attribute from function
264 // calls.
265 std::vector<std::string> name_tokens =
266 absl::StrSplit(name, '.', absl::SkipEmpty());
267 TF_RET_CHECK(name_tokens.size() <= 2);
268 auto it = func_call_attrs.find(name_tokens[0]);
269 if (it == func_call_attrs.end())
270 (*values)[std::string(name)] = value;
271 else
272 (*it->second.mutable_func()->mutable_attr())[name_tokens[1]] = value;
273 }
274 for (const auto& it : func_call_attrs) {
275 (*values)[it.first] = it.second;
276 }
277 return ::tensorflow::OkStatus();
278 }
279
SetShapeAttribute(absl::string_view name,ShapedType shaped_type,AttrValueMap * values)280 Status SetShapeAttribute(absl::string_view name, ShapedType shaped_type,
281 AttrValueMap* values) {
282 AttrValue value;
283 SetTensorShapeProto(shaped_type, value.mutable_shape());
284
285 auto result = values->insert({std::string(name), value});
286 if (!result.second) {
287 // This should be extremely rare as it means we are adding the same
288 // attribute multiple times/have some redundancy in representing this
289 // attribute.
290 TensorShapeProto actual_shape = result.first->second.shape();
291 // Just check via string output as we shouldn't get here and if we do they
292 // should be trivially the same, else fail.
293 std::string new_shape_string = value.shape().ShortDebugString();
294 if (actual_shape.ShortDebugString() != new_shape_string) {
295 return InvalidArgument("Expected ", new_shape_string, " '", name,
296 "' attribute but found ",
297 actual_shape.ShortDebugString());
298 }
299 }
300 return ::tensorflow::OkStatus();
301 }
302
303 // Converts non func AttrValue proto into an MLIR attribute. Func attribute is
304 // exclused in this function because the function might be renamed when the
305 // function definition is imported.
ConvertNonFuncAttributeValue(const AttrValue & value,Builder & builder)306 StatusOr<Attribute> ConvertNonFuncAttributeValue(const AttrValue& value,
307 Builder& builder) {
308 switch (value.value_case()) {
309 case AttrValue::kI:
310 return builder.getI64IntegerAttr(value.i());
311 case AttrValue::kS:
312 return builder.getStringAttr(value.s());
313 case AttrValue::kF:
314 return builder.getFloatAttr(builder.getF32Type(), value.f());
315 case AttrValue::kB:
316 return builder.getBoolAttr(value.b());
317 case AttrValue::kType: {
318 Type type;
319 TF_RETURN_IF_ERROR(ConvertDataType(value.type(), builder, &type));
320 return TypeAttr::get(type);
321 }
322 case AttrValue::kShape:
323 return ConvertTensorShapeProto(value.shape(), builder.getContext());
324 case AttrValue::kTensor:
325 return ConvertTensorProto(value.tensor(), builder);
326 case AttrValue::kList: {
327 absl::InlinedVector<Attribute, 8> attrs;
328 for (const auto& item : value.list().i())
329 attrs.push_back(builder.getI64IntegerAttr(item));
330 for (const auto& item : value.list().s())
331 attrs.push_back(builder.getStringAttr(item));
332 for (const auto& item : value.list().f())
333 attrs.push_back(builder.getFloatAttr(builder.getF32Type(), item));
334 for (const auto& item : value.list().b())
335 attrs.push_back(builder.getBoolAttr(item));
336 for (const auto& item : value.list().type()) {
337 Type type;
338 TF_RETURN_IF_ERROR(ConvertDataType(DataType(item), builder, &type));
339 attrs.push_back(TypeAttr::get(type));
340 }
341 for (const auto& item : value.list().shape()) {
342 TF_ASSIGN_OR_RETURN(
343 auto attr, ConvertTensorShapeProto(item, builder.getContext()));
344 attrs.push_back(attr);
345 }
346 for (const auto& item : value.list().tensor()) {
347 TF_ASSIGN_OR_RETURN(auto attr, ConvertTensorProto(item, builder));
348 attrs.push_back(attr);
349 }
350 for (const auto& func_attr : value.list().func()) {
351 NamedAttrList subattrs;
352 for (const auto& subattr : func_attr.attr()) {
353 TF_ASSIGN_OR_RETURN(auto attr,
354 ConvertAttributeValue(subattr.second, builder));
355 if (subattr.first.empty())
356 return InvalidArgument("empty func_attr name");
357 subattrs.push_back(builder.getNamedAttr(subattr.first, attr));
358 }
359 attrs.push_back(FuncAttr::get(builder.getContext(), func_attr.name(),
360 builder.getDictionaryAttr(subattrs)));
361 }
362 return builder.getArrayAttr(
363 llvm::makeArrayRef(attrs.begin(), attrs.end()));
364 }
365 case AttrValue::VALUE_NOT_SET:
366 return builder.getUnitAttr();
367 case AttrValue::kPlaceholder:
368 return PlaceholderAttr::get(builder.getContext(), value.placeholder());
369 default:
370 return tensorflow::errors::Unimplemented(
371 absl::StrCat("Attribute ", value.DebugString()));
372 }
373 }
374
ConvertAttributeValue(const AttrValue & value,Builder & builder)375 StatusOr<Attribute> ConvertAttributeValue(const AttrValue& value,
376 Builder& builder) {
377 switch (value.value_case()) {
378 case AttrValue::kFunc: {
379 NamedAttrList attrs;
380 for (const auto& func_attr : value.func().attr()) {
381 if (func_attr.first.empty()) return InvalidArgument("empty attr name");
382 TF_ASSIGN_OR_RETURN(auto attr,
383 ConvertAttributeValue(func_attr.second, builder));
384 attrs.push_back(builder.getNamedAttr(func_attr.first, attr));
385 }
386 auto func_attrs = builder.getDictionaryAttr(attrs);
387 return FuncAttr::get(builder.getContext(), value.func().name(),
388 func_attrs);
389 }
390 default:
391 return ConvertNonFuncAttributeValue(value, builder);
392 }
393 }
394
ConvertAttribute(const tensorflow::FullTypeDef & full_type,Builder & builder)395 StatusOr<tf_type::FullTypeAttr> ConvertAttribute(
396 const tensorflow::FullTypeDef& full_type, Builder& builder) {
397 using FullTypeAttr = ::mlir::tf_type::FullTypeAttr;
398
399 SmallVector<FullTypeAttr> args;
400 for (const tensorflow::FullTypeDef& it : full_type.args()) {
401 TF_ASSIGN_OR_RETURN(FullTypeAttr arg, ConvertAttribute(it, builder));
402 args.push_back(arg);
403 }
404
405 Attribute attr;
406 switch (full_type.attr_case()) {
407 case tensorflow::FullTypeDef::AttrCase::kS:
408 attr = builder.getStringAttr(full_type.s());
409 break;
410 case tensorflow::FullTypeDef::AttrCase::kI:
411 attr = builder.getI64IntegerAttr(full_type.i());
412 break;
413 case tensorflow::FullTypeDef::ATTR_NOT_SET:
414 break;
415 default:
416 return InvalidArgument("Unsupported attr kind in FullType");
417 }
418
419 return FullTypeAttr::get(builder.getContext(), full_type.type_id(), args,
420 attr);
421 }
422
ConvertAttribute(tf_type::FullTypeAttr full_type)423 StatusOr<tensorflow::FullTypeDef> ConvertAttribute(
424 tf_type::FullTypeAttr full_type) {
425 using FullTypeDef = tensorflow::FullTypeDef;
426
427 FullTypeDef ret;
428 for (tf_type::FullTypeAttr it : full_type.getArgs()) {
429 TF_ASSIGN_OR_RETURN(*ret.add_args(), ConvertAttribute(it));
430 }
431
432 if (full_type.getAttr()) {
433 bool converted = llvm::TypeSwitch<Attribute, bool>(full_type.getAttr())
434 .Case<StringAttr>([&](StringAttr sattr) {
435 ret.set_s(sattr.str());
436 return true;
437 })
438 .Case<IntegerAttr>([&](IntegerAttr iattr) {
439 ret.set_i(iattr.getInt());
440 return true;
441 })
442 .Default([&](Attribute attr) { return false; });
443 if (!converted)
444 return InvalidArgument("Unsupported attr kind in FullType:",
445 mlir::debugString(full_type.getAttr()));
446 }
447
448 ret.set_type_id(static_cast<tensorflow::FullTypeId>(full_type.getTypeId()));
449
450 return ret;
451 }
452
ConvertHandleData(Builder builder,const tensorflow::protobuf::RepeatedPtrField<tensorflow::ResourceHandleProto_DtypeAndShape> & handle_data)453 StatusOr<ArrayAttr> ConvertHandleData(
454 Builder builder,
455 const tensorflow::protobuf::RepeatedPtrField<
456 tensorflow::ResourceHandleProto_DtypeAndShape>& handle_data) {
457 SmallVector<Attribute> dtype_and_shape;
458 for (const auto& handle : handle_data) {
459 if (handle.dtype() == tensorflow::DT_INVALID)
460 return InvalidArgument("Invalid dtype for handle_data");
461 Type dtype;
462 TF_RETURN_IF_ERROR(ConvertDataType(handle.dtype(), builder, &dtype));
463 TF_ASSIGN_OR_RETURN(
464 ShapeAttr shape,
465 ConvertTensorShapeProto(handle.shape(), builder.getContext()));
466 TensorType handle_type;
467 if (shape.hasRank()) {
468 handle_type = RankedTensorType::get(shape.getShape(), dtype);
469 } else {
470 handle_type = UnrankedTensorType::get(dtype);
471 }
472 dtype_and_shape.push_back(TypeAttr::get(handle_type));
473 }
474 return builder.getArrayAttr(dtype_and_shape);
475 }
476
ConvertHandleData(ArrayAttr handle_data_arr,tensorflow::OpDef::ArgDef * arg)477 Status ConvertHandleData(ArrayAttr handle_data_arr,
478 tensorflow::OpDef::ArgDef* arg) {
479 if (!handle_data_arr) return {};
480 for (auto handle_data_attr : handle_data_arr.getAsRange<TypeAttr>()) {
481 TensorType handle_type = handle_data_attr.getValue().dyn_cast<TensorType>();
482 if (!handle_type) {
483 return InvalidArgument("Expected an array of tensor types, but got ",
484 debugString(handle_data_arr));
485 }
486 auto* handle_data = arg->add_handle_data();
487 if (handle_type.hasRank()) {
488 ConvertToTensorShapeProto(handle_type.getShape(),
489 handle_data->mutable_shape());
490 } else {
491 handle_data->mutable_shape()->set_unknown_rank(true);
492 }
493 DataType dtype;
494 TF_RETURN_IF_ERROR(ConvertToDataType(handle_type.getElementType(), &dtype));
495 handle_data->set_dtype(dtype);
496 }
497 return {};
498 }
499
500 } // namespace tfg
501 } // namespace mlir
502