xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/jit_type.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/custom_class.h>
4 #include <ATen/core/jit_type_base.h>
5 #include <ATen/core/TensorBody.h>
6 #include <ATen/core/functional.h>
7 #include <ATen/core/symbol.h>
8 #include <ATen/core/type_factory.h>
9 #include <ATen/core/qualified_name.h>
10 #include <c10/util/TypeList.h>
11 #include <optional>
12 #include <c10/core/SymFloat.h>
13 #include <c10/core/SymBool.h>
14 #include <c10/core/Device.h>
15 
16 #include <array>
17 #include <memory>
18 #include <ostream>
19 #include <sstream>
20 #include <utility>
21 
22 
23 namespace torch::jit {
24 struct Function;
25 } // namespace torch::jit
26 
27 
28 namespace c10 {
29 
30 template<class Key, class Value>
31 class Dict;
32 struct IValue;
33 struct FunctionSchema;
34 struct NamedType;
35 using OptNameList = std::optional<std::vector<std::string>>;
36 
37 void standardizeVectorForUnion(std::vector<TypePtr>& reference, std::vector<TypePtr>* to_fill);
38 void standardizeVectorForUnion(std::vector<TypePtr>* to_flatten);
39 
is_contiguous_strides(const IntArrayRef sizes,const IntArrayRef strides)40 inline bool is_contiguous_strides(
41     const IntArrayRef sizes,
42     const IntArrayRef strides) {
43   int n_dim = static_cast<int>(sizes.size());
44   if (n_dim == 0) {
45     return true;
46   }
47 
48   if (strides[n_dim - 1] != 1) {
49     return false;
50   }
51 
52   for (int i = n_dim - 2; i >= 0; i--) {
53     if (strides[i] != strides[i + 1] * sizes[i + 1]) {
54       return false;
55     }
56   }
57   return true;
58 }
59 
60 struct AnyType;
61 using AnyTypePtr = SingletonTypePtr<AnyType>;
62 // Any is the top of the type hierarchy, all other types are subtypes
63 // T <: Any, forall T
64 struct TORCH_API AnyType : public Type {
equalsAnyType65   bool equals(const Type& rhs) const override {
66     return rhs.kind() == kind();
67   }
strAnyType68   std::string str() const override {
69     return "Any";
70   }
71   static const TypeKind Kind = TypeKind::AnyType;
72   // global singleton
73   static AnyTypePtr get();
74 
75  private:
AnyTypeAnyType76   AnyType() : Type(TypeKind::AnyType) {}
77 };
78 
toString(const Type & type)79 inline std::string toString(const Type& type) {
80   return type.str();
81 }
82 
83 // Shim for compatibility with code that uses TypePtr.
toString(const TypePtr & typePtr)84 inline std::string toString(const TypePtr& typePtr) {
85   return toString(*typePtr);
86 }
87 
88 inline bool operator!=(const Type& lhs, const Type& rhs) {
89   return !(lhs == rhs);
90 }
91 
92 // common base for all types that have a single sub element
93 // e.g. Future[T], Optional[T], List[T]
94 template <TypeKind K, typename T>
95 struct SingleElementType : public SharedType {
96   static const TypeKind Kind = K;
97 
getElementTypeSingleElementType98   const TypePtr& getElementType() const {
99     return elem;
100   }
101 
hasFreeVariablesSingleElementType102   bool hasFreeVariables() const override {
103     return getElementType()->hasFreeVariables();
104   }
105 
containedTypesSingleElementType106   at::ArrayRef<TypePtr> containedTypes() const override {
107     return elem;
108   }
109 
equalsSingleElementType110   bool equals(const Type& rhs) const override {
111     if (auto rhs_ = rhs.cast<T>()) {
112       return *getElementType() == *rhs_->getElementType();
113     }
114     return false;
115   }
116 
117  protected:
SingleElementTypeSingleElementType118   SingleElementType(TypePtr elem) : SharedType(Kind), elem(std::move(elem)) {
119     if (!this->elem) {
120       throw std::runtime_error(c10::str(
121             "Can not create ", typeKindToString(Kind), " with None type"));
122     }
123   }
124 
125  private:
126   TypePtr elem;
127 };
128 
129 struct UnionType;
130 using UnionTypePtr = std::shared_ptr<UnionType>;
131 struct TORCH_API UnionType : public SharedType {
132   friend struct Type;
133 
134   static const TypeKind Kind = TypeKind::UnionType;
135 
136   bool isSubtypeOfExt(const Type& rhs_, std::ostream* why_not) const override;
137 
138   std::string str() const override;
139 
140   static UnionTypePtr create(std::vector<TypePtr> reference);
141 
142   bool equals(const Type& rhs) const override;
143 
isUnionTypeUnionType144   bool isUnionType() const override {
145     return true;
146   }
147 
containedTypesUnionType148   at::ArrayRef<TypePtr> containedTypes() const override {
149     return types_;
150   }
151 
152   // For testing purposes only
getTypesUnionType153   at::ArrayRef<TypePtr> getTypes() const {
154     return types_;
155   }
156 
createWithContainedUnionType157   TypePtr createWithContained(std::vector<TypePtr> contained_types) const override {
158     return create(std::move(contained_types));
159   }
160 
161   bool canHoldType(const Type& type) const;
162 
hasFreeVariablesUnionType163   bool hasFreeVariables() const override {
164     return has_free_variables_;
165   }
166 
167   std::optional<TypePtr> toOptional() const;
168 
169   std::optional<TypePtr> subtractTypeSet(std::vector<TypePtr>& to_subtract) const;
170 
171  protected:
172     explicit UnionType(std::vector<TypePtr> types, TypeKind kind=TypeKind::UnionType);
173     std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override;
174     std::string unionStr(
175         const TypePrinter& printer = nullptr,
176         bool is_annotation_str = false) const;
177     // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
178     bool has_free_variables_;
179     // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
180     std::vector<TypePtr> types_;
181     // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
182     bool can_hold_none_;
183 
184 };
185 
186 struct OptionalType;
187 using OptionalTypePtr = std::shared_ptr<OptionalType>;
188 // This type represents an optional type. There is one `Optional` for
189 // each element type. `Optional[T]` can accept both `T` and
190 // `None`(`std::nullopt` in C++)
191 // Subtype hierarchy for Optional:
192 //     - Optional[T] <: Optional[R] iff T <: R
193 //     - T <: Optional[R] if T <: R
194 //     - None <: Optional[T] for all T
195 //     - Optional[T] == Union[T, None] for all T
196 struct TORCH_API OptionalType : public UnionType {
197   static OptionalTypePtr create(const TypePtr& contained);
198 
199   static const TypeKind Kind = TypeKind::OptionalType;
200 
201   friend struct Type;
202 
203   bool equals(const Type& rhs) const override;
204 
getElementTypeOptionalType205   const TypePtr& getElementType() const {
206     return contained_;
207   }
208 
containedTypesOptionalType209   at::ArrayRef<TypePtr> containedTypes() const override {
210     return contained_;
211   }
212 
strOptionalType213   std::string str() const override {
214     std::stringstream ss;
215     ss << getElementType()->str() << "?";
216     return ss.str();
217   }
218 
createWithContainedOptionalType219   TypePtr createWithContained(
220       std::vector<TypePtr> contained_types) const override {
221     AT_ASSERT(contained_types.size() == 1);
222     return create(contained_types[0]);
223   }
224 
225   bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
226 
isUnionTypeOptionalType227   bool isUnionType() const override {
228     return true;
229   }
230 
231   // common cast Optional[Tensor] for undefined tensor type
232   static TypePtr ofTensor();
233   //
234   // global singleton
235   static TypePtr get(TypePtr inner);
236 
237  private:
238   explicit OptionalType(const TypePtr& contained);
239 
240   TypePtr contained_;
241 
242   std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
243     std::stringstream ss;
244     ss << "Optional[" << getElementType()->annotation_str(printer) << "]";
245     return ss.str();
246   }
247 };
248 
249 template <typename T>
merge_primitive(const std::optional<T> & a,const std::optional<T> & b)250 inline std::optional<T> merge_primitive(
251     const std::optional<T>& a,
252     const std::optional<T>& b) {
253   if (a.has_value() && b.has_value() && a.value() == b.value()) {
254     return a;
255   }
256   return std::optional<T>{};
257 }
258 
259 // If we see `a + b + c`  and know that a, b, and c are the same size and have
260 // two dimensions (WxH), then we can generate a fused kernel for them. That
261 // fused kernel would likely have indexing math to handling both the W and H
262 // dimensions. However, if we knew the WxH dimensions were contiguous, we can
263 // pretend like we only have a single dimension, simplifying the indexing logic.
264 // This can be performed even if the dimensions are transposed,
265 // as long as a, b, and c are transposed in the same way.
266 // We'd like to have the compiler be able to do this dimensionality reduction,
267 // but simply knowing sizes is not enough.
268 // We can extend profiling to also record stride information.
269 // Rather than recording specific strides,
270 // we can simply order the strides from smallest to largest with
271 // `stride_indices` A contiguity marker on the smallest stride (c0) indicates
272 // the stride is precisely 1, otherwise a contiguity marker means that $stride_n
273 // = size_{n-1}*stride_{n-1}$
274 struct TORCH_API Stride {
275   Stride() = default;
StrideStride276   Stride(
277       const std::optional<size_t>& stride_index,
278       std::optional<bool> contiguous,
279       const std::optional<size_t>& stride)
280       : stride_index_(stride_index), contiguous_(contiguous), stride_(stride) {}
281 
282   bool operator==(const Stride& b) const {
283     return stride_index_ == b.stride_index_ && contiguous_ == b.contiguous_ &&
284         stride_ == b.stride_;
285   }
286 
isCompleteStride287   bool isComplete() const {
288     return stride_index_ && contiguous_ && stride_;
289   }
290 
291   std::optional<size_t> stride_index_;
292   std::optional<bool> contiguous_;
293   std::optional<size_t> stride_;
294 };
295 
296 template <>
merge_primitive(const std::optional<Stride> & a,const std::optional<Stride> & b)297 inline std::optional<Stride> merge_primitive(
298     const std::optional<Stride>& a,
299     const std::optional<Stride>& b) {
300   std::optional<Stride> left = a;
301   std::optional<Stride> right = b;
302   if (!left.has_value()) {
303     left = {Stride()};
304   }
305   if (!right.has_value()) {
306     right = {Stride()};
307   }
308 
309   auto merged_index =
310       merge_primitive(left->stride_index_, right->stride_index_);
311   auto merged_cont = merge_primitive(left->contiguous_, right->contiguous_);
312   auto merged_stride = merge_primitive(left->stride_, right->stride_);
313   auto r = Stride(merged_index, merged_cont, merged_stride);
314   // normalize
315   if (!r.stride_index_.has_value() && !r.contiguous_.has_value() &&
316       !r.stride_.has_value()) {
317     return std::optional<Stride>{};
318   }
319 
320   return r;
321 }
322 
323 struct TORCH_API ShapeSymbol {
324   // needed for use in `std::map`
ShapeSymbolShapeSymbol325   ShapeSymbol() : value_(-1) {}
326   // is this symbol a fixed/static dimension
is_staticShapeSymbol327   bool is_static() const {
328     return value_ >= 0;
329   };
330   bool operator==(const ShapeSymbol& b) const {
331     return value_ == b.value_;
332   }
333   bool operator<(const ShapeSymbol& b) const {
334     return value_ < b.value_;
335   }
336 
fromStaticSizeShapeSymbol337   static ShapeSymbol fromStaticSize(int64_t val) {
338     return ShapeSymbol(val);
339   }
static_sizeShapeSymbol340   int64_t static_size() const {
341     TORCH_CHECK(is_static());
342     return value_;
343   };
344 
valueShapeSymbol345   int64_t value() const {
346     return value_;
347   };
348 
newSymbolShapeSymbol349   static ShapeSymbol newSymbol() {
350     return fromStaticSize(-static_cast<int64_t>(++num_symbols));
351   };
352   friend TORCH_API std::ostream& operator<<(
353       std::ostream& os,
354       const ShapeSymbol& s);
355 
356  private:
ShapeSymbolShapeSymbol357   ShapeSymbol(int64_t val) : value_(val) {}
358   int64_t value_;
359   static std::atomic<size_t> num_symbols;
360 };
361 
merge_primitive(const ShapeSymbol & a,const ShapeSymbol & b)362 inline ShapeSymbol merge_primitive(
363     const ShapeSymbol& a,
364     const ShapeSymbol& b) {
365   if (a.is_static() && b.is_static() && a == b) {
366     return a;
367   }
368   return ShapeSymbol::newSymbol();
369 }
370 
371 // Shape of a Tensor represented with ShapeSymbol's. Unranked, ranked unknown
372 // dims, partially known and fully known shapes are all supported.
373 struct TORCH_API SymbolicShape {
374   // Unranked shape constructor.
SymbolicShapeSymbolicShape375   SymbolicShape() : dims_(std::nullopt) {}
376 
377   // Known rank but unknown dimentions.
SymbolicShapeSymbolicShape378   SymbolicShape(std::optional<size_t> rank) : dims_(std::nullopt) {
379     if(!rank) {
380       return;
381     }
382 
383     std::vector<ShapeSymbol> shape_symbols;
384     shape_symbols.reserve(*rank);
385     for(size_t i = 0; i < *rank; ++i) {
386       shape_symbols.push_back(ShapeSymbol::newSymbol());
387     }
388     dims_ = shape_symbols;
389   }
390 
391   // Mix of known and unknown ranks
SymbolicShapeSymbolicShape392   SymbolicShape(const std::vector<std::optional<int64_t>>& dims) {
393     std::vector<ShapeSymbol> shape_symbols;
394     shape_symbols.reserve(dims.size());
395     for(std::optional<int64_t> dim: dims) {
396       if(!dim) {
397         shape_symbols.push_back(ShapeSymbol::newSymbol());
398       } else {
399         shape_symbols.push_back(ShapeSymbol::fromStaticSize(*dim));
400       }
401     }
402     dims_ = shape_symbols;
403   }
404 
405   void dump() const;
406 
SymbolicShapeSymbolicShape407   SymbolicShape(std::vector<ShapeSymbol> dims) : dims_(std::move(dims)) {}
408 
SymbolicShapeSymbolicShape409   SymbolicShape(c10::IntArrayRef dims) {
410     std::vector<ShapeSymbol> shape_symbols;
411     shape_symbols.reserve(dims.size());
412     for(int64_t dim : dims) {
413       shape_symbols.push_back(ShapeSymbol::fromStaticSize(dim));
414     }
415     dims_ = shape_symbols;
416   }
417 
418   ShapeSymbol operator[](size_t i) const {
419     if (!dims_) {
420       throw std::runtime_error("Rank isn't fixed");
421     }
422     return (*dims_).at(i);
423   }
424 
atSymbolicShape425   ShapeSymbol at(size_t i) const {
426     if (!dims_) {
427       throw std::runtime_error("Rank isn't fixed");
428     }
429     return (*dims_).at(i);
430   }
431 
432   // Returns rank or nullopt in case of unranked shape.
rankSymbolicShape433   std::optional<size_t> rank() const {
434     if(!dims_) {
435       return std::nullopt;
436     }
437     return dims_->size();
438   }
439 
sizesSymbolicShape440   std::optional<std::vector<ShapeSymbol>> sizes() const {
441     return dims_;
442   }
443 
symbolicDimsSymbolicShape444   std::optional<std::vector<bool>> symbolicDims() const {
445     if (!dims_) {
446       return std::nullopt;
447     }
448     auto symbolic_dims = std::vector<bool>();
449     for (const ShapeSymbol& s : *dims_) {
450       symbolic_dims.push_back(!s.is_static());
451     }
452     return symbolic_dims;
453   }
454 
455   // Checks whether the shape is fully defined/complete, ie. rank and sizes
456   // of every dimension are known.
isCompleteSymbolicShape457   bool isComplete() const {
458     if(!dims_) {
459       return false;
460     }
461     for(auto d : *dims_) {
462       if(!d.is_static()) {
463         return false;
464       }
465     }
466     return true;
467   }
468 
469   // Create new SymbolicShape that is result of merging self and another
470   // SymbolicShape. Only dimensions that are static and equal will be
471   // preserved.
472   // If either of two shapes are of unknown rank or they have unmatching rank,
473   // result will be unranked.
474   SymbolicShape merge(const SymbolicShape& other) const;
475 
476   friend bool operator==(const SymbolicShape& lhs, const SymbolicShape& rhs) {
477     return lhs.dims_ == rhs.dims_;
478   }
479 
480   friend bool operator!=(const SymbolicShape& lhs, const SymbolicShape& rhs) {
481     return !(lhs == rhs);
482   }
483 
484   private:
485     std::optional<std::vector<ShapeSymbol>> dims_;
486 };
487 
488 namespace detail {
isComplete(const Stride & s)489 inline bool isComplete(const Stride& s) {
490   return s.isComplete();
491 }
492 
493 template<typename T>
isComplete(const T &)494 inline bool isComplete(const T& /*t*/) {
495   return true;
496 }
497 }
498 
499 template <typename T>
500 struct VaryingShape {
501   using ListOfOptionalElements = std::vector<std::optional<T>>;
VaryingShapeVaryingShape502   VaryingShape(const std::vector<T>& vec)
503       : VaryingShape(ListOfOptionalElements(vec.begin(), vec.end())) {}
504 
VaryingShapeVaryingShape505   VaryingShape(c10::ArrayRef<T> vec)
506       : VaryingShape(ListOfOptionalElements(vec.begin(), vec.end())) {}
507 
dims_VaryingShape508   VaryingShape(std::optional<size_t> size = std::nullopt) : dims_(std::nullopt) {
509     if (size) {
510       dims_ = ListOfOptionalElements(*size);
511     }
512   }
513 
VaryingShapeVaryingShape514   VaryingShape(ListOfOptionalElements dims) : dims_(std::move(dims)) {}
515 
VaryingShapeVaryingShape516   VaryingShape(size_t size) : VaryingShape(std::optional<size_t>(size)) {}
517 
518   bool operator==(const VaryingShape& other) const {
519     return dims_ == other.dims_;
520   }
521 
522   const std::optional<T> &operator[](size_t i) const {
523     if (!dims_) {
524       throw std::runtime_error("Rank isn't fixed");
525     }
526     return (*dims_).at(i);
527   }
528 
sizeVaryingShape529   std::optional<size_t> size() const {
530     if (!dims_) {
531       return std::nullopt;
532     }
533     const auto& dims = dims_.value();
534     return dims.size();
535   }
536 
sizesVaryingShape537   const std::optional<ListOfOptionalElements>& sizes() const {
538     return dims_;
539   }
540 
541   TORCH_API VaryingShape merge(const VaryingShape& other) const;
542 
concrete_sizesVaryingShape543   std::optional<std::vector<T>> concrete_sizes() const {
544     if (!dims_) {
545       return std::nullopt;
546     }
547     std::vector<T> sizes;
548     sizes.reserve(dims_.value().size());
549     for (auto d : *dims_) {
550       if (!d) {
551         return std::nullopt;
552       }
553       sizes.push_back(d.value());
554     }
555     return sizes;
556   }
557 
isCompleteVaryingShape558   bool isComplete() const {
559     if (!dims_) {
560       return false;
561     }
562     for (auto d : *dims_) {
563       if (!d || !detail::isComplete(*d)) {
564         return false;
565       }
566     }
567     return true;
568   }
569 
570  private:
571   std::optional<ListOfOptionalElements> dims_;
572 };
573 
574 struct TensorType;
575 // TODO: investigate making this SingletonOrSharedTypePtr<TensorType>
576 using TensorTypePtr = std::shared_ptr<TensorType>;
577 // This type represents a single Tensor with a specific size
578 struct TORCH_API TensorType : public SharedType {
579   static TensorTypePtr create(const at::Tensor& t);
580 
581   // used by TensorType::create(size_t dim) which in turn used by
582   // shape_analysis.cpp
583   static TensorTypePtr create(
584       std::optional<at::ScalarType> scalar_type,
585       std::optional<Device> device,
586       const VaryingShape<int64_t>& sizes,
587       const VaryingShape<int64_t>& strides,
588       std::optional<bool> requires_grad,
589       std::optional<bool> undefined = false,
590       bool tensor_contiguity = false);
591 
592   static TensorTypePtr create(
593       std::optional<at::ScalarType> scalar_type,
594       std::optional<Device> device,
595       const SymbolicShape& sizes,
596       const VaryingShape<Stride>& stride_,
597       std::optional<bool> requires_grad,
598       std::optional<bool> undefined = false);
599 
600   static TensorTypePtr create(
601       std::optional<at::ScalarType> scalar_type,
602       std::optional<Device> device,
603       std::optional<size_t> dim,
604       std::optional<bool> requires_grad);
605 
606   // overloaded create variadic template argument as it could not distinguish
607   // initializer list
608   static TensorTypePtr createContiguous(
609       at::ScalarType scalar_type,
610       at::Device device,
611       at::IntArrayRef sizes);
612 
613   static TypePtr fromNumberType(const Type& typ);
614   static TypePtr fromBoolType();
615 
dimTensorType616   std::optional<size_t> dim() const {
617     return sizes().size();
618   }
619 
620   VaryingShape<int64_t> sizes() const;
621 
622   VaryingShape<int64_t> strides() const;
623 
stride_propertiesTensorType624   const VaryingShape<Stride>& stride_properties() const {
625     return strides_;
626   }
627 
deviceTensorType628   std::optional<at::Device> device() const {
629     return device_;
630   }
scalarTypeTensorType631   std::optional<at::ScalarType> scalarType() const {
632     return scalar_type_;
633   }
requiresGradTensorType634   std::optional<bool> requiresGrad() const {
635     return requires_grad_;
636   }
requires_gradTensorType637   bool requires_grad() const override {
638     return requires_grad_ ? *requires_grad_ : true;
639   }
640 
641   bool equals(const Type& rhs) const override;
642   bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
643 
644   std::string str() const override;
645 
repr_strTensorType646   std::string repr_str() const override {
647     if (isInferredType()) {
648       return str() + " (inferred)";
649     } else {
650       return str();
651     }
652   }
653 
numelTensorType654   std::optional<size_t> numel() const {
655     size_t prod = 1;
656     const auto& shape = sizes();
657 
658     for (size_t i = 0; i < shape.size(); i++) {
659       if (!shape[i]) {
660         return std::optional<size_t>{};
661       }
662       prod *= shape[i].value();
663     }
664     return prod;
665   }
666 
withRequiresGradTensorType667   TensorTypePtr withRequiresGrad(std::optional<bool> s) {
668     auto copy = clone();
669     copy->requires_grad_ = s;
670     return copy;
671   }
672 
withScalarTypeTensorType673   TensorTypePtr withScalarType(std::optional<ScalarType> st) {
674     auto copy = clone();
675     copy->scalar_type_ = st;
676     return copy;
677   }
678 
withDimTensorType679   TensorTypePtr withDim(std::optional<size_t> d) {
680     auto copy = clone();
681     // withDim is only used by the legacy executor
682     // that only cares about the rank, so create dummy symbols)) :
683     copy->sizes_ = SymbolicShape(d);
684     copy->strides_ = VaryingShape<Stride>(d);
685     return copy;
686   }
687 
withStridesTensorType688   TensorTypePtr withStrides(VaryingShape<Stride> sstrides) const {
689     auto cloned = clone();
690     cloned->strides_ = std::move(sstrides);
691     return cloned;
692   }
693 
withSizesStridesTensorType694   TensorTypePtr withSizesStrides(
695       at::IntArrayRef sizes,
696       at::IntArrayRef strides) const {
697     auto cloned = clone();
698     auto ssizes = SymbolicShape(sizes);
699     cloned->sizes_ = ssizes;
700     cloned->strides_ = computeStrideProps(sizes, strides);
701     return cloned;
702   }
703 
withSymbolicShapesTensorType704   TensorTypePtr withSymbolicShapes(SymbolicShape ssizes) const {
705     auto cloned = clone();
706     cloned->sizes_ = std::move(ssizes);
707     return cloned;
708   }
709 
withSizesTensorType710   TensorTypePtr withSizes(at::IntArrayRef sizes) const {
711     return withSizesStrides(
712         sizes, contiguousStridesOf(sizes));
713   }
714 
withDeviceTensorType715   TensorTypePtr withDevice(const std::optional<at::Device> device) const {
716     auto copy = clone();
717     copy->device_ = device;
718     return copy;
719   }
720 
dimensionedOnlyTensorType721   TensorTypePtr dimensionedOnly() const {
722     auto copy = clone();
723     copy->sizes_ = SymbolicShape(sizes().size());
724     copy->strides_ = VaryingShape<Stride>(sizes().size());
725     return copy;
726   }
727 
contiguousTensorType728   TensorTypePtr contiguous() const {
729     auto cloned = clone();
730     TORCH_INTERNAL_ASSERT(sizes().concrete_sizes().has_value());
731     auto strides = computeStrideProps(
732         *sizes().concrete_sizes(),
733         contiguousStridesOf(*sizes().concrete_sizes()));
734     cloned->strides_ = strides;
735     return cloned;
736   }
737 
738   const SymbolicShape& symbolic_sizes() const;
739 
740   TensorTypePtr merge(const TensorType& other, bool merge_sizes = true) const;
741 
742   bool matchTensor(const at::Tensor& t);
743 
744   // is all information about the type specified except for autograd?
745   // This replaces the notion of a 'CompleteTensorType' that used to exist
746   // in the type-hierarchy. Excluding require_grad and undefined allows
747   // this to match the old behavior.
isCompleteTensorType748   bool isComplete() const {
749     return scalar_type_ && device_ && sizes_.isComplete() && strides_.isComplete();
750   }
751 
isInferredTypeTensorType752   bool isInferredType() const {
753     return is_inferred_;
754   }
755 
getInferredTensorType756   static TensorTypePtr getInferred() {
757     static auto valueInferred = TensorType::create(
758         /*scalar_type=*/{},
759         /*device=*/{},
760         /*sizes=*/SymbolicShape(),
761         /*stride=*/VaryingShape<Stride>{},
762         /*requires_grad=*/{},
763         /*undefined=*/false);
764     valueInferred->is_inferred_ = true;
765     return valueInferred;
766   }
767 
768   // this property is used by GuardElimination
769   // please see `checkInputs` for more details
isSummarizedTensorType770   bool isSummarized() const {
771     return !(isComplete() && requiresGrad().has_value() &&
772              undefined().has_value());
773   }
774 
withUndefinedTensorType775   TensorTypePtr withUndefined() {
776     auto r = clone();
777     r->undefined_ = true;
778     return r;
779   }
780 
withPossiblyUndefinedTensorType781   TensorTypePtr withPossiblyUndefined() {
782     auto r = clone();
783     r->undefined_ = std::nullopt;
784     return r;
785   }
786 
undefinedTensorType787   std::optional<bool> undefined() const { return undefined_; }
788 
789   static const TensorTypePtr& get();
790 
791   static const TypeKind Kind = TypeKind::TensorType;
792 
793   static std::vector<int64_t> contiguousStridesOf(
794       at::IntArrayRef in_sizes,
795       at::MemoryFormat memory_format = MemoryFormat::Contiguous) {
796     auto contiguous_fn = [](const at::IntArrayRef& sizes,
797                             const std::vector<int64_t>& dim_order) {
798       std::vector<int64_t> strides(sizes.size());
799       if (sizes.empty()) // zero-dim case
800         return strides;
801 
802       strides[dim_order[0]] = 1;
803       for (size_t i = 1; i < dim_order.size(); i++) {
804         auto cur_dim = dim_order[i];
805         auto pre_dim = dim_order[i - 1];
806         strides[cur_dim] = strides[pre_dim] * sizes[pre_dim];
807       }
808       return strides;
809     };
810 
811     std::vector<int64_t> dim_order(in_sizes.size());
812     if (memory_format == MemoryFormat::ChannelsLast) {
813       dim_order = {1, 3, 2, 0};
814     } else if (memory_format == MemoryFormat::ChannelsLast3d) {
815       dim_order = {1, 4, 3, 2, 0};
816     } else {
817       auto ndims = in_sizes.size();
818       for (size_t i = 0; i < ndims; i++) {
819         dim_order[i] = static_cast<int64_t>(ndims - i - 1); // Reverse
820       }
821     }
822     return contiguous_fn(in_sizes, dim_order);
823   }
824 
825  private:
826   TensorType(
827       std::optional<at::ScalarType> scalar_type,
828       std::optional<Device> device,
829       SymbolicShape sizes,
830       VaryingShape<Stride> strides,
831       std::optional<bool> requires_grad,
832       std::optional<bool> undefined = false);
833 
cloneTensorType834   TensorTypePtr clone() const {
835     return TensorTypePtr(new TensorType(
836         scalar_type_, device_, sizes_, strides_, requires_grad_, undefined_));
837   }
838 
839   static VaryingShape<Stride> computeStrideProps(
840       at::IntArrayRef sizes,
841       at::IntArrayRef strides,
842       bool tensor_contiguity = false);
843 
844   std::optional<at::ScalarType> scalar_type_;
845   std::optional<at::Device> device_;
846   SymbolicShape sizes_;
847   VaryingShape<Stride> strides_;
848   std::optional<bool> requires_grad_;
849   // we exploit the fact certain tensors must be zero in the autograd to
850   // optimize gradient computation. Such zero tensors are currently implemented
851   // with `UndefinedTensorImpl.` They can be handled only by special operators
852   // (e.g. `AutogradAdd`) and their `Tensor::defined()` property returns false.
853   // Normally, `undefined_` is set to false, unless a type was created
854   // with `withUndefined`
855   // This will also mean that `undefined` tensors will fail
856   // `subtypeOf(TensorType::get())` check
857   // undefined_ may become `std::nullopt` if the tensor was observed to be both
858   // defined and undefined. However, no tensor type starts out with
859   // `undefined_` set to `std::nullopt`
860   std::optional<bool> undefined_;
861   // Represents whether or not this type was inferred.
862   bool is_inferred_ = false;
863 };
864 
865 struct ListType;
866 using ListTypePtr = std::shared_ptr<ListType>;
867 struct TORCH_API ListType
868     : public SingleElementType<TypeKind::ListType, ListType> {
869   // It's not exactly a singleton, but there should be exactly one instance of
870   // List[T] for every T
871   friend struct Type;
872   template <typename... T>
createListType873   static ListTypePtr create(T&&... all) {
874     return ListTypePtr(
875         new ListType(std::forward<T>(all)...)); // NOLINT(modernize-make-shared)
876   }
877 
strListType878   std::string str() const override {
879     std::stringstream ss;
880     ss << getElementType()->str() << "[]";
881     return ss.str();
882   }
createWithContainedListType883   TypePtr createWithContained(
884       std::vector<TypePtr> contained_types) const override {
885     return create(std::move(contained_types.at(0)));
886   }
887 
888   bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
889 
890   // global singleton
891   // Given an inner type T and an identifier,
892   // this function wil return the global singleton type pointer
893   // the type List<T>.
894   // The extra "identifier" argument is needed beccause we have multiple container types
895   // that all re-use this function (List<T>, array<T, N>, etc.)
896   static TypePtr get(const std::string& identifier, TypePtr inner);
897 
898   // common cast List[Tensor]
899   static ListTypePtr ofTensors();
900   static ListTypePtr ofOptionalTensors();
901   static ListTypePtr ofInts();
902   static ListTypePtr ofSymInts();
903   static ListTypePtr ofFloats();
904   static ListTypePtr ofComplexDoubles();
905   static ListTypePtr ofBools();
906   static ListTypePtr ofStrings();
907   static ListTypePtr ofNumbers();
908 
909  private:
ListTypeListType910   ListType(TypePtr elem) : SingleElementType(std::move(elem)) {}
911 
912   std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
913     std::stringstream ss;
914     ss << "List[" << getElementType()->annotation_str(printer) << "]";
915     return ss.str();
916   }
917 };
918 
919 struct DictType;
920 using DictTypePtr = std::shared_ptr<DictType>;
921 struct TORCH_API DictType : public SharedType {
922   friend struct Type;
923   static const TypeKind Kind = TypeKind::DictType;
924 
createDictType925   static DictTypePtr create(TypePtr key, TypePtr value) {
926     auto kind = key->kind();
927     if (auto dyn = key->castRaw<DynamicType>()) {
928       kind = dyn->dynamicKind();
929     }
930     switch (kind) {
931       case TypeKind::AnyType:
932       case TypeKind::IntType:
933       case TypeKind::BoolType:
934       case TypeKind::FloatType:
935       case TypeKind::ComplexType:
936       case TypeKind::StringType:
937       case TypeKind::TensorType:
938       case TypeKind::DeviceObjType:
939         return DictTypePtr(new DictType(std::move(key), std::move(value)));
940       default:
941         AT_ERROR(
942             "Cannot create dict for key type '",
943             key->str(),
944             "', only int, float, complex, Tensor, device and string keys are supported");
945     }
946   }
947 
948   // aligned with the format in FunctionSchema
strDictType949   std::string str() const override {
950     std::stringstream ss;
951     ss << "Dict(" << getKeyType()->str() << ", " << getValueType()->str()
952        << ")";
953     return ss.str();
954   }
955 
createWithContainedDictType956   TypePtr createWithContained(
957       std::vector<TypePtr> contained_types) const override {
958     if (contained_types.size() != 2) {
959       throw std::runtime_error("Expected 2 contained types");
960     }
961     return create(std::move(contained_types.at(0)), std::move(contained_types.at(1)));
962   }
963 
getKeyTypeDictType964   const TypePtr& getKeyType() const {
965     return types.at(0);
966   }
967 
getValueTypeDictType968   const TypePtr& getValueType() const {
969     return types.at(1);
970   }
971 
hasFreeVariablesDictType972   bool hasFreeVariables() const override {
973     return has_free_variables;
974   }
975 
containedTypesDictType976   at::ArrayRef<TypePtr> containedTypes() const override {
977     return types;
978   }
979 
equalsDictType980   bool equals(const Type& rhs) const override {
981     if (auto* dict_rhs = rhs.castRaw<DictType>()) {
982       return *getKeyType() == *(dict_rhs->getKeyType()) &&
983           *getValueType() == *(dict_rhs->getValueType());
984     }
985     return false;
986   }
987 
988   // global singleton
989   // Given an inner type T and an identifier,
990   // this function will return the global singleton type pointer
991   // the type List<T>.
992   // The extra "identifier" argument is needed because we have multiple container types
993   // that all re-use this function (Dict<K, V> and unordered_map<K, V>)
994   static TypePtr get(const std::string& identifier, TypePtr key, TypePtr val);
995 
996  private:
DictTypeDictType997   DictType(TypePtr key, TypePtr value)
998       : SharedType(TypeKind::DictType),
999         has_free_variables(
1000             key->hasFreeVariables() || value->hasFreeVariables()) {
1001     types.reserve(2);
1002     types.push_back(std::move(key));
1003     types.push_back(std::move(value));
1004   }
1005 
1006   std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override;
1007 
1008   std::vector<TypePtr> types;
1009   bool has_free_variables;
1010 };
1011 
1012 struct FutureType;
1013 using FutureTypePtr = std::shared_ptr<FutureType>;
1014 
1015 struct TORCH_API FutureType
1016     : public SingleElementType<TypeKind::FutureType, FutureType> {
1017   friend struct Type;
1018   template <typename... T>
createFutureType1019   static FutureTypePtr create(TypePtr elem) {
1020     return FutureTypePtr(
1021         new FutureType(std::move(elem))); // NOLINT(modernize-make-shared)
1022   }
1023 
strFutureType1024   std::string str() const override {
1025     std::stringstream ss;
1026     ss << "Future(" << getElementType()->str() << ")";
1027     return ss.str();
1028   }
createWithContainedFutureType1029   TypePtr createWithContained(
1030       std::vector<TypePtr> contained_types) const override {
1031     return create(std::move(contained_types.at(0)));
1032   }
1033 
isSubtypeOfExtFutureType1034   bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override {
1035     if (Type::isSubtypeOfExt(rhs, why_not)) {
1036       return true;
1037     }
1038     if (auto rhs_ = rhs.castRaw<FutureType>()) {
1039       return getElementType()->isSubtypeOfExt(*rhs_->getElementType(), why_not);
1040     }
1041     return false;
1042   }
1043 
1044  private:
FutureTypeFutureType1045   FutureType(TypePtr elem) : SingleElementType(std::move(elem)) {}
1046 
1047   std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
1048     std::stringstream ss;
1049     ss << "Future[" << getElementType()->annotation_str(printer) << "]";
1050     return ss.str();
1051   }
1052 };
1053 
1054 struct AwaitType;
1055 using AwaitTypePtr = std::shared_ptr<AwaitType>;
1056 
1057 struct TORCH_API AwaitType
1058     : public SingleElementType<TypeKind::AwaitType, AwaitType> {
1059   friend struct Type;
1060   template <typename... T>
createAwaitType1061   static AwaitTypePtr create(TypePtr elem) {
1062     return AwaitTypePtr(
1063         new AwaitType(std::move(elem))); // NOLINT(modernize-make-shared)
1064   }
1065 
strAwaitType1066   std::string str() const override {
1067     std::stringstream ss;
1068     ss << "Await(" << getElementType()->str() << ")";
1069     return ss.str();
1070   }
createWithContainedAwaitType1071   TypePtr createWithContained(
1072       std::vector<TypePtr> contained_types) const override {
1073     return create(std::move(contained_types.at(0)));
1074   }
1075 
isSubtypeOfExtAwaitType1076   bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override {
1077     if (Type::isSubtypeOfExt(rhs, why_not)) {
1078       return true;
1079     }
1080     if (auto rhs_ = rhs.castRaw<AwaitType>()) {
1081       return getElementType()->isSubtypeOfExt(*rhs_->getElementType(), why_not);
1082     }
1083     return false;
1084   }
1085 
1086  private:
AwaitTypeAwaitType1087   AwaitType(TypePtr elem) : SingleElementType(std::move(elem)) {}
1088 
1089   std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
1090     std::stringstream ss;
1091     ss << "Await[" << getElementType()->annotation_str(printer) << "]";
1092     return ss.str();
1093   }
1094 };
1095 
1096 struct RRefType;
1097 using RRefTypePtr = std::shared_ptr<RRefType>;
1098 
1099 struct TORCH_API RRefType
1100     : public SingleElementType<TypeKind::RRefType, RRefType> {
1101   friend struct Type;
1102   template <typename... T>
createRRefType1103   static RRefTypePtr create(TypePtr elem) {
1104     return RRefTypePtr(
1105         new RRefType(std::move(elem))); // NOLINT(modernize-make-shared)
1106   }
1107 
strRRefType1108   std::string str() const override {
1109     std::stringstream ss;
1110     ss << "RRef(" << getElementType()->str() << ")";
1111     return ss.str();
1112   }
createWithContainedRRefType1113   TypePtr createWithContained(
1114       std::vector<TypePtr> contained_types) const override {
1115     return create(std::move(contained_types.at(0)));
1116   }
1117 
1118  private:
RRefTypeRRefType1119   RRefType(TypePtr elem) : SingleElementType(std::move(elem)) {}
1120 
1121   std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
1122     std::stringstream ss;
1123     ss << "RRef[" << getElementType()->annotation_str(printer) << "]";
1124     return ss.str();
1125   }
1126 };
1127 
1128 // Any should never appear in a named type like a class, namedtuple or
1129 // interface. If it does, then dynamic type information will be lost in the
1130 // Pickler, leading to hard-to-track-down bugs that will only occur
1131 // after saving or loading a model. This is because we rely on the
1132 // static types in named types to reconstruct type tags of loaded
1133 // values. Lifting this restriction requires solving the serialization
1134 // problem first.
1135 TORCH_API void checkNoAny(
1136     const Type& base,
1137     const char* what,
1138     const std::string& attrname,
1139     const TypePtr& attrtype);
1140 
1141 struct TupleType;
1142 using TupleTypePtr = std::shared_ptr<TupleType>;
1143 using NameList = std::vector<std::string>;
1144 // This type represents a Tuple
1145 struct TORCH_API TupleType : public NamedType {
1146 
1147   static TupleTypePtr createNamed(const std::optional<c10::QualifiedName>& name,
1148       const std::vector<std::string>& field_names,
1149       const std::vector<TypePtr>& field_types,
1150       std::vector<IValue>& field_defaults);
1151 
1152   static TupleTypePtr createNamed(const std::optional<c10::QualifiedName>& name,
1153       const std::vector<std::string>& field_names,
1154       const std::vector<TypePtr>& field_types);
1155 
1156   static TupleTypePtr createNamed(const std::optional<c10::QualifiedName>& name,
1157       const std::vector<c10::string_view>& field_names,
1158       const std::vector<TypePtr>& field_types);
1159 
createTupleType1160   static TupleTypePtr create(
1161       std::vector<TypePtr> types) {
1162     return TupleTypePtr(new TupleType(
1163         std::move(types),
1164         std::nullopt,
1165         nullptr)); // NOLINT(modernize-make-shared)
1166   }
createTupleType1167   static TupleTypePtr create() {
1168     return create({});
1169   }
1170 
elementsTupleType1171   at::ArrayRef<TypePtr> elements() const {
1172     return elements_;
1173   }
1174 
1175   bool equals(const Type& rhs) const override;
1176   bool isSubtypeOfExt(const Type& rhs_, std::ostream* why_not) const override;
1177 
1178   std::string str() const override;
hasFreeVariablesTupleType1179   bool hasFreeVariables() const override {
1180     return has_free_variables_;
1181   }
containedTypesTupleType1182   at::ArrayRef<TypePtr> containedTypes() const override {
1183     return elements_;
1184   }
createWithContainedTupleType1185   TypePtr createWithContained(
1186       std::vector<TypePtr> contained_types) const override {
1187     return std::shared_ptr<TupleType>(
1188         new TupleType(std::move(contained_types), name(), schema()));
1189   }
schemaTupleType1190   const std::shared_ptr<FunctionSchema>& schema() const {
1191     return schema_;
1192   }
1193   std::optional<std::vector<c10::string_view>> names() const;
1194 
1195   static const TypeKind Kind = TypeKind::TupleType;
1196 
1197  private:
1198   template <typename S>
1199   static TupleTypePtr createWithSpec(
1200       const std::optional<c10::QualifiedName>& name,
1201       const std::vector<S>& field_names,
1202       const std::vector<TypePtr>& field_types,
1203       std::vector<IValue>& field_defaults);
1204 
1205   TupleType(
1206       std::vector<TypePtr> elements_,
1207       std::optional<c10::QualifiedName> name,
1208       std::shared_ptr<FunctionSchema> schema);
1209 
compareTupleType1210   bool compare(
1211       const Type& rhs,
1212       const std::function<bool(const Type&, const Type&)>& fn) const {
1213     if (rhs.kind() != kind()) {
1214       return false;
1215     }
1216 
1217     const auto& l_elements = elements();
1218     const auto& r_elements = rhs.castRaw<TupleType>()->elements();
1219     if (l_elements.size() != r_elements.size())
1220       return false;
1221     for (size_t i = 0; i < l_elements.size(); ++i) {
1222       if (!fn(*l_elements[i], *r_elements[i]))
1223         return false;
1224     }
1225     return true;
1226   }
1227 
1228   std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override;
1229 
1230   std::vector<TypePtr> elements_;
1231   bool has_free_variables_;
1232   std::shared_ptr<FunctionSchema> schema_;
1233 };
1234 
1235 // the common supertype of all Enums, only used in operator registraion.
1236 // EnumType <: AnyEnumType for all Enums
1237 struct AnyEnumType;
1238 using AnyEnumTypePtr = SingletonTypePtr<AnyEnumType>;
1239 struct TORCH_API AnyEnumType final : public Type {
equalsfinal1240   bool equals(const Type& rhs) const override {
1241     return rhs.kind() == kind();
1242   }
strfinal1243   std::string str() const override {
1244     return "AnyEnumType";
1245   }
1246   static const TypeKind Kind = TypeKind::AnyEnumType;
1247   // global singleton
1248   static AnyEnumTypePtr get();
1249 private:
AnyEnumTypefinal1250   AnyEnumType()
1251   : Type(TypeKind::AnyEnumType) {}
1252 };
1253 
1254 struct NumberType;
1255 using NumberTypePtr = SingletonTypePtr<NumberType>;
1256 // This type represents a Python number
1257 // Subtype hierarchy for Number Types (NumberType as the base type):
1258 // IntType <: NumberType
1259 // FloatType <: NumberType
1260 // ComplexType <:NumberType
1261 //
1262 // WARNING: if you add a new subtype of NumberType that is not
1263 // represented by a global singleton, you need to change NumberTypePtr
1264 // to a SingletonOrSharedTypePtr and deal with NumberType needing to
1265 // both inherit and not inherit from SharedType!
1266 struct TORCH_API NumberType : public Type {
1267   bool equals(const Type& rhs) const override;
1268 
1269   bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
1270 
strNumberType1271   std::string str() const override {
1272     return "Scalar"; // match what PythonArgParser says for clarity
1273   }
1274   static const TypeKind Kind = TypeKind::NumberType;
1275   // global singleton
1276   static NumberTypePtr get();
1277 
1278  protected:
TypeNumberType1279   NumberType(TypeKind kind = TypeKind::NumberType) : Type(kind) {}
1280 
1281   std::string annotation_str_impl(C10_UNUSED const TypePrinter& printer = nullptr) const override {
1282     return "number"; // technically not a valid python type, but
1283                      // we need to use it when parsing back in annotations
1284                      // for implicit conversions
1285   }
1286 };
1287 
1288 struct FloatType;
1289 using FloatTypePtr = SingletonTypePtr<FloatType>;
1290 // This type represents a Python float number
1291 struct TORCH_API FloatType : public NumberType {
equalsFloatType1292   bool equals(const Type& rhs) const override {
1293     return rhs.kind() == kind();
1294   }
strFloatType1295   std::string str() const override {
1296     return "float";
1297   }
isSubtypeOfExtFloatType1298   bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override {
1299     // NOLINTNEXTLINE(bugprone-parent-virtual-call)
1300     return rhs.kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
1301   }
1302   static const TypeKind Kind = TypeKind::FloatType;
1303   // global singleton
1304   static FloatTypePtr get();
1305 
1306  private:
FloatTypeFloatType1307   FloatType() : NumberType(TypeKind::FloatType) {}
1308   std::string annotation_str_impl(C10_UNUSED const TypePrinter& printer = nullptr) const override {
1309     return "float";
1310   }
1311 };
1312 
1313 struct ComplexType;
1314 using ComplexTypePtr = SingletonTypePtr<ComplexType>;
1315 // This type represents a Python float number
1316 struct TORCH_API ComplexType : public NumberType {
equalsComplexType1317   bool equals(const Type& rhs) const override {
1318     return rhs.kind() == kind();
1319   }
strComplexType1320   std::string str() const override {
1321     return "complex";
1322   }
isSubtypeOfExtComplexType1323   bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override {
1324     // NOLINTNEXTLINE(bugprone-parent-virtual-call)
1325     return rhs.kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
1326   }
1327   static const TypeKind Kind = TypeKind::ComplexType;
1328   // global singleton
1329   static ComplexTypePtr get();
1330 
1331  private:
ComplexTypeComplexType1332   ComplexType() : NumberType(TypeKind::ComplexType) {}
1333   std::string annotation_str_impl(C10_UNUSED const TypePrinter& printer = nullptr) const override {
1334     return "complex";
1335   }
1336 };
1337 
1338 // We need to introduce `SymIntType` to represent the `SymInt` type
1339 // used in function schemas e.g. `aten::narrow_copy(... SymInt length)
1340 // `SymInt` will be used to enable tracing arithmetic operations on
1341 // dimension values. Please see [SymInt.h] for more information
1342 struct SymIntType;
1343 using SymIntTypePtr = SingletonTypePtr<SymIntType>;
1344 struct TORCH_API SymIntType : public Type {
equalsSymIntType1345   bool equals(const Type& rhs) const override {
1346     return rhs.kind() == kind();
1347   }
strSymIntType1348   std::string str() const override {
1349     return "SymInt";
1350   }
1351   std::string annotation_str_impl(const TypePrinter& printer [[maybe_unused]] = nullptr) const override {
1352     return "int";
1353   }
1354   static const TypeKind Kind = TypeKind::SymIntType;
1355   // global singleton
1356   static SymIntTypePtr get();
1357 
1358  private:
SymIntTypeSymIntType1359   SymIntType() : Type(TypeKind::SymIntType) {}
1360 };
1361 
1362 struct SymFloatType;
1363 using SymFloatTypePtr = SingletonTypePtr<SymFloatType>;
1364 struct TORCH_API SymFloatType : public Type {
equalsSymFloatType1365   bool equals(const Type& rhs) const override {
1366     return rhs.kind() == kind();
1367   }
strSymFloatType1368   std::string str() const override {
1369     return "SymFloat";
1370   }
1371   std::string annotation_str_impl(const TypePrinter& printer [[maybe_unused]] = nullptr) const override {
1372     return "float";
1373   }
1374   static const TypeKind Kind = TypeKind::SymFloatType;
1375   // global singleton
1376   static SymFloatTypePtr get();
1377 
1378  private:
SymFloatTypeSymFloatType1379   SymFloatType() : Type(TypeKind::SymFloatType) {}
1380 };
1381 
1382 struct SymBoolType;
1383 using SymBoolTypePtr = SingletonTypePtr<SymBoolType>;
1384 struct TORCH_API SymBoolType : public Type {
equalsSymBoolType1385   bool equals(const Type& rhs) const override {
1386     return rhs.kind() == kind();
1387   }
strSymBoolType1388   std::string str() const override {
1389     return "SymBool";
1390   }
1391   std::string annotation_str_impl(const TypePrinter& printer [[maybe_unused]] = nullptr) const override {
1392     return "bool";
1393   }
1394   static const TypeKind Kind = TypeKind::SymBoolType;
1395   // global singleton
1396   static SymBoolTypePtr get();
1397 
1398  private:
SymBoolTypeSymBoolType1399   SymBoolType() : Type(TypeKind::SymBoolType) {}
1400 };
1401 
1402 struct IntType;
1403 using IntTypePtr = SingletonTypePtr<IntType>;
1404 // This type represents a Python int number
1405 struct TORCH_API IntType : public NumberType {
equalsIntType1406   bool equals(const Type& rhs) const override {
1407     return rhs.kind() == kind();
1408   }
strIntType1409   std::string str() const override {
1410     return "int";
1411   }
isSubtypeOfExtIntType1412   bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override {
1413     // NOLINTNEXTLINE(bugprone-parent-virtual-call)
1414     return rhs.kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
1415   }
1416   static const TypeKind Kind = TypeKind::IntType;
1417   // global singleton
1418   static IntTypePtr get();
1419 
1420  private:
IntTypeIntType1421   IntType() : NumberType(TypeKind::IntType) {}
1422   std::string annotation_str_impl(C10_UNUSED const TypePrinter& printer = nullptr) const override {
1423     return "int";
1424   }
1425 };
1426 
1427 struct BoolType;
1428 using BoolTypePtr = SingletonTypePtr<BoolType>;
1429 // This node represents a Python bool value
1430 struct TORCH_API BoolType : public Type {
equalsBoolType1431   bool equals(const Type& rhs) const override {
1432     return rhs.kind() == kind();
1433   }
strBoolType1434   std::string str() const override {
1435     return "bool";
1436   }
1437   static const TypeKind Kind = TypeKind::BoolType;
1438   // global singleton
1439   static BoolTypePtr get();
1440 
1441  private:
BoolTypeBoolType1442   BoolType() : Type(TypeKind::BoolType) {}
1443 };
1444 
1445 struct StringType;
1446 using StringTypePtr = SingletonTypePtr<StringType>;
1447 // This type represents a Python string
1448 struct TORCH_API StringType : public Type {
equalsStringType1449   bool equals(const Type& rhs) const override {
1450     return rhs.kind() == kind();
1451   }
strStringType1452   std::string str() const override {
1453     // we only use "str" (not "string") in both FunctionSchema and script
1454     return annotation_str();
1455   }
1456   std::string annotation_str_impl(C10_UNUSED const TypePrinter& printer = nullptr) const override {
1457     return "str";
1458   }
1459   static const TypeKind Kind = TypeKind::StringType;
1460   // global singleton
1461   static StringTypePtr get();
1462 
1463  private:
StringTypeStringType1464   StringType() : Type(TypeKind::StringType) {}
1465 };
1466 
1467 struct StorageType;
1468 using StorageTypePtr = SingletonTypePtr<StorageType>;
1469 struct TORCH_API StorageType : public Type {
equalsStorageType1470   bool equals(const Type& rhs) const override {
1471     return rhs.kind() == kind();
1472   }
strStorageType1473   std::string str() const override {
1474     return annotation_str();
1475   }
1476   std::string annotation_str_impl(C10_UNUSED const TypePrinter& printer = nullptr) const override {
1477     return "Storage";
1478   }
1479   static const TypeKind Kind = TypeKind::StorageType;
1480   // global singleton
1481   static StorageTypePtr get();
1482 
1483  private:
StorageTypeStorageType1484   StorageType() : Type(TypeKind::StorageType) {}
1485 };
1486 
1487 struct FunctionType;
1488 using FunctionTypePtr = std::shared_ptr<FunctionType>;
1489 struct TORCH_API FunctionType : public NamedType {
createFunctionType1490   static FunctionTypePtr create(torch::jit::Function* function) {
1491     return FunctionTypePtr(
1492         new FunctionType(function)); // NOLINT(modernize-make-shared)
1493   }
equalsFunctionType1494   bool equals(const Type& rhs) const override {
1495     if (auto func_type = rhs.cast<FunctionType>()) {
1496       return func_type->function_ == function_;
1497     }
1498 
1499     return false;
1500   }
strFunctionType1501   std::string str() const override {
1502     return "Function";
1503   }
functionFunctionType1504   torch::jit::Function* function() const {
1505     return function_;
1506   }
1507   static const TypeKind Kind = TypeKind::FunctionType;
1508 
1509  private:
1510   FunctionType(torch::jit::Function* function);
1511   std::string annotation_str_impl(C10_UNUSED const TypePrinter& printer = nullptr) const override {
1512     const auto& n = name().value();
1513     return n.qualifiedName();
1514   }
1515   torch::jit::Function* function_;
1516 };
1517 
1518 struct NoneType;
1519 using NoneTypePtr = SingletonTypePtr<NoneType>;
1520 // This type represents a Python None
1521 struct TORCH_API NoneType : public Type {
equalsNoneType1522   bool equals(const Type& rhs) const override {
1523     return rhs.kind() == kind();
1524   }
strNoneType1525   std::string str() const override {
1526     return "NoneType";
1527   }
1528   bool isSubtypeOfExt(const Type& rhs, std::ostream *why_not) const override;
1529 
1530   static const TypeKind Kind = TypeKind::NoneType;
1531   // global singleton
1532   static NoneTypePtr get();
1533 
1534  private:
NoneTypeNoneType1535   NoneType() : Type(TypeKind::NoneType) {}
1536 };
1537 
1538 struct GeneratorType;
1539 using GeneratorTypePtr = SingletonTypePtr<GeneratorType>;
1540 // This type represents a Generator
1541 struct TORCH_API GeneratorType : public Type {
equalsGeneratorType1542   bool equals(const Type& rhs) const override {
1543     return rhs.kind() == kind();
1544   }
strGeneratorType1545   std::string str() const override {
1546     return "Generator";
1547   }
1548   static const TypeKind Kind = TypeKind::GeneratorType;
1549   // global singleton
1550   static GeneratorTypePtr get();
1551 
1552  private:
GeneratorTypeGeneratorType1553   GeneratorType() : Type(TypeKind::GeneratorType) {}
1554 };
1555 
1556 struct QuantizerType;
1557 using QuantizerTypePtr = SingletonTypePtr<QuantizerType>;
1558 // This type represents a Quantizer
1559 struct TORCH_API QuantizerType : public Type {
equalsQuantizerType1560   bool equals(const Type& rhs) const override {
1561     return rhs.kind() == kind();
1562   }
strQuantizerType1563   std::string str() const override {
1564     return "Quantizer";
1565   }
1566   static const TypeKind Kind = TypeKind::QuantizerType;
1567   // global singleton
1568   static QuantizerTypePtr get();
1569 
1570  private:
QuantizerTypeQuantizerType1571   QuantizerType() : Type(TypeKind::QuantizerType) {}
1572 };
1573 
1574 struct QSchemeType;
1575 using QSchemeTypePtr = SingletonTypePtr<QSchemeType>;
1576 // This type represents a QScheme
1577 struct TORCH_API QSchemeType : public Type {
equalsQSchemeType1578   bool equals(const Type& rhs) const override {
1579     return rhs.kind() == kind();
1580   }
strQSchemeType1581   std::string str() const override {
1582     return "QScheme";
1583   }
1584   static const TypeKind Kind = TypeKind::QSchemeType;
1585   // global singleton
1586   static QSchemeTypePtr get();
1587 
1588  private:
QSchemeTypeQSchemeType1589   QSchemeType() : Type(TypeKind::QSchemeType) {}
1590 };
1591 
1592 struct DeviceObjType;
1593 using DeviceObjTypePtr = SingletonTypePtr<DeviceObjType>;
1594 // This type represents a Device
1595 struct TORCH_API DeviceObjType : public Type {
equalsDeviceObjType1596   bool equals(const Type& rhs) const override {
1597     return rhs.kind() == kind();
1598   }
strDeviceObjType1599   std::string str() const override {
1600     return "Device";
1601   }
1602   static const TypeKind Kind = TypeKind::DeviceObjType;
1603   // global singleton
1604   static DeviceObjTypePtr get();
1605 
1606  private:
DeviceObjTypeDeviceObjType1607   DeviceObjType() : Type(TypeKind::DeviceObjType) {}
1608 };
1609 
1610 struct StreamObjType;
1611 using StreamObjTypePtr = SingletonTypePtr<StreamObjType>;
1612 // This type represents a Generator
1613 struct TORCH_API StreamObjType : public Type {
equalsStreamObjType1614   bool equals(const Type& rhs) const override {
1615     return rhs.kind() == kind();
1616   }
strStreamObjType1617   std::string str() const override {
1618     return "Stream";
1619   }
1620   static const TypeKind Kind = TypeKind::StreamObjType;
1621   // global singleton
1622   static StreamObjTypePtr get();
1623 
1624 private:
StreamObjTypeStreamObjType1625   StreamObjType() : Type(TypeKind::StreamObjType) {}
1626 };
1627 
1628 struct VarType;
1629 using VarTypePtr = std::shared_ptr<VarType>;
1630 // This type represents a type variable, used in FunctionSchema
1631 struct VarType : public SharedType {
createVarType1632   static VarTypePtr create(std::string name_) {
1633     return VarTypePtr(new VarType(std::move(name_)));
1634   }
equalsVarType1635   bool equals(const Type& rhs) const override {
1636     return rhs.kind() == kind();
1637   }
strVarType1638   std::string str() const override {
1639     return name();
1640   }
nameVarType1641   const std::string& name() const {
1642     return name_;
1643   }
hasFreeVariablesVarType1644   bool hasFreeVariables() const override {
1645     return true;
1646   }
1647   static const TypeKind Kind = TypeKind::VarType;
1648 
1649  private:
VarTypeVarType1650   VarType(std::string name_)
1651       : SharedType(TypeKind::VarType), name_(std::move(name_)) {}
1652   std::string name_;
1653 };
1654 
1655 struct CapsuleType;
1656 using CapsuleTypePtr = SingletonTypePtr<CapsuleType>;
1657 // This type represents a Python Capsule.
1658 // It does not appear in the IR and is only used during runtime
1659 struct TORCH_API CapsuleType : public Type {
equalsCapsuleType1660   bool equals(const Type& rhs) const override {
1661     return rhs.kind() == kind();
1662   }
strCapsuleType1663   std::string str() const override {
1664     return "Capsule";
1665   }
1666   static const TypeKind Kind = TypeKind::CapsuleType;
1667   // global singleton
1668   static CapsuleTypePtr get();
1669 private:
CapsuleTypeCapsuleType1670   CapsuleType()
1671   : Type(TypeKind::CapsuleType) {}
1672 };
1673 
1674 struct PyObjectType;
1675 using PyObjectTypePtr = SingletonTypePtr<PyObjectType>;
1676 // This type represents a PyObject Type
1677 struct TORCH_API PyObjectType : public Type {
equalsPyObjectType1678   bool equals(const Type& rhs) const override {
1679     return rhs.kind() == kind();
1680   }
strPyObjectType1681   std::string str() const override {
1682     return "PyObject";
1683   }
1684   static const TypeKind Kind = TypeKind::PyObjectType;
1685   // global singleton
1686   static PyObjectTypePtr get();
1687 private:
PyObjectTypePyObjectType1688   PyObjectType()
1689   : Type(TypeKind::PyObjectType) {}
1690 };
1691 
1692 enum class TypeVerbosity {
1693   None,
1694   Type,
1695   TypeAndStride,
1696   Full,
1697   Symbolic,
1698   Default = Full,
1699 };
1700 
1701 TORCH_API TypeVerbosity type_verbosity();
1702 
1703 TORCH_API std::ostream& operator<<(std::ostream& out, const Type& t);
1704 template <typename T>
1705 TORCH_API std::ostream& operator<<(
1706     std::ostream& out,
1707     const VaryingShape<T>& t);
1708 TORCH_API std::ostream& operator<<(std::ostream& os, const SymbolicShape& s);
1709 TORCH_API std::ostream& operator<<(std::ostream& os, const ShapeSymbol& s);
1710 TORCH_API std::ostream& operator<<(std::ostream& os, const Stride& s);
1711 // what is the type, ignoring extra size/shape information?
1712 // e.g. Tensor(2x3) -> Dynamic, and Tuple(Tensor(2x3),...) -> Tuple(Dynamic,...)
1713 
1714 // `unshapedType` is used to remove Tensor subtypes. We treat all Tensor
1715 // subtypes as simply "Tensor"; we also create a new version of any
1716 // container types in which internal Tensors have undergone the same
1717 // operation. This is used for type comparisons between two Tensor types
1718 // (`unshapedType` means that we don't falsely return `false` for e.g.
1719 // Tensors of different dimensions). It's also used in the alias
1720 // analysis pass.
1721 // Be careful with calls because this can be very slow. If calling this
1722 // on a graph, use `EraseShapeInformation` in shape_analysis.h
unshapedType(const TypePtr & type)1723 inline TypePtr unshapedType(const TypePtr& type) {
1724   if (type->isSubtypeOf(*TensorType::get())) {
1725     return TensorType::get();
1726   }
1727   at::ArrayRef<TypePtr> contained = type->containedTypes();
1728   if (contained.empty()) {
1729     return type;
1730   }
1731   return type->withContained(fmap(type->containedTypes(), unshapedType));
1732 }
1733 
fromNumberType(const Type & typ)1734 inline TypePtr TensorType::fromNumberType(const Type& typ) {
1735   if (typ.isSubtypeOf(*IntType::get())) {
1736     return TensorType::createContiguous(at::kLong, at::kCPU, {});
1737   } else if (typ.isSubtypeOf(*FloatType::get())) {
1738     return TensorType::createContiguous(at::kDouble, at::kCPU, {});
1739   } else if (typ.isSubtypeOf(*BoolType::get())) {
1740     return TensorType::createContiguous(at::kBool, at::kCPU, {});
1741   } else if (typ.kind() == NumberType::Kind) {
1742     return TensorType::create(std::nullopt, at::kCPU, {}, std::nullopt);
1743   }
1744   TORCH_CHECK(false, "Unknown number type: ", typ.str());
1745 }
fromBoolType()1746 inline TypePtr TensorType::fromBoolType() {
1747   return TensorType::createContiguous(at::kBool, at::kCPU, {});
1748 }
1749 
tryScalarTypeFromJitType(const Type & type)1750 inline std::optional<c10::ScalarType> tryScalarTypeFromJitType(const Type& type) {
1751   if (type == *FloatType::get()) {
1752     return at::typeMetaToScalarType(c10::get_default_dtype());
1753   } else if (type == *IntType::get()) {
1754     return at::ScalarType::Long;
1755   } else if (type == *BoolType::get()) {
1756     return at::ScalarType::Bool;
1757   }
1758   return std::nullopt;
1759 }
1760 
scalarTypeFromJitType(const Type & type)1761 inline at::ScalarType scalarTypeFromJitType(const Type& type) {
1762   auto result = tryScalarTypeFromJitType(type);
1763   TORCH_CHECK(
1764       result,
1765       "Add new condition, expected Float, Complex, Int, or Bool but got",
1766       type.str());
1767   return *result;
1768 }
1769 
1770 // Attempt to find the correct supertype of the two types `t1` and `t2`.
1771 // If no supertype is found, then nullopt will be returned if
1772 // `default_to_union` is false, and `Union[t1, t2]` will be returned
1773 // if it is true. If `t1 == t2`, or `t1` is a type refinement of `t2`,
1774 // then `t2` will be returned (and vice versa).
1775 //
1776 // Two different tensortypes will return dynamic.
1777 //
1778 // Currently we chose not to support returning a NumberType for
1779 // two types from the set of {FloatType, IntType, ComplexType}, because
1780 // there is a lack of operator support for NumberType.
1781 //
1782 // If `type_hint` is an `InterfaceType`, then we can use that as a
1783 // potential supertype for `ClassType`s in the list. Otherwise, we have
1784 // no way to find and use some common interface type
1785 TORCH_API std::optional<TypePtr> unifyTypes(
1786     const TypePtr& t1,
1787     const TypePtr& t2,
1788     bool default_to_union = false,
1789     const TypePtr& type_hint = nullptr);
1790 
1791 TORCH_API std::optional<TypePtr> unifyTypeList(
1792     at::ArrayRef<TypePtr> elements,
1793     std::ostream& why_not,
1794     bool default_to_union = false,
1795     const TypePtr& type_hint = nullptr);
1796 
1797 namespace detail {
1798 template <typename T>
1799 struct getTypePtr_ final {
callfinal1800   static decltype(auto) call() {
1801     return ([]() {
1802       try {
1803         return getCustomClassType<T>();
1804       } catch(const c10::Error&) {
1805         TORCH_CHECK(
1806             false,
1807             "Type ",
1808             c10::util::get_fully_qualified_type_name<T>(),
1809             " could not be converted to any of the known types."
1810         );
1811       }
1812     }());
1813   }
1814 };
1815 
1816 template <typename T, bool fake>
1817 struct getMaybeFakeTypePtr_ final {
callfinal1818   static decltype(auto) call() {
1819     return getTypePtr_<T>::call();
1820   }
1821 };
1822 
1823 template <>
1824 struct getTypePtr_<at::IValue> final {
1825   static decltype(auto) call() {
1826     return AnyType::get();
1827   }
1828 };
1829 
1830 template <>
1831 struct getTypePtr_<at::Tensor> final {
1832   static decltype(auto) call() {
1833     return TensorType::get();
1834   }
1835 };
1836 template <>
1837 struct getTypePtr_<c10::Storage> final {
1838   static decltype(auto) call() {
1839     return StorageType::get();
1840   }
1841 };
1842 template <>
1843 struct getTypePtr_<c10::Stream> final {
1844   static decltype(auto) call() {
1845     return StreamObjType::get();
1846   }
1847 };
1848 template <>
1849 struct getTypePtr_<double> final {
1850   static decltype(auto) call() {
1851     return FloatType::get();
1852   }
1853 };
1854 template <>
1855 struct getTypePtr_<c10::complex<double>> final {
1856   static decltype(auto) call() {
1857     return ComplexType::get();
1858   }
1859 };
1860 template <>
1861 struct getTypePtr_<int64_t> final {
1862   static decltype(auto) call() {
1863     return IntType::get();
1864   }
1865 };
1866 
1867 template <>
1868 struct getTypePtr_<DeviceIndex> final {
1869   static decltype(auto) call() {
1870     return IntType::get();
1871   }
1872 };
1873 
1874 template <>
1875 struct getMaybeFakeTypePtr_<SymInt, false> final {
1876   static decltype(auto) call() {
1877     return SymIntType::get();
1878   }
1879 };
1880 template <>
1881 struct getMaybeFakeTypePtr_<SymInt, true> final {
1882   static decltype(auto) call() {
1883     return IntType::get();
1884   }
1885 };
1886 
1887 template <>
1888 struct getMaybeFakeTypePtr_<SymFloat, false> final {
1889   static decltype(auto) call() {
1890     return SymFloatType::get();
1891   }
1892 };
1893 template <>
1894 struct getMaybeFakeTypePtr_<SymFloat, true> final {
1895   static decltype(auto) call() {
1896     return FloatType::get();
1897   }
1898 };
1899 
1900 template <>
1901 struct getMaybeFakeTypePtr_<SymBool, false> final {
1902   static decltype(auto) call() {
1903     return SymBoolType::get();
1904   }
1905 };
1906 template <>
1907 struct getMaybeFakeTypePtr_<SymBool, true> final {
1908   static decltype(auto) call() {
1909     return BoolType::get();
1910   }
1911 };
1912 
1913 template <>
1914 struct getTypePtr_<c10::Device> final {
1915   static decltype(auto) call() {
1916     return DeviceObjType::get();
1917   }
1918 };
1919 template <>
1920 struct getTypePtr_<bool> final {
1921   static decltype(auto) call() {
1922     return BoolType::get();
1923   }
1924 };
1925 template <>
1926 struct getTypePtr_<at::Scalar> final {
1927   static decltype(auto) call() {
1928     return NumberType::get();
1929   }
1930 };
1931 template <>
1932 struct getTypePtr_<c10::QScheme> final {
1933   static decltype(auto) call() {
1934     return QSchemeType::get();
1935   }
1936 };
1937 template <>
1938 struct getTypePtr_<at::Generator> final {
1939   static decltype(auto) call() {
1940     return TypeFactory::create<OptionalType>(
1941         TypeFactory::get<GeneratorType>());
1942   }
1943 };
1944 template <>
1945 struct getTypePtr_<std::string> final {
1946   static decltype(auto) call() {
1947     return StringType::get();
1948   }
1949 };
1950 template <>
1951 struct getTypePtr_<c10::string_view> final {
1952   static decltype(auto) call() {
1953     return StringType::get();
1954   }
1955 };
1956 template <>
1957 struct getTypePtr_<at::Dimname> final {
1958   static decltype(auto) call() {
1959     return StringType::get();
1960   }
1961 };
1962 template <class T, bool fake>
1963 struct getMaybeFakeTypePtr_<std::vector<T>, fake> final {
1964   static const auto& call() {
1965     static auto inner_type = getMaybeFakeTypePtr_<T, fake>::call();
1966     // The "per vector<T>" static singleton needs to live in a .cpp file,
1967     // otherwise we'll end up with one singleton instance per shared library.
1968     static auto type = ListType::get("vector", inner_type);
1969     return type;
1970   }
1971 };
1972 template <class T, bool fake>
1973 struct getMaybeFakeTypePtr_<c10::ArrayRef<T>, fake> final {
1974   static const auto& call() {
1975     static auto inner_type = getMaybeFakeTypePtr_<T, fake>::call();
1976     // The "per ArrayRef<T>" static singleton needs to live in a .cpp file,
1977     // otherwise we'll end up with one singleton instance per shared library.
1978     static auto type = ListType::get("ArrayRef", inner_type);
1979     return type;
1980   }
1981 };
1982 template <bool fake>
1983 struct getMaybeFakeTypePtr_<c10::SymIntArrayRef, fake> final {
1984   static const auto& call() {
1985     static auto type = ListType::create(getMaybeFakeTypePtr_<c10::SymInt, fake>::call());
1986     return type;
1987   }
1988 };
1989 template <class T, bool fake>
1990 struct getMaybeFakeTypePtr_<c10::List<T>, fake> final {
1991   static const auto& call() {
1992     static auto inner_type = getMaybeFakeTypePtr_<T, fake>::call();
1993     // The "per List<T>" static singleton needs to live in a .cpp file,
1994     // otherwise we'll end up with one singleton instance per shared library.
1995     static auto type = ListType::get("List", inner_type);
1996     return type;
1997   }
1998 };
1999 template <class T, bool fake>
2000 struct getMaybeFakeTypePtr_<c10::IListRef<T>, fake> final {
2001   static const auto& call() {
2002     static auto inner_type = getMaybeFakeTypePtr_<T, fake>::call();
2003     static auto type = ListType::get("List", inner_type);
2004     return type;
2005   }
2006 };
2007 template <class T, size_t N, bool fake>
2008 struct getMaybeFakeTypePtr_<std::array<T, N>, fake> final {
2009   static const auto& call() {
2010     static auto inner_type = getMaybeFakeTypePtr_<T, fake>::call();
2011     // The "per array<T, N>" static singleton needs to live in a .cpp file,
2012     // otherwise we'll end up with one singleton instance per shared library.
2013     // (Concatenating the length onto the end of the string because we want a unique
2014     // type_ptr created for every std::array<T, N> type).
2015     static auto type = ListType::get(std::string("array") + std::to_string(N), inner_type);
2016     return type;
2017   }
2018 };
2019 template <class K, class V, bool fake>
2020 struct getMaybeFakeTypePtr_<std::unordered_map<K, V>, fake> final {
2021   static const auto& call() {
2022     static auto inner_key_type = getMaybeFakeTypePtr_<K, fake>::call();
2023     static auto inner_val_type = getMaybeFakeTypePtr_<V, fake>::call();
2024     // The "per unordered_map<K, V>" static singleton needs to live in a .cpp file,
2025     // otherwise we'll end up with one singleton instance per shared library.
2026     static auto type = DictType::get("unordered_map", inner_key_type, inner_val_type);
2027     return type;
2028   }
2029 };
2030 template <class K, class V, bool fake>
2031 struct getMaybeFakeTypePtr_<c10::Dict<K, V>, fake> final {
2032   static const auto& call() {
2033     static auto inner_key_type = getMaybeFakeTypePtr_<K, fake>::call();
2034     static auto inner_val_type = getMaybeFakeTypePtr_<V, fake>::call();
2035     // The "per Dict<K, V>" static singleton needs to live in a .cpp file,
2036     // otherwise we'll end up with one singleton instance per shared library.
2037     static auto type = DictType::get("Dict", inner_key_type, inner_val_type);
2038     return type;
2039   }
2040 };
2041 
2042 template <class T, bool fake>
2043 struct getMaybeFakeTypePtr_<std::optional<T>, fake> final {
2044   static const auto& call() {
2045     static auto inner_type = getMaybeFakeTypePtr_<T, fake>::call();
2046     // The "per std::optional<T>" static singleton needs to live in a .cpp file,
2047     // otherwise we'll end up with one singleton instance per shared library.
2048     static auto type = OptionalType::get(inner_type);
2049     return type;
2050   }
2051 };
2052 
2053 
2054 template<>
2055 struct getTypePtr_<at::OptionalIntArrayRef> final {
2056   static const auto& call() {
2057     static auto inner_type = getMaybeFakeTypePtr_<IntArrayRef, false>::call();
2058     // The "per std::optional<T>" static singleton needs to live in a .cpp file,
2059     // otherwise we'll end up with one singleton instance per shared library.
2060     static auto type = OptionalType::get(inner_type);
2061     return type;
2062   }
2063 };
2064 
2065 template <bool fake>
2066 struct getMaybeFakeTypePtr_<at::OptionalSymIntArrayRef, fake> final {
2067   static const auto& call() {
2068     // The "per std::optional<T>" static singleton needs to live in a .cpp file,
2069     // otherwise we'll end up with one singleton instance per shared library.
2070     static auto inner_type = getMaybeFakeTypePtr_<SymIntArrayRef, fake>::call();
2071     static auto type = OptionalType::get(inner_type);
2072     return type;
2073   }
2074 };
2075 
2076 template <class... Contained, bool fake>
2077 struct getMaybeFakeTypePtr_<std::tuple<Contained...>, fake> final {
2078   static const auto& call() {
2079     static auto type = ([]() {
2080       std::vector<TypePtr> contained_types = {
2081         (getMaybeFakeTypePtr_<Contained, fake>::call())...
2082       };
2083       return TupleType::create(std::move(contained_types));
2084     })();
2085     return type;
2086   }
2087 };
2088 template <>
2089 struct getTypePtr_<void> final {
2090   static decltype(auto) call() {
2091     return NoneType::get();
2092   }
2093 };
2094 } // namespace detail
2095 template <class T>
2096 inline decltype(auto) getTypePtr() {
2097   // TODO: static_assert that a templated function exists, and throw a friendly
2098   // error message if not
2099   return detail::getMaybeFakeTypePtr_<T, false>::call();
2100 }
2101 
2102 template <class T>
2103 inline TypePtr getTypePtrCopy() {
2104   // TODO: static_assert that a templated function exists, and throw a friendly
2105   // error message if not
2106   return getTypePtr<T>();
2107 }
2108 
2109 template <class T>
2110 inline decltype(auto) getFakeTypePtr() {
2111   return detail::getMaybeFakeTypePtr_<T, true>::call();
2112 }
2113 
2114 template <class T>
2115 inline TypePtr getFakeTypePtrCopy() {
2116   return getFakeTypePtr<T>();
2117 }
2118 
2119 using TypeEnv = std::unordered_map<std::string, TypePtr>;
2120 struct MatchTypeReturn {
2121   MatchTypeReturn(std::string reason) : reason_(std::move(reason)) {}
2122   static MatchTypeReturn Success() {
2123     return MatchTypeReturn();
2124   }
2125   bool success() const {
2126     return !reason_.has_value();
2127   }
2128   const std::string& reason() const {
2129     return reason_.value();
2130   }
2131 
2132  private:
2133   MatchTypeReturn()
2134   : reason_(std::nullopt) {}
2135   std::optional<std::string> reason_; // is there is no match, this contains the reason
2136 };
2137 
2138 // attempt to match the type variables in formal to actual, adding them to type_env.
2139 // If no match is possible this returns a MatchTypeReturn with r.success() == false
2140 // and a r.reason() that describes why it could not match.
2141 // note: It is possible to successfully match a formal, but for type variables
2142 // in the formal to still not be defined. In particular, None matches Optional[T]
2143 // but does not define the value of T.
2144 TORCH_API MatchTypeReturn
2145 matchTypeVariables(const TypePtr& formal, const TypePtr& actual, TypeEnv& type_env);
2146 
2147 // replace type variables appearing in `type` with the values in
2148 // `type_env`. Returns nullptr if a variable used in `type`
2149 // does not appear in `type_env`
2150 TORCH_API TypePtr tryEvalTypeVariables(const TypePtr& type, TypeEnv& type_env);
2151 
2152 TORCH_API bool elementTypeCanBeInferredFromMembers(const TypePtr& elem_type);
2153 
2154 struct InterfaceType;
2155 using InterfaceTypePtr = std::shared_ptr<InterfaceType>;
2156 
2157 // Interfaces are a list of abstract methods that a class might meet.
2158 // If a class provides those methods, it implicitly meets the interface.
2159 
2160 // Subtype relations for Interface with ClassType:
2161 // lhs (ClassType or InterfaceType) is a subtype of rhs if:
2162 // 1. lhs methods are a superset of rhs methods
2163 // 2. if rhs is module interface, the lhs must be module interface or module itself
2164 struct TORCH_API InterfaceType : public NamedType {
2165   static InterfaceTypePtr create(
2166       QualifiedName qualifiedName, bool is_module=false);
2167 
2168   bool equals(const Type& rhs) const override {
2169     if (auto user_rhs = rhs.castRaw<InterfaceType>()) {
2170       return isSubTypeImpl(*this, *user_rhs, nullptr) &&
2171           isSubTypeImpl(*user_rhs, *this, nullptr);
2172     }
2173     return false;
2174   }
2175 
2176   std::string str() const override {
2177     return std::string("InterfaceType<") + name()->name() + ">";
2178   }
2179 
2180   bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
2181 
2182   // try to find a method of this interface,
2183   // returns nullptr if not found.
2184   const FunctionSchema* getMethod(const std::string& name) const;
2185   void addMethod(FunctionSchema schema);
2186   const std::vector<FunctionSchema>& methods() const {
2187     return *methods_;
2188   }
2189 
2190   bool is_module() const override{
2191     return is_module_;
2192   }
2193   static const TypeKind Kind = TypeKind::InterfaceType;
2194   ~InterfaceType() override;
2195  private:
2196   InterfaceType(QualifiedName name, bool is_module);
2197   static bool isSubTypeImpl(
2198       const InterfaceType& lhs,
2199       const InterfaceType& rhs,
2200       std::ostream* why_not);
2201 
2202   std::string annotation_str_impl(C10_UNUSED const TypePrinter& printer = nullptr) const override {
2203     return name()->qualifiedName();
2204   }
2205 
2206   // shared_ptr so that this header does not have to depend on
2207   // FunctionSchema.h
2208   std::shared_ptr<std::vector<FunctionSchema>> methods_;
2209   // flag to distinguish if it's an interface type from a module or not
2210   bool is_module_;
2211 };
2212 
2213 template <TypeKind K>
2214 struct EnumerationType : public Type {
2215 static const TypeKind Kind = K;
2216 
2217 bool equals(const Type& rhs) const override {
2218   return rhs.kind() == kind();
2219 }
2220 
2221 protected:
2222 EnumerationType() : Type(Kind) {}
2223 };
2224 
2225 // WARNING: These enumeration types below DO NOT actually get parsed out
2226 // from the logical schema strings, instead they are mapped as ints.  To
2227 // observe these types, use real_type() instead of type() on Argument
2228 
2229 struct ScalarTypeType;
2230 using ScalarTypeTypePtr = SingletonTypePtr<ScalarTypeType>;
2231 struct TORCH_API ScalarTypeType : public EnumerationType<TypeKind::ScalarTypeType> {
2232 std::string str() const override {
2233 return "ScalarType";
2234 }
2235 static const TypeKind Kind = TypeKind::ScalarTypeType;
2236 // global singleton
2237 static ScalarTypeTypePtr get();
2238 
2239 private:
2240 ScalarTypeType() : EnumerationType() {}
2241 };
2242 
2243 struct MemoryFormatType;
2244 using MemoryFormatTypePtr = SingletonTypePtr<MemoryFormatType>;
2245 struct TORCH_API MemoryFormatType : public EnumerationType<TypeKind::MemoryFormatType> {
2246 std::string str() const override {
2247 return "MemoryFormat";
2248 }
2249 static const TypeKind Kind = TypeKind::MemoryFormatType;
2250 // global singleton
2251 static MemoryFormatTypePtr get();
2252 
2253 private:
2254 MemoryFormatType() : EnumerationType() {}
2255 };
2256 
2257 struct LayoutType;
2258 using LayoutTypePtr = SingletonTypePtr<LayoutType>;
2259 struct TORCH_API LayoutType : public EnumerationType<TypeKind::LayoutType> {
2260 std::string str() const override {
2261 return "Layout";
2262 }
2263 static const TypeKind Kind = TypeKind::LayoutType;
2264 // global singleton
2265 static LayoutTypePtr get();
2266 
2267 private:
2268 LayoutType() : EnumerationType() {}
2269 };
2270 
2271 namespace detail {
2272 template <>
2273 struct getMaybeFakeTypePtr_<c10::ScalarType, false> final {
2274   static decltype(auto) call() {
2275     return ScalarTypeType::get();
2276   }
2277 };
2278 template <>
2279 struct getMaybeFakeTypePtr_<c10::Layout, false> final {
2280   static decltype(auto) call() {
2281     return LayoutType::get();
2282   }
2283 };
2284 template <>
2285 struct getMaybeFakeTypePtr_<c10::MemoryFormat, false> final {
2286   static decltype(auto) call() {
2287     return MemoryFormatType::get();
2288   }
2289 };
2290 template <>
2291 struct getMaybeFakeTypePtr_<c10::ScalarType, true> final {
2292   static decltype(auto) call() {
2293     return IntType::get();
2294   }
2295 };
2296 template <>
2297 struct getMaybeFakeTypePtr_<c10::Layout, true> final {
2298   static decltype(auto) call() {
2299     return IntType::get();
2300   }
2301 };
2302 template <>
2303 struct getMaybeFakeTypePtr_<c10::MemoryFormat, true> final {
2304   static decltype(auto) call() {
2305     return IntType::get();
2306   }
2307 };
2308 } // namespace detail
2309 
2310 // the common supertype of all lists,
2311 // List[T] <: AnyList for all T
2312 struct AnyListType;
2313 using AnyListTypePtr = SingletonTypePtr<AnyListType>;
2314 struct TORCH_API AnyListType : public Type {
2315   bool equals(const Type& rhs) const override {
2316     return rhs.kind() == kind();
2317   }
2318   std::string str() const override {
2319     return "list";
2320   }
2321   static const TypeKind Kind = TypeKind::AnyListType;
2322   // global singleton
2323   static AnyListTypePtr get();
2324 private:
2325   AnyListType()
2326   : Type(TypeKind::AnyListType) {}
2327 };
2328 
2329 // the common supertype of all tuples,
2330 // Tuple[T...] <: AnyTuple for all T
2331 struct AnyTupleType;
2332 using AnyTupleTypePtr = SingletonTypePtr<AnyTupleType>;
2333 struct TORCH_API AnyTupleType : public Type {
2334   bool equals(const Type& rhs) const override {
2335     return rhs.kind() == kind();
2336   }
2337 
2338   std::string str() const override {
2339     return "tuple";
2340   }
2341   static const TypeKind Kind = TypeKind::AnyTupleType;
2342 
2343   // global singleton
2344   static AnyTupleTypePtr get();
2345 private:
2346   AnyTupleType()
2347   : Type(TypeKind::AnyTupleType) {}
2348 };
2349 
2350 // the common supertype of all classes,
2351 // ClassType <: AnyClassType for all classes
2352 struct AnyClassType;
2353 using AnyClassTypePtr = SingletonTypePtr<AnyClassType>;
2354 struct TORCH_API AnyClassType : public Type {
2355   bool equals(const Type& rhs) const override {
2356     return rhs.kind() == kind();
2357   }
2358   std::string str() const override {
2359     return "AnyClassType";
2360   }
2361   static const TypeKind Kind = TypeKind::AnyClassType;
2362   // global singleton
2363   static AnyClassTypePtr get();
2364 private:
2365   AnyClassType()
2366   : Type(TypeKind::AnyClassType) {}
2367 };
2368 
2369 template<>
2370 inline typename detail::CastReturnType<NamedType>::type Type::cast<NamedType>() {
2371   if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType ||
2372       kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) {
2373     return std::static_pointer_cast<NamedType>(static_cast<NamedType *>(this)->shared_from_this());
2374   }
2375   return nullptr;
2376 }
2377 
2378 template<>
2379 inline typename detail::CastConstReturnType<NamedType>::type Type::cast<NamedType>() const {
2380   if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType ||
2381       kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) {
2382     return std::static_pointer_cast<const NamedType>(static_cast<const NamedType *>(this)->shared_from_this());
2383   }
2384   return nullptr;
2385 }
2386 
2387 template<>
2388 inline const NamedType* Type::castRaw<NamedType>() const {
2389   if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType ||
2390       kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) {
2391     return static_cast<const NamedType*>(this);
2392   }
2393   return nullptr;
2394 }
2395 
2396 // Used as a return type when inferring the IValue type of a Python object.
2397 struct InferredType {
2398   /* implicit */ InferredType(TypePtr type) : type_(std::move(type)) {}
2399   /* implicit */ InferredType(std::string reason)
2400       : type_(nullptr), reason_(std::move(reason)) {}
2401   TypePtr type() const {
2402     TORCH_INTERNAL_ASSERT(
2403         type_,
2404         "Tried to get the type from an InferredType but the type is null. ",
2405         "Reason: ",
2406         reason_);
2407     return type_;
2408   }
2409   bool success() const {
2410     return type_ != nullptr;
2411   }
2412   const std::string& reason() const {
2413     TORCH_INTERNAL_ASSERT(!type_);
2414     return reason_;
2415   }
2416 
2417 private:
2418   TypePtr type_;
2419   std::string reason_;
2420 };
2421 
2422 TORCH_API bool containsAnyType(const TypePtr& type);
2423 
2424 } // namespace c10
2425