1 #pragma once 2 3 #include <ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h> 4 #include <ATen/core/function.h> 5 #include <c10/util/Metaprogramming.h> 6 #include <c10/util/TypeTraits.h> 7 #include <c10/util/irange.h> 8 9 namespace torch { 10 11 namespace detail { 12 /** 13 * In the Facebook internal build (using BUCK), this macro is enabled by 14 * passing in -c pt.enable_record_kernel_dtype=1 when building the tracer 15 * binary. 16 */ 17 #if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE 18 TORCH_API void record_custom_class(std::string name); 19 20 /** 21 * Record an instance of a custom class being loaded 22 * grab portion of string after final '.' from qualified name 23 * as this seemingly aligns with how users name their custom classes 24 * example: __torch__.torch.classes.xnnpack.Conv2dOpContext 25 */ 26 #define RECORD_CUSTOM_CLASS(NAME) \ 27 auto name = std::string(NAME); \ 28 detail::record_custom_class(name.substr(name.find_last_of(".") + 1)); 29 #else 30 #define RECORD_CUSTOM_CLASS(NAME) 31 #endif 32 } // namespace detail 33 34 /// This struct is used to represent default values for arguments 35 /// when registering methods for custom classes. 36 /// static auto register_foo = torch::class_<Foo>("myclasses", "Foo") 37 /// .def("myMethod", &Foo::myMethod, {torch::arg("name") = name}); 38 struct arg { 39 // Static method for representing a default value of None. This is meant to 40 // be used like so: 41 // torch::arg("name") = torch::arg::none 42 // and is identical to: 43 // torch::arg("name") = IValue() nonearg44 static c10::IValue none() { 45 return c10::IValue(); 46 } 47 48 // Explicit constructor. argarg49 explicit arg(std::string name) 50 : name_(std::move(name)), value_(std::nullopt) {} 51 // Assignment operator. This enables the pybind-like syntax of 52 // torch::arg("name") = value. 53 arg& operator=(const c10::IValue& rhs) { 54 value_ = rhs; 55 return *this; 56 } 57 58 // The name of the argument. This is copied to the schema; argument 59 // names cannot be extracted from the C++ declaration. 60 std::string name_; 61 // IValue's default constructor makes it None, which is not distinguishable 62 // from an actual, user-provided default value that is None. This boolean 63 // helps distinguish between the two cases. 64 std::optional<c10::IValue> value_; 65 }; 66 67 namespace detail { 68 69 // Argument type utilities 70 template <class R, class...> 71 struct types { 72 using type = types; 73 }; 74 75 template <typename Method> 76 struct WrapMethod; 77 78 template <typename R, typename CurrClass, typename... Args> 79 struct WrapMethod<R (CurrClass::*)(Args...)> { 80 WrapMethod(R (CurrClass::*m)(Args...)) : m(std::move(m)) {} 81 82 R operator()(c10::intrusive_ptr<CurrClass> cur, Args... args) { 83 return c10::guts::invoke(m, *cur, args...); 84 } 85 86 R (CurrClass::*m)(Args...); 87 }; 88 89 template <typename R, typename CurrClass, typename... Args> 90 struct WrapMethod<R (CurrClass::*)(Args...) const> { 91 WrapMethod(R (CurrClass::*m)(Args...) const) : m(std::move(m)) {} 92 93 R operator()(c10::intrusive_ptr<CurrClass> cur, Args... args) { 94 return c10::guts::invoke(m, *cur, args...); 95 } 96 97 R (CurrClass::*m)(Args...) const; 98 }; 99 100 // Adapter for different callable types 101 template < 102 typename CurClass, 103 typename Func, 104 std::enable_if_t< 105 std::is_member_function_pointer_v<std::decay_t<Func>>, 106 bool> = false> 107 WrapMethod<Func> wrap_func(Func f) { 108 return WrapMethod<Func>(std::move(f)); 109 } 110 111 template < 112 typename CurClass, 113 typename Func, 114 std::enable_if_t< 115 !std::is_member_function_pointer_v<std::decay_t<Func>>, 116 bool> = false> 117 Func wrap_func(Func f) { 118 return f; 119 } 120 121 template < 122 class Functor, 123 bool AllowDeprecatedTypes, 124 size_t... ivalue_arg_indices> 125 typename c10::guts::infer_function_traits_t<Functor>::return_type 126 call_torchbind_method_from_stack( 127 Functor& functor, 128 jit::Stack& stack, 129 std::index_sequence<ivalue_arg_indices...>) { 130 (void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would 131 // be unused and we have to silence the compiler warning. 132 133 constexpr size_t num_ivalue_args = sizeof...(ivalue_arg_indices); 134 135 using IValueArgTypes = 136 typename c10::guts::infer_function_traits_t<Functor>::parameter_types; 137 // TODO We shouldn't use c10::impl stuff directly here. We should use the 138 // KernelFunction API instead. 139 return (functor)(c10::impl::ivalue_to_arg< 140 typename c10::impl::decay_if_not_tensor< 141 c10::guts::typelist:: 142 element_t<ivalue_arg_indices, IValueArgTypes>>::type, 143 AllowDeprecatedTypes>:: 144 call(torch::jit::peek( 145 stack, ivalue_arg_indices, num_ivalue_args))...); 146 } 147 148 template <class Functor, bool AllowDeprecatedTypes> 149 typename c10::guts::infer_function_traits_t<Functor>::return_type 150 call_torchbind_method_from_stack(Functor& functor, jit::Stack& stack) { 151 constexpr size_t num_ivalue_args = 152 c10::guts::infer_function_traits_t<Functor>::number_of_parameters; 153 return call_torchbind_method_from_stack<Functor, AllowDeprecatedTypes>( 154 functor, stack, std::make_index_sequence<num_ivalue_args>()); 155 } 156 157 template <class RetType, class Func> 158 struct BoxedProxy; 159 160 template <class RetType, class Func> 161 struct BoxedProxy { 162 void operator()(jit::Stack& stack, Func& func) { 163 auto retval = call_torchbind_method_from_stack<Func, false>(func, stack); 164 constexpr size_t num_ivalue_args = 165 c10::guts::infer_function_traits_t<Func>::number_of_parameters; 166 torch::jit::drop(stack, num_ivalue_args); 167 stack.emplace_back(c10::ivalue::from(std::move(retval))); 168 } 169 }; 170 171 template <class Func> 172 struct BoxedProxy<void, Func> { 173 void operator()(jit::Stack& stack, Func& func) { 174 call_torchbind_method_from_stack<Func, false>(func, stack); 175 constexpr size_t num_ivalue_args = 176 c10::guts::infer_function_traits_t<Func>::number_of_parameters; 177 torch::jit::drop(stack, num_ivalue_args); 178 stack.emplace_back(); 179 } 180 }; 181 182 inline bool validIdent(size_t i, char n) { 183 return isalpha(n) || n == '_' || (i > 0 && isdigit(n)); 184 } 185 186 inline void checkValidIdent(const std::string& str, const char* type) { 187 for (const auto i : c10::irange(str.size())) { 188 TORCH_CHECK( 189 validIdent(i, str[i]), 190 type, 191 " must be a valid Python/C++ identifier." 192 " Character '", 193 str[i], 194 "' at index ", 195 i, 196 " is illegal."); 197 } 198 } 199 200 class TORCH_API class_base { 201 protected: 202 explicit class_base( 203 const std::string& namespaceName, 204 const std::string& className, 205 std::string doc_string, 206 const std::type_info& intrusivePtrClassTypeid, 207 const std::type_info& taggedCapsuleClass); 208 209 static c10::FunctionSchema withNewArguments( 210 const c10::FunctionSchema& schema, 211 std::initializer_list<arg> default_args); 212 std::string qualClassName; 213 at::ClassTypePtr classTypePtr; 214 }; 215 216 } // namespace detail 217 218 TORCH_API void registerCustomClass(at::ClassTypePtr class_type); 219 TORCH_API void registerCustomClassMethod(std::unique_ptr<jit::Function> method); 220 221 // Given a qualified name (e.g. __torch__.torch.classes.Foo), return 222 // the ClassType pointer to the Type that describes that custom class, 223 // or nullptr if no class by that name was found. 224 TORCH_API at::ClassTypePtr getCustomClass(const std::string& name); 225 226 // Given an IValue, return true if the object contained in that IValue 227 // is a custom C++ class, otherwise return false. 228 TORCH_API bool isCustomClass(const c10::IValue& v); 229 230 // This API is for testing purposes ONLY. It should not be used in 231 // any load-bearing code. 232 TORCH_API std::vector<c10::FunctionSchema> customClassSchemasForBCCheck(); 233 234 namespace jit { 235 using ::torch::registerCustomClass; 236 using ::torch::registerCustomClassMethod; 237 } // namespace jit 238 239 } // namespace torch 240