#pragma once #include #include #include #include #include namespace torch { namespace detail { /** * In the Facebook internal build (using BUCK), this macro is enabled by * passing in -c pt.enable_record_kernel_dtype=1 when building the tracer * binary. */ #if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE TORCH_API void record_custom_class(std::string name); /** * Record an instance of a custom class being loaded * grab portion of string after final '.' from qualified name * as this seemingly aligns with how users name their custom classes * example: __torch__.torch.classes.xnnpack.Conv2dOpContext */ #define RECORD_CUSTOM_CLASS(NAME) \ auto name = std::string(NAME); \ detail::record_custom_class(name.substr(name.find_last_of(".") + 1)); #else #define RECORD_CUSTOM_CLASS(NAME) #endif } // namespace detail /// This struct is used to represent default values for arguments /// when registering methods for custom classes. /// static auto register_foo = torch::class_("myclasses", "Foo") /// .def("myMethod", &Foo::myMethod, {torch::arg("name") = name}); struct arg { // Static method for representing a default value of None. This is meant to // be used like so: // torch::arg("name") = torch::arg::none // and is identical to: // torch::arg("name") = IValue() static c10::IValue none() { return c10::IValue(); } // Explicit constructor. explicit arg(std::string name) : name_(std::move(name)), value_(std::nullopt) {} // Assignment operator. This enables the pybind-like syntax of // torch::arg("name") = value. arg& operator=(const c10::IValue& rhs) { value_ = rhs; return *this; } // The name of the argument. This is copied to the schema; argument // names cannot be extracted from the C++ declaration. std::string name_; // IValue's default constructor makes it None, which is not distinguishable // from an actual, user-provided default value that is None. This boolean // helps distinguish between the two cases. std::optional value_; }; namespace detail { // Argument type utilities template struct types { using type = types; }; template struct WrapMethod; template struct WrapMethod { WrapMethod(R (CurrClass::*m)(Args...)) : m(std::move(m)) {} R operator()(c10::intrusive_ptr cur, Args... args) { return c10::guts::invoke(m, *cur, args...); } R (CurrClass::*m)(Args...); }; template struct WrapMethod { WrapMethod(R (CurrClass::*m)(Args...) const) : m(std::move(m)) {} R operator()(c10::intrusive_ptr cur, Args... args) { return c10::guts::invoke(m, *cur, args...); } R (CurrClass::*m)(Args...) const; }; // Adapter for different callable types template < typename CurClass, typename Func, std::enable_if_t< std::is_member_function_pointer_v>, bool> = false> WrapMethod wrap_func(Func f) { return WrapMethod(std::move(f)); } template < typename CurClass, typename Func, std::enable_if_t< !std::is_member_function_pointer_v>, bool> = false> Func wrap_func(Func f) { return f; } template < class Functor, bool AllowDeprecatedTypes, size_t... ivalue_arg_indices> typename c10::guts::infer_function_traits_t::return_type call_torchbind_method_from_stack( Functor& functor, jit::Stack& stack, std::index_sequence) { (void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would // be unused and we have to silence the compiler warning. constexpr size_t num_ivalue_args = sizeof...(ivalue_arg_indices); using IValueArgTypes = typename c10::guts::infer_function_traits_t::parameter_types; // TODO We shouldn't use c10::impl stuff directly here. We should use the // KernelFunction API instead. return (functor)(c10::impl::ivalue_to_arg< typename c10::impl::decay_if_not_tensor< c10::guts::typelist:: element_t>::type, AllowDeprecatedTypes>:: call(torch::jit::peek( stack, ivalue_arg_indices, num_ivalue_args))...); } template typename c10::guts::infer_function_traits_t::return_type call_torchbind_method_from_stack(Functor& functor, jit::Stack& stack) { constexpr size_t num_ivalue_args = c10::guts::infer_function_traits_t::number_of_parameters; return call_torchbind_method_from_stack( functor, stack, std::make_index_sequence()); } template struct BoxedProxy; template struct BoxedProxy { void operator()(jit::Stack& stack, Func& func) { auto retval = call_torchbind_method_from_stack(func, stack); constexpr size_t num_ivalue_args = c10::guts::infer_function_traits_t::number_of_parameters; torch::jit::drop(stack, num_ivalue_args); stack.emplace_back(c10::ivalue::from(std::move(retval))); } }; template struct BoxedProxy { void operator()(jit::Stack& stack, Func& func) { call_torchbind_method_from_stack(func, stack); constexpr size_t num_ivalue_args = c10::guts::infer_function_traits_t::number_of_parameters; torch::jit::drop(stack, num_ivalue_args); stack.emplace_back(); } }; inline bool validIdent(size_t i, char n) { return isalpha(n) || n == '_' || (i > 0 && isdigit(n)); } inline void checkValidIdent(const std::string& str, const char* type) { for (const auto i : c10::irange(str.size())) { TORCH_CHECK( validIdent(i, str[i]), type, " must be a valid Python/C++ identifier." " Character '", str[i], "' at index ", i, " is illegal."); } } class TORCH_API class_base { protected: explicit class_base( const std::string& namespaceName, const std::string& className, std::string doc_string, const std::type_info& intrusivePtrClassTypeid, const std::type_info& taggedCapsuleClass); static c10::FunctionSchema withNewArguments( const c10::FunctionSchema& schema, std::initializer_list default_args); std::string qualClassName; at::ClassTypePtr classTypePtr; }; } // namespace detail TORCH_API void registerCustomClass(at::ClassTypePtr class_type); TORCH_API void registerCustomClassMethod(std::unique_ptr method); // Given a qualified name (e.g. __torch__.torch.classes.Foo), return // the ClassType pointer to the Type that describes that custom class, // or nullptr if no class by that name was found. TORCH_API at::ClassTypePtr getCustomClass(const std::string& name); // Given an IValue, return true if the object contained in that IValue // is a custom C++ class, otherwise return false. TORCH_API bool isCustomClass(const c10::IValue& v); // This API is for testing purposes ONLY. It should not be used in // any load-bearing code. TORCH_API std::vector customClassSchemasForBCCheck(); namespace jit { using ::torch::registerCustomClass; using ::torch::registerCustomClassMethod; } // namespace jit } // namespace torch