1 #include <ATen/core/function_schema.h>
2 #include <ATen/core/functional.h>
3 #include <ATen/core/jit_type.h>
4 #include <ATen/core/type_factory.h>
5 #include <ATen/record_function.h>
6 #include <c10/util/flat_hash_map.h>
7 #include <torch/custom_class.h>
8 #include <torch/custom_class_detail.h>
9
10 #include <unordered_map>
11
12 namespace c10 {
13
14 static ska::flat_hash_map<std::type_index, c10::ClassTypePtr>&
getCustomClassTypeMap()15 getCustomClassTypeMap() {
16 static ska::flat_hash_map<std::type_index, c10::ClassTypePtr> tmap;
17 return tmap;
18 }
19
getCustomClassTypeImpl(const std::type_index & tindex)20 c10::ClassTypePtr getCustomClassTypeImpl(const std::type_index& tindex) {
21 auto& tmap = c10::getCustomClassTypeMap();
22 auto res = tmap.find(tindex);
23 if (C10_UNLIKELY(res == tmap.end())) {
24 // type_index is not guaranteed to be unique across shared libraries on some
25 // platforms For example see
26 // https://github.com/llvm-mirror/libcxx/blob/78d6a7767ed57b50122a161b91f59f19c9bd0d19/include/typeinfo#L133
27 // Also, this is not the case if RTLD_LOCAL option is used, see
28 // https://github.com/pybind/pybind11/blob/f791dc8648e1f6ec33f402d679b6b116a76d4e1b/include/pybind11/detail/internals.h#L101-L106
29 // Take a slow path of iterating over all registered types and compare their
30 // names
31 auto class_name = std::string(tindex.name());
32 for (const auto& it : tmap) {
33 if (class_name == it.first.name()) {
34 // Do not modify existing type map here as this template is supposed to
35 // be called only once per type from getCustomClassTypeImpl()
36 return it.second;
37 }
38 }
39 TORCH_CHECK(
40 false,
41 "Can't find class id in custom class type map for ",
42 tindex.name());
43 }
44 return res->second;
45 }
46
47 } // namespace c10
48
49 namespace torch {
50
51 namespace detail {
52
53 #if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE
record_custom_class(std::string name)54 void record_custom_class(std::string name) {
55 RECORD_FUNCTION_WITH_SCOPE(
56 at::RecordScope::CUSTOM_CLASS,
57 std::move(name),
58 c10::ArrayRef<const c10::IValue>{});
59 }
60 #endif
61
62 } // namespace detail
63
customClasses()64 static std::unordered_map<std::string, at::ClassTypePtr>& customClasses() {
65 static std::unordered_map<std::string, at::ClassTypePtr> customClasses;
66 return customClasses;
67 }
68
registerCustomClass(at::ClassTypePtr class_type)69 void registerCustomClass(at::ClassTypePtr class_type) {
70 TORCH_INTERNAL_ASSERT(class_type->name());
71 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
72 auto name = class_type->name()->qualifiedName();
73 TORCH_CHECK(
74 !customClasses().count(name),
75 "Custom class with name ",
76 name,
77 " is already registered. Ensure that registration with torch::class_ is only called once.");
78 customClasses()[name] = std::move(class_type);
79 }
80
getCustomClass(const std::string & class_name)81 at::ClassTypePtr getCustomClass(const std::string& class_name) {
82 auto ret =
83 customClasses().count(class_name) ? customClasses()[class_name] : nullptr;
84 if (ret) {
85 RECORD_CUSTOM_CLASS(class_name);
86 }
87 return ret;
88 }
89
getAllCustomClassesNames()90 const std::unordered_set<std::string> getAllCustomClassesNames() {
91 std::unordered_set<std::string> ret;
92 for (const auto& kv : customClasses()) {
93 ret.insert(kv.first);
94 }
95 return ret;
96 }
97
isCustomClass(const c10::IValue & v)98 bool isCustomClass(const c10::IValue& v) {
99 return v.isObject() && v.toObject()->type()->name() &&
100 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
101 getCustomClass(v.toObject()->type()->name()->qualifiedName());
102 }
103
customClassMethods()104 static std::vector<std::unique_ptr<jit::Function>>& customClassMethods() {
105 static std::vector<std::unique_ptr<jit::Function>> customClassMethods;
106 return customClassMethods;
107 }
108
registerCustomClassMethod(std::unique_ptr<jit::Function> fn)109 void registerCustomClassMethod(std::unique_ptr<jit::Function> fn) {
110 customClassMethods().emplace_back(std::move(fn));
111 }
112
customClassSchemasForBCCheck()113 std::vector<c10::FunctionSchema> customClassSchemasForBCCheck() {
114 auto& methods = customClassMethods();
115 return c10::fmap(methods, [](const std::unique_ptr<jit::Function>& fn) {
116 return fn->getSchema();
117 });
118 }
119
120 namespace detail {
class_base(const std::string & namespaceName,const std::string & className,std::string doc_string,const std::type_info & intrusivePtrClassTypeid,const std::type_info & taggedCapsuleClassTypeid)121 class_base::class_base(
122 const std::string& namespaceName,
123 const std::string& className,
124 std::string doc_string,
125 const std::type_info& intrusivePtrClassTypeid,
126 const std::type_info& taggedCapsuleClassTypeid)
127 : qualClassName(
128 "__torch__.torch.classes." + namespaceName + '.' + className),
129 classTypePtr(at::ClassType::create(
130 c10::QualifiedName(qualClassName),
131 std::weak_ptr<jit::CompilationUnit>(),
132 /*is_module=*/false,
133 std::move(doc_string))) {
134 detail::checkValidIdent(namespaceName, "Namespace name");
135 detail::checkValidIdent(className, "Class name");
136 classTypePtr->addAttribute(
137 "capsule", c10::TypeFactory::get<c10::CapsuleType>());
138 c10::getCustomClassTypeMap().insert(
139 {std::type_index(intrusivePtrClassTypeid), classTypePtr});
140 c10::getCustomClassTypeMap().insert(
141 {std::type_index(taggedCapsuleClassTypeid), classTypePtr});
142
143 registerCustomClass(classTypePtr);
144 }
145
withNewArguments(const c10::FunctionSchema & schema,std::initializer_list<arg> default_args)146 c10::FunctionSchema class_base::withNewArguments(
147 const c10::FunctionSchema& schema,
148 std::initializer_list<arg> default_args) {
149 const auto& old_args = schema.arguments();
150 std::vector<c10::Argument> new_args;
151 new_args.reserve(old_args.size());
152
153 new_args.emplace_back(old_args[0]);
154 // Skip self.
155 size_t argIdx = 1;
156 for (const auto& default_arg : default_args) {
157 auto& old_arg = old_args[argIdx++];
158 new_args.emplace_back(
159 default_arg.name_,
160 old_arg.type(),
161 old_arg.real_type(),
162 old_arg.N(),
163 default_arg.value_);
164 }
165 return schema.cloneWithArguments(std::move(new_args));
166 }
167
168 } // namespace detail
169 } // namespace torch
170