1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <torch/library.h>
3
4 #include <torch/custom_class.h>
5 #include <ATen/native/ao_sparse/quantized/cpu/packed_params.h>
6 #include <ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h>
7
8 // Register operators
TORCH_LIBRARY(sparse,m)9 TORCH_LIBRARY(sparse, m) {
10 ao::sparse::register_linear_params();
11
12 m.def(TORCH_SELECTIVE_SCHEMA(
13 "sparse::qlinear(Tensor X, __torch__.torch.classes.sparse.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y"));
14 m.def(TORCH_SELECTIVE_SCHEMA(
15 "sparse::qlinear_relu(Tensor X, __torch__.torch.classes.sparse.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y"));
16
17 m.def(TORCH_SELECTIVE_SCHEMA(
18 "sparse::qlinear_dynamic(Tensor X, __torch__.torch.classes.sparse.LinearPackedParamsBase W_prepack) -> Tensor Y"));
19 m.def(TORCH_SELECTIVE_SCHEMA(
20 "sparse::qlinear_relu_dynamic(Tensor X, __torch__.torch.classes.sparse.LinearPackedParamsBase W_prepack) -> Tensor Y"));
21
22 m.def(TORCH_SELECTIVE_SCHEMA(
23 "sparse::qlinear_prepack(Tensor W, Tensor? B, int out_features_block_size, int in_features_block_size) -> __torch__.torch.classes.sparse.LinearPackedParamsBase W_prepack"));
24
25 m.def(TORCH_SELECTIVE_SCHEMA(
26 "sparse::qlinear_unpack(__torch__.torch.classes.sparse.LinearPackedParamsBase W_prepack) -> (Tensor W_origin, Tensor? B_origin, int[] block_pattern)"));
27 }
28