1 #pragma once 2 3 #include <ATen/core/Tensor.h> 4 #include <ATen/native/DispatchStub.h> 5 6 namespace at::native { 7 8 using weight_to_int4pack_fn = void(*)(const Tensor&, const Tensor&, int, int); 9 using int4pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, int, const Tensor&, int, int); 10 using int8pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&); 11 12 DECLARE_DISPATCH(weight_to_int4pack_fn, weight_to_int4pack_stub); 13 DECLARE_DISPATCH(int4pack_mm_fn, int4pack_mm_stub); 14 DECLARE_DISPATCH(int8pack_mm_fn, int8pack_mm_stub); 15 16 } // namespace at::native 17