xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/int_mm_kernel.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <ATen/native/DispatchStub.h>
5 
6 namespace at::native {
7 
8 using weight_to_int4pack_fn = void(*)(const Tensor&, const Tensor&, int, int);
9 using int4pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, int, const Tensor&, int, int);
10 using int8pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&);
11 
12 DECLARE_DISPATCH(weight_to_int4pack_fn, weight_to_int4pack_stub);
13 DECLARE_DISPATCH(int4pack_mm_fn, int4pack_mm_stub);
14 DECLARE_DISPATCH(int8pack_mm_fn, int8pack_mm_stub);
15 
16 } // namespace at::native
17