1 #pragma once
2
3 #include <ATen/core/custom_class.h>
4 #include <ATen/core/jit_type_base.h>
5 #include <ATen/core/TensorBody.h>
6 #include <ATen/core/functional.h>
7 #include <ATen/core/symbol.h>
8 #include <ATen/core/type_factory.h>
9 #include <ATen/core/qualified_name.h>
10 #include <c10/util/TypeList.h>
11 #include <optional>
12 #include <c10/core/SymFloat.h>
13 #include <c10/core/SymBool.h>
14 #include <c10/core/Device.h>
15
16 #include <array>
17 #include <memory>
18 #include <ostream>
19 #include <sstream>
20 #include <utility>
21
22
23 namespace torch::jit {
24 struct Function;
25 } // namespace torch::jit
26
27
28 namespace c10 {
29
30 template<class Key, class Value>
31 class Dict;
32 struct IValue;
33 struct FunctionSchema;
34 struct NamedType;
35 using OptNameList = std::optional<std::vector<std::string>>;
36
37 void standardizeVectorForUnion(std::vector<TypePtr>& reference, std::vector<TypePtr>* to_fill);
38 void standardizeVectorForUnion(std::vector<TypePtr>* to_flatten);
39
is_contiguous_strides(const IntArrayRef sizes,const IntArrayRef strides)40 inline bool is_contiguous_strides(
41 const IntArrayRef sizes,
42 const IntArrayRef strides) {
43 int n_dim = static_cast<int>(sizes.size());
44 if (n_dim == 0) {
45 return true;
46 }
47
48 if (strides[n_dim - 1] != 1) {
49 return false;
50 }
51
52 for (int i = n_dim - 2; i >= 0; i--) {
53 if (strides[i] != strides[i + 1] * sizes[i + 1]) {
54 return false;
55 }
56 }
57 return true;
58 }
59
60 struct AnyType;
61 using AnyTypePtr = SingletonTypePtr<AnyType>;
62 // Any is the top of the type hierarchy, all other types are subtypes
63 // T <: Any, forall T
64 struct TORCH_API AnyType : public Type {
equalsAnyType65 bool equals(const Type& rhs) const override {
66 return rhs.kind() == kind();
67 }
strAnyType68 std::string str() const override {
69 return "Any";
70 }
71 static const TypeKind Kind = TypeKind::AnyType;
72 // global singleton
73 static AnyTypePtr get();
74
75 private:
AnyTypeAnyType76 AnyType() : Type(TypeKind::AnyType) {}
77 };
78
toString(const Type & type)79 inline std::string toString(const Type& type) {
80 return type.str();
81 }
82
83 // Shim for compatibility with code that uses TypePtr.
toString(const TypePtr & typePtr)84 inline std::string toString(const TypePtr& typePtr) {
85 return toString(*typePtr);
86 }
87
88 inline bool operator!=(const Type& lhs, const Type& rhs) {
89 return !(lhs == rhs);
90 }
91
92 // common base for all types that have a single sub element
93 // e.g. Future[T], Optional[T], List[T]
94 template <TypeKind K, typename T>
95 struct SingleElementType : public SharedType {
96 static const TypeKind Kind = K;
97
getElementTypeSingleElementType98 const TypePtr& getElementType() const {
99 return elem;
100 }
101
hasFreeVariablesSingleElementType102 bool hasFreeVariables() const override {
103 return getElementType()->hasFreeVariables();
104 }
105
containedTypesSingleElementType106 at::ArrayRef<TypePtr> containedTypes() const override {
107 return elem;
108 }
109
equalsSingleElementType110 bool equals(const Type& rhs) const override {
111 if (auto rhs_ = rhs.cast<T>()) {
112 return *getElementType() == *rhs_->getElementType();
113 }
114 return false;
115 }
116
117 protected:
SingleElementTypeSingleElementType118 SingleElementType(TypePtr elem) : SharedType(Kind), elem(std::move(elem)) {
119 if (!this->elem) {
120 throw std::runtime_error(c10::str(
121 "Can not create ", typeKindToString(Kind), " with None type"));
122 }
123 }
124
125 private:
126 TypePtr elem;
127 };
128
129 struct UnionType;
130 using UnionTypePtr = std::shared_ptr<UnionType>;
131 struct TORCH_API UnionType : public SharedType {
132 friend struct Type;
133
134 static const TypeKind Kind = TypeKind::UnionType;
135
136 bool isSubtypeOfExt(const Type& rhs_, std::ostream* why_not) const override;
137
138 std::string str() const override;
139
140 static UnionTypePtr create(std::vector<TypePtr> reference);
141
142 bool equals(const Type& rhs) const override;
143
isUnionTypeUnionType144 bool isUnionType() const override {
145 return true;
146 }
147
containedTypesUnionType148 at::ArrayRef<TypePtr> containedTypes() const override {
149 return types_;
150 }
151
152 // For testing purposes only
getTypesUnionType153 at::ArrayRef<TypePtr> getTypes() const {
154 return types_;
155 }
156
createWithContainedUnionType157 TypePtr createWithContained(std::vector<TypePtr> contained_types) const override {
158 return create(std::move(contained_types));
159 }
160
161 bool canHoldType(const Type& type) const;
162
hasFreeVariablesUnionType163 bool hasFreeVariables() const override {
164 return has_free_variables_;
165 }
166
167 std::optional<TypePtr> toOptional() const;
168
169 std::optional<TypePtr> subtractTypeSet(std::vector<TypePtr>& to_subtract) const;
170
171 protected:
172 explicit UnionType(std::vector<TypePtr> types, TypeKind kind=TypeKind::UnionType);
173 std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override;
174 std::string unionStr(
175 const TypePrinter& printer = nullptr,
176 bool is_annotation_str = false) const;
177 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
178 bool has_free_variables_;
179 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
180 std::vector<TypePtr> types_;
181 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
182 bool can_hold_none_;
183
184 };
185
186 struct OptionalType;
187 using OptionalTypePtr = std::shared_ptr<OptionalType>;
188 // This type represents an optional type. There is one `Optional` for
189 // each element type. `Optional[T]` can accept both `T` and
190 // `None`(`std::nullopt` in C++)
191 // Subtype hierarchy for Optional:
192 // - Optional[T] <: Optional[R] iff T <: R
193 // - T <: Optional[R] if T <: R
194 // - None <: Optional[T] for all T
195 // - Optional[T] == Union[T, None] for all T
196 struct TORCH_API OptionalType : public UnionType {
197 static OptionalTypePtr create(const TypePtr& contained);
198
199 static const TypeKind Kind = TypeKind::OptionalType;
200
201 friend struct Type;
202
203 bool equals(const Type& rhs) const override;
204
getElementTypeOptionalType205 const TypePtr& getElementType() const {
206 return contained_;
207 }
208
containedTypesOptionalType209 at::ArrayRef<TypePtr> containedTypes() const override {
210 return contained_;
211 }
212
strOptionalType213 std::string str() const override {
214 std::stringstream ss;
215 ss << getElementType()->str() << "?";
216 return ss.str();
217 }
218
createWithContainedOptionalType219 TypePtr createWithContained(
220 std::vector<TypePtr> contained_types) const override {
221 AT_ASSERT(contained_types.size() == 1);
222 return create(contained_types[0]);
223 }
224
225 bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
226
isUnionTypeOptionalType227 bool isUnionType() const override {
228 return true;
229 }
230
231 // common cast Optional[Tensor] for undefined tensor type
232 static TypePtr ofTensor();
233 //
234 // global singleton
235 static TypePtr get(TypePtr inner);
236
237 private:
238 explicit OptionalType(const TypePtr& contained);
239
240 TypePtr contained_;
241
242 std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
243 std::stringstream ss;
244 ss << "Optional[" << getElementType()->annotation_str(printer) << "]";
245 return ss.str();
246 }
247 };
248
249 template <typename T>
merge_primitive(const std::optional<T> & a,const std::optional<T> & b)250 inline std::optional<T> merge_primitive(
251 const std::optional<T>& a,
252 const std::optional<T>& b) {
253 if (a.has_value() && b.has_value() && a.value() == b.value()) {
254 return a;
255 }
256 return std::optional<T>{};
257 }
258
259 // If we see `a + b + c` and know that a, b, and c are the same size and have
260 // two dimensions (WxH), then we can generate a fused kernel for them. That
261 // fused kernel would likely have indexing math to handling both the W and H
262 // dimensions. However, if we knew the WxH dimensions were contiguous, we can
263 // pretend like we only have a single dimension, simplifying the indexing logic.
264 // This can be performed even if the dimensions are transposed,
265 // as long as a, b, and c are transposed in the same way.
266 // We'd like to have the compiler be able to do this dimensionality reduction,
267 // but simply knowing sizes is not enough.
268 // We can extend profiling to also record stride information.
269 // Rather than recording specific strides,
270 // we can simply order the strides from smallest to largest with
271 // `stride_indices` A contiguity marker on the smallest stride (c0) indicates
272 // the stride is precisely 1, otherwise a contiguity marker means that $stride_n
273 // = size_{n-1}*stride_{n-1}$
274 struct TORCH_API Stride {
275 Stride() = default;
StrideStride276 Stride(
277 const std::optional<size_t>& stride_index,
278 std::optional<bool> contiguous,
279 const std::optional<size_t>& stride)
280 : stride_index_(stride_index), contiguous_(contiguous), stride_(stride) {}
281
282 bool operator==(const Stride& b) const {
283 return stride_index_ == b.stride_index_ && contiguous_ == b.contiguous_ &&
284 stride_ == b.stride_;
285 }
286
isCompleteStride287 bool isComplete() const {
288 return stride_index_ && contiguous_ && stride_;
289 }
290
291 std::optional<size_t> stride_index_;
292 std::optional<bool> contiguous_;
293 std::optional<size_t> stride_;
294 };
295
296 template <>
merge_primitive(const std::optional<Stride> & a,const std::optional<Stride> & b)297 inline std::optional<Stride> merge_primitive(
298 const std::optional<Stride>& a,
299 const std::optional<Stride>& b) {
300 std::optional<Stride> left = a;
301 std::optional<Stride> right = b;
302 if (!left.has_value()) {
303 left = {Stride()};
304 }
305 if (!right.has_value()) {
306 right = {Stride()};
307 }
308
309 auto merged_index =
310 merge_primitive(left->stride_index_, right->stride_index_);
311 auto merged_cont = merge_primitive(left->contiguous_, right->contiguous_);
312 auto merged_stride = merge_primitive(left->stride_, right->stride_);
313 auto r = Stride(merged_index, merged_cont, merged_stride);
314 // normalize
315 if (!r.stride_index_.has_value() && !r.contiguous_.has_value() &&
316 !r.stride_.has_value()) {
317 return std::optional<Stride>{};
318 }
319
320 return r;
321 }
322
323 struct TORCH_API ShapeSymbol {
324 // needed for use in `std::map`
ShapeSymbolShapeSymbol325 ShapeSymbol() : value_(-1) {}
326 // is this symbol a fixed/static dimension
is_staticShapeSymbol327 bool is_static() const {
328 return value_ >= 0;
329 };
330 bool operator==(const ShapeSymbol& b) const {
331 return value_ == b.value_;
332 }
333 bool operator<(const ShapeSymbol& b) const {
334 return value_ < b.value_;
335 }
336
fromStaticSizeShapeSymbol337 static ShapeSymbol fromStaticSize(int64_t val) {
338 return ShapeSymbol(val);
339 }
static_sizeShapeSymbol340 int64_t static_size() const {
341 TORCH_CHECK(is_static());
342 return value_;
343 };
344
valueShapeSymbol345 int64_t value() const {
346 return value_;
347 };
348
newSymbolShapeSymbol349 static ShapeSymbol newSymbol() {
350 return fromStaticSize(-static_cast<int64_t>(++num_symbols));
351 };
352 friend TORCH_API std::ostream& operator<<(
353 std::ostream& os,
354 const ShapeSymbol& s);
355
356 private:
ShapeSymbolShapeSymbol357 ShapeSymbol(int64_t val) : value_(val) {}
358 int64_t value_;
359 static std::atomic<size_t> num_symbols;
360 };
361
merge_primitive(const ShapeSymbol & a,const ShapeSymbol & b)362 inline ShapeSymbol merge_primitive(
363 const ShapeSymbol& a,
364 const ShapeSymbol& b) {
365 if (a.is_static() && b.is_static() && a == b) {
366 return a;
367 }
368 return ShapeSymbol::newSymbol();
369 }
370
371 // Shape of a Tensor represented with ShapeSymbol's. Unranked, ranked unknown
372 // dims, partially known and fully known shapes are all supported.
373 struct TORCH_API SymbolicShape {
374 // Unranked shape constructor.
SymbolicShapeSymbolicShape375 SymbolicShape() : dims_(std::nullopt) {}
376
377 // Known rank but unknown dimentions.
SymbolicShapeSymbolicShape378 SymbolicShape(std::optional<size_t> rank) : dims_(std::nullopt) {
379 if(!rank) {
380 return;
381 }
382
383 std::vector<ShapeSymbol> shape_symbols;
384 shape_symbols.reserve(*rank);
385 for(size_t i = 0; i < *rank; ++i) {
386 shape_symbols.push_back(ShapeSymbol::newSymbol());
387 }
388 dims_ = shape_symbols;
389 }
390
391 // Mix of known and unknown ranks
SymbolicShapeSymbolicShape392 SymbolicShape(const std::vector<std::optional<int64_t>>& dims) {
393 std::vector<ShapeSymbol> shape_symbols;
394 shape_symbols.reserve(dims.size());
395 for(std::optional<int64_t> dim: dims) {
396 if(!dim) {
397 shape_symbols.push_back(ShapeSymbol::newSymbol());
398 } else {
399 shape_symbols.push_back(ShapeSymbol::fromStaticSize(*dim));
400 }
401 }
402 dims_ = shape_symbols;
403 }
404
405 void dump() const;
406
SymbolicShapeSymbolicShape407 SymbolicShape(std::vector<ShapeSymbol> dims) : dims_(std::move(dims)) {}
408
SymbolicShapeSymbolicShape409 SymbolicShape(c10::IntArrayRef dims) {
410 std::vector<ShapeSymbol> shape_symbols;
411 shape_symbols.reserve(dims.size());
412 for(int64_t dim : dims) {
413 shape_symbols.push_back(ShapeSymbol::fromStaticSize(dim));
414 }
415 dims_ = shape_symbols;
416 }
417
418 ShapeSymbol operator[](size_t i) const {
419 if (!dims_) {
420 throw std::runtime_error("Rank isn't fixed");
421 }
422 return (*dims_).at(i);
423 }
424
atSymbolicShape425 ShapeSymbol at(size_t i) const {
426 if (!dims_) {
427 throw std::runtime_error("Rank isn't fixed");
428 }
429 return (*dims_).at(i);
430 }
431
432 // Returns rank or nullopt in case of unranked shape.
rankSymbolicShape433 std::optional<size_t> rank() const {
434 if(!dims_) {
435 return std::nullopt;
436 }
437 return dims_->size();
438 }
439
sizesSymbolicShape440 std::optional<std::vector<ShapeSymbol>> sizes() const {
441 return dims_;
442 }
443
symbolicDimsSymbolicShape444 std::optional<std::vector<bool>> symbolicDims() const {
445 if (!dims_) {
446 return std::nullopt;
447 }
448 auto symbolic_dims = std::vector<bool>();
449 for (const ShapeSymbol& s : *dims_) {
450 symbolic_dims.push_back(!s.is_static());
451 }
452 return symbolic_dims;
453 }
454
455 // Checks whether the shape is fully defined/complete, ie. rank and sizes
456 // of every dimension are known.
isCompleteSymbolicShape457 bool isComplete() const {
458 if(!dims_) {
459 return false;
460 }
461 for(auto d : *dims_) {
462 if(!d.is_static()) {
463 return false;
464 }
465 }
466 return true;
467 }
468
469 // Create new SymbolicShape that is result of merging self and another
470 // SymbolicShape. Only dimensions that are static and equal will be
471 // preserved.
472 // If either of two shapes are of unknown rank or they have unmatching rank,
473 // result will be unranked.
474 SymbolicShape merge(const SymbolicShape& other) const;
475
476 friend bool operator==(const SymbolicShape& lhs, const SymbolicShape& rhs) {
477 return lhs.dims_ == rhs.dims_;
478 }
479
480 friend bool operator!=(const SymbolicShape& lhs, const SymbolicShape& rhs) {
481 return !(lhs == rhs);
482 }
483
484 private:
485 std::optional<std::vector<ShapeSymbol>> dims_;
486 };
487
488 namespace detail {
isComplete(const Stride & s)489 inline bool isComplete(const Stride& s) {
490 return s.isComplete();
491 }
492
493 template<typename T>
isComplete(const T &)494 inline bool isComplete(const T& /*t*/) {
495 return true;
496 }
497 }
498
499 template <typename T>
500 struct VaryingShape {
501 using ListOfOptionalElements = std::vector<std::optional<T>>;
VaryingShapeVaryingShape502 VaryingShape(const std::vector<T>& vec)
503 : VaryingShape(ListOfOptionalElements(vec.begin(), vec.end())) {}
504
VaryingShapeVaryingShape505 VaryingShape(c10::ArrayRef<T> vec)
506 : VaryingShape(ListOfOptionalElements(vec.begin(), vec.end())) {}
507
dims_VaryingShape508 VaryingShape(std::optional<size_t> size = std::nullopt) : dims_(std::nullopt) {
509 if (size) {
510 dims_ = ListOfOptionalElements(*size);
511 }
512 }
513
VaryingShapeVaryingShape514 VaryingShape(ListOfOptionalElements dims) : dims_(std::move(dims)) {}
515
VaryingShapeVaryingShape516 VaryingShape(size_t size) : VaryingShape(std::optional<size_t>(size)) {}
517
518 bool operator==(const VaryingShape& other) const {
519 return dims_ == other.dims_;
520 }
521
522 const std::optional<T> &operator[](size_t i) const {
523 if (!dims_) {
524 throw std::runtime_error("Rank isn't fixed");
525 }
526 return (*dims_).at(i);
527 }
528
sizeVaryingShape529 std::optional<size_t> size() const {
530 if (!dims_) {
531 return std::nullopt;
532 }
533 const auto& dims = dims_.value();
534 return dims.size();
535 }
536
sizesVaryingShape537 const std::optional<ListOfOptionalElements>& sizes() const {
538 return dims_;
539 }
540
541 TORCH_API VaryingShape merge(const VaryingShape& other) const;
542
concrete_sizesVaryingShape543 std::optional<std::vector<T>> concrete_sizes() const {
544 if (!dims_) {
545 return std::nullopt;
546 }
547 std::vector<T> sizes;
548 sizes.reserve(dims_.value().size());
549 for (auto d : *dims_) {
550 if (!d) {
551 return std::nullopt;
552 }
553 sizes.push_back(d.value());
554 }
555 return sizes;
556 }
557
isCompleteVaryingShape558 bool isComplete() const {
559 if (!dims_) {
560 return false;
561 }
562 for (auto d : *dims_) {
563 if (!d || !detail::isComplete(*d)) {
564 return false;
565 }
566 }
567 return true;
568 }
569
570 private:
571 std::optional<ListOfOptionalElements> dims_;
572 };
573
574 struct TensorType;
575 // TODO: investigate making this SingletonOrSharedTypePtr<TensorType>
576 using TensorTypePtr = std::shared_ptr<TensorType>;
577 // This type represents a single Tensor with a specific size
578 struct TORCH_API TensorType : public SharedType {
579 static TensorTypePtr create(const at::Tensor& t);
580
581 // used by TensorType::create(size_t dim) which in turn used by
582 // shape_analysis.cpp
583 static TensorTypePtr create(
584 std::optional<at::ScalarType> scalar_type,
585 std::optional<Device> device,
586 const VaryingShape<int64_t>& sizes,
587 const VaryingShape<int64_t>& strides,
588 std::optional<bool> requires_grad,
589 std::optional<bool> undefined = false,
590 bool tensor_contiguity = false);
591
592 static TensorTypePtr create(
593 std::optional<at::ScalarType> scalar_type,
594 std::optional<Device> device,
595 const SymbolicShape& sizes,
596 const VaryingShape<Stride>& stride_,
597 std::optional<bool> requires_grad,
598 std::optional<bool> undefined = false);
599
600 static TensorTypePtr create(
601 std::optional<at::ScalarType> scalar_type,
602 std::optional<Device> device,
603 std::optional<size_t> dim,
604 std::optional<bool> requires_grad);
605
606 // overloaded create variadic template argument as it could not distinguish
607 // initializer list
608 static TensorTypePtr createContiguous(
609 at::ScalarType scalar_type,
610 at::Device device,
611 at::IntArrayRef sizes);
612
613 static TypePtr fromNumberType(const Type& typ);
614 static TypePtr fromBoolType();
615
dimTensorType616 std::optional<size_t> dim() const {
617 return sizes().size();
618 }
619
620 VaryingShape<int64_t> sizes() const;
621
622 VaryingShape<int64_t> strides() const;
623
stride_propertiesTensorType624 const VaryingShape<Stride>& stride_properties() const {
625 return strides_;
626 }
627
deviceTensorType628 std::optional<at::Device> device() const {
629 return device_;
630 }
scalarTypeTensorType631 std::optional<at::ScalarType> scalarType() const {
632 return scalar_type_;
633 }
requiresGradTensorType634 std::optional<bool> requiresGrad() const {
635 return requires_grad_;
636 }
requires_gradTensorType637 bool requires_grad() const override {
638 return requires_grad_ ? *requires_grad_ : true;
639 }
640
641 bool equals(const Type& rhs) const override;
642 bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
643
644 std::string str() const override;
645
repr_strTensorType646 std::string repr_str() const override {
647 if (isInferredType()) {
648 return str() + " (inferred)";
649 } else {
650 return str();
651 }
652 }
653
numelTensorType654 std::optional<size_t> numel() const {
655 size_t prod = 1;
656 const auto& shape = sizes();
657
658 for (size_t i = 0; i < shape.size(); i++) {
659 if (!shape[i]) {
660 return std::optional<size_t>{};
661 }
662 prod *= shape[i].value();
663 }
664 return prod;
665 }
666
withRequiresGradTensorType667 TensorTypePtr withRequiresGrad(std::optional<bool> s) {
668 auto copy = clone();
669 copy->requires_grad_ = s;
670 return copy;
671 }
672
withScalarTypeTensorType673 TensorTypePtr withScalarType(std::optional<ScalarType> st) {
674 auto copy = clone();
675 copy->scalar_type_ = st;
676 return copy;
677 }
678
withDimTensorType679 TensorTypePtr withDim(std::optional<size_t> d) {
680 auto copy = clone();
681 // withDim is only used by the legacy executor
682 // that only cares about the rank, so create dummy symbols)) :
683 copy->sizes_ = SymbolicShape(d);
684 copy->strides_ = VaryingShape<Stride>(d);
685 return copy;
686 }
687
withStridesTensorType688 TensorTypePtr withStrides(VaryingShape<Stride> sstrides) const {
689 auto cloned = clone();
690 cloned->strides_ = std::move(sstrides);
691 return cloned;
692 }
693
withSizesStridesTensorType694 TensorTypePtr withSizesStrides(
695 at::IntArrayRef sizes,
696 at::IntArrayRef strides) const {
697 auto cloned = clone();
698 auto ssizes = SymbolicShape(sizes);
699 cloned->sizes_ = ssizes;
700 cloned->strides_ = computeStrideProps(sizes, strides);
701 return cloned;
702 }
703
withSymbolicShapesTensorType704 TensorTypePtr withSymbolicShapes(SymbolicShape ssizes) const {
705 auto cloned = clone();
706 cloned->sizes_ = std::move(ssizes);
707 return cloned;
708 }
709
withSizesTensorType710 TensorTypePtr withSizes(at::IntArrayRef sizes) const {
711 return withSizesStrides(
712 sizes, contiguousStridesOf(sizes));
713 }
714
withDeviceTensorType715 TensorTypePtr withDevice(const std::optional<at::Device> device) const {
716 auto copy = clone();
717 copy->device_ = device;
718 return copy;
719 }
720
dimensionedOnlyTensorType721 TensorTypePtr dimensionedOnly() const {
722 auto copy = clone();
723 copy->sizes_ = SymbolicShape(sizes().size());
724 copy->strides_ = VaryingShape<Stride>(sizes().size());
725 return copy;
726 }
727
contiguousTensorType728 TensorTypePtr contiguous() const {
729 auto cloned = clone();
730 TORCH_INTERNAL_ASSERT(sizes().concrete_sizes().has_value());
731 auto strides = computeStrideProps(
732 *sizes().concrete_sizes(),
733 contiguousStridesOf(*sizes().concrete_sizes()));
734 cloned->strides_ = strides;
735 return cloned;
736 }
737
738 const SymbolicShape& symbolic_sizes() const;
739
740 TensorTypePtr merge(const TensorType& other, bool merge_sizes = true) const;
741
742 bool matchTensor(const at::Tensor& t);
743
744 // is all information about the type specified except for autograd?
745 // This replaces the notion of a 'CompleteTensorType' that used to exist
746 // in the type-hierarchy. Excluding require_grad and undefined allows
747 // this to match the old behavior.
isCompleteTensorType748 bool isComplete() const {
749 return scalar_type_ && device_ && sizes_.isComplete() && strides_.isComplete();
750 }
751
isInferredTypeTensorType752 bool isInferredType() const {
753 return is_inferred_;
754 }
755
getInferredTensorType756 static TensorTypePtr getInferred() {
757 static auto valueInferred = TensorType::create(
758 /*scalar_type=*/{},
759 /*device=*/{},
760 /*sizes=*/SymbolicShape(),
761 /*stride=*/VaryingShape<Stride>{},
762 /*requires_grad=*/{},
763 /*undefined=*/false);
764 valueInferred->is_inferred_ = true;
765 return valueInferred;
766 }
767
768 // this property is used by GuardElimination
769 // please see `checkInputs` for more details
isSummarizedTensorType770 bool isSummarized() const {
771 return !(isComplete() && requiresGrad().has_value() &&
772 undefined().has_value());
773 }
774
withUndefinedTensorType775 TensorTypePtr withUndefined() {
776 auto r = clone();
777 r->undefined_ = true;
778 return r;
779 }
780
withPossiblyUndefinedTensorType781 TensorTypePtr withPossiblyUndefined() {
782 auto r = clone();
783 r->undefined_ = std::nullopt;
784 return r;
785 }
786
undefinedTensorType787 std::optional<bool> undefined() const { return undefined_; }
788
789 static const TensorTypePtr& get();
790
791 static const TypeKind Kind = TypeKind::TensorType;
792
793 static std::vector<int64_t> contiguousStridesOf(
794 at::IntArrayRef in_sizes,
795 at::MemoryFormat memory_format = MemoryFormat::Contiguous) {
796 auto contiguous_fn = [](const at::IntArrayRef& sizes,
797 const std::vector<int64_t>& dim_order) {
798 std::vector<int64_t> strides(sizes.size());
799 if (sizes.empty()) // zero-dim case
800 return strides;
801
802 strides[dim_order[0]] = 1;
803 for (size_t i = 1; i < dim_order.size(); i++) {
804 auto cur_dim = dim_order[i];
805 auto pre_dim = dim_order[i - 1];
806 strides[cur_dim] = strides[pre_dim] * sizes[pre_dim];
807 }
808 return strides;
809 };
810
811 std::vector<int64_t> dim_order(in_sizes.size());
812 if (memory_format == MemoryFormat::ChannelsLast) {
813 dim_order = {1, 3, 2, 0};
814 } else if (memory_format == MemoryFormat::ChannelsLast3d) {
815 dim_order = {1, 4, 3, 2, 0};
816 } else {
817 auto ndims = in_sizes.size();
818 for (size_t i = 0; i < ndims; i++) {
819 dim_order[i] = static_cast<int64_t>(ndims - i - 1); // Reverse
820 }
821 }
822 return contiguous_fn(in_sizes, dim_order);
823 }
824
825 private:
826 TensorType(
827 std::optional<at::ScalarType> scalar_type,
828 std::optional<Device> device,
829 SymbolicShape sizes,
830 VaryingShape<Stride> strides,
831 std::optional<bool> requires_grad,
832 std::optional<bool> undefined = false);
833
cloneTensorType834 TensorTypePtr clone() const {
835 return TensorTypePtr(new TensorType(
836 scalar_type_, device_, sizes_, strides_, requires_grad_, undefined_));
837 }
838
839 static VaryingShape<Stride> computeStrideProps(
840 at::IntArrayRef sizes,
841 at::IntArrayRef strides,
842 bool tensor_contiguity = false);
843
844 std::optional<at::ScalarType> scalar_type_;
845 std::optional<at::Device> device_;
846 SymbolicShape sizes_;
847 VaryingShape<Stride> strides_;
848 std::optional<bool> requires_grad_;
849 // we exploit the fact certain tensors must be zero in the autograd to
850 // optimize gradient computation. Such zero tensors are currently implemented
851 // with `UndefinedTensorImpl.` They can be handled only by special operators
852 // (e.g. `AutogradAdd`) and their `Tensor::defined()` property returns false.
853 // Normally, `undefined_` is set to false, unless a type was created
854 // with `withUndefined`
855 // This will also mean that `undefined` tensors will fail
856 // `subtypeOf(TensorType::get())` check
857 // undefined_ may become `std::nullopt` if the tensor was observed to be both
858 // defined and undefined. However, no tensor type starts out with
859 // `undefined_` set to `std::nullopt`
860 std::optional<bool> undefined_;
861 // Represents whether or not this type was inferred.
862 bool is_inferred_ = false;
863 };
864
865 struct ListType;
866 using ListTypePtr = std::shared_ptr<ListType>;
867 struct TORCH_API ListType
868 : public SingleElementType<TypeKind::ListType, ListType> {
869 // It's not exactly a singleton, but there should be exactly one instance of
870 // List[T] for every T
871 friend struct Type;
872 template <typename... T>
createListType873 static ListTypePtr create(T&&... all) {
874 return ListTypePtr(
875 new ListType(std::forward<T>(all)...)); // NOLINT(modernize-make-shared)
876 }
877
strListType878 std::string str() const override {
879 std::stringstream ss;
880 ss << getElementType()->str() << "[]";
881 return ss.str();
882 }
createWithContainedListType883 TypePtr createWithContained(
884 std::vector<TypePtr> contained_types) const override {
885 return create(std::move(contained_types.at(0)));
886 }
887
888 bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
889
890 // global singleton
891 // Given an inner type T and an identifier,
892 // this function wil return the global singleton type pointer
893 // the type List<T>.
894 // The extra "identifier" argument is needed beccause we have multiple container types
895 // that all re-use this function (List<T>, array<T, N>, etc.)
896 static TypePtr get(const std::string& identifier, TypePtr inner);
897
898 // common cast List[Tensor]
899 static ListTypePtr ofTensors();
900 static ListTypePtr ofOptionalTensors();
901 static ListTypePtr ofInts();
902 static ListTypePtr ofSymInts();
903 static ListTypePtr ofFloats();
904 static ListTypePtr ofComplexDoubles();
905 static ListTypePtr ofBools();
906 static ListTypePtr ofStrings();
907 static ListTypePtr ofNumbers();
908
909 private:
ListTypeListType910 ListType(TypePtr elem) : SingleElementType(std::move(elem)) {}
911
912 std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
913 std::stringstream ss;
914 ss << "List[" << getElementType()->annotation_str(printer) << "]";
915 return ss.str();
916 }
917 };
918
919 struct DictType;
920 using DictTypePtr = std::shared_ptr<DictType>;
921 struct TORCH_API DictType : public SharedType {
922 friend struct Type;
923 static const TypeKind Kind = TypeKind::DictType;
924
createDictType925 static DictTypePtr create(TypePtr key, TypePtr value) {
926 auto kind = key->kind();
927 if (auto dyn = key->castRaw<DynamicType>()) {
928 kind = dyn->dynamicKind();
929 }
930 switch (kind) {
931 case TypeKind::AnyType:
932 case TypeKind::IntType:
933 case TypeKind::BoolType:
934 case TypeKind::FloatType:
935 case TypeKind::ComplexType:
936 case TypeKind::StringType:
937 case TypeKind::TensorType:
938 case TypeKind::DeviceObjType:
939 return DictTypePtr(new DictType(std::move(key), std::move(value)));
940 default:
941 AT_ERROR(
942 "Cannot create dict for key type '",
943 key->str(),
944 "', only int, float, complex, Tensor, device and string keys are supported");
945 }
946 }
947
948 // aligned with the format in FunctionSchema
strDictType949 std::string str() const override {
950 std::stringstream ss;
951 ss << "Dict(" << getKeyType()->str() << ", " << getValueType()->str()
952 << ")";
953 return ss.str();
954 }
955
createWithContainedDictType956 TypePtr createWithContained(
957 std::vector<TypePtr> contained_types) const override {
958 if (contained_types.size() != 2) {
959 throw std::runtime_error("Expected 2 contained types");
960 }
961 return create(std::move(contained_types.at(0)), std::move(contained_types.at(1)));
962 }
963
getKeyTypeDictType964 const TypePtr& getKeyType() const {
965 return types.at(0);
966 }
967
getValueTypeDictType968 const TypePtr& getValueType() const {
969 return types.at(1);
970 }
971
hasFreeVariablesDictType972 bool hasFreeVariables() const override {
973 return has_free_variables;
974 }
975
containedTypesDictType976 at::ArrayRef<TypePtr> containedTypes() const override {
977 return types;
978 }
979
equalsDictType980 bool equals(const Type& rhs) const override {
981 if (auto* dict_rhs = rhs.castRaw<DictType>()) {
982 return *getKeyType() == *(dict_rhs->getKeyType()) &&
983 *getValueType() == *(dict_rhs->getValueType());
984 }
985 return false;
986 }
987
988 // global singleton
989 // Given an inner type T and an identifier,
990 // this function will return the global singleton type pointer
991 // the type List<T>.
992 // The extra "identifier" argument is needed because we have multiple container types
993 // that all re-use this function (Dict<K, V> and unordered_map<K, V>)
994 static TypePtr get(const std::string& identifier, TypePtr key, TypePtr val);
995
996 private:
DictTypeDictType997 DictType(TypePtr key, TypePtr value)
998 : SharedType(TypeKind::DictType),
999 has_free_variables(
1000 key->hasFreeVariables() || value->hasFreeVariables()) {
1001 types.reserve(2);
1002 types.push_back(std::move(key));
1003 types.push_back(std::move(value));
1004 }
1005
1006 std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override;
1007
1008 std::vector<TypePtr> types;
1009 bool has_free_variables;
1010 };
1011
1012 struct FutureType;
1013 using FutureTypePtr = std::shared_ptr<FutureType>;
1014
1015 struct TORCH_API FutureType
1016 : public SingleElementType<TypeKind::FutureType, FutureType> {
1017 friend struct Type;
1018 template <typename... T>
createFutureType1019 static FutureTypePtr create(TypePtr elem) {
1020 return FutureTypePtr(
1021 new FutureType(std::move(elem))); // NOLINT(modernize-make-shared)
1022 }
1023
strFutureType1024 std::string str() const override {
1025 std::stringstream ss;
1026 ss << "Future(" << getElementType()->str() << ")";
1027 return ss.str();
1028 }
createWithContainedFutureType1029 TypePtr createWithContained(
1030 std::vector<TypePtr> contained_types) const override {
1031 return create(std::move(contained_types.at(0)));
1032 }
1033
isSubtypeOfExtFutureType1034 bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override {
1035 if (Type::isSubtypeOfExt(rhs, why_not)) {
1036 return true;
1037 }
1038 if (auto rhs_ = rhs.castRaw<FutureType>()) {
1039 return getElementType()->isSubtypeOfExt(*rhs_->getElementType(), why_not);
1040 }
1041 return false;
1042 }
1043
1044 private:
FutureTypeFutureType1045 FutureType(TypePtr elem) : SingleElementType(std::move(elem)) {}
1046
1047 std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
1048 std::stringstream ss;
1049 ss << "Future[" << getElementType()->annotation_str(printer) << "]";
1050 return ss.str();
1051 }
1052 };
1053
1054 struct AwaitType;
1055 using AwaitTypePtr = std::shared_ptr<AwaitType>;
1056
1057 struct TORCH_API AwaitType
1058 : public SingleElementType<TypeKind::AwaitType, AwaitType> {
1059 friend struct Type;
1060 template <typename... T>
createAwaitType1061 static AwaitTypePtr create(TypePtr elem) {
1062 return AwaitTypePtr(
1063 new AwaitType(std::move(elem))); // NOLINT(modernize-make-shared)
1064 }
1065
strAwaitType1066 std::string str() const override {
1067 std::stringstream ss;
1068 ss << "Await(" << getElementType()->str() << ")";
1069 return ss.str();
1070 }
createWithContainedAwaitType1071 TypePtr createWithContained(
1072 std::vector<TypePtr> contained_types) const override {
1073 return create(std::move(contained_types.at(0)));
1074 }
1075
isSubtypeOfExtAwaitType1076 bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override {
1077 if (Type::isSubtypeOfExt(rhs, why_not)) {
1078 return true;
1079 }
1080 if (auto rhs_ = rhs.castRaw<AwaitType>()) {
1081 return getElementType()->isSubtypeOfExt(*rhs_->getElementType(), why_not);
1082 }
1083 return false;
1084 }
1085
1086 private:
AwaitTypeAwaitType1087 AwaitType(TypePtr elem) : SingleElementType(std::move(elem)) {}
1088
1089 std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
1090 std::stringstream ss;
1091 ss << "Await[" << getElementType()->annotation_str(printer) << "]";
1092 return ss.str();
1093 }
1094 };
1095
1096 struct RRefType;
1097 using RRefTypePtr = std::shared_ptr<RRefType>;
1098
1099 struct TORCH_API RRefType
1100 : public SingleElementType<TypeKind::RRefType, RRefType> {
1101 friend struct Type;
1102 template <typename... T>
createRRefType1103 static RRefTypePtr create(TypePtr elem) {
1104 return RRefTypePtr(
1105 new RRefType(std::move(elem))); // NOLINT(modernize-make-shared)
1106 }
1107
strRRefType1108 std::string str() const override {
1109 std::stringstream ss;
1110 ss << "RRef(" << getElementType()->str() << ")";
1111 return ss.str();
1112 }
createWithContainedRRefType1113 TypePtr createWithContained(
1114 std::vector<TypePtr> contained_types) const override {
1115 return create(std::move(contained_types.at(0)));
1116 }
1117
1118 private:
RRefTypeRRefType1119 RRefType(TypePtr elem) : SingleElementType(std::move(elem)) {}
1120
1121 std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
1122 std::stringstream ss;
1123 ss << "RRef[" << getElementType()->annotation_str(printer) << "]";
1124 return ss.str();
1125 }
1126 };
1127
1128 // Any should never appear in a named type like a class, namedtuple or
1129 // interface. If it does, then dynamic type information will be lost in the
1130 // Pickler, leading to hard-to-track-down bugs that will only occur
1131 // after saving or loading a model. This is because we rely on the
1132 // static types in named types to reconstruct type tags of loaded
1133 // values. Lifting this restriction requires solving the serialization
1134 // problem first.
1135 TORCH_API void checkNoAny(
1136 const Type& base,
1137 const char* what,
1138 const std::string& attrname,
1139 const TypePtr& attrtype);
1140
1141 struct TupleType;
1142 using TupleTypePtr = std::shared_ptr<TupleType>;
1143 using NameList = std::vector<std::string>;
1144 // This type represents a Tuple
1145 struct TORCH_API TupleType : public NamedType {
1146
1147 static TupleTypePtr createNamed(const std::optional<c10::QualifiedName>& name,
1148 const std::vector<std::string>& field_names,
1149 const std::vector<TypePtr>& field_types,
1150 std::vector<IValue>& field_defaults);
1151
1152 static TupleTypePtr createNamed(const std::optional<c10::QualifiedName>& name,
1153 const std::vector<std::string>& field_names,
1154 const std::vector<TypePtr>& field_types);
1155
1156 static TupleTypePtr createNamed(const std::optional<c10::QualifiedName>& name,
1157 const std::vector<c10::string_view>& field_names,
1158 const std::vector<TypePtr>& field_types);
1159
createTupleType1160 static TupleTypePtr create(
1161 std::vector<TypePtr> types) {
1162 return TupleTypePtr(new TupleType(
1163 std::move(types),
1164 std::nullopt,
1165 nullptr)); // NOLINT(modernize-make-shared)
1166 }
createTupleType1167 static TupleTypePtr create() {
1168 return create({});
1169 }
1170
elementsTupleType1171 at::ArrayRef<TypePtr> elements() const {
1172 return elements_;
1173 }
1174
1175 bool equals(const Type& rhs) const override;
1176 bool isSubtypeOfExt(const Type& rhs_, std::ostream* why_not) const override;
1177
1178 std::string str() const override;
hasFreeVariablesTupleType1179 bool hasFreeVariables() const override {
1180 return has_free_variables_;
1181 }
containedTypesTupleType1182 at::ArrayRef<TypePtr> containedTypes() const override {
1183 return elements_;
1184 }
createWithContainedTupleType1185 TypePtr createWithContained(
1186 std::vector<TypePtr> contained_types) const override {
1187 return std::shared_ptr<TupleType>(
1188 new TupleType(std::move(contained_types), name(), schema()));
1189 }
schemaTupleType1190 const std::shared_ptr<FunctionSchema>& schema() const {
1191 return schema_;
1192 }
1193 std::optional<std::vector<c10::string_view>> names() const;
1194
1195 static const TypeKind Kind = TypeKind::TupleType;
1196
1197 private:
1198 template <typename S>
1199 static TupleTypePtr createWithSpec(
1200 const std::optional<c10::QualifiedName>& name,
1201 const std::vector<S>& field_names,
1202 const std::vector<TypePtr>& field_types,
1203 std::vector<IValue>& field_defaults);
1204
1205 TupleType(
1206 std::vector<TypePtr> elements_,
1207 std::optional<c10::QualifiedName> name,
1208 std::shared_ptr<FunctionSchema> schema);
1209
compareTupleType1210 bool compare(
1211 const Type& rhs,
1212 const std::function<bool(const Type&, const Type&)>& fn) const {
1213 if (rhs.kind() != kind()) {
1214 return false;
1215 }
1216
1217 const auto& l_elements = elements();
1218 const auto& r_elements = rhs.castRaw<TupleType>()->elements();
1219 if (l_elements.size() != r_elements.size())
1220 return false;
1221 for (size_t i = 0; i < l_elements.size(); ++i) {
1222 if (!fn(*l_elements[i], *r_elements[i]))
1223 return false;
1224 }
1225 return true;
1226 }
1227
1228 std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override;
1229
1230 std::vector<TypePtr> elements_;
1231 bool has_free_variables_;
1232 std::shared_ptr<FunctionSchema> schema_;
1233 };
1234
1235 // the common supertype of all Enums, only used in operator registraion.
1236 // EnumType <: AnyEnumType for all Enums
1237 struct AnyEnumType;
1238 using AnyEnumTypePtr = SingletonTypePtr<AnyEnumType>;
1239 struct TORCH_API AnyEnumType final : public Type {
equalsfinal1240 bool equals(const Type& rhs) const override {
1241 return rhs.kind() == kind();
1242 }
strfinal1243 std::string str() const override {
1244 return "AnyEnumType";
1245 }
1246 static const TypeKind Kind = TypeKind::AnyEnumType;
1247 // global singleton
1248 static AnyEnumTypePtr get();
1249 private:
AnyEnumTypefinal1250 AnyEnumType()
1251 : Type(TypeKind::AnyEnumType) {}
1252 };
1253
1254 struct NumberType;
1255 using NumberTypePtr = SingletonTypePtr<NumberType>;
1256 // This type represents a Python number
1257 // Subtype hierarchy for Number Types (NumberType as the base type):
1258 // IntType <: NumberType
1259 // FloatType <: NumberType
1260 // ComplexType <:NumberType
1261 //
1262 // WARNING: if you add a new subtype of NumberType that is not
1263 // represented by a global singleton, you need to change NumberTypePtr
1264 // to a SingletonOrSharedTypePtr and deal with NumberType needing to
1265 // both inherit and not inherit from SharedType!
1266 struct TORCH_API NumberType : public Type {
1267 bool equals(const Type& rhs) const override;
1268
1269 bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
1270
strNumberType1271 std::string str() const override {
1272 return "Scalar"; // match what PythonArgParser says for clarity
1273 }
1274 static const TypeKind Kind = TypeKind::NumberType;
1275 // global singleton
1276 static NumberTypePtr get();
1277
1278 protected:
TypeNumberType1279 NumberType(TypeKind kind = TypeKind::NumberType) : Type(kind) {}
1280
1281 std::string annotation_str_impl(C10_UNUSED const TypePrinter& printer = nullptr) const override {
1282 return "number"; // technically not a valid python type, but
1283 // we need to use it when parsing back in annotations
1284 // for implicit conversions
1285 }
1286 };
1287
1288 struct FloatType;
1289 using FloatTypePtr = SingletonTypePtr<FloatType>;
1290 // This type represents a Python float number
1291 struct TORCH_API FloatType : public NumberType {
equalsFloatType1292 bool equals(const Type& rhs) const override {
1293 return rhs.kind() == kind();
1294 }
strFloatType1295 std::string str() const override {
1296 return "float";
1297 }
isSubtypeOfExtFloatType1298 bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override {
1299 // NOLINTNEXTLINE(bugprone-parent-virtual-call)
1300 return rhs.kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
1301 }
1302 static const TypeKind Kind = TypeKind::FloatType;
1303 // global singleton
1304 static FloatTypePtr get();
1305
1306 private:
FloatTypeFloatType1307 FloatType() : NumberType(TypeKind::FloatType) {}
1308 std::string annotation_str_impl(C10_UNUSED const TypePrinter& printer = nullptr) const override {
1309 return "float";
1310 }
1311 };
1312
1313 struct ComplexType;
1314 using ComplexTypePtr = SingletonTypePtr<ComplexType>;
1315 // This type represents a Python float number
1316 struct TORCH_API ComplexType : public NumberType {
equalsComplexType1317 bool equals(const Type& rhs) const override {
1318 return rhs.kind() == kind();
1319 }
strComplexType1320 std::string str() const override {
1321 return "complex";
1322 }
isSubtypeOfExtComplexType1323 bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override {
1324 // NOLINTNEXTLINE(bugprone-parent-virtual-call)
1325 return rhs.kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
1326 }
1327 static const TypeKind Kind = TypeKind::ComplexType;
1328 // global singleton
1329 static ComplexTypePtr get();
1330
1331 private:
ComplexTypeComplexType1332 ComplexType() : NumberType(TypeKind::ComplexType) {}
1333 std::string annotation_str_impl(C10_UNUSED const TypePrinter& printer = nullptr) const override {
1334 return "complex";
1335 }
1336 };
1337
1338 // We need to introduce `SymIntType` to represent the `SymInt` type
1339 // used in function schemas e.g. `aten::narrow_copy(... SymInt length)
1340 // `SymInt` will be used to enable tracing arithmetic operations on
1341 // dimension values. Please see [SymInt.h] for more information
1342 struct SymIntType;
1343 using SymIntTypePtr = SingletonTypePtr<SymIntType>;
1344 struct TORCH_API SymIntType : public Type {
equalsSymIntType1345 bool equals(const Type& rhs) const override {
1346 return rhs.kind() == kind();
1347 }
strSymIntType1348 std::string str() const override {
1349 return "SymInt";
1350 }
1351 std::string annotation_str_impl(const TypePrinter& printer [[maybe_unused]] = nullptr) const override {
1352 return "int";
1353 }
1354 static const TypeKind Kind = TypeKind::SymIntType;
1355 // global singleton
1356 static SymIntTypePtr get();
1357
1358 private:
SymIntTypeSymIntType1359 SymIntType() : Type(TypeKind::SymIntType) {}
1360 };
1361
1362 struct SymFloatType;
1363 using SymFloatTypePtr = SingletonTypePtr<SymFloatType>;
1364 struct TORCH_API SymFloatType : public Type {
equalsSymFloatType1365 bool equals(const Type& rhs) const override {
1366 return rhs.kind() == kind();
1367 }
strSymFloatType1368 std::string str() const override {
1369 return "SymFloat";
1370 }
1371 std::string annotation_str_impl(const TypePrinter& printer [[maybe_unused]] = nullptr) const override {
1372 return "float";
1373 }
1374 static const TypeKind Kind = TypeKind::SymFloatType;
1375 // global singleton
1376 static SymFloatTypePtr get();
1377
1378 private:
SymFloatTypeSymFloatType1379 SymFloatType() : Type(TypeKind::SymFloatType) {}
1380 };
1381
1382 struct SymBoolType;
1383 using SymBoolTypePtr = SingletonTypePtr<SymBoolType>;
1384 struct TORCH_API SymBoolType : public Type {
equalsSymBoolType1385 bool equals(const Type& rhs) const override {
1386 return rhs.kind() == kind();
1387 }
strSymBoolType1388 std::string str() const override {
1389 return "SymBool";
1390 }
1391 std::string annotation_str_impl(const TypePrinter& printer [[maybe_unused]] = nullptr) const override {
1392 return "bool";
1393 }
1394 static const TypeKind Kind = TypeKind::SymBoolType;
1395 // global singleton
1396 static SymBoolTypePtr get();
1397
1398 private:
SymBoolTypeSymBoolType1399 SymBoolType() : Type(TypeKind::SymBoolType) {}
1400 };
1401
1402 struct IntType;
1403 using IntTypePtr = SingletonTypePtr<IntType>;
1404 // This type represents a Python int number
1405 struct TORCH_API IntType : public NumberType {
equalsIntType1406 bool equals(const Type& rhs) const override {
1407 return rhs.kind() == kind();
1408 }
strIntType1409 std::string str() const override {
1410 return "int";
1411 }
isSubtypeOfExtIntType1412 bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override {
1413 // NOLINTNEXTLINE(bugprone-parent-virtual-call)
1414 return rhs.kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
1415 }
1416 static const TypeKind Kind = TypeKind::IntType;
1417 // global singleton
1418 static IntTypePtr get();
1419
1420 private:
IntTypeIntType1421 IntType() : NumberType(TypeKind::IntType) {}
1422 std::string annotation_str_impl(C10_UNUSED const TypePrinter& printer = nullptr) const override {
1423 return "int";
1424 }
1425 };
1426
1427 struct BoolType;
1428 using BoolTypePtr = SingletonTypePtr<BoolType>;
1429 // This node represents a Python bool value
1430 struct TORCH_API BoolType : public Type {
equalsBoolType1431 bool equals(const Type& rhs) const override {
1432 return rhs.kind() == kind();
1433 }
strBoolType1434 std::string str() const override {
1435 return "bool";
1436 }
1437 static const TypeKind Kind = TypeKind::BoolType;
1438 // global singleton
1439 static BoolTypePtr get();
1440
1441 private:
BoolTypeBoolType1442 BoolType() : Type(TypeKind::BoolType) {}
1443 };
1444
1445 struct StringType;
1446 using StringTypePtr = SingletonTypePtr<StringType>;
1447 // This type represents a Python string
1448 struct TORCH_API StringType : public Type {
equalsStringType1449 bool equals(const Type& rhs) const override {
1450 return rhs.kind() == kind();
1451 }
strStringType1452 std::string str() const override {
1453 // we only use "str" (not "string") in both FunctionSchema and script
1454 return annotation_str();
1455 }
1456 std::string annotation_str_impl(C10_UNUSED const TypePrinter& printer = nullptr) const override {
1457 return "str";
1458 }
1459 static const TypeKind Kind = TypeKind::StringType;
1460 // global singleton
1461 static StringTypePtr get();
1462
1463 private:
StringTypeStringType1464 StringType() : Type(TypeKind::StringType) {}
1465 };
1466
1467 struct StorageType;
1468 using StorageTypePtr = SingletonTypePtr<StorageType>;
1469 struct TORCH_API StorageType : public Type {
equalsStorageType1470 bool equals(const Type& rhs) const override {
1471 return rhs.kind() == kind();
1472 }
strStorageType1473 std::string str() const override {
1474 return annotation_str();
1475 }
1476 std::string annotation_str_impl(C10_UNUSED const TypePrinter& printer = nullptr) const override {
1477 return "Storage";
1478 }
1479 static const TypeKind Kind = TypeKind::StorageType;
1480 // global singleton
1481 static StorageTypePtr get();
1482
1483 private:
StorageTypeStorageType1484 StorageType() : Type(TypeKind::StorageType) {}
1485 };
1486
1487 struct FunctionType;
1488 using FunctionTypePtr = std::shared_ptr<FunctionType>;
1489 struct TORCH_API FunctionType : public NamedType {
createFunctionType1490 static FunctionTypePtr create(torch::jit::Function* function) {
1491 return FunctionTypePtr(
1492 new FunctionType(function)); // NOLINT(modernize-make-shared)
1493 }
equalsFunctionType1494 bool equals(const Type& rhs) const override {
1495 if (auto func_type = rhs.cast<FunctionType>()) {
1496 return func_type->function_ == function_;
1497 }
1498
1499 return false;
1500 }
strFunctionType1501 std::string str() const override {
1502 return "Function";
1503 }
functionFunctionType1504 torch::jit::Function* function() const {
1505 return function_;
1506 }
1507 static const TypeKind Kind = TypeKind::FunctionType;
1508
1509 private:
1510 FunctionType(torch::jit::Function* function);
1511 std::string annotation_str_impl(C10_UNUSED const TypePrinter& printer = nullptr) const override {
1512 const auto& n = name().value();
1513 return n.qualifiedName();
1514 }
1515 torch::jit::Function* function_;
1516 };
1517
1518 struct NoneType;
1519 using NoneTypePtr = SingletonTypePtr<NoneType>;
1520 // This type represents a Python None
1521 struct TORCH_API NoneType : public Type {
equalsNoneType1522 bool equals(const Type& rhs) const override {
1523 return rhs.kind() == kind();
1524 }
strNoneType1525 std::string str() const override {
1526 return "NoneType";
1527 }
1528 bool isSubtypeOfExt(const Type& rhs, std::ostream *why_not) const override;
1529
1530 static const TypeKind Kind = TypeKind::NoneType;
1531 // global singleton
1532 static NoneTypePtr get();
1533
1534 private:
NoneTypeNoneType1535 NoneType() : Type(TypeKind::NoneType) {}
1536 };
1537
1538 struct GeneratorType;
1539 using GeneratorTypePtr = SingletonTypePtr<GeneratorType>;
1540 // This type represents a Generator
1541 struct TORCH_API GeneratorType : public Type {
equalsGeneratorType1542 bool equals(const Type& rhs) const override {
1543 return rhs.kind() == kind();
1544 }
strGeneratorType1545 std::string str() const override {
1546 return "Generator";
1547 }
1548 static const TypeKind Kind = TypeKind::GeneratorType;
1549 // global singleton
1550 static GeneratorTypePtr get();
1551
1552 private:
GeneratorTypeGeneratorType1553 GeneratorType() : Type(TypeKind::GeneratorType) {}
1554 };
1555
1556 struct QuantizerType;
1557 using QuantizerTypePtr = SingletonTypePtr<QuantizerType>;
1558 // This type represents a Quantizer
1559 struct TORCH_API QuantizerType : public Type {
equalsQuantizerType1560 bool equals(const Type& rhs) const override {
1561 return rhs.kind() == kind();
1562 }
strQuantizerType1563 std::string str() const override {
1564 return "Quantizer";
1565 }
1566 static const TypeKind Kind = TypeKind::QuantizerType;
1567 // global singleton
1568 static QuantizerTypePtr get();
1569
1570 private:
QuantizerTypeQuantizerType1571 QuantizerType() : Type(TypeKind::QuantizerType) {}
1572 };
1573
1574 struct QSchemeType;
1575 using QSchemeTypePtr = SingletonTypePtr<QSchemeType>;
1576 // This type represents a QScheme
1577 struct TORCH_API QSchemeType : public Type {
equalsQSchemeType1578 bool equals(const Type& rhs) const override {
1579 return rhs.kind() == kind();
1580 }
strQSchemeType1581 std::string str() const override {
1582 return "QScheme";
1583 }
1584 static const TypeKind Kind = TypeKind::QSchemeType;
1585 // global singleton
1586 static QSchemeTypePtr get();
1587
1588 private:
QSchemeTypeQSchemeType1589 QSchemeType() : Type(TypeKind::QSchemeType) {}
1590 };
1591
1592 struct DeviceObjType;
1593 using DeviceObjTypePtr = SingletonTypePtr<DeviceObjType>;
1594 // This type represents a Device
1595 struct TORCH_API DeviceObjType : public Type {
equalsDeviceObjType1596 bool equals(const Type& rhs) const override {
1597 return rhs.kind() == kind();
1598 }
strDeviceObjType1599 std::string str() const override {
1600 return "Device";
1601 }
1602 static const TypeKind Kind = TypeKind::DeviceObjType;
1603 // global singleton
1604 static DeviceObjTypePtr get();
1605
1606 private:
DeviceObjTypeDeviceObjType1607 DeviceObjType() : Type(TypeKind::DeviceObjType) {}
1608 };
1609
1610 struct StreamObjType;
1611 using StreamObjTypePtr = SingletonTypePtr<StreamObjType>;
1612 // This type represents a Generator
1613 struct TORCH_API StreamObjType : public Type {
equalsStreamObjType1614 bool equals(const Type& rhs) const override {
1615 return rhs.kind() == kind();
1616 }
strStreamObjType1617 std::string str() const override {
1618 return "Stream";
1619 }
1620 static const TypeKind Kind = TypeKind::StreamObjType;
1621 // global singleton
1622 static StreamObjTypePtr get();
1623
1624 private:
StreamObjTypeStreamObjType1625 StreamObjType() : Type(TypeKind::StreamObjType) {}
1626 };
1627
1628 struct VarType;
1629 using VarTypePtr = std::shared_ptr<VarType>;
1630 // This type represents a type variable, used in FunctionSchema
1631 struct VarType : public SharedType {
createVarType1632 static VarTypePtr create(std::string name_) {
1633 return VarTypePtr(new VarType(std::move(name_)));
1634 }
equalsVarType1635 bool equals(const Type& rhs) const override {
1636 return rhs.kind() == kind();
1637 }
strVarType1638 std::string str() const override {
1639 return name();
1640 }
nameVarType1641 const std::string& name() const {
1642 return name_;
1643 }
hasFreeVariablesVarType1644 bool hasFreeVariables() const override {
1645 return true;
1646 }
1647 static const TypeKind Kind = TypeKind::VarType;
1648
1649 private:
VarTypeVarType1650 VarType(std::string name_)
1651 : SharedType(TypeKind::VarType), name_(std::move(name_)) {}
1652 std::string name_;
1653 };
1654
1655 struct CapsuleType;
1656 using CapsuleTypePtr = SingletonTypePtr<CapsuleType>;
1657 // This type represents a Python Capsule.
1658 // It does not appear in the IR and is only used during runtime
1659 struct TORCH_API CapsuleType : public Type {
equalsCapsuleType1660 bool equals(const Type& rhs) const override {
1661 return rhs.kind() == kind();
1662 }
strCapsuleType1663 std::string str() const override {
1664 return "Capsule";
1665 }
1666 static const TypeKind Kind = TypeKind::CapsuleType;
1667 // global singleton
1668 static CapsuleTypePtr get();
1669 private:
CapsuleTypeCapsuleType1670 CapsuleType()
1671 : Type(TypeKind::CapsuleType) {}
1672 };
1673
1674 struct PyObjectType;
1675 using PyObjectTypePtr = SingletonTypePtr<PyObjectType>;
1676 // This type represents a PyObject Type
1677 struct TORCH_API PyObjectType : public Type {
equalsPyObjectType1678 bool equals(const Type& rhs) const override {
1679 return rhs.kind() == kind();
1680 }
strPyObjectType1681 std::string str() const override {
1682 return "PyObject";
1683 }
1684 static const TypeKind Kind = TypeKind::PyObjectType;
1685 // global singleton
1686 static PyObjectTypePtr get();
1687 private:
PyObjectTypePyObjectType1688 PyObjectType()
1689 : Type(TypeKind::PyObjectType) {}
1690 };
1691
1692 enum class TypeVerbosity {
1693 None,
1694 Type,
1695 TypeAndStride,
1696 Full,
1697 Symbolic,
1698 Default = Full,
1699 };
1700
1701 TORCH_API TypeVerbosity type_verbosity();
1702
1703 TORCH_API std::ostream& operator<<(std::ostream& out, const Type& t);
1704 template <typename T>
1705 TORCH_API std::ostream& operator<<(
1706 std::ostream& out,
1707 const VaryingShape<T>& t);
1708 TORCH_API std::ostream& operator<<(std::ostream& os, const SymbolicShape& s);
1709 TORCH_API std::ostream& operator<<(std::ostream& os, const ShapeSymbol& s);
1710 TORCH_API std::ostream& operator<<(std::ostream& os, const Stride& s);
1711 // what is the type, ignoring extra size/shape information?
1712 // e.g. Tensor(2x3) -> Dynamic, and Tuple(Tensor(2x3),...) -> Tuple(Dynamic,...)
1713
1714 // `unshapedType` is used to remove Tensor subtypes. We treat all Tensor
1715 // subtypes as simply "Tensor"; we also create a new version of any
1716 // container types in which internal Tensors have undergone the same
1717 // operation. This is used for type comparisons between two Tensor types
1718 // (`unshapedType` means that we don't falsely return `false` for e.g.
1719 // Tensors of different dimensions). It's also used in the alias
1720 // analysis pass.
1721 // Be careful with calls because this can be very slow. If calling this
1722 // on a graph, use `EraseShapeInformation` in shape_analysis.h
unshapedType(const TypePtr & type)1723 inline TypePtr unshapedType(const TypePtr& type) {
1724 if (type->isSubtypeOf(*TensorType::get())) {
1725 return TensorType::get();
1726 }
1727 at::ArrayRef<TypePtr> contained = type->containedTypes();
1728 if (contained.empty()) {
1729 return type;
1730 }
1731 return type->withContained(fmap(type->containedTypes(), unshapedType));
1732 }
1733
fromNumberType(const Type & typ)1734 inline TypePtr TensorType::fromNumberType(const Type& typ) {
1735 if (typ.isSubtypeOf(*IntType::get())) {
1736 return TensorType::createContiguous(at::kLong, at::kCPU, {});
1737 } else if (typ.isSubtypeOf(*FloatType::get())) {
1738 return TensorType::createContiguous(at::kDouble, at::kCPU, {});
1739 } else if (typ.isSubtypeOf(*BoolType::get())) {
1740 return TensorType::createContiguous(at::kBool, at::kCPU, {});
1741 } else if (typ.kind() == NumberType::Kind) {
1742 return TensorType::create(std::nullopt, at::kCPU, {}, std::nullopt);
1743 }
1744 TORCH_CHECK(false, "Unknown number type: ", typ.str());
1745 }
fromBoolType()1746 inline TypePtr TensorType::fromBoolType() {
1747 return TensorType::createContiguous(at::kBool, at::kCPU, {});
1748 }
1749
tryScalarTypeFromJitType(const Type & type)1750 inline std::optional<c10::ScalarType> tryScalarTypeFromJitType(const Type& type) {
1751 if (type == *FloatType::get()) {
1752 return at::typeMetaToScalarType(c10::get_default_dtype());
1753 } else if (type == *IntType::get()) {
1754 return at::ScalarType::Long;
1755 } else if (type == *BoolType::get()) {
1756 return at::ScalarType::Bool;
1757 }
1758 return std::nullopt;
1759 }
1760
scalarTypeFromJitType(const Type & type)1761 inline at::ScalarType scalarTypeFromJitType(const Type& type) {
1762 auto result = tryScalarTypeFromJitType(type);
1763 TORCH_CHECK(
1764 result,
1765 "Add new condition, expected Float, Complex, Int, or Bool but got",
1766 type.str());
1767 return *result;
1768 }
1769
1770 // Attempt to find the correct supertype of the two types `t1` and `t2`.
1771 // If no supertype is found, then nullopt will be returned if
1772 // `default_to_union` is false, and `Union[t1, t2]` will be returned
1773 // if it is true. If `t1 == t2`, or `t1` is a type refinement of `t2`,
1774 // then `t2` will be returned (and vice versa).
1775 //
1776 // Two different tensortypes will return dynamic.
1777 //
1778 // Currently we chose not to support returning a NumberType for
1779 // two types from the set of {FloatType, IntType, ComplexType}, because
1780 // there is a lack of operator support for NumberType.
1781 //
1782 // If `type_hint` is an `InterfaceType`, then we can use that as a
1783 // potential supertype for `ClassType`s in the list. Otherwise, we have
1784 // no way to find and use some common interface type
1785 TORCH_API std::optional<TypePtr> unifyTypes(
1786 const TypePtr& t1,
1787 const TypePtr& t2,
1788 bool default_to_union = false,
1789 const TypePtr& type_hint = nullptr);
1790
1791 TORCH_API std::optional<TypePtr> unifyTypeList(
1792 at::ArrayRef<TypePtr> elements,
1793 std::ostream& why_not,
1794 bool default_to_union = false,
1795 const TypePtr& type_hint = nullptr);
1796
1797 namespace detail {
1798 template <typename T>
1799 struct getTypePtr_ final {
callfinal1800 static decltype(auto) call() {
1801 return ([]() {
1802 try {
1803 return getCustomClassType<T>();
1804 } catch(const c10::Error&) {
1805 TORCH_CHECK(
1806 false,
1807 "Type ",
1808 c10::util::get_fully_qualified_type_name<T>(),
1809 " could not be converted to any of the known types."
1810 );
1811 }
1812 }());
1813 }
1814 };
1815
1816 template <typename T, bool fake>
1817 struct getMaybeFakeTypePtr_ final {
callfinal1818 static decltype(auto) call() {
1819 return getTypePtr_<T>::call();
1820 }
1821 };
1822
1823 template <>
1824 struct getTypePtr_<at::IValue> final {
1825 static decltype(auto) call() {
1826 return AnyType::get();
1827 }
1828 };
1829
1830 template <>
1831 struct getTypePtr_<at::Tensor> final {
1832 static decltype(auto) call() {
1833 return TensorType::get();
1834 }
1835 };
1836 template <>
1837 struct getTypePtr_<c10::Storage> final {
1838 static decltype(auto) call() {
1839 return StorageType::get();
1840 }
1841 };
1842 template <>
1843 struct getTypePtr_<c10::Stream> final {
1844 static decltype(auto) call() {
1845 return StreamObjType::get();
1846 }
1847 };
1848 template <>
1849 struct getTypePtr_<double> final {
1850 static decltype(auto) call() {
1851 return FloatType::get();
1852 }
1853 };
1854 template <>
1855 struct getTypePtr_<c10::complex<double>> final {
1856 static decltype(auto) call() {
1857 return ComplexType::get();
1858 }
1859 };
1860 template <>
1861 struct getTypePtr_<int64_t> final {
1862 static decltype(auto) call() {
1863 return IntType::get();
1864 }
1865 };
1866
1867 template <>
1868 struct getTypePtr_<DeviceIndex> final {
1869 static decltype(auto) call() {
1870 return IntType::get();
1871 }
1872 };
1873
1874 template <>
1875 struct getMaybeFakeTypePtr_<SymInt, false> final {
1876 static decltype(auto) call() {
1877 return SymIntType::get();
1878 }
1879 };
1880 template <>
1881 struct getMaybeFakeTypePtr_<SymInt, true> final {
1882 static decltype(auto) call() {
1883 return IntType::get();
1884 }
1885 };
1886
1887 template <>
1888 struct getMaybeFakeTypePtr_<SymFloat, false> final {
1889 static decltype(auto) call() {
1890 return SymFloatType::get();
1891 }
1892 };
1893 template <>
1894 struct getMaybeFakeTypePtr_<SymFloat, true> final {
1895 static decltype(auto) call() {
1896 return FloatType::get();
1897 }
1898 };
1899
1900 template <>
1901 struct getMaybeFakeTypePtr_<SymBool, false> final {
1902 static decltype(auto) call() {
1903 return SymBoolType::get();
1904 }
1905 };
1906 template <>
1907 struct getMaybeFakeTypePtr_<SymBool, true> final {
1908 static decltype(auto) call() {
1909 return BoolType::get();
1910 }
1911 };
1912
1913 template <>
1914 struct getTypePtr_<c10::Device> final {
1915 static decltype(auto) call() {
1916 return DeviceObjType::get();
1917 }
1918 };
1919 template <>
1920 struct getTypePtr_<bool> final {
1921 static decltype(auto) call() {
1922 return BoolType::get();
1923 }
1924 };
1925 template <>
1926 struct getTypePtr_<at::Scalar> final {
1927 static decltype(auto) call() {
1928 return NumberType::get();
1929 }
1930 };
1931 template <>
1932 struct getTypePtr_<c10::QScheme> final {
1933 static decltype(auto) call() {
1934 return QSchemeType::get();
1935 }
1936 };
1937 template <>
1938 struct getTypePtr_<at::Generator> final {
1939 static decltype(auto) call() {
1940 return TypeFactory::create<OptionalType>(
1941 TypeFactory::get<GeneratorType>());
1942 }
1943 };
1944 template <>
1945 struct getTypePtr_<std::string> final {
1946 static decltype(auto) call() {
1947 return StringType::get();
1948 }
1949 };
1950 template <>
1951 struct getTypePtr_<c10::string_view> final {
1952 static decltype(auto) call() {
1953 return StringType::get();
1954 }
1955 };
1956 template <>
1957 struct getTypePtr_<at::Dimname> final {
1958 static decltype(auto) call() {
1959 return StringType::get();
1960 }
1961 };
1962 template <class T, bool fake>
1963 struct getMaybeFakeTypePtr_<std::vector<T>, fake> final {
1964 static const auto& call() {
1965 static auto inner_type = getMaybeFakeTypePtr_<T, fake>::call();
1966 // The "per vector<T>" static singleton needs to live in a .cpp file,
1967 // otherwise we'll end up with one singleton instance per shared library.
1968 static auto type = ListType::get("vector", inner_type);
1969 return type;
1970 }
1971 };
1972 template <class T, bool fake>
1973 struct getMaybeFakeTypePtr_<c10::ArrayRef<T>, fake> final {
1974 static const auto& call() {
1975 static auto inner_type = getMaybeFakeTypePtr_<T, fake>::call();
1976 // The "per ArrayRef<T>" static singleton needs to live in a .cpp file,
1977 // otherwise we'll end up with one singleton instance per shared library.
1978 static auto type = ListType::get("ArrayRef", inner_type);
1979 return type;
1980 }
1981 };
1982 template <bool fake>
1983 struct getMaybeFakeTypePtr_<c10::SymIntArrayRef, fake> final {
1984 static const auto& call() {
1985 static auto type = ListType::create(getMaybeFakeTypePtr_<c10::SymInt, fake>::call());
1986 return type;
1987 }
1988 };
1989 template <class T, bool fake>
1990 struct getMaybeFakeTypePtr_<c10::List<T>, fake> final {
1991 static const auto& call() {
1992 static auto inner_type = getMaybeFakeTypePtr_<T, fake>::call();
1993 // The "per List<T>" static singleton needs to live in a .cpp file,
1994 // otherwise we'll end up with one singleton instance per shared library.
1995 static auto type = ListType::get("List", inner_type);
1996 return type;
1997 }
1998 };
1999 template <class T, bool fake>
2000 struct getMaybeFakeTypePtr_<c10::IListRef<T>, fake> final {
2001 static const auto& call() {
2002 static auto inner_type = getMaybeFakeTypePtr_<T, fake>::call();
2003 static auto type = ListType::get("List", inner_type);
2004 return type;
2005 }
2006 };
2007 template <class T, size_t N, bool fake>
2008 struct getMaybeFakeTypePtr_<std::array<T, N>, fake> final {
2009 static const auto& call() {
2010 static auto inner_type = getMaybeFakeTypePtr_<T, fake>::call();
2011 // The "per array<T, N>" static singleton needs to live in a .cpp file,
2012 // otherwise we'll end up with one singleton instance per shared library.
2013 // (Concatenating the length onto the end of the string because we want a unique
2014 // type_ptr created for every std::array<T, N> type).
2015 static auto type = ListType::get(std::string("array") + std::to_string(N), inner_type);
2016 return type;
2017 }
2018 };
2019 template <class K, class V, bool fake>
2020 struct getMaybeFakeTypePtr_<std::unordered_map<K, V>, fake> final {
2021 static const auto& call() {
2022 static auto inner_key_type = getMaybeFakeTypePtr_<K, fake>::call();
2023 static auto inner_val_type = getMaybeFakeTypePtr_<V, fake>::call();
2024 // The "per unordered_map<K, V>" static singleton needs to live in a .cpp file,
2025 // otherwise we'll end up with one singleton instance per shared library.
2026 static auto type = DictType::get("unordered_map", inner_key_type, inner_val_type);
2027 return type;
2028 }
2029 };
2030 template <class K, class V, bool fake>
2031 struct getMaybeFakeTypePtr_<c10::Dict<K, V>, fake> final {
2032 static const auto& call() {
2033 static auto inner_key_type = getMaybeFakeTypePtr_<K, fake>::call();
2034 static auto inner_val_type = getMaybeFakeTypePtr_<V, fake>::call();
2035 // The "per Dict<K, V>" static singleton needs to live in a .cpp file,
2036 // otherwise we'll end up with one singleton instance per shared library.
2037 static auto type = DictType::get("Dict", inner_key_type, inner_val_type);
2038 return type;
2039 }
2040 };
2041
2042 template <class T, bool fake>
2043 struct getMaybeFakeTypePtr_<std::optional<T>, fake> final {
2044 static const auto& call() {
2045 static auto inner_type = getMaybeFakeTypePtr_<T, fake>::call();
2046 // The "per std::optional<T>" static singleton needs to live in a .cpp file,
2047 // otherwise we'll end up with one singleton instance per shared library.
2048 static auto type = OptionalType::get(inner_type);
2049 return type;
2050 }
2051 };
2052
2053
2054 template<>
2055 struct getTypePtr_<at::OptionalIntArrayRef> final {
2056 static const auto& call() {
2057 static auto inner_type = getMaybeFakeTypePtr_<IntArrayRef, false>::call();
2058 // The "per std::optional<T>" static singleton needs to live in a .cpp file,
2059 // otherwise we'll end up with one singleton instance per shared library.
2060 static auto type = OptionalType::get(inner_type);
2061 return type;
2062 }
2063 };
2064
2065 template <bool fake>
2066 struct getMaybeFakeTypePtr_<at::OptionalSymIntArrayRef, fake> final {
2067 static const auto& call() {
2068 // The "per std::optional<T>" static singleton needs to live in a .cpp file,
2069 // otherwise we'll end up with one singleton instance per shared library.
2070 static auto inner_type = getMaybeFakeTypePtr_<SymIntArrayRef, fake>::call();
2071 static auto type = OptionalType::get(inner_type);
2072 return type;
2073 }
2074 };
2075
2076 template <class... Contained, bool fake>
2077 struct getMaybeFakeTypePtr_<std::tuple<Contained...>, fake> final {
2078 static const auto& call() {
2079 static auto type = ([]() {
2080 std::vector<TypePtr> contained_types = {
2081 (getMaybeFakeTypePtr_<Contained, fake>::call())...
2082 };
2083 return TupleType::create(std::move(contained_types));
2084 })();
2085 return type;
2086 }
2087 };
2088 template <>
2089 struct getTypePtr_<void> final {
2090 static decltype(auto) call() {
2091 return NoneType::get();
2092 }
2093 };
2094 } // namespace detail
2095 template <class T>
2096 inline decltype(auto) getTypePtr() {
2097 // TODO: static_assert that a templated function exists, and throw a friendly
2098 // error message if not
2099 return detail::getMaybeFakeTypePtr_<T, false>::call();
2100 }
2101
2102 template <class T>
2103 inline TypePtr getTypePtrCopy() {
2104 // TODO: static_assert that a templated function exists, and throw a friendly
2105 // error message if not
2106 return getTypePtr<T>();
2107 }
2108
2109 template <class T>
2110 inline decltype(auto) getFakeTypePtr() {
2111 return detail::getMaybeFakeTypePtr_<T, true>::call();
2112 }
2113
2114 template <class T>
2115 inline TypePtr getFakeTypePtrCopy() {
2116 return getFakeTypePtr<T>();
2117 }
2118
2119 using TypeEnv = std::unordered_map<std::string, TypePtr>;
2120 struct MatchTypeReturn {
2121 MatchTypeReturn(std::string reason) : reason_(std::move(reason)) {}
2122 static MatchTypeReturn Success() {
2123 return MatchTypeReturn();
2124 }
2125 bool success() const {
2126 return !reason_.has_value();
2127 }
2128 const std::string& reason() const {
2129 return reason_.value();
2130 }
2131
2132 private:
2133 MatchTypeReturn()
2134 : reason_(std::nullopt) {}
2135 std::optional<std::string> reason_; // is there is no match, this contains the reason
2136 };
2137
2138 // attempt to match the type variables in formal to actual, adding them to type_env.
2139 // If no match is possible this returns a MatchTypeReturn with r.success() == false
2140 // and a r.reason() that describes why it could not match.
2141 // note: It is possible to successfully match a formal, but for type variables
2142 // in the formal to still not be defined. In particular, None matches Optional[T]
2143 // but does not define the value of T.
2144 TORCH_API MatchTypeReturn
2145 matchTypeVariables(const TypePtr& formal, const TypePtr& actual, TypeEnv& type_env);
2146
2147 // replace type variables appearing in `type` with the values in
2148 // `type_env`. Returns nullptr if a variable used in `type`
2149 // does not appear in `type_env`
2150 TORCH_API TypePtr tryEvalTypeVariables(const TypePtr& type, TypeEnv& type_env);
2151
2152 TORCH_API bool elementTypeCanBeInferredFromMembers(const TypePtr& elem_type);
2153
2154 struct InterfaceType;
2155 using InterfaceTypePtr = std::shared_ptr<InterfaceType>;
2156
2157 // Interfaces are a list of abstract methods that a class might meet.
2158 // If a class provides those methods, it implicitly meets the interface.
2159
2160 // Subtype relations for Interface with ClassType:
2161 // lhs (ClassType or InterfaceType) is a subtype of rhs if:
2162 // 1. lhs methods are a superset of rhs methods
2163 // 2. if rhs is module interface, the lhs must be module interface or module itself
2164 struct TORCH_API InterfaceType : public NamedType {
2165 static InterfaceTypePtr create(
2166 QualifiedName qualifiedName, bool is_module=false);
2167
2168 bool equals(const Type& rhs) const override {
2169 if (auto user_rhs = rhs.castRaw<InterfaceType>()) {
2170 return isSubTypeImpl(*this, *user_rhs, nullptr) &&
2171 isSubTypeImpl(*user_rhs, *this, nullptr);
2172 }
2173 return false;
2174 }
2175
2176 std::string str() const override {
2177 return std::string("InterfaceType<") + name()->name() + ">";
2178 }
2179
2180 bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
2181
2182 // try to find a method of this interface,
2183 // returns nullptr if not found.
2184 const FunctionSchema* getMethod(const std::string& name) const;
2185 void addMethod(FunctionSchema schema);
2186 const std::vector<FunctionSchema>& methods() const {
2187 return *methods_;
2188 }
2189
2190 bool is_module() const override{
2191 return is_module_;
2192 }
2193 static const TypeKind Kind = TypeKind::InterfaceType;
2194 ~InterfaceType() override;
2195 private:
2196 InterfaceType(QualifiedName name, bool is_module);
2197 static bool isSubTypeImpl(
2198 const InterfaceType& lhs,
2199 const InterfaceType& rhs,
2200 std::ostream* why_not);
2201
2202 std::string annotation_str_impl(C10_UNUSED const TypePrinter& printer = nullptr) const override {
2203 return name()->qualifiedName();
2204 }
2205
2206 // shared_ptr so that this header does not have to depend on
2207 // FunctionSchema.h
2208 std::shared_ptr<std::vector<FunctionSchema>> methods_;
2209 // flag to distinguish if it's an interface type from a module or not
2210 bool is_module_;
2211 };
2212
2213 template <TypeKind K>
2214 struct EnumerationType : public Type {
2215 static const TypeKind Kind = K;
2216
2217 bool equals(const Type& rhs) const override {
2218 return rhs.kind() == kind();
2219 }
2220
2221 protected:
2222 EnumerationType() : Type(Kind) {}
2223 };
2224
2225 // WARNING: These enumeration types below DO NOT actually get parsed out
2226 // from the logical schema strings, instead they are mapped as ints. To
2227 // observe these types, use real_type() instead of type() on Argument
2228
2229 struct ScalarTypeType;
2230 using ScalarTypeTypePtr = SingletonTypePtr<ScalarTypeType>;
2231 struct TORCH_API ScalarTypeType : public EnumerationType<TypeKind::ScalarTypeType> {
2232 std::string str() const override {
2233 return "ScalarType";
2234 }
2235 static const TypeKind Kind = TypeKind::ScalarTypeType;
2236 // global singleton
2237 static ScalarTypeTypePtr get();
2238
2239 private:
2240 ScalarTypeType() : EnumerationType() {}
2241 };
2242
2243 struct MemoryFormatType;
2244 using MemoryFormatTypePtr = SingletonTypePtr<MemoryFormatType>;
2245 struct TORCH_API MemoryFormatType : public EnumerationType<TypeKind::MemoryFormatType> {
2246 std::string str() const override {
2247 return "MemoryFormat";
2248 }
2249 static const TypeKind Kind = TypeKind::MemoryFormatType;
2250 // global singleton
2251 static MemoryFormatTypePtr get();
2252
2253 private:
2254 MemoryFormatType() : EnumerationType() {}
2255 };
2256
2257 struct LayoutType;
2258 using LayoutTypePtr = SingletonTypePtr<LayoutType>;
2259 struct TORCH_API LayoutType : public EnumerationType<TypeKind::LayoutType> {
2260 std::string str() const override {
2261 return "Layout";
2262 }
2263 static const TypeKind Kind = TypeKind::LayoutType;
2264 // global singleton
2265 static LayoutTypePtr get();
2266
2267 private:
2268 LayoutType() : EnumerationType() {}
2269 };
2270
2271 namespace detail {
2272 template <>
2273 struct getMaybeFakeTypePtr_<c10::ScalarType, false> final {
2274 static decltype(auto) call() {
2275 return ScalarTypeType::get();
2276 }
2277 };
2278 template <>
2279 struct getMaybeFakeTypePtr_<c10::Layout, false> final {
2280 static decltype(auto) call() {
2281 return LayoutType::get();
2282 }
2283 };
2284 template <>
2285 struct getMaybeFakeTypePtr_<c10::MemoryFormat, false> final {
2286 static decltype(auto) call() {
2287 return MemoryFormatType::get();
2288 }
2289 };
2290 template <>
2291 struct getMaybeFakeTypePtr_<c10::ScalarType, true> final {
2292 static decltype(auto) call() {
2293 return IntType::get();
2294 }
2295 };
2296 template <>
2297 struct getMaybeFakeTypePtr_<c10::Layout, true> final {
2298 static decltype(auto) call() {
2299 return IntType::get();
2300 }
2301 };
2302 template <>
2303 struct getMaybeFakeTypePtr_<c10::MemoryFormat, true> final {
2304 static decltype(auto) call() {
2305 return IntType::get();
2306 }
2307 };
2308 } // namespace detail
2309
2310 // the common supertype of all lists,
2311 // List[T] <: AnyList for all T
2312 struct AnyListType;
2313 using AnyListTypePtr = SingletonTypePtr<AnyListType>;
2314 struct TORCH_API AnyListType : public Type {
2315 bool equals(const Type& rhs) const override {
2316 return rhs.kind() == kind();
2317 }
2318 std::string str() const override {
2319 return "list";
2320 }
2321 static const TypeKind Kind = TypeKind::AnyListType;
2322 // global singleton
2323 static AnyListTypePtr get();
2324 private:
2325 AnyListType()
2326 : Type(TypeKind::AnyListType) {}
2327 };
2328
2329 // the common supertype of all tuples,
2330 // Tuple[T...] <: AnyTuple for all T
2331 struct AnyTupleType;
2332 using AnyTupleTypePtr = SingletonTypePtr<AnyTupleType>;
2333 struct TORCH_API AnyTupleType : public Type {
2334 bool equals(const Type& rhs) const override {
2335 return rhs.kind() == kind();
2336 }
2337
2338 std::string str() const override {
2339 return "tuple";
2340 }
2341 static const TypeKind Kind = TypeKind::AnyTupleType;
2342
2343 // global singleton
2344 static AnyTupleTypePtr get();
2345 private:
2346 AnyTupleType()
2347 : Type(TypeKind::AnyTupleType) {}
2348 };
2349
2350 // the common supertype of all classes,
2351 // ClassType <: AnyClassType for all classes
2352 struct AnyClassType;
2353 using AnyClassTypePtr = SingletonTypePtr<AnyClassType>;
2354 struct TORCH_API AnyClassType : public Type {
2355 bool equals(const Type& rhs) const override {
2356 return rhs.kind() == kind();
2357 }
2358 std::string str() const override {
2359 return "AnyClassType";
2360 }
2361 static const TypeKind Kind = TypeKind::AnyClassType;
2362 // global singleton
2363 static AnyClassTypePtr get();
2364 private:
2365 AnyClassType()
2366 : Type(TypeKind::AnyClassType) {}
2367 };
2368
2369 template<>
2370 inline typename detail::CastReturnType<NamedType>::type Type::cast<NamedType>() {
2371 if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType ||
2372 kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) {
2373 return std::static_pointer_cast<NamedType>(static_cast<NamedType *>(this)->shared_from_this());
2374 }
2375 return nullptr;
2376 }
2377
2378 template<>
2379 inline typename detail::CastConstReturnType<NamedType>::type Type::cast<NamedType>() const {
2380 if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType ||
2381 kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) {
2382 return std::static_pointer_cast<const NamedType>(static_cast<const NamedType *>(this)->shared_from_this());
2383 }
2384 return nullptr;
2385 }
2386
2387 template<>
2388 inline const NamedType* Type::castRaw<NamedType>() const {
2389 if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType ||
2390 kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) {
2391 return static_cast<const NamedType*>(this);
2392 }
2393 return nullptr;
2394 }
2395
2396 // Used as a return type when inferring the IValue type of a Python object.
2397 struct InferredType {
2398 /* implicit */ InferredType(TypePtr type) : type_(std::move(type)) {}
2399 /* implicit */ InferredType(std::string reason)
2400 : type_(nullptr), reason_(std::move(reason)) {}
2401 TypePtr type() const {
2402 TORCH_INTERNAL_ASSERT(
2403 type_,
2404 "Tried to get the type from an InferredType but the type is null. ",
2405 "Reason: ",
2406 reason_);
2407 return type_;
2408 }
2409 bool success() const {
2410 return type_ != nullptr;
2411 }
2412 const std::string& reason() const {
2413 TORCH_INTERNAL_ASSERT(!type_);
2414 return reason_;
2415 }
2416
2417 private:
2418 TypePtr type_;
2419 std::string reason_;
2420 };
2421
2422 TORCH_API bool containsAnyType(const TypePtr& type);
2423
2424 } // namespace c10
2425