1 #ifdef USE_XNNPACK
2
3 #include <torch/library.h>
4 #include <ATen/native/xnnpack/Convolution.h>
5 #include <ATen/native/xnnpack/Linear.h>
6 #include <ATen/native/xnnpack/OpContext.h>
7 #include <torch/custom_class.h>
8
9 namespace at::native::xnnpack {
10
11 using internal::linear::createLinearClampPrePackOpContext;
12 using internal::convolution2d::createConv2dClampPrePackOpContext;
13 using internal::convolution2d::createConv2dTransposeClampPrePackOpContext;
14
TORCH_LIBRARY(xnnpack,m)15 TORCH_LIBRARY(xnnpack, m) {
16 m.class_<LinearOpContext>(TORCH_SELECTIVE_CLASS("LinearOpContext"))
17 .def_pickle(
18 [](const c10::intrusive_ptr<LinearOpContext>& op_context)
19 -> SerializationTypeLinearPrePack { // __getstate__
20 return op_context->unpack();
21 },
22 [](SerializationTypeLinearPrePack state)
23 -> c10::intrusive_ptr<LinearOpContext> { // __setstate__
24 return createLinearClampPrePackOpContext(
25 std::get<0>(state),
26 std::get<1>(state),
27 std::get<2>(state),
28 std::get<3>(state));
29 })
30 .def("unpack", &LinearOpContext::unpack);
31
32 m.class_<Conv2dOpContext>(TORCH_SELECTIVE_CLASS("Conv2dOpContext"))
33 .def_pickle(
34 [](const c10::intrusive_ptr<Conv2dOpContext>& op_context)
35 -> SerializationTypeConv2dPrePack { // __getstate__
36 return op_context->unpack();
37 },
38 [](SerializationTypeConv2dPrePack state)
39 -> c10::intrusive_ptr<Conv2dOpContext> { // __setstate__
40 return createConv2dClampPrePackOpContext(
41 std::get<0>(state),
42 std::get<1>(state),
43 std::get<2>(state),
44 std::get<3>(state),
45 std::get<4>(state),
46 std::get<5>(state),
47 std::get<6>(state),
48 std::get<7>(state));
49 })
50 .def("unpack", &Conv2dOpContext::unpack);
51
52 m.class_<TransposeConv2dOpContext>(TORCH_SELECTIVE_CLASS("TransposeConv2dOpContext"))
53 .def_pickle(
54 [](const c10::intrusive_ptr<TransposeConv2dOpContext>& op_context)
55 -> SerializationTypeTransposeConv2dPrePack { // __getstate__
56 return op_context->unpack();
57 },
58 [](SerializationTypeTransposeConv2dPrePack state)
59 -> c10::intrusive_ptr<TransposeConv2dOpContext> { // __setstate__
60 return createConv2dTransposeClampPrePackOpContext(
61 std::get<0>(state),
62 std::get<1>(state),
63 std::get<2>(state),
64 std::get<3>(state),
65 std::get<4>(state),
66 std::get<5>(state),
67 std::get<6>(state),
68 std::get<7>(state),
69 std::get<8>(state));
70 });
71
72 }
73
74 // Registration using the TORCH_LIBRARY def gives dispatching errors when there is no tensor input
TORCH_LIBRARY(prepacked,m)75 TORCH_LIBRARY(prepacked, m) {
76 m.def(TORCH_SELECTIVE_SCHEMA("prepacked::unpack_prepacked_sizes_conv2d(Any W_prepack) -> (Any)"), [](const IValue& inp) { return internal::convolution2d::unpack_prepacked_sizes_conv2d(inp);});
77 m.def(TORCH_SELECTIVE_SCHEMA("prepacked::unpack_prepacked_sizes_linear(Any W_prepack) -> (Any)"), [](const IValue& inp) { return internal::linear::unpack_prepacked_sizes_linear(inp);});
78 m.def(TORCH_SELECTIVE_SCHEMA("prepacked::linear_clamp_prepack(Tensor W, Tensor? B=None, Scalar? output_min=None, Scalar? output_max=None) -> __torch__.torch.classes.xnnpack.LinearOpContext"));
79 m.def(TORCH_SELECTIVE_SCHEMA("prepacked::linear_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.LinearOpContext W_prepack) -> Tensor Y"));
80 m.def(TORCH_SELECTIVE_SCHEMA("prepacked::conv2d_clamp_prepack(Tensor W, Tensor? B, int[2] stride, int[2] padding, int[2] dilation, int groups, Scalar? output_min=None, Scalar? output_max=None) -> __torch__.torch.classes.xnnpack.Conv2dOpContext"));
81 m.def(TORCH_SELECTIVE_SCHEMA("prepacked::conv2d_transpose_clamp_prepack(Tensor W, Tensor? B, int[2] stride, int[2] padding, int[2] output_padding, int[2] dilation, int groups, Scalar? output_min=None, Scalar? output_max=None) -> __torch__.torch.classes.xnnpack.TransposeConv2dOpContext"));
82 m.def(TORCH_SELECTIVE_SCHEMA("prepacked::conv2d_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.Conv2dOpContext W_prepack) -> Tensor Y"));
83 m.def(TORCH_SELECTIVE_SCHEMA("prepacked::conv2d_transpose_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.TransposeConv2dOpContext W_prepack) -> Tensor Y"));
84 }
85
TORCH_LIBRARY_IMPL(prepacked,CPU,m)86 TORCH_LIBRARY_IMPL(prepacked, CPU, m) {
87 m.impl(TORCH_SELECTIVE_NAME("prepacked::linear_clamp_prepack"), TORCH_FN(createLinearClampPrePackOpContext));
88 m.impl(TORCH_SELECTIVE_NAME("prepacked::linear_clamp_run"), TORCH_FN(internal::linear::linear_clamp_run));
89 m.impl(TORCH_SELECTIVE_NAME("prepacked::conv2d_clamp_prepack"), TORCH_FN(createConv2dClampPrePackOpContext));
90 m.impl(TORCH_SELECTIVE_NAME("prepacked::conv2d_transpose_clamp_prepack"), TORCH_FN(createConv2dTransposeClampPrePackOpContext));
91 m.impl(TORCH_SELECTIVE_NAME("prepacked::conv2d_clamp_run"), TORCH_FN(internal::convolution2d::conv2d_clamp_run));
92 m.impl(TORCH_SELECTIVE_NAME("prepacked::conv2d_transpose_clamp_run"), TORCH_FN(internal::convolution2d::conv2d_transpose_clamp_run));
93 }
94
95 } // namespace at::native::xnnpack
96
97 #endif /* USE_XNNPACK */
98