xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifdef USE_XNNPACK
2 
3 #include <torch/library.h>
4 #include <ATen/native/xnnpack/Convolution.h>
5 #include <ATen/native/xnnpack/Linear.h>
6 #include <ATen/native/xnnpack/OpContext.h>
7 #include <torch/custom_class.h>
8 
9 namespace at::native::xnnpack {
10 
11 using internal::linear::createLinearClampPrePackOpContext;
12 using internal::convolution2d::createConv2dClampPrePackOpContext;
13 using internal::convolution2d::createConv2dTransposeClampPrePackOpContext;
14 
TORCH_LIBRARY(xnnpack,m)15 TORCH_LIBRARY(xnnpack, m) {
16   m.class_<LinearOpContext>(TORCH_SELECTIVE_CLASS("LinearOpContext"))
17     .def_pickle(
18         [](const c10::intrusive_ptr<LinearOpContext>& op_context)
19             -> SerializationTypeLinearPrePack { // __getstate__
20           return op_context->unpack();
21         },
22         [](SerializationTypeLinearPrePack state)
23             -> c10::intrusive_ptr<LinearOpContext> { // __setstate__
24           return createLinearClampPrePackOpContext(
25               std::get<0>(state),
26               std::get<1>(state),
27               std::get<2>(state),
28               std::get<3>(state));
29         })
30     .def("unpack", &LinearOpContext::unpack);
31 
32   m.class_<Conv2dOpContext>(TORCH_SELECTIVE_CLASS("Conv2dOpContext"))
33     .def_pickle(
34         [](const c10::intrusive_ptr<Conv2dOpContext>& op_context)
35             -> SerializationTypeConv2dPrePack { // __getstate__
36           return op_context->unpack();
37         },
38         [](SerializationTypeConv2dPrePack state)
39             -> c10::intrusive_ptr<Conv2dOpContext> { // __setstate__
40           return createConv2dClampPrePackOpContext(
41               std::get<0>(state),
42               std::get<1>(state),
43               std::get<2>(state),
44               std::get<3>(state),
45               std::get<4>(state),
46               std::get<5>(state),
47               std::get<6>(state),
48               std::get<7>(state));
49         })
50     .def("unpack", &Conv2dOpContext::unpack);
51 
52   m.class_<TransposeConv2dOpContext>(TORCH_SELECTIVE_CLASS("TransposeConv2dOpContext"))
53     .def_pickle(
54         [](const c10::intrusive_ptr<TransposeConv2dOpContext>& op_context)
55             -> SerializationTypeTransposeConv2dPrePack { // __getstate__
56           return op_context->unpack();
57         },
58         [](SerializationTypeTransposeConv2dPrePack state)
59             -> c10::intrusive_ptr<TransposeConv2dOpContext> { // __setstate__
60           return createConv2dTransposeClampPrePackOpContext(
61               std::get<0>(state),
62               std::get<1>(state),
63               std::get<2>(state),
64               std::get<3>(state),
65               std::get<4>(state),
66               std::get<5>(state),
67               std::get<6>(state),
68               std::get<7>(state),
69               std::get<8>(state));
70         });
71 
72 }
73 
74 // Registration using the TORCH_LIBRARY def gives dispatching errors when there is no tensor input
TORCH_LIBRARY(prepacked,m)75 TORCH_LIBRARY(prepacked, m) {
76   m.def(TORCH_SELECTIVE_SCHEMA("prepacked::unpack_prepacked_sizes_conv2d(Any W_prepack) -> (Any)"), [](const IValue& inp) { return internal::convolution2d::unpack_prepacked_sizes_conv2d(inp);});
77   m.def(TORCH_SELECTIVE_SCHEMA("prepacked::unpack_prepacked_sizes_linear(Any W_prepack) -> (Any)"), [](const IValue& inp) { return internal::linear::unpack_prepacked_sizes_linear(inp);});
78   m.def(TORCH_SELECTIVE_SCHEMA("prepacked::linear_clamp_prepack(Tensor W, Tensor? B=None, Scalar? output_min=None, Scalar? output_max=None) -> __torch__.torch.classes.xnnpack.LinearOpContext"));
79   m.def(TORCH_SELECTIVE_SCHEMA("prepacked::linear_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.LinearOpContext W_prepack) -> Tensor Y"));
80   m.def(TORCH_SELECTIVE_SCHEMA("prepacked::conv2d_clamp_prepack(Tensor W, Tensor? B, int[2] stride, int[2] padding, int[2] dilation, int groups, Scalar? output_min=None, Scalar? output_max=None) -> __torch__.torch.classes.xnnpack.Conv2dOpContext"));
81   m.def(TORCH_SELECTIVE_SCHEMA("prepacked::conv2d_transpose_clamp_prepack(Tensor W, Tensor? B, int[2] stride, int[2] padding, int[2] output_padding, int[2] dilation, int groups, Scalar? output_min=None, Scalar? output_max=None) -> __torch__.torch.classes.xnnpack.TransposeConv2dOpContext"));
82   m.def(TORCH_SELECTIVE_SCHEMA("prepacked::conv2d_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.Conv2dOpContext W_prepack) -> Tensor Y"));
83   m.def(TORCH_SELECTIVE_SCHEMA("prepacked::conv2d_transpose_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.TransposeConv2dOpContext W_prepack) -> Tensor Y"));
84 }
85 
TORCH_LIBRARY_IMPL(prepacked,CPU,m)86 TORCH_LIBRARY_IMPL(prepacked, CPU, m) {
87   m.impl(TORCH_SELECTIVE_NAME("prepacked::linear_clamp_prepack"), TORCH_FN(createLinearClampPrePackOpContext));
88   m.impl(TORCH_SELECTIVE_NAME("prepacked::linear_clamp_run"), TORCH_FN(internal::linear::linear_clamp_run));
89   m.impl(TORCH_SELECTIVE_NAME("prepacked::conv2d_clamp_prepack"), TORCH_FN(createConv2dClampPrePackOpContext));
90   m.impl(TORCH_SELECTIVE_NAME("prepacked::conv2d_transpose_clamp_prepack"), TORCH_FN(createConv2dTransposeClampPrePackOpContext));
91   m.impl(TORCH_SELECTIVE_NAME("prepacked::conv2d_clamp_run"), TORCH_FN(internal::convolution2d::conv2d_clamp_run));
92   m.impl(TORCH_SELECTIVE_NAME("prepacked::conv2d_transpose_clamp_run"), TORCH_FN(internal::convolution2d::conv2d_transpose_clamp_run));
93 }
94 
95 } // namespace at::native::xnnpack
96 
97 #endif /* USE_XNNPACK */
98