xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/type.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/Dict.h>
2 #include <ATen/core/Tensor.h>
3 #include <ATen/core/dynamic_type.h>
4 #include <ATen/core/enum_type.h>
5 #include <ATen/core/function.h>
6 #include <ATen/core/function_schema.h>
7 #include <ATen/core/grad_mode.h>
8 #include <ATen/core/jit_type.h>
9 #include <c10/macros/Macros.h>
10 #include <c10/util/flat_hash_map.h>
11 #include <c10/util/irange.h>
12 #include <array>
13 #include <iostream>
14 #include <utility>
15 
16 namespace std {
17 template<>
18 struct hash<std::tuple<std::string, c10::TypePtr, c10::TypePtr>> {
operator ()std::hash19   size_t operator()(std::tuple<std::string, c10::TypePtr, c10::TypePtr> const& t) const {
20     // This hashing is all hidden behind a static initializer so it
21     // doesn't have to be optimal
22     auto hash = std::hash<std::string>()(std::get<0>(t));
23     hash = at::hash_combine(hash, std::hash<c10::TypePtr>()(std::get<1>(t)));
24     hash = at::hash_combine(hash, std::hash<c10::TypePtr>()(std::get<2>(t)));
25     return hash;
26   }
27 };
28 template<>
29 struct hash<std::tuple<std::string, c10::TypePtr>> {
operator ()std::hash30   size_t operator()(std::tuple<std::string, c10::TypePtr> const& t) const {
31     auto hash = std::hash<std::string>()(std::get<0>(t));
32     hash = at::hash_combine(hash, std::hash<c10::TypePtr>()(std::get<1>(t)));
33     return hash;
34   }
35 };
36 } // namespace std
37 
38 namespace c10 {
39 
40 static_assert(
41     sizeof(SingletonOrSharedTypePtr<void>) == sizeof(std::shared_ptr<void>) && sizeof(std::shared_ptr<void>) == 2 * sizeof(void*),
42     "std::shared_ptr has an unexpected representation on this platform!");
43 static_assert(
44     std::is_same_v<decltype(getTypePtr<std::tuple<int64_t, int64_t>>()), const TupleTypePtr&>,
45     "getTypePtr<std::tuple<int64_t, int64_t>> not returning const ref!");
46 
type_verbosity()47 TypeVerbosity type_verbosity() {
48   static const char* c_verbosity = std::getenv("PYTORCH_JIT_TYPE_VERBOSITY");
49   static TypeVerbosity verbosity = c_verbosity ?
50     static_cast<TypeVerbosity>(std::stoi(c_verbosity)) : TypeVerbosity::Default;
51   return verbosity;
52 }
53 
operator <<(std::ostream & out,const Type & t)54 std::ostream& operator<<(std::ostream & out, const Type & t) {
55   if (auto value = t.cast<TensorType>()) {
56     if  (value->scalarType().has_value()) {
57       out << toString(*value->scalarType());
58       if (!value->sizes().size().has_value()) {
59         out << "Tensor";
60       }
61     } else {
62       out << "Tensor";
63     }
64     if (auto ndim = value->sizes().size()) {
65       bool has_valid_strides_info = *ndim > 0 &&
66           value->strides().isComplete() && value->strides().size() == ndim;
67 
68       out << "(";
69       size_t i = 0;
70       bool symbolic = type_verbosity() == TypeVerbosity::Symbolic;
71       for (i = 0; i < *ndim; ++i) {
72         if (i > 0) {
73           out << ", ";
74         }
75         if (auto s = value->sizes()[i]) {
76           out << *s;
77         } else if (symbolic) {
78           out << value->symbolic_sizes().at(i);
79         } else {
80           out << "*";
81         }
82       }
83       if (has_valid_strides_info &&
84           type_verbosity() >= TypeVerbosity::TypeAndStride) {
85         out << ", strides=[";
86         for (size_t i = 0; i < *ndim; ++i) {
87           if (i > 0) {
88             out << ", ";
89           }
90           out << *value->strides()[i];
91         }
92         out << "]";
93       }
94       if (type_verbosity() >= TypeVerbosity::Full) {
95         if (value->requiresGrad()) {
96           if (i++ > 0) {
97             out << ", ";
98           }
99           out << "requires_grad=" << *value->requiresGrad();
100         }
101         if (value->device()) {
102           if (i++ > 0) {
103             out << ", ";
104           }
105           out << "device=" << *value->device();
106         }
107       }
108       out << ")";
109     } else {
110       if (type_verbosity() >= TypeVerbosity::Full) {
111         size_t i = 0;
112         if (value->requiresGrad()) {
113           out << "("
114               << "requires_grad=" << *value->requiresGrad();
115           i++;
116         }
117         if (value->device()) {
118           out << ((i++ > 0) ? ", " : "(") << "device=" << *value->device();
119         }
120         if (i > 0) {
121           out << ")";
122         }
123       }
124     }
125 
126     if (value->undefined() && *value->undefined()) {
127       out << "[Undefined]";
128     }
129   } else if(t.kind() == TypeKind::ListType) {
130     auto prim = t.castRaw<ListType>()->getElementType();
131     out << *prim << "[]";
132   } else if (t.kind() == TypeKind::OptionalType) {
133     auto prim = t.castRaw<OptionalType>()->getElementType();
134     out << *prim << "?";
135   } else if(t.kind() == TypeKind::FutureType) {
136     auto elem = t.castRaw<FutureType>()->getElementType();
137     out << "Future[" << *elem << "]";
138   } else if(t.kind() == TypeKind::RRefType) {
139     auto elem = t.castRaw<RRefType>()->getElementType();
140     out << "RRef[" << *elem << "]";
141   } else if(auto tup = t.cast<TupleType>()) {
142     if (tup->schema()) {
143       out << "NamedTuple";
144     }
145     out << "(";
146     for(size_t i = 0; i < tup->elements().size(); ++i) {
147       if(i > 0)
148         out << ", ";
149       if (tup->schema()) {
150         auto arg = tup->schema()->arguments()[i];
151         out << arg.name() << " : ";
152         out << *(tup->elements()[i]);
153         if (arg.default_value()) {
154           out << " = " << *arg.default_value();
155         }
156       }
157       else {
158         out << *(tup->elements()[i]);
159       }
160     }
161     out << ")";
162   } else if (t.kind() == TypeKind::FunctionType) {
163     out << "Function";
164   } else {
165      out << t.str();
166   }
167   return out;
168 }
169 
get()170 AnyTypePtr AnyType::get() {
171   static AnyTypePtr value(new AnyType());
172   return value;
173 }
174 
get()175 NumberTypePtr NumberType::get() {
176   static NumberTypePtr value(new NumberType());
177   return value;
178 }
get()179 IntTypePtr IntType::get() {
180   static IntTypePtr value(new IntType());
181   return value;
182 }
get()183 FloatTypePtr FloatType::get() {
184   static FloatTypePtr value(new FloatType());
185   return value;
186 }
get()187 ComplexTypePtr ComplexType::get() {
188   static ComplexTypePtr value(new ComplexType());
189   return value;
190 }
get()191 BoolTypePtr BoolType::get() {
192   static BoolTypePtr value(new BoolType());
193   return value;
194 }
get()195 StorageTypePtr StorageType::get() {
196   static StorageTypePtr value(new StorageType());
197   return value;
198 }
get()199 NoneTypePtr NoneType::get() {
200   static NoneTypePtr value(new NoneType());
201   return value;
202 }
get()203 GeneratorTypePtr GeneratorType::get() {
204   static GeneratorTypePtr value(new GeneratorType());
205   return value;
206 }
get()207 QuantizerTypePtr QuantizerType::get() {
208   static QuantizerTypePtr value(new QuantizerType());
209   return value;
210 }
get()211 QSchemeTypePtr QSchemeType::get() {
212   static QSchemeTypePtr value(new QSchemeType());
213   return value;
214 }
get()215 StringTypePtr StringType::get() {
216   static StringTypePtr value(new StringType());
217   return value;
218 }
get()219 DeviceObjTypePtr DeviceObjType::get() {
220   static DeviceObjTypePtr value(new DeviceObjType());
221   return value;
222 }
get()223 StreamObjTypePtr StreamObjType::get() {
224   static StreamObjTypePtr value(new StreamObjType());
225   return value;
226 }
get()227 ScalarTypeTypePtr ScalarTypeType::get() {
228 static ScalarTypeTypePtr value(new ScalarTypeType());
229 return value;
230 }
get()231 LayoutTypePtr LayoutType::get() {
232 static LayoutTypePtr value(new LayoutType());
233 return value;
234 }
get()235 MemoryFormatTypePtr MemoryFormatType::get() {
236 static MemoryFormatTypePtr value(new MemoryFormatType());
237 return value;
238 }
get()239 PyObjectTypePtr PyObjectType::get() {
240   static PyObjectTypePtr value(new PyObjectType());
241   return value;
242 }
get()243 CapsuleTypePtr CapsuleType::get() {
244   static CapsuleTypePtr value(new CapsuleType());
245   return value;
246 }
ofInts()247 ListTypePtr ListType::ofInts() {
248   static auto value = ListType::create(IntType::get());
249   return value;
250 }
ofSymInts()251 ListTypePtr ListType::ofSymInts() {
252   static auto value = ListType::create(SymIntType::get());
253   return value;
254 }
ofComplexDoubles()255 ListTypePtr ListType::ofComplexDoubles() {
256   static auto value = ListType::create(ComplexType::get());
257   return value;
258 }
ofFloats()259 ListTypePtr ListType::ofFloats() {
260   static auto value = ListType::create(FloatType::get());
261   return value;
262 }
ofBools()263 ListTypePtr ListType::ofBools() {
264   static auto value = ListType::create(BoolType::get());
265   return value;
266 }
ofStrings()267 ListTypePtr ListType::ofStrings() {
268   static auto value = ListType::create(StringType::get());
269   return value;
270 }
ofNumbers()271 ListTypePtr ListType::ofNumbers() {
272   static auto value = ListType::create(NumberType::get());
273   return value;
274 }
275 
get(TypePtr inner)276 TypePtr OptionalType::get(TypePtr inner) {
277   static ska::flat_hash_map<TypePtr, TypePtr> containerTypePtrs;
278   static std::mutex mutex;
279   // Perf from the lock is ok because this function is guarded behind
280   // a static initializer; it should only be called once per type.
281   std::lock_guard<std::mutex> lock(mutex);
282   if (containerTypePtrs.find(inner) == containerTypePtrs.end()) {
283     TypePtr t = TypeFactory::create<OptionalType>(inner);
284     containerTypePtrs.emplace(inner, std::move(t));
285   }
286   return containerTypePtrs[inner];
287 }
288 
get(const std::string & identifier,TypePtr inner)289 TypePtr ListType::get(const std::string& identifier, TypePtr inner) {
290   static ska::flat_hash_map<std::tuple<std::string, TypePtr>, TypePtr> containerTypePtrs;
291   static std::mutex mutex;
292   // Perf from the lock is ok because this function is guarded behind
293   // a static initializer; it should only be called once per type.
294   auto key = std::make_tuple(identifier, inner);
295   std::lock_guard<std::mutex> lock(mutex);
296   if (containerTypePtrs.find(key) == containerTypePtrs.end()) {
297     TypePtr t = ListType::create(inner);
298     containerTypePtrs.emplace(key, std::move(t));
299   }
300   return containerTypePtrs[key];
301 }
302 
get(const std::string & identifier,TypePtr key,TypePtr value)303 TypePtr DictType::get(const std::string& identifier, TypePtr key, TypePtr value) {
304   static ska::flat_hash_map<std::tuple<std::string, TypePtr, TypePtr>, TypePtr> containerTypePtrs;
305   static std::mutex mutex;
306   // Perf from the lock is ok because this function is guarded behind
307   // a static initializer; it should only be called once per type.
308   auto map_key = std::make_tuple(identifier, key, value);
309   std::lock_guard<std::mutex> lock(mutex);
310   if (containerTypePtrs.find(map_key) == containerTypePtrs.end()) {
311     TypePtr t = DictType::create(std::move(key), std::move(value));
312     containerTypePtrs.emplace(map_key, std::move(t));
313   }
314   return containerTypePtrs[map_key];
315 }
316 
annotation_str_impl(const TypePrinter & printer) const317 std::string DictType::annotation_str_impl(const TypePrinter& printer) const {
318   auto keyAnnotation = getKeyType()->annotation_str(printer);
319   auto valueAnnotation = getValueType()->annotation_str(printer);
320 
321   std::string result;
322   result.reserve(5 /* "Dict[" */ + keyAnnotation.size() + 2 /* ", " */ + valueAnnotation.size() + 1 /* "]" */);
323   result = "Dict[";
324   result += keyAnnotation;
325   result.push_back(',');
326   result.push_back(' ');
327   result += valueAnnotation;
328   result.push_back(']');
329   return result;
330 }
331 
get()332 AnyListTypePtr AnyListType::get() {
333   static AnyListTypePtr value(new AnyListType());
334   return value;
335 }
336 
get()337 AnyTupleTypePtr AnyTupleType::get() {
338   static AnyTupleTypePtr value(new AnyTupleType());
339   return value;
340 }
341 
get()342 AnyClassTypePtr AnyClassType::get() {
343   static AnyClassTypePtr value(new AnyClassType());
344   return value;
345 }
346 
get()347 AnyEnumTypePtr AnyEnumType::get() {
348   static AnyEnumTypePtr value(new AnyEnumType());
349   return value;
350 }
351 
get()352 SymIntTypePtr SymIntType::get() {
353   static SymIntTypePtr value(new SymIntType());
354   return value;
355 }
356 
get()357 SymFloatTypePtr SymFloatType::get() {
358   static SymFloatTypePtr value(new SymFloatType());
359   return value;
360 }
361 
get()362 SymBoolTypePtr SymBoolType::get() {
363   static SymBoolTypePtr value(new SymBoolType());
364   return value;
365 }
366 
unifyTypesImpl(const TypePtr & t1,const TypePtr & t2,bool default_to_union=false,const TypePtr & type_hint=nullptr)367 static std::optional<TypePtr> unifyTypesImpl(const TypePtr& t1, const TypePtr& t2, bool default_to_union=false, const TypePtr& type_hint=nullptr) {
368   // check direct subtyping relation
369   if (t1->isSubtypeOf(*t2)) {
370     return t2;
371   } else if (t2->isSubtypeOf(*t1)) {
372     return t1;
373   }
374 
375   // Handle non-container types which do not subtype each other and unify
376   if (t1->kind() == TensorType::Kind && t2->kind() == TensorType::Kind) {
377     return t1->expectRef<TensorType>().merge(t2->expectRef<TensorType>());
378   }
379 
380   if (t1->isSubtypeOf(*NoneType::get()) && !t2->isSubtypeOf(*NoneType::get())) {
381     return OptionalType::create(t2);
382   } else if (t2->isSubtypeOf(*NoneType::get()) && !t1->isSubtypeOf(*NoneType::get())) {
383     return OptionalType::create(t1);
384   }
385 
386   // NB: we do not return NumberType because there is not currently enough
387   // operator support for it
388 
389   // Attempt to unify Complete Tensor Types for immutable type containers
390 
391   // unify(Optional[t1], t2) => Optional[unify(t1, t2)]
392   if (auto opt_t1 = t1->cast<OptionalType>()) {
393     if (auto elem = unifyTypes(opt_t1->getElementType(), t2)) {
394       return OptionalType::create(*std::move(elem));
395     }
396   } else if (auto opt_t2 = t2->cast<OptionalType>()) {
397     if (auto elem = unifyTypes(opt_t2->getElementType(), t1)) {
398       return OptionalType::create(*std::move(elem));
399     }
400   }
401 
402   if (t1->castRaw<TupleType>() && t2->castRaw<TupleType>()) {
403     auto tuple1 = t1->castRaw<TupleType>();
404     auto tuple2 = t2->castRaw<TupleType>();
405     if (tuple1->elements().size() != tuple2->elements().size()) {
406       return std::nullopt;
407     }
408     std::vector<TypePtr> elements;
409     for (size_t i = 0; i < tuple1->elements().size(); i++) {
410       if (auto elem = unifyTypes(tuple1->elements().at(i), tuple2->elements().at(i), default_to_union)) {
411         elements.push_back(*std::move(elem));
412       } else {
413         return std::nullopt;
414       }
415     }
416     return static_cast<TypePtr>(TupleType::create(std::move(elements)));
417   }
418 
419   if (t1->castRaw<FutureType>() && t2->castRaw<FutureType>()) {
420     if (auto elem = unifyTypes(
421             t1->castRaw<FutureType>()->getElementType(),
422             t2->castRaw<FutureType>()->getElementType())) {
423       return FutureType::create(*elem);
424     }
425   }
426 
427   // Check direct subtyping relations again with Unshaped Types,
428   // to handle unification of mutable container types which might contain two different
429   // specialized tensors (ListType / DictType)
430   auto t1_unshaped = unshapedType(t1);
431   auto t2_unshaped = unshapedType(t2);
432 
433   if (t1_unshaped->isSubtypeOf(*t2_unshaped)) {
434     return t2_unshaped;
435   } else if (t2_unshaped->isSubtypeOf(*t1_unshaped)) {
436     return t1_unshaped;
437   }
438 
439   // Check whether or not `type_hint` is a common parent. This case
440   // could occur if we had two class types that had been annotated with
441   // a common interface
442   if (type_hint && t1->isSubtypeOf(*type_hint) && t2->isSubtypeOf(*type_hint)) {
443     return type_hint;
444   }
445 
446   return std::nullopt;
447 }
448 
unifyTypes(const TypePtr & t1,const TypePtr & t2,bool default_to_union,const TypePtr & type_hint)449 std::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2, bool default_to_union, const TypePtr& type_hint) {
450   auto unified = unifyTypesImpl(t1, t2, default_to_union, type_hint);
451 
452   if (default_to_union && !unified) {
453     return UnionType::create({t1, t2});
454   }
455 
456   return unified;
457 }
458 
unifyTypeList(at::ArrayRef<TypePtr> elements,std::ostream & why_not,bool default_to_union,const TypePtr & type_hint)459 std::optional<TypePtr> unifyTypeList(
460     at::ArrayRef<TypePtr> elements,
461     std::ostream& why_not,
462     bool default_to_union,
463     const TypePtr& type_hint) {
464   if (elements.empty()) {
465     why_not << "Cannot get unified type from empty list";
466     return std::nullopt;
467   }
468 
469   TypePtr ret_type = elements.at(0);
470   for (size_t i = 1; i < elements.size() && ret_type; ++i) {
471     std::optional<TypePtr> maybe_unified = unifyTypes(ret_type, elements.at(i), default_to_union, type_hint);
472     if (!maybe_unified) {
473       why_not << "Could not unify type list since element " << i << " of type "
474               << elements.at(i)->repr_str()
475               << " did not match the types before it ("
476               << ret_type->repr_str() << ")";
477       return std::nullopt;
478     }
479     ret_type = *maybe_unified;
480   }
481 
482   return ret_type;
483 }
484 
485 // NOTE: This function actually does need to take const TypePtr&
486 // because it sometimes calls unifyTypes, which needs const TypePtr&.
matchTypeVariables(const TypePtr & formal,const TypePtr & actual,TypeEnv & type_env)487 MatchTypeReturn matchTypeVariables(
488     const TypePtr& formal,
489     const TypePtr& actual,
490     TypeEnv& type_env) {
491   if (!formal->hasFreeVariables()) {
492     if (auto dyn = formal->castRaw<c10::DynamicType>()) {
493       return matchTypeVariables(dyn->fallback(), actual, type_env);
494     }
495     return MatchTypeReturn::Success();
496   }
497 
498   if (auto vt = formal->castRaw<VarType>()) {
499     auto it = type_env.find(vt->name());
500     if (it == type_env.end()) {
501       type_env[vt->name()] = actual;
502       return MatchTypeReturn::Success();
503     } else if (unifyTypes(it->second, actual)) {
504       // note: unifyTypes allows subtyping in either direction, so actual
505       // may be a supertype of the current binding. we're not responsible
506       // for reporting the error, only for keeping type_env stable
507       return MatchTypeReturn::Success();
508     }
509     std::stringstream ss;
510     ss << "Type variable '" << vt->name() << "' previously matched to type "
511        << it->second->repr_str() << " is matched to type "
512        << actual->repr_str();
513     return ss.str();
514   } else if (auto lt_formal = formal->castRaw<ListType>()) {
515     if (auto lt_actual = actual->castRaw<ListType>()) {
516       auto innerMatch = matchTypeVariables(
517           lt_formal->getElementType(), lt_actual->getElementType(), type_env);
518       if (!innerMatch.success()) {
519         // propagate the errMsg onward
520         return innerMatch;
521       }
522       return MatchTypeReturn::Success();
523     } else if (auto tup_type = actual->castRaw<TupleType>()) {
524       std::stringstream ss;
525       auto maybe_tuple_unified = unifyTypeList(tup_type->elements(), ss);
526       if (maybe_tuple_unified) {
527         return matchTypeVariables(
528             lt_formal->getElementType(), *maybe_tuple_unified, type_env);
529       }
530     }
531 
532     std::stringstream ss;
533     ss << "Cannot match " << lt_formal->repr_str() << " to "
534        << actual->repr_str();
535     return ss.str();
536   } else if (auto tp_formal = formal->castRaw<TupleType>()) {
537     if (auto tp_actual = actual->castRaw<TupleType>()) {
538       if (tp_formal->elements().size() != tp_actual->elements().size()) {
539         return MatchTypeReturn("Cannot match tuples of mismatched size");
540       }
541       for (size_t i = 0; i < tp_formal->elements().size(); ++i) {
542         auto result = matchTypeVariables(
543             tp_formal->elements()[i], tp_actual->elements()[i], type_env);
544         if (!result.success()) {
545           return result;
546         }
547       }
548       return MatchTypeReturn::Success();
549     } else {
550       std::stringstream ss;
551       ss << "Cannot match a tuple to " << actual->repr_str();
552       return MatchTypeReturn(ss.str());
553     }
554   } else if (auto lt_formal = formal->castRaw<FutureType>()) {
555     if (auto lt_actual = actual->castRaw<FutureType>()) {
556       auto innerMatch = matchTypeVariables(
557           lt_formal->getElementType(), lt_actual->getElementType(), type_env);
558       if (!innerMatch.success()) {
559         return innerMatch;
560       }
561       return MatchTypeReturn::Success();
562     } else {
563       std::stringstream ss;
564       ss << "Cannot match a future to " << actual->repr_str();
565       return ss.str();
566     }
567   } else if (auto lt_formal = formal->castRaw<AwaitType>()) {
568     if (auto lt_actual = actual->castRaw<AwaitType>()) {
569       auto innerMatch = matchTypeVariables(
570           lt_formal->getElementType(), lt_actual->getElementType(), type_env);
571       if (!innerMatch.success()) {
572         return innerMatch;
573       }
574       return MatchTypeReturn::Success();
575     } else {
576       std::stringstream ss;
577       ss << "Cannot match an await to " << actual->repr_str();
578       return ss.str();
579     }
580   } else if (auto lt_formal = formal->castRaw<RRefType>()) {
581     if (auto lt_actual = actual->castRaw<RRefType>()) {
582       auto innerMatch = matchTypeVariables(
583           lt_formal->getElementType(), lt_actual->getElementType(), type_env);
584       if (!innerMatch.success()) {
585         return innerMatch;
586       }
587       return MatchTypeReturn::Success();
588     } else {
589       std::stringstream ss;
590       ss << "Cannot match a rref to " << actual->repr_str();
591       return ss.str();
592     }
593   } else if (auto opt_formal = formal->castRaw<OptionalType>()) {
594     if (auto opt_actual = actual->castRaw<OptionalType>()) {
595       auto optionedMatch = matchTypeVariables(
596           opt_formal->getElementType(), opt_actual->getElementType(), type_env);
597       if (!optionedMatch.success()) {
598         return optionedMatch;
599       }
600     } else if (!actual->isSubtypeOf(*NoneType::get())) {
601       // If the actual type is a non-optional, allow matching to the formal if
602       // its element type matches the actual.
603       // Don't match None because it is already an optional (but one of
604       // unknown type).
605       return matchTypeVariables(opt_formal->getElementType(), actual, type_env);
606     }
607     // note: if actual was None here we potentially did not fill in the type
608     // variables contained in the formal. It is still a valid match because None
609     // matches Optional[T] later error checking on tryEvalTypeVariables will
610     // report the problem if we never match variables in type T
611     return MatchTypeReturn::Success();
612   } else if (auto dict_formal = formal->castRaw<DictType>()) {
613     if (auto dict_actual = actual->castRaw<DictType>()) {
614       auto key_match = matchTypeVariables(
615           dict_formal->getKeyType(), dict_actual->getKeyType(), type_env);
616       if (!key_match.success()) {
617         return key_match;
618       }
619       auto value_match = matchTypeVariables(
620           dict_formal->getValueType(), dict_actual->getValueType(), type_env);
621       if (!value_match.success()) {
622         return value_match;
623       }
624       return MatchTypeReturn::Success();
625     } else {
626       std::stringstream ss;
627       ss << "Cannot match a dict to " << actual->repr_str();
628       return ss.str();
629     }
630   }
631 
632   AT_ERROR("Unhandled free variable container: ", formal->repr_str());
633 }
634 
635 // change return types like List[List[t]] into List[List[int]]
tryEvalTypeVariables(const TypePtr & type,std::unordered_map<std::string,TypePtr> & type_env)636 TORCH_API TypePtr tryEvalTypeVariables(const TypePtr& type, std::unordered_map<std::string, TypePtr>& type_env) {
637   if (!type->hasFreeVariables()) {
638     if (auto dyn = type->castRaw<c10::DynamicType>()) {
639       return tryEvalTypeVariables(dyn->fallback(), type_env);
640     }
641     return type;
642   }
643 
644   if (auto vt = type->castRaw<VarType>()) {
645     auto it = type_env.find(vt->name());
646     if (it == type_env.end()) {
647       return nullptr;
648     }
649     return it->second;
650   } else {
651     at::ArrayRef<TypePtr> contained = type->containedTypes();
652     if (contained.empty()) {
653       return type;
654     }
655     std::vector<TypePtr> new_contained;
656     new_contained.reserve(contained.size());
657     for (const TypePtr& t : contained) {
658       TypePtr r = tryEvalTypeVariables(t, type_env);
659       if (!r) {
660         return nullptr;
661       }
662       new_contained.push_back(std::move(r));
663     }
664     return type->withContained(std::move(new_contained));
665   }
666 }
667 
elementTypeCanBeInferredFromMembers(const TypePtr & elem_type)668 TORCH_API bool elementTypeCanBeInferredFromMembers(const TypePtr& elem_type) {
669   if (elem_type->kind() == UnionType::Kind
670       || elem_type->kind() == OptionalType::Kind
671       || elem_type->kind() == NumberType::Kind) {
672     // Builtin Union types
673     return false;
674   }
675   if (elem_type->kind() == InterfaceType::Kind) {
676     // since classes can be members of multiple interfaces, we cannot
677     // construct which interface the list holds from the members alone
678     return false;
679   }
680   if (elem_type->kind() == AnyType::Kind) {
681     // List of Any can contains heterogenous types
682     return false;
683   }
684   return true;
685 }
686 
typeKindToString(TypeKind kind)687 const char * typeKindToString(TypeKind kind) {
688 #define CASE_TYPE(T) case TypeKind::T: return #T;
689   switch(kind) {
690     C10_FORALL_TYPES(CASE_TYPE)
691   }
692 #undef CASE_TYPE
693   return "";
694 }
695 
isSubtypeOfExt(const Type & rhs,std::ostream * why_not) const696 bool Type::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
697   if (rhs.kind() == TypeKind::AnyType || *this == rhs) {
698     return true;
699   }
700   if (auto opt_rhs = rhs.castRaw<OptionalType>()) {
701     return this->isSubtypeOfExt(*opt_rhs->getElementType(), why_not);
702   }
703   if (auto union_rhs = rhs.castRaw<UnionType>()) {
704     // Check if `this` is a subtype of any of the types within the Union
705     return std::any_of(union_rhs->containedTypes().begin(),
706                        union_rhs->containedTypes().end(),
707                        [&](const TypePtr& inner) {
708                          return this->isSubtypeOfExt(*inner, why_not);
709                        });
710   }
711   if (auto dyn = rhs.castRaw<DynamicType>()) {
712     return DynamicType::create(*this)->isSubtypeOf(*dyn);
713   }
714   return false;
715 }
716 
is_module() const717 bool Type::is_module() const {
718   return false;
719 }
720 
createNamed(const std::optional<c10::QualifiedName> & qualName,const std::vector<std::string> & field_names,const std::vector<TypePtr> & field_types)721 TupleTypePtr TupleType::createNamed(
722     const std::optional<c10::QualifiedName>& qualName,
723     const std::vector<std::string>& field_names,
724     const std::vector<TypePtr>& field_types) {
725   std::vector<IValue> empty_defaults;
726   return TupleType::createNamed(qualName, field_names, field_types, empty_defaults);
727 }
728 
createNamed(const std::optional<c10::QualifiedName> & qualName,const std::vector<c10::string_view> & field_names,const std::vector<TypePtr> & field_types)729 TupleTypePtr TupleType::createNamed(
730     const std::optional<c10::QualifiedName>& qualName,
731     const std::vector<c10::string_view>& field_names,
732     const std::vector<TypePtr>& field_types) {
733   std::vector<IValue> empty_defaults;
734   return createWithSpec(qualName, field_names, field_types, empty_defaults);
735 }
736 
createNamed(const std::optional<c10::QualifiedName> & qualName,const std::vector<std::string> & field_names,const std::vector<TypePtr> & field_types,std::vector<IValue> & field_defaults)737 TupleTypePtr TupleType::createNamed(
738     const std::optional<c10::QualifiedName>& qualName,
739     const std::vector<std::string>& field_names,
740     const std::vector<TypePtr>& field_types,
741     std::vector<IValue>& field_defaults) {
742   return createWithSpec(qualName, field_names, field_types, field_defaults);
743 }
744 
745 template <typename S>
createWithSpec(const std::optional<c10::QualifiedName> & qualName,const std::vector<S> & field_names,const std::vector<TypePtr> & field_types,std::vector<IValue> & field_defaults)746 TupleTypePtr TupleType::createWithSpec(const std::optional<c10::QualifiedName>& qualName,
747     const std::vector<S>& field_names,
748     const std::vector<TypePtr>& field_types,
749     std::vector<IValue>& field_defaults) {
750   TORCH_INTERNAL_ASSERT(field_names.size() == field_types.size());
751 
752   std::vector<Argument> arguments;
753   arguments.reserve(field_names.size());
754   auto min_default_idx = field_names.size() - field_defaults.size();
755   for (size_t i = 0; i < field_names.size(); ++i) {
756     if (i < min_default_idx) {
757       Argument arg{
758           /*name=*/std::string{field_names[i]},
759           /*type=*/field_types[i],
760           /*N=*/i};
761       arguments.emplace_back(std::move(arg));
762     }
763     else {
764       size_t j = i - min_default_idx;
765       TORCH_CHECK(field_defaults[j].tagKind() != "Tensor", "Tensors are "
766                   "not supported as default NamedTuple fields. Their "
767                   "mutability could lead to potential memory aliasing "
768                   "problems");
769       Argument arg{
770           /*name=*/std::string{field_names[i]},
771           /*type=*/field_types[i],
772           /*N=*/i,
773           /*default_value=*/field_defaults[j]};
774       arguments.emplace_back(std::move(arg));
775     }
776   }
777 
778   auto schema = std::make_shared<FunctionSchema>(
779       /*name=*/qualName.value_or(c10::QualifiedName()).name(),
780       /*overload_name=*/std::string(""),
781       /*arguments=*/std::move(arguments),
782       /*returns=*/std::vector<Argument>{});
783   return std::shared_ptr<TupleType>(new TupleType(
784       field_types, qualName, std::move(schema))); // NOLINT(modernize-make-shared)
785 }
786 
names() const787 std::optional<std::vector<c10::string_view>> TupleType::names() const {
788   if (!schema_) {
789     return {};
790   }
791   std::vector<c10::string_view> ret;
792   for (const auto& arg : schema_->arguments()) {
793     ret.emplace_back(arg.name());
794   }
795   return ret;
796 }
797 
isSubtypeOfExt(const Type & rhs,std::ostream * why_not) const798 bool NoneType::isSubtypeOfExt(const Type& rhs, std::ostream *why_not) const {
799   if (rhs.kind() == OptionalType::Kind) {
800     return true;
801   }
802   return Type::isSubtypeOfExt(rhs, why_not);
803 }
804 
equals(const Type & rhs) const805 bool NumberType::equals(const Type& rhs) const {
806   if (auto union_type = rhs.cast<UnionType>()) {
807     return union_type->containedTypes().size() == 3 && union_type->canHoldType(*NumberType::get());
808   } else {
809     return rhs.kind() == this->kind();
810   }
811 }
812 
isSubtypeOfExt(const Type & rhs,std::ostream * why_not) const813 bool NumberType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
814   if (auto union_type = rhs.cast<UnionType>()) {
815     return union_type->canHoldType(*NumberType::get());
816   } else {
817     return Type::isSubtypeOfExt(rhs, why_not);
818   }
819 }
820 
TupleType(std::vector<TypePtr> elements,std::optional<c10::QualifiedName> name,std::shared_ptr<FunctionSchema> schema)821 TupleType::TupleType(
822     std::vector<TypePtr> elements,
823     std::optional<c10::QualifiedName> name,
824     std::shared_ptr<FunctionSchema> schema)
825     : NamedType(TypeKind::TupleType, std::move(name)),
826       elements_(std::move(elements)),
827       has_free_variables_(std::any_of(elements_.begin(), elements_.end(), [](const TypePtr& v) {
828         if (!v) {
829           throw std::runtime_error("Can not create tuple with None type");
830         }
831         return v->hasFreeVariables();
832       })), schema_(std::move(schema)) {
833 
834   if (schema_) {
835     for (const Argument& arg : schema_->arguments()) {
836       checkNoAny(*this, "attribute", arg.name(), arg.type());
837     }
838   }
839 }
840 
isSubtypeOfExt(const Type & rhs_,std::ostream * why_not) const841 bool TupleType::isSubtypeOfExt(const Type& rhs_, std::ostream* why_not) const {
842   if (Type::isSubtypeOfExt(rhs_, why_not)) {
843     return true;
844   }
845   if (rhs_.kind() == AnyTupleType::Kind) {
846     return true;
847   }
848   auto rhs = rhs_.cast<TupleType>();
849   if (!rhs)
850     return false;
851   // unnamed tuple is not a subtype of nametuple
852   if (!schema() && rhs->schema())
853     return false;
854   // namedtuple may be a subtype of unnamed tuple
855   auto test_names_match = [&](const std::shared_ptr<FunctionSchema>& lhs, const std::shared_ptr<FunctionSchema>& rhs) {
856     const auto& args_lhs = lhs->arguments();
857     const auto& args_rhs = rhs->arguments();
858     if (args_lhs.size() != args_rhs.size()) {
859       return false;
860     }
861 
862     for (size_t i = 0; i < args_lhs.size(); ++i) {
863       if (args_lhs[i].name() != args_rhs[i].name()) {
864         return false;
865       }
866     }
867     return true;
868   };
869   bool names_match = !rhs->schema() || test_names_match(schema(), rhs->schema());
870   // co-variant rules for tuples
871   return names_match && compare(*rhs, [&](const Type& a, const Type& b) {
872     return a.isSubtypeOfExt(b, why_not);
873   });
874 }
875 
isSubtypeOfExt(const Type & rhs_,std::ostream * why_not) const876 bool ListType::isSubtypeOfExt(const Type& rhs_, std::ostream* why_not) const {
877   if (Type::isSubtypeOfExt(rhs_, why_not)) {
878     return true;
879   }
880   if (rhs_.kind() == AnyListType::Kind) {
881     return true;
882   }
883   return false;
884 }
885 
equals(const Type & rhs) const886  bool TupleType::equals(const Type& rhs) const {
887    bool typesSame =
888        compare(rhs, [](const Type& a, const Type& b) { return a == b; });
889    if (!typesSame) {
890      return false;
891   }
892 
893   // `compare` guarantees that rhs is always a TupleType.
894   auto rhsTuple = rhs.expect<TupleType>();
895   if (schema_ == nullptr && rhsTuple->schema_ == nullptr) {
896     return typesSame;
897   }
898   if (schema_ == nullptr || rhsTuple->schema_ == nullptr) {
899     return false;
900   }
901   return *schema_ == *rhsTuple->schema_;
902 }
903 
str() const904 std::string TupleType::str() const {
905   std::stringstream ss;
906   if (schema_ && name()) {
907     ss << name()->qualifiedName();
908   } else {
909     ss << "(";
910     for(size_t i = 0; i < elements().size(); ++i) {
911       if(i > 0)
912         ss << ", ";
913       ss << elements()[i]->str();
914     }
915     ss << ")";
916   }
917   return ss.str();
918 }
annotation_str_impl(const TypePrinter & printer) const919 std::string TupleType::annotation_str_impl(const TypePrinter& printer) const {
920   if (schema_ && name()) {
921     return name()->qualifiedName();
922   }
923 
924   if (elements().empty()) {
925     // `typing.Tuple` special-cases the annotation syntax for empty tuple
926     // with `typing.Tuple[()]`. See
927     // https://docs.python.org/3/library/typing.html#typing.Tuple
928     return "Tuple[()]";
929   }
930 
931   // Fast path for expectedly-small Tuples.
932   const auto elts = elements();
933   if (elts.size() <= 3) {
934     std::array<std::string, 3> elements_strs;
935     size_t total_length = 0;
936     int idx = 0;
937     for (const auto& element: elts) {
938       elements_strs[idx] = element->annotation_str(printer);
939       total_length += elements_strs[idx].size();
940       idx++;
941     }
942     std::string result;
943     result.reserve(strlen("Tuple[") + strlen(", ") * (elts.size() - 1) + total_length + 1);
944     result.append("Tuple[");
945     for (const auto ii : c10::irange(elts.size())) {
946       if (ii > 0) {
947         result.push_back(',');
948         result.push_back(' ');
949       }
950       result.append(elements_strs[ii]);
951     }
952     result.push_back(']');
953     return result;
954   }
955 
956   std::ostringstream ss;
957   ss << "Tuple[";
958   size_t i = 0;
959   for (const auto& element: elts) {
960     if (i > 0) {
961       ss << ", ";
962     }
963     ss << element->annotation_str(printer);
964     i++;
965   }
966   ss << ']';
967   return std::move(ss).str();
968 }
969 
create(QualifiedName qualifiedName,bool is_module)970 InterfaceTypePtr InterfaceType::create(QualifiedName qualifiedName, bool is_module) {
971   return InterfaceTypePtr(
972       new InterfaceType(std::move(qualifiedName), is_module));
973 }
974 
FunctionType(torch::jit::Function * function)975 FunctionType::FunctionType(torch::jit::Function* function)
976   : NamedType(TypeKind::FunctionType, function->qualname()),
977     function_(function) {}
978 
isSubTypeImpl(const InterfaceType & lhs,const InterfaceType & rhs,std::ostream * why_not)979 bool InterfaceType::isSubTypeImpl(
980     const InterfaceType& lhs,
981     const InterfaceType& rhs,
982     std::ostream* why_not) {
983   if (!lhs.is_module() && rhs.is_module()) {
984     if (why_not) {
985       *why_not << "Interface '" << lhs.repr_str() << "' is not a subtype of "
986                << "the module interface '" << rhs.repr_str() << "'.\n";
987     }
988     return false;
989   }
990     for (const FunctionSchema& schema : *rhs.methods_) {
991       auto self_schema = lhs.getMethod(schema.name());
992       if (!self_schema) {
993         if (why_not) {
994           *why_not << "Interface '" << lhs.repr_str()
995                    << "' does not have method '" << schema.name() << "' but interface '"
996                    << rhs.repr_str() << "' does.\n";
997         }
998         return false;
999       }
1000       // NOLINTNEXTLINE(bugprone-argument-comment)
1001       if (!self_schema->isSubtypeOf(schema, /*is_method=*/true, why_not)) {
1002         if (why_not) {
1003           *why_not << "Method on interface '" << lhs.repr_str()
1004                    << "' (1) is not compatible with interface '"
1005                    << rhs.repr_str() << "' (2)\n"
1006                    << "  (1) " << *self_schema << "\n"
1007                    << "  (2) " << schema << "\n";
1008           return false;
1009         }
1010         return false;
1011       }
1012     }
1013     return true;
1014 }
1015 
isSubtypeOfExt(const Type & rhs,std::ostream * why_not) const1016 bool InterfaceType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
1017   // to improve performance this check can be cached
1018   if (auto iface = rhs.castRaw<InterfaceType>()) {
1019     return isSubTypeImpl(*this, *iface, why_not);
1020   }
1021   return Type::isSubtypeOfExt(rhs, why_not);
1022 }
1023 
getMethod(const std::string & name) const1024 const FunctionSchema* InterfaceType::getMethod(const std::string& name) const {
1025   for (const FunctionSchema& method : *methods_) {
1026     if (method.name() == name) {
1027       return &method;
1028     }
1029   }
1030   return nullptr;
1031 }
addMethod(FunctionSchema schema)1032 void InterfaceType::addMethod(FunctionSchema schema) {
1033   methods_->emplace_back(std::move(schema));
1034 }
InterfaceType(QualifiedName name,bool is_module)1035 InterfaceType::InterfaceType(QualifiedName name, bool is_module)
1036     : NamedType(InterfaceType::Kind, std::move(name)),
1037       methods_(std::make_shared<std::vector<FunctionSchema>>()),
1038       is_module_(is_module) {}
1039 
1040 InterfaceType::~InterfaceType() = default;
1041 
containsAnyType(const TypePtr & type)1042 bool containsAnyType(const TypePtr& type) {
1043   std::vector<TypePtr> to_scan = { type };
1044   while (!to_scan.empty()) {
1045     const auto typ = to_scan.back();
1046     to_scan.pop_back();
1047     if (typ->kind() == AnyType::Kind) {
1048       return true;
1049     }
1050     for (const TypePtr& sub : typ->containedTypes()) {
1051       to_scan.emplace_back(sub);
1052     }
1053   }
1054   return false;
1055 }
1056 
checkNoAny(const Type & base,const char * what,const std::string & attrname,const TypePtr & attrtype)1057 void checkNoAny(const Type& base, const char* what, const std::string& attrname, const TypePtr& attrtype) {
1058   TORCH_CHECK(
1059       !containsAnyType(attrtype),
1060       "attempting to add ",
1061       what,
1062       " '",
1063       attrname,
1064       "' of type ",
1065       attrtype->repr_str(),
1066       " to '",
1067       base.repr_str(),
1068       "' but it contains an Any type. Any types cannot be members of modules, classes, or named tuples.");
1069 }
1070 
merge(const SymbolicShape & other) const1071 SymbolicShape SymbolicShape::merge(const SymbolicShape& other) const {
1072   if (!dims_ || !other.dims_ || dims_->size() != other.dims_->size()) {
1073     return SymbolicShape();
1074   }
1075   std::vector<ShapeSymbol> dims;
1076   for (size_t i = 0, n = dims_->size(); i < n; i++) {
1077     dims.push_back(merge_primitive((*dims_)[i], (*other.dims_)[i]));
1078   }
1079   return SymbolicShape(std::move(dims));
1080 }
1081 
dump() const1082 void SymbolicShape::dump() const {
1083   std::cout << *this << "\n";
1084 }
1085 
isSubtypeOfExt(const Type & rhs,std::ostream * why_not) const1086 bool EnumType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
1087   return rhs.kind() == TypeKind::AnyType ||
1088       rhs.kind() == TypeKind::AnyEnumType ||
1089       *this == rhs ||
1090       Type::isSubtypeOfExt(rhs, why_not);
1091 }
1092 
1093 } // namespace c10
1094