1 #pragma once 2 #include <ATen/core/Tensor.h> 3 #include <ATen/core/IListRef.h> 4 #include <ATen/Dispatch.h> 5 #include <ATen/TensorIterator.h> 6 #include <ATen/native/Activation.h> 7 #include <ATen/native/DispatchStub.h> 8 9 namespace at { 10 namespace native { 11 12 using qrelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/); 13 using qrelu_leaky_fn = void (*)(Tensor& /*out*/, const Tensor& /*qx*/, 14 const Scalar& /*negval_*/); 15 using qgelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/, GeluType /* approximate */); 16 using qsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/, double output_scale, int64_t output_zero_point); 17 using qhardsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/); 18 using qclamp_fn = void (*)( 19 const at::Tensor& /*qx*/, 20 const Scalar& min, 21 const Scalar& max, 22 at::Tensor& /*qy*/); 23 using qclamp_minmax_fn = void (*)( 24 const at::Tensor& /*qx*/, 25 const Scalar& /*min or max*/, 26 at::Tensor& /*qy*/); 27 using qthreshold_fn = void (*)( 28 const at::Tensor& /*qx*/, 29 const Scalar& threshold, 30 const Scalar& value, 31 at::Tensor& /*qy*/); 32 using qtanh_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/); 33 using qelu_fn = void(*)( 34 const at::Tensor& /*qx*/, 35 const Scalar& /*alpha*/, 36 const Scalar& /*scale*/, 37 const Scalar& /*input_scale*/, 38 at::Tensor& /*qy*/); 39 using qbinary_fn = 40 void (*)(Tensor& /*out*/, const Tensor& /*self*/, const Tensor& /*other*/); 41 using qadd_scalar_fn = 42 void (*)(Tensor& /*out*/, const Tensor& /*self*/, const Scalar& other /*other*/); 43 using qhardswish_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/); 44 using qdropout_fn = void(*)( 45 const at::Tensor& /*qx*/, 46 const Scalar& /*p*/, 47 bool training /*training*/, 48 at::Tensor& /*qy*/); 49 using qmaxpool_2d_fn = void (*)( 50 const Tensor& qx, 51 int64_t iC, // input/output channels 52 int64_t iH, 53 int64_t iW, // input sizes 54 int64_t oH, 55 int64_t oW, // output sizes 56 int64_t kH, 57 int64_t kW, // kernel size 58 int64_t sH, 59 int64_t sW, // strides 60 int64_t pH, 61 int64_t pW, // padding 62 int64_t dH, 63 int64_t dW, // dilation 64 Tensor& qy); 65 using qmaxpool_3d_fn = void (*)( 66 const Tensor& qx, 67 int64_t iC, // input/output channels 68 int64_t iT, 69 int64_t iH, 70 int64_t iW, // input sizes 71 int64_t oT, 72 int64_t oH, 73 int64_t oW, // output sizes 74 int64_t kT, 75 int64_t kH, 76 int64_t kW, // kernel size 77 int64_t sT, 78 int64_t sH, 79 int64_t sW, // strides 80 int64_t pT, 81 int64_t pH, 82 int64_t pW, // padding 83 int64_t dT, 84 int64_t dH, 85 int64_t dW, // dilation 86 Tensor& qy); 87 using qadaptive_avg_pool2d_fn = void (*)( 88 const Tensor& qx, 89 Tensor& qy, 90 int64_t sizeB, 91 int64_t sizeC, 92 int64_t isizeH, 93 int64_t isizeW, 94 int64_t osizeH, 95 int64_t osizeW, 96 int64_t istrideB, 97 int64_t istrideC, 98 int64_t istrideH, 99 int64_t istrideW); 100 using qadaptive_avg_pool3d_fn = void (*)( 101 const Tensor& qx, 102 Tensor& qy, 103 int64_t sizeB, 104 int64_t sizeC, 105 int64_t isizeD, 106 int64_t isizeH, 107 int64_t isizeW, 108 int64_t osizeD, 109 int64_t osizeH, 110 int64_t osizeW, 111 int64_t istrideB, 112 int64_t istrideC, 113 int64_t istrideD, 114 int64_t istrideH, 115 int64_t istrideW); 116 using qavg_pool2d_fn = void (*)( 117 const Tensor& qx, 118 Tensor& qy, 119 int64_t nBatch, 120 int64_t nInputPlane, 121 int64_t inputWidth, 122 int64_t inputHeight, 123 int64_t outputWidth, 124 int64_t outputHeight, 125 int kW, 126 int kH, 127 int dW, 128 int dH, 129 int padW, 130 int padH, 131 bool count_include_pad, 132 std::optional<int64_t> divisor_override); 133 134 using qavg_pool3d_fn = void (*)( 135 const Tensor& qx, 136 Tensor& qy, 137 int64_t nBatch, 138 int64_t nInputPlane, 139 int64_t inputWidth, 140 int64_t inputHeight, 141 int64_t inputDepth, 142 int64_t outputWidth, 143 int64_t outputHeight, 144 int64_t outputDepth, 145 int kW, 146 int kH, 147 int kD, 148 int dW, 149 int dH, 150 int dD, 151 int padW, 152 int padH, 153 int padD, 154 bool count_include_pad, 155 std::optional<int64_t> divisor_override); 156 157 using qupsample_bilinear2d_fn = void (*)( 158 Tensor& output, 159 const Tensor& input, 160 int64_t input_height, 161 int64_t input_width, 162 int64_t output_height, 163 int64_t output_width, 164 int64_t nbatch, 165 int64_t channels, 166 bool align_corners, 167 std::optional<double> scales_h, 168 std::optional<double> scales_w); 169 170 using qcat_nhwc_fn = Tensor (*)( 171 const MaterializedITensorListRef& qxs, 172 int64_t dim, 173 double scale, 174 int64_t zero_point); 175 using qtopk_fn = void(*)(Tensor&, Tensor&, const Tensor&, int64_t, int64_t, bool, bool); 176 177 using qbatch_norm_fn = void(*)(int64_t, int64_t, int64_t, int64_t, int64_t, const Tensor&, const Tensor&, const Tensor&, Tensor&); 178 179 using qnormalize_fn = void (*)( 180 const Tensor& /* X */, 181 const Tensor& /* gamma */, 182 const Tensor& /* beta */, 183 bool /* affine_per_channel */, 184 int /* num_channels */, 185 int /* num_groups */, 186 int64_t /* M */, 187 int64_t /* N */, 188 double /* eps */, 189 Tensor* /* Y */); 190 191 using qmean_inner_dim_fn = void (*)( 192 const Tensor& /* X */, 193 OptionalIntArrayRef /* opt_dim */, 194 bool /* keepdim */, 195 std::optional<ScalarType> /* opt_dtype */, 196 Tensor& /* Y */); 197 198 using qstd_inner_dim_fn = void (*)( 199 const Tensor& /* X */, 200 OptionalIntArrayRef /* dim */, 201 const std::optional<Scalar>& /* correction */, 202 bool /* keepdim */, 203 Tensor& /* Y */); 204 205 using qnormalize_nhwc_fn = void (*)( 206 const Tensor& /* X */, 207 const Tensor& /* gamma */, 208 const Tensor& /* beta */, 209 bool /* affine_per_channel */, 210 int /* num_channels */, 211 int /* num_groups */, 212 int64_t /* M */, 213 int64_t /* N */, 214 double /* eps */, 215 Tensor* /* Y */); 216 217 using qprelu_fn = void (*)(Tensor& /*out*/, const Tensor& /*qx*/, 218 const Tensor& /*qw*/); 219 220 DECLARE_DISPATCH(qadaptive_avg_pool2d_fn, qadaptive_avg_pool2d_nhwc_stub); 221 DECLARE_DISPATCH(qadaptive_avg_pool3d_fn, qadaptive_avg_pool3d_ndhwc_stub); 222 DECLARE_DISPATCH(qadd_scalar_fn, qadd_scalar_relu_stub); 223 DECLARE_DISPATCH(qadd_scalar_fn, qadd_scalar_stub); 224 DECLARE_DISPATCH(qavg_pool2d_fn, qavg_pool2d_nhwc_stub); 225 DECLARE_DISPATCH(qavg_pool3d_fn, qavg_pool3d_nhwc_stub); 226 DECLARE_DISPATCH(qbatch_norm_fn, qbatch_norm_relu_stub); 227 DECLARE_DISPATCH(qbatch_norm_fn, qbatch_norm_stub); 228 DECLARE_DISPATCH(qbinary_fn, qadd_relu_stub); 229 DECLARE_DISPATCH(qbinary_fn, qadd_stub); 230 DECLARE_DISPATCH(qbinary_fn, qmul_relu_stub); 231 DECLARE_DISPATCH(qbinary_fn, qmul_stub); 232 DECLARE_DISPATCH(qcat_nhwc_fn, qcat_nhwc_stub); 233 DECLARE_DISPATCH(qcat_nhwc_fn, qcat_relu_nhwc_stub); 234 DECLARE_DISPATCH(qclamp_fn, qclamp_stub); 235 DECLARE_DISPATCH(qclamp_minmax_fn, qclamp_min_stub); 236 DECLARE_DISPATCH(qclamp_minmax_fn, qclamp_max_stub); 237 DECLARE_DISPATCH(qelu_fn, qelu_stub); 238 DECLARE_DISPATCH(qhardsigmoid_fn, qhardsigmoid_stub); 239 DECLARE_DISPATCH(qhardswish_fn, qhardswish_stub); 240 DECLARE_DISPATCH(qdropout_fn, qdropout_stub); 241 DECLARE_DISPATCH(qmaxpool_2d_fn, qmaxpool_2d_nhwc_stub); 242 DECLARE_DISPATCH(qmaxpool_3d_fn, qmaxpool_3d_nthwc_stub); 243 DECLARE_DISPATCH(qnormalize_fn, quantized_normalize_stub); 244 DECLARE_DISPATCH(qnormalize_nhwc_fn, quantized_groupnorm_nhwc_stub); 245 DECLARE_DISPATCH(qrelu_fn, qrelu_stub); 246 DECLARE_DISPATCH(qrelu_leaky_fn, qrelu_leaky_stub); 247 DECLARE_DISPATCH(qgelu_fn, qgelu_stub); 248 DECLARE_DISPATCH(qsigmoid_fn, qsigmoid_stub); 249 DECLARE_DISPATCH(qtanh_fn, qtanh_stub); 250 DECLARE_DISPATCH(qthreshold_fn, qthreshold_stub); 251 DECLARE_DISPATCH(qtopk_fn, qtopk_stub); 252 DECLARE_DISPATCH(qupsample_bilinear2d_fn, qupsample_bilinear2d_nhwc_stub); 253 DECLARE_DISPATCH(qmean_inner_dim_fn, qmean_inner_dim_stub); 254 DECLARE_DISPATCH(qstd_inner_dim_fn, qstd_inner_dim_stub); 255 DECLARE_DISPATCH(qprelu_fn, qprelu_stub); 256 257 } // namespace native 258 } // namespace at 259