xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/FakeQuantAffine.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/native/DispatchStub.h>
6 
7 namespace at {
8 
9 struct TensorIterator;
10 
11 namespace native {
12 
13 using fake_quant_tensor_cachemask_fn = void (*)(
14     Tensor& output,
15     Tensor& mask,
16     const Tensor& input,
17     float sc,
18     int64_t z_point,
19     int64_t quant_min,
20     int64_t quant_max);
21 
22 using fake_quant_tensor_cachemask_tensor_qparams_fn = void (*)(
23     Tensor& output,
24     Tensor& mask,
25     const Tensor& input,
26     const Tensor& sc,
27     const Tensor& z_point,
28     const Tensor& fake_quant_enabled,
29     int64_t quant_min,
30     int64_t quant_max);
31 
32 using fake_quant_learnable_grad_tensor_fn = void (*)(
33     TensorIterator& iter,
34     float scale,
35     float inv_scale,
36     int64_t zero_point,
37     int64_t quant_min,
38     int64_t quant_max,
39     float grad_factor);
40 
41 DECLARE_DISPATCH(fake_quant_tensor_cachemask_fn, fake_quant_tensor_cachemask_stub);
42 DECLARE_DISPATCH(fake_quant_tensor_cachemask_tensor_qparams_fn, fake_quant_tensor_cachemask_tensor_qparams_stub);
43 DECLARE_DISPATCH(fake_quant_learnable_grad_tensor_fn, fake_quant_grad_learnable_tensor_stub);
44 
45 using fake_quant_per_channel_fn = void (*)(
46     TensorIterator &iter,
47     int64_t quant_min,
48     int64_t quant_max);
49 
50 using fake_quant_per_channel_cachemask_fn = void (*)(
51     TensorIterator &iter,
52     TensorIterator &iter_mask,
53     int64_t quant_min,
54     int64_t quant_max);
55 
56 DECLARE_DISPATCH(fake_quant_per_channel_cachemask_fn, fake_quant_per_channel_cachemask_stub);
57 
58 using fake_quant_learnable_per_channel_fn = void (*)(
59     TensorIterator &iter,
60     int64_t quant_min,
61     int64_t quant_max,
62     float grad_factor);
63 
64 DECLARE_DISPATCH(fake_quant_learnable_per_channel_fn, fake_quant_grad_learnable_channel_stub);
65 
66 } // namespace native
67 } // namespace at
68