xref: /aosp_15_r20/external/pytorch/torch/custom_class_detail.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h>
4 #include <ATen/core/function.h>
5 #include <c10/util/Metaprogramming.h>
6 #include <c10/util/TypeTraits.h>
7 #include <c10/util/irange.h>
8 
9 namespace torch {
10 
11 namespace detail {
12 /**
13  * In the Facebook internal build (using BUCK), this macro is enabled by
14  * passing in -c pt.enable_record_kernel_dtype=1 when building the tracer
15  * binary.
16  */
17 #if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE
18 TORCH_API void record_custom_class(std::string name);
19 
20 /**
21  * Record an instance of a custom class being loaded
22  * grab portion of string after final '.' from qualified name
23  * as this seemingly aligns with how users name their custom classes
24  * example: __torch__.torch.classes.xnnpack.Conv2dOpContext
25  */
26 #define RECORD_CUSTOM_CLASS(NAME) \
27   auto name = std::string(NAME);  \
28   detail::record_custom_class(name.substr(name.find_last_of(".") + 1));
29 #else
30 #define RECORD_CUSTOM_CLASS(NAME)
31 #endif
32 } // namespace detail
33 
34 /// This struct is used to represent default values for arguments
35 /// when registering methods for custom classes.
36 ///     static auto register_foo = torch::class_<Foo>("myclasses", "Foo")
37 ///       .def("myMethod", &Foo::myMethod, {torch::arg("name") = name});
38 struct arg {
39   // Static method for representing a default value of None. This is meant to
40   // be used like so:
41   //     torch::arg("name") = torch::arg::none
42   // and is identical to:
43   //     torch::arg("name") = IValue()
nonearg44   static c10::IValue none() {
45     return c10::IValue();
46   }
47 
48   // Explicit constructor.
argarg49   explicit arg(std::string name)
50       : name_(std::move(name)), value_(std::nullopt) {}
51   // Assignment operator. This enables the pybind-like syntax of
52   // torch::arg("name") = value.
53   arg& operator=(const c10::IValue& rhs) {
54     value_ = rhs;
55     return *this;
56   }
57 
58   // The name of the argument. This is copied to the schema; argument
59   // names cannot be extracted from the C++ declaration.
60   std::string name_;
61   // IValue's default constructor makes it None, which is not distinguishable
62   // from an actual, user-provided default value that is None. This boolean
63   // helps distinguish between the two cases.
64   std::optional<c10::IValue> value_;
65 };
66 
67 namespace detail {
68 
69 // Argument type utilities
70 template <class R, class...>
71 struct types {
72   using type = types;
73 };
74 
75 template <typename Method>
76 struct WrapMethod;
77 
78 template <typename R, typename CurrClass, typename... Args>
79 struct WrapMethod<R (CurrClass::*)(Args...)> {
80   WrapMethod(R (CurrClass::*m)(Args...)) : m(std::move(m)) {}
81 
82   R operator()(c10::intrusive_ptr<CurrClass> cur, Args... args) {
83     return c10::guts::invoke(m, *cur, args...);
84   }
85 
86   R (CurrClass::*m)(Args...);
87 };
88 
89 template <typename R, typename CurrClass, typename... Args>
90 struct WrapMethod<R (CurrClass::*)(Args...) const> {
91   WrapMethod(R (CurrClass::*m)(Args...) const) : m(std::move(m)) {}
92 
93   R operator()(c10::intrusive_ptr<CurrClass> cur, Args... args) {
94     return c10::guts::invoke(m, *cur, args...);
95   }
96 
97   R (CurrClass::*m)(Args...) const;
98 };
99 
100 // Adapter for different callable types
101 template <
102     typename CurClass,
103     typename Func,
104     std::enable_if_t<
105         std::is_member_function_pointer_v<std::decay_t<Func>>,
106         bool> = false>
107 WrapMethod<Func> wrap_func(Func f) {
108   return WrapMethod<Func>(std::move(f));
109 }
110 
111 template <
112     typename CurClass,
113     typename Func,
114     std::enable_if_t<
115         !std::is_member_function_pointer_v<std::decay_t<Func>>,
116         bool> = false>
117 Func wrap_func(Func f) {
118   return f;
119 }
120 
121 template <
122     class Functor,
123     bool AllowDeprecatedTypes,
124     size_t... ivalue_arg_indices>
125 typename c10::guts::infer_function_traits_t<Functor>::return_type
126 call_torchbind_method_from_stack(
127     Functor& functor,
128     jit::Stack& stack,
129     std::index_sequence<ivalue_arg_indices...>) {
130   (void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would
131                  // be unused and we have to silence the compiler warning.
132 
133   constexpr size_t num_ivalue_args = sizeof...(ivalue_arg_indices);
134 
135   using IValueArgTypes =
136       typename c10::guts::infer_function_traits_t<Functor>::parameter_types;
137   // TODO We shouldn't use c10::impl stuff directly here. We should use the
138   // KernelFunction API instead.
139   return (functor)(c10::impl::ivalue_to_arg<
140                    typename c10::impl::decay_if_not_tensor<
141                        c10::guts::typelist::
142                            element_t<ivalue_arg_indices, IValueArgTypes>>::type,
143                    AllowDeprecatedTypes>::
144                        call(torch::jit::peek(
145                            stack, ivalue_arg_indices, num_ivalue_args))...);
146 }
147 
148 template <class Functor, bool AllowDeprecatedTypes>
149 typename c10::guts::infer_function_traits_t<Functor>::return_type
150 call_torchbind_method_from_stack(Functor& functor, jit::Stack& stack) {
151   constexpr size_t num_ivalue_args =
152       c10::guts::infer_function_traits_t<Functor>::number_of_parameters;
153   return call_torchbind_method_from_stack<Functor, AllowDeprecatedTypes>(
154       functor, stack, std::make_index_sequence<num_ivalue_args>());
155 }
156 
157 template <class RetType, class Func>
158 struct BoxedProxy;
159 
160 template <class RetType, class Func>
161 struct BoxedProxy {
162   void operator()(jit::Stack& stack, Func& func) {
163     auto retval = call_torchbind_method_from_stack<Func, false>(func, stack);
164     constexpr size_t num_ivalue_args =
165         c10::guts::infer_function_traits_t<Func>::number_of_parameters;
166     torch::jit::drop(stack, num_ivalue_args);
167     stack.emplace_back(c10::ivalue::from(std::move(retval)));
168   }
169 };
170 
171 template <class Func>
172 struct BoxedProxy<void, Func> {
173   void operator()(jit::Stack& stack, Func& func) {
174     call_torchbind_method_from_stack<Func, false>(func, stack);
175     constexpr size_t num_ivalue_args =
176         c10::guts::infer_function_traits_t<Func>::number_of_parameters;
177     torch::jit::drop(stack, num_ivalue_args);
178     stack.emplace_back();
179   }
180 };
181 
182 inline bool validIdent(size_t i, char n) {
183   return isalpha(n) || n == '_' || (i > 0 && isdigit(n));
184 }
185 
186 inline void checkValidIdent(const std::string& str, const char* type) {
187   for (const auto i : c10::irange(str.size())) {
188     TORCH_CHECK(
189         validIdent(i, str[i]),
190         type,
191         " must be a valid Python/C++ identifier."
192         " Character '",
193         str[i],
194         "' at index ",
195         i,
196         " is illegal.");
197   }
198 }
199 
200 class TORCH_API class_base {
201  protected:
202   explicit class_base(
203       const std::string& namespaceName,
204       const std::string& className,
205       std::string doc_string,
206       const std::type_info& intrusivePtrClassTypeid,
207       const std::type_info& taggedCapsuleClass);
208 
209   static c10::FunctionSchema withNewArguments(
210       const c10::FunctionSchema& schema,
211       std::initializer_list<arg> default_args);
212   std::string qualClassName;
213   at::ClassTypePtr classTypePtr;
214 };
215 
216 } // namespace detail
217 
218 TORCH_API void registerCustomClass(at::ClassTypePtr class_type);
219 TORCH_API void registerCustomClassMethod(std::unique_ptr<jit::Function> method);
220 
221 // Given a qualified name (e.g. __torch__.torch.classes.Foo), return
222 // the ClassType pointer to the Type that describes that custom class,
223 // or nullptr if no class by that name was found.
224 TORCH_API at::ClassTypePtr getCustomClass(const std::string& name);
225 
226 // Given an IValue, return true if the object contained in that IValue
227 // is a custom C++ class, otherwise return false.
228 TORCH_API bool isCustomClass(const c10::IValue& v);
229 
230 // This API is for testing purposes ONLY. It should not be used in
231 // any load-bearing code.
232 TORCH_API std::vector<c10::FunctionSchema> customClassSchemasForBCCheck();
233 
234 namespace jit {
235 using ::torch::registerCustomClass;
236 using ::torch::registerCustomClassMethod;
237 } // namespace jit
238 
239 } // namespace torch
240