xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qelu.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/core/ivalue.h>
4 #include <torch/library.h>
5 #include <ATen/native/quantized/cpu/QuantizedOps.h>
6 
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #else
10 #include <ATen/ops/_empty_affine_quantized.h>
11 #endif
12 
13 namespace at {
14 namespace native {
15 
16 DEFINE_DISPATCH(qelu_stub);
17 
quantized_elu(const Tensor & qx,double output_scale,int64_t output_zero_point,const Scalar & alpha,const Scalar & scale,const Scalar & input_scale)18 static Tensor quantized_elu(
19     const Tensor& qx, double output_scale, int64_t output_zero_point, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale) {
20   Tensor qy = at::_empty_affine_quantized(qx.sizes(), qx.options(), output_scale, output_zero_point);
21   qelu_stub(qx.device().type(), qx, alpha, scale, input_scale, qy);
22   return qy;
23 }
24 
quantized_celu(const Tensor & qx,double output_scale,int64_t output_zero_point,const Scalar & alpha)25 static Tensor quantized_celu(const Tensor& qx, double output_scale, int64_t output_zero_point, const Scalar& alpha) {
26   TORCH_CHECK(alpha.to<double>() != 0,
27       "ZeroDivisionError: alpha cannot be 0 for CELU");
28   double inv_alpha = 1. / alpha.to<double>();
29   return quantized_elu(qx, output_scale, output_zero_point, alpha, Scalar(1.0), Scalar(inv_alpha));
30 }
31 
TORCH_LIBRARY_IMPL(quantized,QuantizedCPU,m)32 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
33   m.impl(TORCH_SELECTIVE_NAME("quantized::elu"), quantized_elu);
34   m.impl(TORCH_SELECTIVE_NAME("quantized::celu"), quantized_celu);
35 }
36 
37 }}  // namespace at::native
38