1 #pragma once 2 3 #include <ATen/core/ivalue.h> 4 5 #include <utility> 6 7 namespace c10 { 8 9 struct EnumType; 10 using EnumTypePtr = std::shared_ptr<EnumType>; 11 using EnumNameValue = std::pair<std::string, IValue>; 12 struct TORCH_API EnumType : public NamedType { 13 friend struct Type; 14 static const TypeKind Kind = TypeKind::EnumType; 15 createEnumType16 static EnumTypePtr create( 17 const c10::QualifiedName& qualified_class_name, 18 TypePtr value, 19 std::vector<EnumNameValue> enum_names_values, 20 std::weak_ptr<::torch::jit::CompilationUnit> cu) { 21 switch (value->kind()) { 22 case TypeKind::IntType: 23 case TypeKind::FloatType: 24 case TypeKind::StringType: 25 return EnumTypePtr(new EnumType( 26 qualified_class_name, 27 std::move(value), 28 std::move(enum_names_values), 29 std::move(cu))); 30 default: 31 AT_ERROR( 32 "Cannot create Enum with value type '", 33 value->str(), 34 "', only int, float and string are supported"); 35 } 36 } 37 strEnumType38 std::string str() const override { 39 return "Enum<" + annotation_str() + ">"; 40 } 41 repr_strEnumType42 std::string repr_str() const override { 43 return str(); 44 } 45 getValueTypeEnumType46 const TypePtr& getValueType() const { 47 return value_type_; 48 } 49 equalsEnumType50 bool equals(const Type& rhs) const override { 51 if (auto* enum_rhs = rhs.castRaw<EnumType>()) { 52 return name().value() == enum_rhs->name().value() && 53 *getValueType() == *(enum_rhs->getValueType()) && 54 this->compilation_unit() == enum_rhs->compilation_unit(); 55 } 56 return false; 57 } 58 59 bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; 60 compilation_unitEnumType61 std::shared_ptr<const ::torch::jit::CompilationUnit> compilation_unit() 62 const { 63 auto cu = cu_.lock(); 64 return cu; 65 } 66 qualifiedClassNameEnumType67 const QualifiedName& qualifiedClassName() const { 68 return name().value(); 69 } 70 containedTypesEnumType71 at::ArrayRef<TypePtr> containedTypes() const override { 72 return value_type_; 73 } 74 enumNamesValuesEnumType75 const at::ArrayRef<EnumNameValue> enumNamesValues() const { 76 return enum_names_values_; 77 } 78 79 private: EnumTypeEnumType80 EnumType( 81 c10::QualifiedName qualified_class_name, 82 TypePtr value_type, 83 std::vector<EnumNameValue> enum_names_values, 84 std::weak_ptr<torch::jit::CompilationUnit> cu) 85 : NamedType(TypeKind::EnumType, std::move(qualified_class_name)), 86 value_type_(std::move(value_type)), 87 enum_names_values_(std::move(enum_names_values)), 88 cu_(std::move(cu)) {} 89 90 std::string annotation_str_impl( 91 C10_UNUSED const TypePrinter& printer = nullptr) const override { 92 const auto& n = name().value(); 93 return n.qualifiedName(); 94 } 95 96 TypePtr value_type_; 97 std::vector<EnumNameValue> enum_names_values_; 98 std::weak_ptr<::torch::jit::CompilationUnit> cu_; 99 }; 100 101 } // namespace c10 102