xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/custom_class.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/function_schema.h>
2 #include <ATen/core/functional.h>
3 #include <ATen/core/jit_type.h>
4 #include <ATen/core/type_factory.h>
5 #include <ATen/record_function.h>
6 #include <c10/util/flat_hash_map.h>
7 #include <torch/custom_class.h>
8 #include <torch/custom_class_detail.h>
9 
10 #include <unordered_map>
11 
12 namespace c10 {
13 
14 static ska::flat_hash_map<std::type_index, c10::ClassTypePtr>&
getCustomClassTypeMap()15 getCustomClassTypeMap() {
16   static ska::flat_hash_map<std::type_index, c10::ClassTypePtr> tmap;
17   return tmap;
18 }
19 
getCustomClassTypeImpl(const std::type_index & tindex)20 c10::ClassTypePtr getCustomClassTypeImpl(const std::type_index& tindex) {
21   auto& tmap = c10::getCustomClassTypeMap();
22   auto res = tmap.find(tindex);
23   if (C10_UNLIKELY(res == tmap.end())) {
24     // type_index is not guaranteed to be unique across shared libraries on some
25     // platforms For example see
26     // https://github.com/llvm-mirror/libcxx/blob/78d6a7767ed57b50122a161b91f59f19c9bd0d19/include/typeinfo#L133
27     // Also, this is not the case if RTLD_LOCAL option is used, see
28     // https://github.com/pybind/pybind11/blob/f791dc8648e1f6ec33f402d679b6b116a76d4e1b/include/pybind11/detail/internals.h#L101-L106
29     // Take a slow path of iterating over all registered types and compare their
30     // names
31     auto class_name = std::string(tindex.name());
32     for (const auto& it : tmap) {
33       if (class_name == it.first.name()) {
34         // Do not modify existing type map here as this template is supposed to
35         // be called only once per type from getCustomClassTypeImpl()
36         return it.second;
37       }
38     }
39     TORCH_CHECK(
40         false,
41         "Can't find class id in custom class type map for ",
42         tindex.name());
43   }
44   return res->second;
45 }
46 
47 } // namespace c10
48 
49 namespace torch {
50 
51 namespace detail {
52 
53 #if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE
record_custom_class(std::string name)54 void record_custom_class(std::string name) {
55   RECORD_FUNCTION_WITH_SCOPE(
56       at::RecordScope::CUSTOM_CLASS,
57       std::move(name),
58       c10::ArrayRef<const c10::IValue>{});
59 }
60 #endif
61 
62 } // namespace detail
63 
customClasses()64 static std::unordered_map<std::string, at::ClassTypePtr>& customClasses() {
65   static std::unordered_map<std::string, at::ClassTypePtr> customClasses;
66   return customClasses;
67 }
68 
registerCustomClass(at::ClassTypePtr class_type)69 void registerCustomClass(at::ClassTypePtr class_type) {
70   TORCH_INTERNAL_ASSERT(class_type->name());
71   // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
72   auto name = class_type->name()->qualifiedName();
73   TORCH_CHECK(
74       !customClasses().count(name),
75       "Custom class with name ",
76       name,
77       " is already registered. Ensure that registration with torch::class_ is only called once.");
78   customClasses()[name] = std::move(class_type);
79 }
80 
getCustomClass(const std::string & class_name)81 at::ClassTypePtr getCustomClass(const std::string& class_name) {
82   auto ret =
83       customClasses().count(class_name) ? customClasses()[class_name] : nullptr;
84   if (ret) {
85     RECORD_CUSTOM_CLASS(class_name);
86   }
87   return ret;
88 }
89 
getAllCustomClassesNames()90 const std::unordered_set<std::string> getAllCustomClassesNames() {
91   std::unordered_set<std::string> ret;
92   for (const auto& kv : customClasses()) {
93     ret.insert(kv.first);
94   }
95   return ret;
96 }
97 
isCustomClass(const c10::IValue & v)98 bool isCustomClass(const c10::IValue& v) {
99   return v.isObject() && v.toObject()->type()->name() &&
100       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
101       getCustomClass(v.toObject()->type()->name()->qualifiedName());
102 }
103 
customClassMethods()104 static std::vector<std::unique_ptr<jit::Function>>& customClassMethods() {
105   static std::vector<std::unique_ptr<jit::Function>> customClassMethods;
106   return customClassMethods;
107 }
108 
registerCustomClassMethod(std::unique_ptr<jit::Function> fn)109 void registerCustomClassMethod(std::unique_ptr<jit::Function> fn) {
110   customClassMethods().emplace_back(std::move(fn));
111 }
112 
customClassSchemasForBCCheck()113 std::vector<c10::FunctionSchema> customClassSchemasForBCCheck() {
114   auto& methods = customClassMethods();
115   return c10::fmap(methods, [](const std::unique_ptr<jit::Function>& fn) {
116     return fn->getSchema();
117   });
118 }
119 
120 namespace detail {
class_base(const std::string & namespaceName,const std::string & className,std::string doc_string,const std::type_info & intrusivePtrClassTypeid,const std::type_info & taggedCapsuleClassTypeid)121 class_base::class_base(
122     const std::string& namespaceName,
123     const std::string& className,
124     std::string doc_string,
125     const std::type_info& intrusivePtrClassTypeid,
126     const std::type_info& taggedCapsuleClassTypeid)
127     : qualClassName(
128           "__torch__.torch.classes." + namespaceName + '.' + className),
129       classTypePtr(at::ClassType::create(
130           c10::QualifiedName(qualClassName),
131           std::weak_ptr<jit::CompilationUnit>(),
132           /*is_module=*/false,
133           std::move(doc_string))) {
134   detail::checkValidIdent(namespaceName, "Namespace name");
135   detail::checkValidIdent(className, "Class name");
136   classTypePtr->addAttribute(
137       "capsule", c10::TypeFactory::get<c10::CapsuleType>());
138   c10::getCustomClassTypeMap().insert(
139       {std::type_index(intrusivePtrClassTypeid), classTypePtr});
140   c10::getCustomClassTypeMap().insert(
141       {std::type_index(taggedCapsuleClassTypeid), classTypePtr});
142 
143   registerCustomClass(classTypePtr);
144 }
145 
withNewArguments(const c10::FunctionSchema & schema,std::initializer_list<arg> default_args)146 c10::FunctionSchema class_base::withNewArguments(
147     const c10::FunctionSchema& schema,
148     std::initializer_list<arg> default_args) {
149   const auto& old_args = schema.arguments();
150   std::vector<c10::Argument> new_args;
151   new_args.reserve(old_args.size());
152 
153   new_args.emplace_back(old_args[0]);
154   // Skip self.
155   size_t argIdx = 1;
156   for (const auto& default_arg : default_args) {
157     auto& old_arg = old_args[argIdx++];
158     new_args.emplace_back(
159         default_arg.name_,
160         old_arg.type(),
161         old_arg.real_type(),
162         old_arg.N(),
163         default_arg.value_);
164   }
165   return schema.cloneWithArguments(std::move(new_args));
166 }
167 
168 } // namespace detail
169 } // namespace torch
170