1 #pragma once 2 3 #include <ATen/core/Tensor.h> 4 #include <ATen/Dispatch.h> 5 #include <ATen/native/DispatchStub.h> 6 7 namespace at { 8 9 struct TensorIterator; 10 11 namespace native { 12 13 using fake_quant_tensor_cachemask_fn = void (*)( 14 Tensor& output, 15 Tensor& mask, 16 const Tensor& input, 17 float sc, 18 int64_t z_point, 19 int64_t quant_min, 20 int64_t quant_max); 21 22 using fake_quant_tensor_cachemask_tensor_qparams_fn = void (*)( 23 Tensor& output, 24 Tensor& mask, 25 const Tensor& input, 26 const Tensor& sc, 27 const Tensor& z_point, 28 const Tensor& fake_quant_enabled, 29 int64_t quant_min, 30 int64_t quant_max); 31 32 using fake_quant_learnable_grad_tensor_fn = void (*)( 33 TensorIterator& iter, 34 float scale, 35 float inv_scale, 36 int64_t zero_point, 37 int64_t quant_min, 38 int64_t quant_max, 39 float grad_factor); 40 41 DECLARE_DISPATCH(fake_quant_tensor_cachemask_fn, fake_quant_tensor_cachemask_stub); 42 DECLARE_DISPATCH(fake_quant_tensor_cachemask_tensor_qparams_fn, fake_quant_tensor_cachemask_tensor_qparams_stub); 43 DECLARE_DISPATCH(fake_quant_learnable_grad_tensor_fn, fake_quant_grad_learnable_tensor_stub); 44 45 using fake_quant_per_channel_fn = void (*)( 46 TensorIterator &iter, 47 int64_t quant_min, 48 int64_t quant_max); 49 50 using fake_quant_per_channel_cachemask_fn = void (*)( 51 TensorIterator &iter, 52 TensorIterator &iter_mask, 53 int64_t quant_min, 54 int64_t quant_max); 55 56 DECLARE_DISPATCH(fake_quant_per_channel_cachemask_fn, fake_quant_per_channel_cachemask_stub); 57 58 using fake_quant_learnable_per_channel_fn = void (*)( 59 TensorIterator &iter, 60 int64_t quant_min, 61 int64_t quant_max, 62 float grad_factor); 63 64 DECLARE_DISPATCH(fake_quant_learnable_per_channel_fn, fake_quant_grad_learnable_channel_stub); 65 66 } // namespace native 67 } // namespace at 68