xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/enum_type.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/ivalue.h>
4 
5 #include <utility>
6 
7 namespace c10 {
8 
9 struct EnumType;
10 using EnumTypePtr = std::shared_ptr<EnumType>;
11 using EnumNameValue = std::pair<std::string, IValue>;
12 struct TORCH_API EnumType : public NamedType {
13   friend struct Type;
14   static const TypeKind Kind = TypeKind::EnumType;
15 
createEnumType16   static EnumTypePtr create(
17       const c10::QualifiedName& qualified_class_name,
18       TypePtr value,
19       std::vector<EnumNameValue> enum_names_values,
20       std::weak_ptr<::torch::jit::CompilationUnit> cu) {
21     switch (value->kind()) {
22       case TypeKind::IntType:
23       case TypeKind::FloatType:
24       case TypeKind::StringType:
25         return EnumTypePtr(new EnumType(
26             qualified_class_name,
27             std::move(value),
28             std::move(enum_names_values),
29             std::move(cu)));
30       default:
31         AT_ERROR(
32             "Cannot create Enum with value type '",
33             value->str(),
34             "', only int, float and string are supported");
35     }
36   }
37 
strEnumType38   std::string str() const override {
39     return "Enum<" + annotation_str() + ">";
40   }
41 
repr_strEnumType42   std::string repr_str() const override {
43     return str();
44   }
45 
getValueTypeEnumType46   const TypePtr& getValueType() const {
47     return value_type_;
48   }
49 
equalsEnumType50   bool equals(const Type& rhs) const override {
51     if (auto* enum_rhs = rhs.castRaw<EnumType>()) {
52       return name().value() == enum_rhs->name().value() &&
53           *getValueType() == *(enum_rhs->getValueType()) &&
54           this->compilation_unit() == enum_rhs->compilation_unit();
55     }
56     return false;
57   }
58 
59   bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
60 
compilation_unitEnumType61   std::shared_ptr<const ::torch::jit::CompilationUnit> compilation_unit()
62       const {
63     auto cu = cu_.lock();
64     return cu;
65   }
66 
qualifiedClassNameEnumType67   const QualifiedName& qualifiedClassName() const {
68     return name().value();
69   }
70 
containedTypesEnumType71   at::ArrayRef<TypePtr> containedTypes() const override {
72     return value_type_;
73   }
74 
enumNamesValuesEnumType75   const at::ArrayRef<EnumNameValue> enumNamesValues() const {
76     return enum_names_values_;
77   }
78 
79  private:
EnumTypeEnumType80   EnumType(
81       c10::QualifiedName qualified_class_name,
82       TypePtr value_type,
83       std::vector<EnumNameValue> enum_names_values,
84       std::weak_ptr<torch::jit::CompilationUnit> cu)
85       : NamedType(TypeKind::EnumType, std::move(qualified_class_name)),
86         value_type_(std::move(value_type)),
87         enum_names_values_(std::move(enum_names_values)),
88         cu_(std::move(cu)) {}
89 
90   std::string annotation_str_impl(
91       C10_UNUSED const TypePrinter& printer = nullptr) const override {
92     const auto& n = name().value();
93     return n.qualifiedName();
94   }
95 
96   TypePtr value_type_;
97   std::vector<EnumNameValue> enum_names_values_;
98   std::weak_ptr<::torch::jit::CompilationUnit> cu_;
99 };
100 
101 } // namespace c10
102