xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/full_type_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/full_type_util.h"
17 
18 #include <algorithm>
19 #include <string>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/full_type.pb.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/node_def_util.h"
26 #include "tensorflow/core/framework/op_def.pb.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/platform/errors.h"
29 #include "tensorflow/core/platform/hash.h"
30 #include "tensorflow/core/platform/statusor.h"
31 #include "tensorflow/core/protobuf/error_codes.pb.h"
32 
33 namespace tensorflow {
34 
35 namespace full_type {
36 
NoOp()37 OpTypeConstructor NoOp() {
38   return nullptr;
39 }
40 
NoOutputs()41 OpTypeConstructor NoOutputs() {
42   return [](OpDef* op_def) {
43     op_def->mutable_output_arg();
44     return OkStatus();
45   };
46 }
47 
Nullary(FullTypeId t)48 OpTypeConstructor Nullary(FullTypeId t) {
49   return [t](OpDef* op_def) {
50     FullTypeDef* tdef =
51         op_def->mutable_output_arg(0)->mutable_experimental_full_type();
52     tdef->set_type_id(t);
53     return OkStatus();
54   };
55 }
56 
Unary(FullTypeId t,const string & var_name)57 OpTypeConstructor Unary(FullTypeId t, const string& var_name) {
58   return [t, var_name](OpDef* op_def) {
59     FullTypeDef* tdef =
60         op_def->mutable_output_arg(0)->mutable_experimental_full_type();
61     tdef->set_type_id(t);
62 
63     FullTypeDef* arg = tdef->add_args();
64     arg->set_type_id(TFT_VAR);
65     arg->set_s(var_name);
66 
67     return OkStatus();
68   };
69 }
70 
UnaryGeneric(FullTypeId t)71 OpTypeConstructor UnaryGeneric(FullTypeId t) {
72   return [t](OpDef* op_def) {
73     FullTypeDef* tdef =
74         op_def->mutable_output_arg(0)->mutable_experimental_full_type();
75     tdef->set_type_id(t);
76 
77     FullTypeDef* arg = tdef->add_args();
78     arg->set_type_id(TFT_ANY);
79 
80     return OkStatus();
81   };
82 }
83 
UnaryTensorContainer(FullTypeId t,FullTypeId dtype)84 OpTypeConstructor UnaryTensorContainer(FullTypeId t, FullTypeId dtype) {
85   return [t, dtype](OpDef* op_def) {
86     FullTypeDef* tdef =
87         op_def->mutable_output_arg(0)->mutable_experimental_full_type();
88     tdef->set_type_id(t);
89 
90     FullTypeDef* arg = tdef->add_args();
91     arg->set_type_id(TFT_TENSOR);
92     FullTypeDef* targ = arg->add_args();
93     targ->set_type_id(dtype);
94 
95     return OkStatus();
96   };
97 }
98 
UnaryTensorContainer(FullTypeId t,const string & var_name)99 OpTypeConstructor UnaryTensorContainer(FullTypeId t, const string& var_name) {
100   return [t, var_name](OpDef* op_def) {
101     FullTypeDef* tdef =
102         op_def->mutable_output_arg(0)->mutable_experimental_full_type();
103     tdef->set_type_id(t);
104 
105     FullTypeDef* targ = tdef->add_args();
106     targ->set_type_id(TFT_TENSOR);
107     FullTypeDef* varg = targ->add_args();
108     varg->set_type_id(TFT_VAR);
109     varg->set_s(var_name);
110 
111     return OkStatus();
112   };
113 }
114 
VariadicTensorContainer(FullTypeId t,const string & var_name)115 OpTypeConstructor VariadicTensorContainer(FullTypeId t,
116                                           const string& var_name) {
117   return [t, var_name](OpDef* op_def) {
118     FullTypeDef* tdef =
119         op_def->mutable_output_arg(0)->mutable_experimental_full_type();
120     tdef->set_type_id(t);
121 
122     FullTypeDef* for_each = tdef->add_args();
123     for_each->set_type_id(TFT_FOR_EACH);
124     for_each->add_args()->set_type_id(TFT_PRODUCT);
125 
126     FullTypeDef* tpl = for_each->add_args();
127     tpl->set_type_id(TFT_TENSOR);
128     FullTypeDef* targ = tpl->add_args();
129     targ->set_type_id(TFT_VAR);
130     targ->set_s(var_name);
131 
132     FullTypeDef* tvar = for_each->add_args();
133     tvar->set_type_id(TFT_VAR);
134     tvar->set_s(var_name);
135 
136     return OkStatus();
137   };
138 }
139 
140 namespace {
141 
142 typedef absl::flat_hash_map<StringPiece, const AttrValue*> AttrMap;
143 
144 inline Status SubstituteFromAttrs(AttrMap& attrs, FullTypeDef& t);
145 
SubstituteVar(AttrMap & attrs,FullTypeDef & t)146 Status SubstituteVar(AttrMap& attrs, FullTypeDef& t) {
147   DCHECK_EQ(t.args_size(), 0);
148 
149   StringPiece var_name = t.s();
150   if (!attrs.contains(var_name)) {
151     return Status(
152         error::INVALID_ARGUMENT,
153         absl::StrCat("could not find an attribute for key '", var_name, "'"));
154   }
155   const AttrValue* attr = attrs.at(var_name);
156 
157   const auto attr_type = attr->value_case();
158   if (attr_type == AttrValue::kType) {
159     map_dtype_to_tensor(attr->type(), t);
160   } else if (attr_type == AttrValue::kList) {
161     const auto& attr_list = attr->list();
162     if (attr_list.type_size() != 1) {
163       return Status(error::UNIMPLEMENTED,
164                     absl::StrCat("lists or other than one type element\n",
165                                  attr_list.DebugString(), "\nkey=", var_name));
166     }
167     map_dtype_to_tensor(attr_list.type(0), t);
168   } else {
169     return Status(error::UNIMPLEMENTED,
170                   absl::StrCat("unsupported attribute type ",
171                                attr->DebugString(), " for name ", var_name));
172   }
173   t.clear_s();
174   return OkStatus();
175 }
176 
SubstituteForEach(AttrMap & attrs,FullTypeDef & t)177 Status SubstituteForEach(AttrMap& attrs, FullTypeDef& t) {
178   if (t.args_size() != 3) {
179     return Status(error::INVALID_ARGUMENT,
180                   absl::StrCat("illegal FOR_EACH type, expected 3 args, got ",
181                                t.args_size()));
182   }
183 
184   const auto& cont = t.args(0);
185   const auto& tmpl = t.args(1);
186   const auto& t_var = t.args(2);
187 
188   StringPiece var_name = t_var.s();
189   if (!attrs.contains(var_name)) {
190     return Status(
191         error::INVALID_ARGUMENT,
192         absl::StrCat("could not find an attribute for key '", var_name, "'"));
193   }
194   const AttrValue* attr = attrs.at(var_name);
195 
196   FullTypeDef result;
197   result.set_type_id(cont.type_id());
198 
199   const auto attr_type = attr->value_case();
200   if (attr_type == AttrValue::kType) {
201     FullTypeDef* target = result.add_args();
202     *target = tmpl;
203     TF_RETURN_WITH_CONTEXT_IF_ERROR(
204         SubstituteFromAttrs(attrs, *target), "while substituting '", var_name,
205         "' from\n", attr->DebugString(), "\ninto ", target->DebugString());
206 
207   } else if (attr_type == AttrValue::kList) {
208     const auto& attr_list = attr->list();
209     int tsize = attr_list.type_size();
210     if (tsize == 0) {
211       return Status(error::UNIMPLEMENTED,
212                     absl::StrCat("unsupported list attribute type\n",
213                                  attr_list.DebugString(), "\nkey=", var_name));
214     }
215     AttrValue replacement;
216     attrs[var_name] = &replacement;
217     for (int i = 0; i < tsize; i++) {
218       replacement.set_type(attr_list.type(i));
219       FullTypeDef* target = result.add_args();
220       *target = tmpl;
221       TF_RETURN_WITH_CONTEXT_IF_ERROR(SubstituteFromAttrs(attrs, *target),
222                                       "while substituting '", var_name,
223                                       "' from\n", attr->DebugString(), "\n[", i,
224                                       "] into\n", target->DebugString());
225     }
226     // In case of error, it's ok for the attributes map to remain in an invalid
227     // state.
228     attrs[var_name] = attr;
229 
230   } else {
231     return Status(error::UNIMPLEMENTED,
232                   absl::StrCat("unsupported attribute type\n",
233                                attr->DebugString(), "\nfor name ", var_name));
234   }
235   t = result;
236   return OkStatus();
237 }
238 
SubstituteGeneric(AttrMap & attrs,FullTypeDef & t)239 Status SubstituteGeneric(AttrMap& attrs, FullTypeDef& t) {
240   int nargs = t.args_size();
241   for (int j = 0; j < nargs; j++) {
242     FullTypeDef* arg_t = t.mutable_args(j);
243     TF_RETURN_WITH_CONTEXT_IF_ERROR(SubstituteFromAttrs(attrs, *arg_t),
244                                     "while substituting arg ", j, ": ",
245                                     arg_t->DebugString());
246 
247     // Special case for DT_VARIANT tensors. We leave those unset to avoid even
248     // more special casing downstream.
249     if (arg_t->type_id() == TFT_TENSOR && arg_t->args_size() &&
250         arg_t->args(0).type_id() == TFT_LEGACY_VARIANT) {
251       t.clear_args();
252       break;
253     }
254   }
255   return OkStatus();
256 }
257 
SubstituteFromAttrs(AttrMap & attrs,FullTypeDef & t)258 inline Status SubstituteFromAttrs(AttrMap& attrs, FullTypeDef& t) {
259   // Resolve dependent types. The convention for op registrations is to use
260   // attributes as type variables.
261   // See https://www.tensorflow.org/guide/create_op#type_polymorphism.
262   // Once the op signature can be defined entirely in FullType, this
263   // convention can be deprecated.
264   //
265   // Note: While this code performs some basic verifications, it generally
266   // assumes consistent op defs and attributes. If more complete
267   // verifications are needed, they should be done by separately, and in a
268   // way that can be reused for type inference.
269   switch (t.type_id()) {
270     case TFT_VAR:
271       return SubstituteVar(attrs, t);
272 
273     case TFT_FOR_EACH:
274       return SubstituteForEach(attrs, t);
275 
276     default:
277       return SubstituteGeneric(attrs, t);
278   }
279   return OkStatus();
280 }
281 
282 }  // namespace
283 
SpecializeType(const AttrSlice & attrs,const OpDef & op_def,FullTypeDef & target)284 Status SpecializeType(const AttrSlice& attrs, const OpDef& op_def,
285                       FullTypeDef& target) {
286   target.Clear();
287   target.set_type_id(TFT_PRODUCT);
288 
289   AttrMap map;
290   for (const auto& attr : attrs) {
291     map.emplace(attr.first, &attr.second);
292   }
293 
294   int nargs = op_def.output_arg_size();
295   for (int i = 0; i < nargs; i++) {
296     auto& t = *(target.add_args());
297     t = op_def.output_arg(i).experimental_full_type();
298     TF_RETURN_WITH_CONTEXT_IF_ERROR(
299         SubstituteFromAttrs(map, t), "while expanding vars of\n",
300         t.DebugString(), "\nfrom\n", attrs.SummarizeNode());
301   }
302 
303   return OkStatus();
304 }
305 
GetArgDefaultUnset(const FullTypeDef & t,int i)306 const FullTypeDef& GetArgDefaultUnset(const FullTypeDef& t, int i) {
307   static FullTypeDef* unset_type = []() {
308     FullTypeDef* t = new FullTypeDef();
309     return t;
310   }();
311 
312   if (i < t.args_size()) {
313     return t.args(i);
314   }
315   return *unset_type;
316 }
317 
GetArgDefaultAny(const FullTypeDef & t,int i)318 const FullTypeDef& GetArgDefaultAny(const FullTypeDef& t, int i) {
319   static FullTypeDef* any_type = []() {
320     FullTypeDef* t = new FullTypeDef();
321     t->set_type_id(TFT_ANY);
322     return t;
323   }();
324 
325   if (i < t.args_size()) {
326     const FullTypeDef& f_val = t.args(i);
327     if (f_val.type_id() == TFT_UNSET) {
328       return *any_type;
329     }
330     return f_val;
331   }
332   return *any_type;
333 }
334 
IsEqual(const FullTypeDef & lhs,const FullTypeDef & rhs)335 bool IsEqual(const FullTypeDef& lhs, const FullTypeDef& rhs) {
336   if (lhs.type_id() != rhs.type_id()) {
337     return false;
338   }
339   const auto& lhs_s = lhs.s();
340   const auto& rhs_s = rhs.s();
341   if (lhs_s.empty()) {
342     if (!rhs_s.empty()) {
343       return false;
344     }
345   } else if (rhs_s != lhs_s) {
346     return false;
347   }
348   for (int i = 0; i < std::max(lhs.args_size(), rhs.args_size()); i++) {
349     const FullTypeDef& lhs_arg = GetArgDefaultAny(lhs, i);
350     const FullTypeDef& rhs_arg = GetArgDefaultAny(rhs, i);
351 
352     if (!IsEqual(lhs_arg, rhs_arg)) {
353       return false;
354     }
355   }
356   return true;
357 }
358 
Hash(const FullTypeDef & arg)359 uint64_t Hash(const FullTypeDef& arg) {
360   // Following style of IsEqual above and walking across FullTypeDef.
361   uint64_t val = Hash64Combine(arg.type_id(), 0);
362 
363   const auto& arg_s = arg.s();
364   val = Hash64Combine(val, Hash64(arg_s));
365   for (int i = 0, e = arg.args_size(); i < e; ++i) {
366     const FullTypeDef& arg_arg = GetArgDefaultAny(arg, i);
367     val = Hash64Combine(val, Hash(arg_arg));
368   }
369 
370   return val;
371 }
372 
IsSubtype(const FullTypeDef & lhs,const FullTypeDef & rhs,bool covariant)373 bool IsSubtype(const FullTypeDef& lhs, const FullTypeDef& rhs, bool covariant) {
374   // Rule: ANY is a supertype of all types.
375   if (rhs.type_id() == TFT_ANY) {
376     return true;
377   }
378   // Compatibility rule: UNSET is treated as ANY for the purpose of subtyping.
379   if (rhs.type_id() == TFT_UNSET) {
380     return true;
381   }
382   // Compatibility rule: TENSOR[LEGACY_VARIANT] is treated as ANY for the
383   // purpose of subtyping.
384   if ((rhs.type_id() == TFT_TENSOR) &&
385       (GetArgDefaultUnset(rhs, 0).type_id() == TFT_LEGACY_VARIANT)) {
386     return true;
387   }
388   // Rule: encodings are subtypes of the encoding type.
389   if (lhs.type_id() == TFT_ENCODED) {
390     return IsSubtype(GetArgDefaultAny(lhs, 1), rhs, true);
391   }
392 
393   // Default rule: type IDs must match.
394   if (lhs.type_id() != rhs.type_id()) {
395     return false;
396   }
397 
398   // Arguments must be subtypes of one another.
399   for (int i = 0; i < std::max(lhs.args_size(), rhs.args_size()); i++) {
400     const FullTypeDef& lhs_arg = GetArgDefaultAny(lhs, i);
401     const FullTypeDef& rhs_arg = GetArgDefaultAny(rhs, i);
402 
403     if (covariant) {
404       if (!IsSubtype(lhs_arg, rhs_arg)) {
405         return false;
406       }
407     } else {
408       if (!IsSubtype(rhs_arg, lhs_arg)) {
409         return false;
410       }
411     }
412   }
413 
414   // Invariant: type IDs are equal, and all args are subtype of one another.
415   return true;
416 }
417 
418 }  // namespace full_type
419 
420 }  // namespace tensorflow
421