1 #include <ATen/core/Dict.h>
2 #include <ATen/core/Tensor.h>
3 #include <ATen/core/dynamic_type.h>
4 #include <ATen/core/enum_type.h>
5 #include <ATen/core/function.h>
6 #include <ATen/core/function_schema.h>
7 #include <ATen/core/grad_mode.h>
8 #include <ATen/core/jit_type.h>
9 #include <c10/macros/Macros.h>
10 #include <c10/util/flat_hash_map.h>
11 #include <c10/util/irange.h>
12 #include <array>
13 #include <iostream>
14 #include <utility>
15
16 namespace std {
17 template<>
18 struct hash<std::tuple<std::string, c10::TypePtr, c10::TypePtr>> {
operator ()std::hash19 size_t operator()(std::tuple<std::string, c10::TypePtr, c10::TypePtr> const& t) const {
20 // This hashing is all hidden behind a static initializer so it
21 // doesn't have to be optimal
22 auto hash = std::hash<std::string>()(std::get<0>(t));
23 hash = at::hash_combine(hash, std::hash<c10::TypePtr>()(std::get<1>(t)));
24 hash = at::hash_combine(hash, std::hash<c10::TypePtr>()(std::get<2>(t)));
25 return hash;
26 }
27 };
28 template<>
29 struct hash<std::tuple<std::string, c10::TypePtr>> {
operator ()std::hash30 size_t operator()(std::tuple<std::string, c10::TypePtr> const& t) const {
31 auto hash = std::hash<std::string>()(std::get<0>(t));
32 hash = at::hash_combine(hash, std::hash<c10::TypePtr>()(std::get<1>(t)));
33 return hash;
34 }
35 };
36 } // namespace std
37
38 namespace c10 {
39
40 static_assert(
41 sizeof(SingletonOrSharedTypePtr<void>) == sizeof(std::shared_ptr<void>) && sizeof(std::shared_ptr<void>) == 2 * sizeof(void*),
42 "std::shared_ptr has an unexpected representation on this platform!");
43 static_assert(
44 std::is_same_v<decltype(getTypePtr<std::tuple<int64_t, int64_t>>()), const TupleTypePtr&>,
45 "getTypePtr<std::tuple<int64_t, int64_t>> not returning const ref!");
46
type_verbosity()47 TypeVerbosity type_verbosity() {
48 static const char* c_verbosity = std::getenv("PYTORCH_JIT_TYPE_VERBOSITY");
49 static TypeVerbosity verbosity = c_verbosity ?
50 static_cast<TypeVerbosity>(std::stoi(c_verbosity)) : TypeVerbosity::Default;
51 return verbosity;
52 }
53
operator <<(std::ostream & out,const Type & t)54 std::ostream& operator<<(std::ostream & out, const Type & t) {
55 if (auto value = t.cast<TensorType>()) {
56 if (value->scalarType().has_value()) {
57 out << toString(*value->scalarType());
58 if (!value->sizes().size().has_value()) {
59 out << "Tensor";
60 }
61 } else {
62 out << "Tensor";
63 }
64 if (auto ndim = value->sizes().size()) {
65 bool has_valid_strides_info = *ndim > 0 &&
66 value->strides().isComplete() && value->strides().size() == ndim;
67
68 out << "(";
69 size_t i = 0;
70 bool symbolic = type_verbosity() == TypeVerbosity::Symbolic;
71 for (i = 0; i < *ndim; ++i) {
72 if (i > 0) {
73 out << ", ";
74 }
75 if (auto s = value->sizes()[i]) {
76 out << *s;
77 } else if (symbolic) {
78 out << value->symbolic_sizes().at(i);
79 } else {
80 out << "*";
81 }
82 }
83 if (has_valid_strides_info &&
84 type_verbosity() >= TypeVerbosity::TypeAndStride) {
85 out << ", strides=[";
86 for (size_t i = 0; i < *ndim; ++i) {
87 if (i > 0) {
88 out << ", ";
89 }
90 out << *value->strides()[i];
91 }
92 out << "]";
93 }
94 if (type_verbosity() >= TypeVerbosity::Full) {
95 if (value->requiresGrad()) {
96 if (i++ > 0) {
97 out << ", ";
98 }
99 out << "requires_grad=" << *value->requiresGrad();
100 }
101 if (value->device()) {
102 if (i++ > 0) {
103 out << ", ";
104 }
105 out << "device=" << *value->device();
106 }
107 }
108 out << ")";
109 } else {
110 if (type_verbosity() >= TypeVerbosity::Full) {
111 size_t i = 0;
112 if (value->requiresGrad()) {
113 out << "("
114 << "requires_grad=" << *value->requiresGrad();
115 i++;
116 }
117 if (value->device()) {
118 out << ((i++ > 0) ? ", " : "(") << "device=" << *value->device();
119 }
120 if (i > 0) {
121 out << ")";
122 }
123 }
124 }
125
126 if (value->undefined() && *value->undefined()) {
127 out << "[Undefined]";
128 }
129 } else if(t.kind() == TypeKind::ListType) {
130 auto prim = t.castRaw<ListType>()->getElementType();
131 out << *prim << "[]";
132 } else if (t.kind() == TypeKind::OptionalType) {
133 auto prim = t.castRaw<OptionalType>()->getElementType();
134 out << *prim << "?";
135 } else if(t.kind() == TypeKind::FutureType) {
136 auto elem = t.castRaw<FutureType>()->getElementType();
137 out << "Future[" << *elem << "]";
138 } else if(t.kind() == TypeKind::RRefType) {
139 auto elem = t.castRaw<RRefType>()->getElementType();
140 out << "RRef[" << *elem << "]";
141 } else if(auto tup = t.cast<TupleType>()) {
142 if (tup->schema()) {
143 out << "NamedTuple";
144 }
145 out << "(";
146 for(size_t i = 0; i < tup->elements().size(); ++i) {
147 if(i > 0)
148 out << ", ";
149 if (tup->schema()) {
150 auto arg = tup->schema()->arguments()[i];
151 out << arg.name() << " : ";
152 out << *(tup->elements()[i]);
153 if (arg.default_value()) {
154 out << " = " << *arg.default_value();
155 }
156 }
157 else {
158 out << *(tup->elements()[i]);
159 }
160 }
161 out << ")";
162 } else if (t.kind() == TypeKind::FunctionType) {
163 out << "Function";
164 } else {
165 out << t.str();
166 }
167 return out;
168 }
169
get()170 AnyTypePtr AnyType::get() {
171 static AnyTypePtr value(new AnyType());
172 return value;
173 }
174
get()175 NumberTypePtr NumberType::get() {
176 static NumberTypePtr value(new NumberType());
177 return value;
178 }
get()179 IntTypePtr IntType::get() {
180 static IntTypePtr value(new IntType());
181 return value;
182 }
get()183 FloatTypePtr FloatType::get() {
184 static FloatTypePtr value(new FloatType());
185 return value;
186 }
get()187 ComplexTypePtr ComplexType::get() {
188 static ComplexTypePtr value(new ComplexType());
189 return value;
190 }
get()191 BoolTypePtr BoolType::get() {
192 static BoolTypePtr value(new BoolType());
193 return value;
194 }
get()195 StorageTypePtr StorageType::get() {
196 static StorageTypePtr value(new StorageType());
197 return value;
198 }
get()199 NoneTypePtr NoneType::get() {
200 static NoneTypePtr value(new NoneType());
201 return value;
202 }
get()203 GeneratorTypePtr GeneratorType::get() {
204 static GeneratorTypePtr value(new GeneratorType());
205 return value;
206 }
get()207 QuantizerTypePtr QuantizerType::get() {
208 static QuantizerTypePtr value(new QuantizerType());
209 return value;
210 }
get()211 QSchemeTypePtr QSchemeType::get() {
212 static QSchemeTypePtr value(new QSchemeType());
213 return value;
214 }
get()215 StringTypePtr StringType::get() {
216 static StringTypePtr value(new StringType());
217 return value;
218 }
get()219 DeviceObjTypePtr DeviceObjType::get() {
220 static DeviceObjTypePtr value(new DeviceObjType());
221 return value;
222 }
get()223 StreamObjTypePtr StreamObjType::get() {
224 static StreamObjTypePtr value(new StreamObjType());
225 return value;
226 }
get()227 ScalarTypeTypePtr ScalarTypeType::get() {
228 static ScalarTypeTypePtr value(new ScalarTypeType());
229 return value;
230 }
get()231 LayoutTypePtr LayoutType::get() {
232 static LayoutTypePtr value(new LayoutType());
233 return value;
234 }
get()235 MemoryFormatTypePtr MemoryFormatType::get() {
236 static MemoryFormatTypePtr value(new MemoryFormatType());
237 return value;
238 }
get()239 PyObjectTypePtr PyObjectType::get() {
240 static PyObjectTypePtr value(new PyObjectType());
241 return value;
242 }
get()243 CapsuleTypePtr CapsuleType::get() {
244 static CapsuleTypePtr value(new CapsuleType());
245 return value;
246 }
ofInts()247 ListTypePtr ListType::ofInts() {
248 static auto value = ListType::create(IntType::get());
249 return value;
250 }
ofSymInts()251 ListTypePtr ListType::ofSymInts() {
252 static auto value = ListType::create(SymIntType::get());
253 return value;
254 }
ofComplexDoubles()255 ListTypePtr ListType::ofComplexDoubles() {
256 static auto value = ListType::create(ComplexType::get());
257 return value;
258 }
ofFloats()259 ListTypePtr ListType::ofFloats() {
260 static auto value = ListType::create(FloatType::get());
261 return value;
262 }
ofBools()263 ListTypePtr ListType::ofBools() {
264 static auto value = ListType::create(BoolType::get());
265 return value;
266 }
ofStrings()267 ListTypePtr ListType::ofStrings() {
268 static auto value = ListType::create(StringType::get());
269 return value;
270 }
ofNumbers()271 ListTypePtr ListType::ofNumbers() {
272 static auto value = ListType::create(NumberType::get());
273 return value;
274 }
275
get(TypePtr inner)276 TypePtr OptionalType::get(TypePtr inner) {
277 static ska::flat_hash_map<TypePtr, TypePtr> containerTypePtrs;
278 static std::mutex mutex;
279 // Perf from the lock is ok because this function is guarded behind
280 // a static initializer; it should only be called once per type.
281 std::lock_guard<std::mutex> lock(mutex);
282 if (containerTypePtrs.find(inner) == containerTypePtrs.end()) {
283 TypePtr t = TypeFactory::create<OptionalType>(inner);
284 containerTypePtrs.emplace(inner, std::move(t));
285 }
286 return containerTypePtrs[inner];
287 }
288
get(const std::string & identifier,TypePtr inner)289 TypePtr ListType::get(const std::string& identifier, TypePtr inner) {
290 static ska::flat_hash_map<std::tuple<std::string, TypePtr>, TypePtr> containerTypePtrs;
291 static std::mutex mutex;
292 // Perf from the lock is ok because this function is guarded behind
293 // a static initializer; it should only be called once per type.
294 auto key = std::make_tuple(identifier, inner);
295 std::lock_guard<std::mutex> lock(mutex);
296 if (containerTypePtrs.find(key) == containerTypePtrs.end()) {
297 TypePtr t = ListType::create(inner);
298 containerTypePtrs.emplace(key, std::move(t));
299 }
300 return containerTypePtrs[key];
301 }
302
get(const std::string & identifier,TypePtr key,TypePtr value)303 TypePtr DictType::get(const std::string& identifier, TypePtr key, TypePtr value) {
304 static ska::flat_hash_map<std::tuple<std::string, TypePtr, TypePtr>, TypePtr> containerTypePtrs;
305 static std::mutex mutex;
306 // Perf from the lock is ok because this function is guarded behind
307 // a static initializer; it should only be called once per type.
308 auto map_key = std::make_tuple(identifier, key, value);
309 std::lock_guard<std::mutex> lock(mutex);
310 if (containerTypePtrs.find(map_key) == containerTypePtrs.end()) {
311 TypePtr t = DictType::create(std::move(key), std::move(value));
312 containerTypePtrs.emplace(map_key, std::move(t));
313 }
314 return containerTypePtrs[map_key];
315 }
316
annotation_str_impl(const TypePrinter & printer) const317 std::string DictType::annotation_str_impl(const TypePrinter& printer) const {
318 auto keyAnnotation = getKeyType()->annotation_str(printer);
319 auto valueAnnotation = getValueType()->annotation_str(printer);
320
321 std::string result;
322 result.reserve(5 /* "Dict[" */ + keyAnnotation.size() + 2 /* ", " */ + valueAnnotation.size() + 1 /* "]" */);
323 result = "Dict[";
324 result += keyAnnotation;
325 result.push_back(',');
326 result.push_back(' ');
327 result += valueAnnotation;
328 result.push_back(']');
329 return result;
330 }
331
get()332 AnyListTypePtr AnyListType::get() {
333 static AnyListTypePtr value(new AnyListType());
334 return value;
335 }
336
get()337 AnyTupleTypePtr AnyTupleType::get() {
338 static AnyTupleTypePtr value(new AnyTupleType());
339 return value;
340 }
341
get()342 AnyClassTypePtr AnyClassType::get() {
343 static AnyClassTypePtr value(new AnyClassType());
344 return value;
345 }
346
get()347 AnyEnumTypePtr AnyEnumType::get() {
348 static AnyEnumTypePtr value(new AnyEnumType());
349 return value;
350 }
351
get()352 SymIntTypePtr SymIntType::get() {
353 static SymIntTypePtr value(new SymIntType());
354 return value;
355 }
356
get()357 SymFloatTypePtr SymFloatType::get() {
358 static SymFloatTypePtr value(new SymFloatType());
359 return value;
360 }
361
get()362 SymBoolTypePtr SymBoolType::get() {
363 static SymBoolTypePtr value(new SymBoolType());
364 return value;
365 }
366
unifyTypesImpl(const TypePtr & t1,const TypePtr & t2,bool default_to_union=false,const TypePtr & type_hint=nullptr)367 static std::optional<TypePtr> unifyTypesImpl(const TypePtr& t1, const TypePtr& t2, bool default_to_union=false, const TypePtr& type_hint=nullptr) {
368 // check direct subtyping relation
369 if (t1->isSubtypeOf(*t2)) {
370 return t2;
371 } else if (t2->isSubtypeOf(*t1)) {
372 return t1;
373 }
374
375 // Handle non-container types which do not subtype each other and unify
376 if (t1->kind() == TensorType::Kind && t2->kind() == TensorType::Kind) {
377 return t1->expectRef<TensorType>().merge(t2->expectRef<TensorType>());
378 }
379
380 if (t1->isSubtypeOf(*NoneType::get()) && !t2->isSubtypeOf(*NoneType::get())) {
381 return OptionalType::create(t2);
382 } else if (t2->isSubtypeOf(*NoneType::get()) && !t1->isSubtypeOf(*NoneType::get())) {
383 return OptionalType::create(t1);
384 }
385
386 // NB: we do not return NumberType because there is not currently enough
387 // operator support for it
388
389 // Attempt to unify Complete Tensor Types for immutable type containers
390
391 // unify(Optional[t1], t2) => Optional[unify(t1, t2)]
392 if (auto opt_t1 = t1->cast<OptionalType>()) {
393 if (auto elem = unifyTypes(opt_t1->getElementType(), t2)) {
394 return OptionalType::create(*std::move(elem));
395 }
396 } else if (auto opt_t2 = t2->cast<OptionalType>()) {
397 if (auto elem = unifyTypes(opt_t2->getElementType(), t1)) {
398 return OptionalType::create(*std::move(elem));
399 }
400 }
401
402 if (t1->castRaw<TupleType>() && t2->castRaw<TupleType>()) {
403 auto tuple1 = t1->castRaw<TupleType>();
404 auto tuple2 = t2->castRaw<TupleType>();
405 if (tuple1->elements().size() != tuple2->elements().size()) {
406 return std::nullopt;
407 }
408 std::vector<TypePtr> elements;
409 for (size_t i = 0; i < tuple1->elements().size(); i++) {
410 if (auto elem = unifyTypes(tuple1->elements().at(i), tuple2->elements().at(i), default_to_union)) {
411 elements.push_back(*std::move(elem));
412 } else {
413 return std::nullopt;
414 }
415 }
416 return static_cast<TypePtr>(TupleType::create(std::move(elements)));
417 }
418
419 if (t1->castRaw<FutureType>() && t2->castRaw<FutureType>()) {
420 if (auto elem = unifyTypes(
421 t1->castRaw<FutureType>()->getElementType(),
422 t2->castRaw<FutureType>()->getElementType())) {
423 return FutureType::create(*elem);
424 }
425 }
426
427 // Check direct subtyping relations again with Unshaped Types,
428 // to handle unification of mutable container types which might contain two different
429 // specialized tensors (ListType / DictType)
430 auto t1_unshaped = unshapedType(t1);
431 auto t2_unshaped = unshapedType(t2);
432
433 if (t1_unshaped->isSubtypeOf(*t2_unshaped)) {
434 return t2_unshaped;
435 } else if (t2_unshaped->isSubtypeOf(*t1_unshaped)) {
436 return t1_unshaped;
437 }
438
439 // Check whether or not `type_hint` is a common parent. This case
440 // could occur if we had two class types that had been annotated with
441 // a common interface
442 if (type_hint && t1->isSubtypeOf(*type_hint) && t2->isSubtypeOf(*type_hint)) {
443 return type_hint;
444 }
445
446 return std::nullopt;
447 }
448
unifyTypes(const TypePtr & t1,const TypePtr & t2,bool default_to_union,const TypePtr & type_hint)449 std::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2, bool default_to_union, const TypePtr& type_hint) {
450 auto unified = unifyTypesImpl(t1, t2, default_to_union, type_hint);
451
452 if (default_to_union && !unified) {
453 return UnionType::create({t1, t2});
454 }
455
456 return unified;
457 }
458
unifyTypeList(at::ArrayRef<TypePtr> elements,std::ostream & why_not,bool default_to_union,const TypePtr & type_hint)459 std::optional<TypePtr> unifyTypeList(
460 at::ArrayRef<TypePtr> elements,
461 std::ostream& why_not,
462 bool default_to_union,
463 const TypePtr& type_hint) {
464 if (elements.empty()) {
465 why_not << "Cannot get unified type from empty list";
466 return std::nullopt;
467 }
468
469 TypePtr ret_type = elements.at(0);
470 for (size_t i = 1; i < elements.size() && ret_type; ++i) {
471 std::optional<TypePtr> maybe_unified = unifyTypes(ret_type, elements.at(i), default_to_union, type_hint);
472 if (!maybe_unified) {
473 why_not << "Could not unify type list since element " << i << " of type "
474 << elements.at(i)->repr_str()
475 << " did not match the types before it ("
476 << ret_type->repr_str() << ")";
477 return std::nullopt;
478 }
479 ret_type = *maybe_unified;
480 }
481
482 return ret_type;
483 }
484
485 // NOTE: This function actually does need to take const TypePtr&
486 // because it sometimes calls unifyTypes, which needs const TypePtr&.
matchTypeVariables(const TypePtr & formal,const TypePtr & actual,TypeEnv & type_env)487 MatchTypeReturn matchTypeVariables(
488 const TypePtr& formal,
489 const TypePtr& actual,
490 TypeEnv& type_env) {
491 if (!formal->hasFreeVariables()) {
492 if (auto dyn = formal->castRaw<c10::DynamicType>()) {
493 return matchTypeVariables(dyn->fallback(), actual, type_env);
494 }
495 return MatchTypeReturn::Success();
496 }
497
498 if (auto vt = formal->castRaw<VarType>()) {
499 auto it = type_env.find(vt->name());
500 if (it == type_env.end()) {
501 type_env[vt->name()] = actual;
502 return MatchTypeReturn::Success();
503 } else if (unifyTypes(it->second, actual)) {
504 // note: unifyTypes allows subtyping in either direction, so actual
505 // may be a supertype of the current binding. we're not responsible
506 // for reporting the error, only for keeping type_env stable
507 return MatchTypeReturn::Success();
508 }
509 std::stringstream ss;
510 ss << "Type variable '" << vt->name() << "' previously matched to type "
511 << it->second->repr_str() << " is matched to type "
512 << actual->repr_str();
513 return ss.str();
514 } else if (auto lt_formal = formal->castRaw<ListType>()) {
515 if (auto lt_actual = actual->castRaw<ListType>()) {
516 auto innerMatch = matchTypeVariables(
517 lt_formal->getElementType(), lt_actual->getElementType(), type_env);
518 if (!innerMatch.success()) {
519 // propagate the errMsg onward
520 return innerMatch;
521 }
522 return MatchTypeReturn::Success();
523 } else if (auto tup_type = actual->castRaw<TupleType>()) {
524 std::stringstream ss;
525 auto maybe_tuple_unified = unifyTypeList(tup_type->elements(), ss);
526 if (maybe_tuple_unified) {
527 return matchTypeVariables(
528 lt_formal->getElementType(), *maybe_tuple_unified, type_env);
529 }
530 }
531
532 std::stringstream ss;
533 ss << "Cannot match " << lt_formal->repr_str() << " to "
534 << actual->repr_str();
535 return ss.str();
536 } else if (auto tp_formal = formal->castRaw<TupleType>()) {
537 if (auto tp_actual = actual->castRaw<TupleType>()) {
538 if (tp_formal->elements().size() != tp_actual->elements().size()) {
539 return MatchTypeReturn("Cannot match tuples of mismatched size");
540 }
541 for (size_t i = 0; i < tp_formal->elements().size(); ++i) {
542 auto result = matchTypeVariables(
543 tp_formal->elements()[i], tp_actual->elements()[i], type_env);
544 if (!result.success()) {
545 return result;
546 }
547 }
548 return MatchTypeReturn::Success();
549 } else {
550 std::stringstream ss;
551 ss << "Cannot match a tuple to " << actual->repr_str();
552 return MatchTypeReturn(ss.str());
553 }
554 } else if (auto lt_formal = formal->castRaw<FutureType>()) {
555 if (auto lt_actual = actual->castRaw<FutureType>()) {
556 auto innerMatch = matchTypeVariables(
557 lt_formal->getElementType(), lt_actual->getElementType(), type_env);
558 if (!innerMatch.success()) {
559 return innerMatch;
560 }
561 return MatchTypeReturn::Success();
562 } else {
563 std::stringstream ss;
564 ss << "Cannot match a future to " << actual->repr_str();
565 return ss.str();
566 }
567 } else if (auto lt_formal = formal->castRaw<AwaitType>()) {
568 if (auto lt_actual = actual->castRaw<AwaitType>()) {
569 auto innerMatch = matchTypeVariables(
570 lt_formal->getElementType(), lt_actual->getElementType(), type_env);
571 if (!innerMatch.success()) {
572 return innerMatch;
573 }
574 return MatchTypeReturn::Success();
575 } else {
576 std::stringstream ss;
577 ss << "Cannot match an await to " << actual->repr_str();
578 return ss.str();
579 }
580 } else if (auto lt_formal = formal->castRaw<RRefType>()) {
581 if (auto lt_actual = actual->castRaw<RRefType>()) {
582 auto innerMatch = matchTypeVariables(
583 lt_formal->getElementType(), lt_actual->getElementType(), type_env);
584 if (!innerMatch.success()) {
585 return innerMatch;
586 }
587 return MatchTypeReturn::Success();
588 } else {
589 std::stringstream ss;
590 ss << "Cannot match a rref to " << actual->repr_str();
591 return ss.str();
592 }
593 } else if (auto opt_formal = formal->castRaw<OptionalType>()) {
594 if (auto opt_actual = actual->castRaw<OptionalType>()) {
595 auto optionedMatch = matchTypeVariables(
596 opt_formal->getElementType(), opt_actual->getElementType(), type_env);
597 if (!optionedMatch.success()) {
598 return optionedMatch;
599 }
600 } else if (!actual->isSubtypeOf(*NoneType::get())) {
601 // If the actual type is a non-optional, allow matching to the formal if
602 // its element type matches the actual.
603 // Don't match None because it is already an optional (but one of
604 // unknown type).
605 return matchTypeVariables(opt_formal->getElementType(), actual, type_env);
606 }
607 // note: if actual was None here we potentially did not fill in the type
608 // variables contained in the formal. It is still a valid match because None
609 // matches Optional[T] later error checking on tryEvalTypeVariables will
610 // report the problem if we never match variables in type T
611 return MatchTypeReturn::Success();
612 } else if (auto dict_formal = formal->castRaw<DictType>()) {
613 if (auto dict_actual = actual->castRaw<DictType>()) {
614 auto key_match = matchTypeVariables(
615 dict_formal->getKeyType(), dict_actual->getKeyType(), type_env);
616 if (!key_match.success()) {
617 return key_match;
618 }
619 auto value_match = matchTypeVariables(
620 dict_formal->getValueType(), dict_actual->getValueType(), type_env);
621 if (!value_match.success()) {
622 return value_match;
623 }
624 return MatchTypeReturn::Success();
625 } else {
626 std::stringstream ss;
627 ss << "Cannot match a dict to " << actual->repr_str();
628 return ss.str();
629 }
630 }
631
632 AT_ERROR("Unhandled free variable container: ", formal->repr_str());
633 }
634
635 // change return types like List[List[t]] into List[List[int]]
tryEvalTypeVariables(const TypePtr & type,std::unordered_map<std::string,TypePtr> & type_env)636 TORCH_API TypePtr tryEvalTypeVariables(const TypePtr& type, std::unordered_map<std::string, TypePtr>& type_env) {
637 if (!type->hasFreeVariables()) {
638 if (auto dyn = type->castRaw<c10::DynamicType>()) {
639 return tryEvalTypeVariables(dyn->fallback(), type_env);
640 }
641 return type;
642 }
643
644 if (auto vt = type->castRaw<VarType>()) {
645 auto it = type_env.find(vt->name());
646 if (it == type_env.end()) {
647 return nullptr;
648 }
649 return it->second;
650 } else {
651 at::ArrayRef<TypePtr> contained = type->containedTypes();
652 if (contained.empty()) {
653 return type;
654 }
655 std::vector<TypePtr> new_contained;
656 new_contained.reserve(contained.size());
657 for (const TypePtr& t : contained) {
658 TypePtr r = tryEvalTypeVariables(t, type_env);
659 if (!r) {
660 return nullptr;
661 }
662 new_contained.push_back(std::move(r));
663 }
664 return type->withContained(std::move(new_contained));
665 }
666 }
667
elementTypeCanBeInferredFromMembers(const TypePtr & elem_type)668 TORCH_API bool elementTypeCanBeInferredFromMembers(const TypePtr& elem_type) {
669 if (elem_type->kind() == UnionType::Kind
670 || elem_type->kind() == OptionalType::Kind
671 || elem_type->kind() == NumberType::Kind) {
672 // Builtin Union types
673 return false;
674 }
675 if (elem_type->kind() == InterfaceType::Kind) {
676 // since classes can be members of multiple interfaces, we cannot
677 // construct which interface the list holds from the members alone
678 return false;
679 }
680 if (elem_type->kind() == AnyType::Kind) {
681 // List of Any can contains heterogenous types
682 return false;
683 }
684 return true;
685 }
686
typeKindToString(TypeKind kind)687 const char * typeKindToString(TypeKind kind) {
688 #define CASE_TYPE(T) case TypeKind::T: return #T;
689 switch(kind) {
690 C10_FORALL_TYPES(CASE_TYPE)
691 }
692 #undef CASE_TYPE
693 return "";
694 }
695
isSubtypeOfExt(const Type & rhs,std::ostream * why_not) const696 bool Type::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
697 if (rhs.kind() == TypeKind::AnyType || *this == rhs) {
698 return true;
699 }
700 if (auto opt_rhs = rhs.castRaw<OptionalType>()) {
701 return this->isSubtypeOfExt(*opt_rhs->getElementType(), why_not);
702 }
703 if (auto union_rhs = rhs.castRaw<UnionType>()) {
704 // Check if `this` is a subtype of any of the types within the Union
705 return std::any_of(union_rhs->containedTypes().begin(),
706 union_rhs->containedTypes().end(),
707 [&](const TypePtr& inner) {
708 return this->isSubtypeOfExt(*inner, why_not);
709 });
710 }
711 if (auto dyn = rhs.castRaw<DynamicType>()) {
712 return DynamicType::create(*this)->isSubtypeOf(*dyn);
713 }
714 return false;
715 }
716
is_module() const717 bool Type::is_module() const {
718 return false;
719 }
720
createNamed(const std::optional<c10::QualifiedName> & qualName,const std::vector<std::string> & field_names,const std::vector<TypePtr> & field_types)721 TupleTypePtr TupleType::createNamed(
722 const std::optional<c10::QualifiedName>& qualName,
723 const std::vector<std::string>& field_names,
724 const std::vector<TypePtr>& field_types) {
725 std::vector<IValue> empty_defaults;
726 return TupleType::createNamed(qualName, field_names, field_types, empty_defaults);
727 }
728
createNamed(const std::optional<c10::QualifiedName> & qualName,const std::vector<c10::string_view> & field_names,const std::vector<TypePtr> & field_types)729 TupleTypePtr TupleType::createNamed(
730 const std::optional<c10::QualifiedName>& qualName,
731 const std::vector<c10::string_view>& field_names,
732 const std::vector<TypePtr>& field_types) {
733 std::vector<IValue> empty_defaults;
734 return createWithSpec(qualName, field_names, field_types, empty_defaults);
735 }
736
createNamed(const std::optional<c10::QualifiedName> & qualName,const std::vector<std::string> & field_names,const std::vector<TypePtr> & field_types,std::vector<IValue> & field_defaults)737 TupleTypePtr TupleType::createNamed(
738 const std::optional<c10::QualifiedName>& qualName,
739 const std::vector<std::string>& field_names,
740 const std::vector<TypePtr>& field_types,
741 std::vector<IValue>& field_defaults) {
742 return createWithSpec(qualName, field_names, field_types, field_defaults);
743 }
744
745 template <typename S>
createWithSpec(const std::optional<c10::QualifiedName> & qualName,const std::vector<S> & field_names,const std::vector<TypePtr> & field_types,std::vector<IValue> & field_defaults)746 TupleTypePtr TupleType::createWithSpec(const std::optional<c10::QualifiedName>& qualName,
747 const std::vector<S>& field_names,
748 const std::vector<TypePtr>& field_types,
749 std::vector<IValue>& field_defaults) {
750 TORCH_INTERNAL_ASSERT(field_names.size() == field_types.size());
751
752 std::vector<Argument> arguments;
753 arguments.reserve(field_names.size());
754 auto min_default_idx = field_names.size() - field_defaults.size();
755 for (size_t i = 0; i < field_names.size(); ++i) {
756 if (i < min_default_idx) {
757 Argument arg{
758 /*name=*/std::string{field_names[i]},
759 /*type=*/field_types[i],
760 /*N=*/i};
761 arguments.emplace_back(std::move(arg));
762 }
763 else {
764 size_t j = i - min_default_idx;
765 TORCH_CHECK(field_defaults[j].tagKind() != "Tensor", "Tensors are "
766 "not supported as default NamedTuple fields. Their "
767 "mutability could lead to potential memory aliasing "
768 "problems");
769 Argument arg{
770 /*name=*/std::string{field_names[i]},
771 /*type=*/field_types[i],
772 /*N=*/i,
773 /*default_value=*/field_defaults[j]};
774 arguments.emplace_back(std::move(arg));
775 }
776 }
777
778 auto schema = std::make_shared<FunctionSchema>(
779 /*name=*/qualName.value_or(c10::QualifiedName()).name(),
780 /*overload_name=*/std::string(""),
781 /*arguments=*/std::move(arguments),
782 /*returns=*/std::vector<Argument>{});
783 return std::shared_ptr<TupleType>(new TupleType(
784 field_types, qualName, std::move(schema))); // NOLINT(modernize-make-shared)
785 }
786
names() const787 std::optional<std::vector<c10::string_view>> TupleType::names() const {
788 if (!schema_) {
789 return {};
790 }
791 std::vector<c10::string_view> ret;
792 for (const auto& arg : schema_->arguments()) {
793 ret.emplace_back(arg.name());
794 }
795 return ret;
796 }
797
isSubtypeOfExt(const Type & rhs,std::ostream * why_not) const798 bool NoneType::isSubtypeOfExt(const Type& rhs, std::ostream *why_not) const {
799 if (rhs.kind() == OptionalType::Kind) {
800 return true;
801 }
802 return Type::isSubtypeOfExt(rhs, why_not);
803 }
804
equals(const Type & rhs) const805 bool NumberType::equals(const Type& rhs) const {
806 if (auto union_type = rhs.cast<UnionType>()) {
807 return union_type->containedTypes().size() == 3 && union_type->canHoldType(*NumberType::get());
808 } else {
809 return rhs.kind() == this->kind();
810 }
811 }
812
isSubtypeOfExt(const Type & rhs,std::ostream * why_not) const813 bool NumberType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
814 if (auto union_type = rhs.cast<UnionType>()) {
815 return union_type->canHoldType(*NumberType::get());
816 } else {
817 return Type::isSubtypeOfExt(rhs, why_not);
818 }
819 }
820
TupleType(std::vector<TypePtr> elements,std::optional<c10::QualifiedName> name,std::shared_ptr<FunctionSchema> schema)821 TupleType::TupleType(
822 std::vector<TypePtr> elements,
823 std::optional<c10::QualifiedName> name,
824 std::shared_ptr<FunctionSchema> schema)
825 : NamedType(TypeKind::TupleType, std::move(name)),
826 elements_(std::move(elements)),
827 has_free_variables_(std::any_of(elements_.begin(), elements_.end(), [](const TypePtr& v) {
828 if (!v) {
829 throw std::runtime_error("Can not create tuple with None type");
830 }
831 return v->hasFreeVariables();
832 })), schema_(std::move(schema)) {
833
834 if (schema_) {
835 for (const Argument& arg : schema_->arguments()) {
836 checkNoAny(*this, "attribute", arg.name(), arg.type());
837 }
838 }
839 }
840
isSubtypeOfExt(const Type & rhs_,std::ostream * why_not) const841 bool TupleType::isSubtypeOfExt(const Type& rhs_, std::ostream* why_not) const {
842 if (Type::isSubtypeOfExt(rhs_, why_not)) {
843 return true;
844 }
845 if (rhs_.kind() == AnyTupleType::Kind) {
846 return true;
847 }
848 auto rhs = rhs_.cast<TupleType>();
849 if (!rhs)
850 return false;
851 // unnamed tuple is not a subtype of nametuple
852 if (!schema() && rhs->schema())
853 return false;
854 // namedtuple may be a subtype of unnamed tuple
855 auto test_names_match = [&](const std::shared_ptr<FunctionSchema>& lhs, const std::shared_ptr<FunctionSchema>& rhs) {
856 const auto& args_lhs = lhs->arguments();
857 const auto& args_rhs = rhs->arguments();
858 if (args_lhs.size() != args_rhs.size()) {
859 return false;
860 }
861
862 for (size_t i = 0; i < args_lhs.size(); ++i) {
863 if (args_lhs[i].name() != args_rhs[i].name()) {
864 return false;
865 }
866 }
867 return true;
868 };
869 bool names_match = !rhs->schema() || test_names_match(schema(), rhs->schema());
870 // co-variant rules for tuples
871 return names_match && compare(*rhs, [&](const Type& a, const Type& b) {
872 return a.isSubtypeOfExt(b, why_not);
873 });
874 }
875
isSubtypeOfExt(const Type & rhs_,std::ostream * why_not) const876 bool ListType::isSubtypeOfExt(const Type& rhs_, std::ostream* why_not) const {
877 if (Type::isSubtypeOfExt(rhs_, why_not)) {
878 return true;
879 }
880 if (rhs_.kind() == AnyListType::Kind) {
881 return true;
882 }
883 return false;
884 }
885
equals(const Type & rhs) const886 bool TupleType::equals(const Type& rhs) const {
887 bool typesSame =
888 compare(rhs, [](const Type& a, const Type& b) { return a == b; });
889 if (!typesSame) {
890 return false;
891 }
892
893 // `compare` guarantees that rhs is always a TupleType.
894 auto rhsTuple = rhs.expect<TupleType>();
895 if (schema_ == nullptr && rhsTuple->schema_ == nullptr) {
896 return typesSame;
897 }
898 if (schema_ == nullptr || rhsTuple->schema_ == nullptr) {
899 return false;
900 }
901 return *schema_ == *rhsTuple->schema_;
902 }
903
str() const904 std::string TupleType::str() const {
905 std::stringstream ss;
906 if (schema_ && name()) {
907 ss << name()->qualifiedName();
908 } else {
909 ss << "(";
910 for(size_t i = 0; i < elements().size(); ++i) {
911 if(i > 0)
912 ss << ", ";
913 ss << elements()[i]->str();
914 }
915 ss << ")";
916 }
917 return ss.str();
918 }
annotation_str_impl(const TypePrinter & printer) const919 std::string TupleType::annotation_str_impl(const TypePrinter& printer) const {
920 if (schema_ && name()) {
921 return name()->qualifiedName();
922 }
923
924 if (elements().empty()) {
925 // `typing.Tuple` special-cases the annotation syntax for empty tuple
926 // with `typing.Tuple[()]`. See
927 // https://docs.python.org/3/library/typing.html#typing.Tuple
928 return "Tuple[()]";
929 }
930
931 // Fast path for expectedly-small Tuples.
932 const auto elts = elements();
933 if (elts.size() <= 3) {
934 std::array<std::string, 3> elements_strs;
935 size_t total_length = 0;
936 int idx = 0;
937 for (const auto& element: elts) {
938 elements_strs[idx] = element->annotation_str(printer);
939 total_length += elements_strs[idx].size();
940 idx++;
941 }
942 std::string result;
943 result.reserve(strlen("Tuple[") + strlen(", ") * (elts.size() - 1) + total_length + 1);
944 result.append("Tuple[");
945 for (const auto ii : c10::irange(elts.size())) {
946 if (ii > 0) {
947 result.push_back(',');
948 result.push_back(' ');
949 }
950 result.append(elements_strs[ii]);
951 }
952 result.push_back(']');
953 return result;
954 }
955
956 std::ostringstream ss;
957 ss << "Tuple[";
958 size_t i = 0;
959 for (const auto& element: elts) {
960 if (i > 0) {
961 ss << ", ";
962 }
963 ss << element->annotation_str(printer);
964 i++;
965 }
966 ss << ']';
967 return std::move(ss).str();
968 }
969
create(QualifiedName qualifiedName,bool is_module)970 InterfaceTypePtr InterfaceType::create(QualifiedName qualifiedName, bool is_module) {
971 return InterfaceTypePtr(
972 new InterfaceType(std::move(qualifiedName), is_module));
973 }
974
FunctionType(torch::jit::Function * function)975 FunctionType::FunctionType(torch::jit::Function* function)
976 : NamedType(TypeKind::FunctionType, function->qualname()),
977 function_(function) {}
978
isSubTypeImpl(const InterfaceType & lhs,const InterfaceType & rhs,std::ostream * why_not)979 bool InterfaceType::isSubTypeImpl(
980 const InterfaceType& lhs,
981 const InterfaceType& rhs,
982 std::ostream* why_not) {
983 if (!lhs.is_module() && rhs.is_module()) {
984 if (why_not) {
985 *why_not << "Interface '" << lhs.repr_str() << "' is not a subtype of "
986 << "the module interface '" << rhs.repr_str() << "'.\n";
987 }
988 return false;
989 }
990 for (const FunctionSchema& schema : *rhs.methods_) {
991 auto self_schema = lhs.getMethod(schema.name());
992 if (!self_schema) {
993 if (why_not) {
994 *why_not << "Interface '" << lhs.repr_str()
995 << "' does not have method '" << schema.name() << "' but interface '"
996 << rhs.repr_str() << "' does.\n";
997 }
998 return false;
999 }
1000 // NOLINTNEXTLINE(bugprone-argument-comment)
1001 if (!self_schema->isSubtypeOf(schema, /*is_method=*/true, why_not)) {
1002 if (why_not) {
1003 *why_not << "Method on interface '" << lhs.repr_str()
1004 << "' (1) is not compatible with interface '"
1005 << rhs.repr_str() << "' (2)\n"
1006 << " (1) " << *self_schema << "\n"
1007 << " (2) " << schema << "\n";
1008 return false;
1009 }
1010 return false;
1011 }
1012 }
1013 return true;
1014 }
1015
isSubtypeOfExt(const Type & rhs,std::ostream * why_not) const1016 bool InterfaceType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
1017 // to improve performance this check can be cached
1018 if (auto iface = rhs.castRaw<InterfaceType>()) {
1019 return isSubTypeImpl(*this, *iface, why_not);
1020 }
1021 return Type::isSubtypeOfExt(rhs, why_not);
1022 }
1023
getMethod(const std::string & name) const1024 const FunctionSchema* InterfaceType::getMethod(const std::string& name) const {
1025 for (const FunctionSchema& method : *methods_) {
1026 if (method.name() == name) {
1027 return &method;
1028 }
1029 }
1030 return nullptr;
1031 }
addMethod(FunctionSchema schema)1032 void InterfaceType::addMethod(FunctionSchema schema) {
1033 methods_->emplace_back(std::move(schema));
1034 }
InterfaceType(QualifiedName name,bool is_module)1035 InterfaceType::InterfaceType(QualifiedName name, bool is_module)
1036 : NamedType(InterfaceType::Kind, std::move(name)),
1037 methods_(std::make_shared<std::vector<FunctionSchema>>()),
1038 is_module_(is_module) {}
1039
1040 InterfaceType::~InterfaceType() = default;
1041
containsAnyType(const TypePtr & type)1042 bool containsAnyType(const TypePtr& type) {
1043 std::vector<TypePtr> to_scan = { type };
1044 while (!to_scan.empty()) {
1045 const auto typ = to_scan.back();
1046 to_scan.pop_back();
1047 if (typ->kind() == AnyType::Kind) {
1048 return true;
1049 }
1050 for (const TypePtr& sub : typ->containedTypes()) {
1051 to_scan.emplace_back(sub);
1052 }
1053 }
1054 return false;
1055 }
1056
checkNoAny(const Type & base,const char * what,const std::string & attrname,const TypePtr & attrtype)1057 void checkNoAny(const Type& base, const char* what, const std::string& attrname, const TypePtr& attrtype) {
1058 TORCH_CHECK(
1059 !containsAnyType(attrtype),
1060 "attempting to add ",
1061 what,
1062 " '",
1063 attrname,
1064 "' of type ",
1065 attrtype->repr_str(),
1066 " to '",
1067 base.repr_str(),
1068 "' but it contains an Any type. Any types cannot be members of modules, classes, or named tuples.");
1069 }
1070
merge(const SymbolicShape & other) const1071 SymbolicShape SymbolicShape::merge(const SymbolicShape& other) const {
1072 if (!dims_ || !other.dims_ || dims_->size() != other.dims_->size()) {
1073 return SymbolicShape();
1074 }
1075 std::vector<ShapeSymbol> dims;
1076 for (size_t i = 0, n = dims_->size(); i < n; i++) {
1077 dims.push_back(merge_primitive((*dims_)[i], (*other.dims_)[i]));
1078 }
1079 return SymbolicShape(std::move(dims));
1080 }
1081
dump() const1082 void SymbolicShape::dump() const {
1083 std::cout << *this << "\n";
1084 }
1085
isSubtypeOfExt(const Type & rhs,std::ostream * why_not) const1086 bool EnumType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
1087 return rhs.kind() == TypeKind::AnyType ||
1088 rhs.kind() == TypeKind::AnyEnumType ||
1089 *this == rhs ||
1090 Type::isSubtypeOfExt(rhs, why_not);
1091 }
1092
1093 } // namespace c10
1094