1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/core/ivalue.h>
4 #include <torch/library.h>
5 #include <ATen/native/quantized/cpu/QuantizedOps.h>
6
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #else
10 #include <ATen/ops/_empty_affine_quantized.h>
11 #endif
12
13 namespace at {
14 namespace native {
15
16 DEFINE_DISPATCH(qelu_stub);
17
quantized_elu(const Tensor & qx,double output_scale,int64_t output_zero_point,const Scalar & alpha,const Scalar & scale,const Scalar & input_scale)18 static Tensor quantized_elu(
19 const Tensor& qx, double output_scale, int64_t output_zero_point, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale) {
20 Tensor qy = at::_empty_affine_quantized(qx.sizes(), qx.options(), output_scale, output_zero_point);
21 qelu_stub(qx.device().type(), qx, alpha, scale, input_scale, qy);
22 return qy;
23 }
24
quantized_celu(const Tensor & qx,double output_scale,int64_t output_zero_point,const Scalar & alpha)25 static Tensor quantized_celu(const Tensor& qx, double output_scale, int64_t output_zero_point, const Scalar& alpha) {
26 TORCH_CHECK(alpha.to<double>() != 0,
27 "ZeroDivisionError: alpha cannot be 0 for CELU");
28 double inv_alpha = 1. / alpha.to<double>();
29 return quantized_elu(qx, output_scale, output_zero_point, alpha, Scalar(1.0), Scalar(inv_alpha));
30 }
31
TORCH_LIBRARY_IMPL(quantized,QuantizedCPU,m)32 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
33 m.impl(TORCH_SELECTIVE_NAME("quantized::elu"), quantized_elu);
34 m.impl(TORCH_SELECTIVE_NAME("quantized::celu"), quantized_celu);
35 }
36
37 }} // namespace at::native
38