xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/ao_sparse/library.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <torch/library.h>
3 
4 #include <torch/custom_class.h>
5 #include <ATen/native/ao_sparse/quantized/cpu/packed_params.h>
6 #include <ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h>
7 
8 // Register operators
TORCH_LIBRARY(sparse,m)9 TORCH_LIBRARY(sparse, m) {
10   ao::sparse::register_linear_params();
11 
12   m.def(TORCH_SELECTIVE_SCHEMA(
13       "sparse::qlinear(Tensor X, __torch__.torch.classes.sparse.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y"));
14   m.def(TORCH_SELECTIVE_SCHEMA(
15       "sparse::qlinear_relu(Tensor X, __torch__.torch.classes.sparse.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y"));
16 
17   m.def(TORCH_SELECTIVE_SCHEMA(
18       "sparse::qlinear_dynamic(Tensor X, __torch__.torch.classes.sparse.LinearPackedParamsBase W_prepack) -> Tensor Y"));
19   m.def(TORCH_SELECTIVE_SCHEMA(
20       "sparse::qlinear_relu_dynamic(Tensor X, __torch__.torch.classes.sparse.LinearPackedParamsBase W_prepack) -> Tensor Y"));
21 
22   m.def(TORCH_SELECTIVE_SCHEMA(
23       "sparse::qlinear_prepack(Tensor W, Tensor? B, int out_features_block_size, int in_features_block_size) -> __torch__.torch.classes.sparse.LinearPackedParamsBase W_prepack"));
24 
25   m.def(TORCH_SELECTIVE_SCHEMA(
26       "sparse::qlinear_unpack(__torch__.torch.classes.sparse.LinearPackedParamsBase W_prepack) -> (Tensor W_origin, Tensor? B_origin, int[] block_pattern)"));
27 }
28