xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qthreshold.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <torch/library.h>
4 #include <ATen/native/quantized/cpu/QuantizedOps.h>
5 
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #include <ATen/NativeFunctions.h>
9 #else
10 #include <ATen/ops/_empty_affine_quantized.h>
11 #include <ATen/ops/threshold_native.h>
12 #endif
13 
14 #include <algorithm>
15 
16 namespace at {
17 namespace native {
18 
19 DEFINE_DISPATCH(qthreshold_stub);
20 
21 // the underlying implementation for quantized threshold kernel
quantized_threshold_impl(const Tensor & qx,const Scalar & threshold,const Scalar & value)22 static Tensor quantized_threshold_impl(
23     const Tensor& qx,
24     const Scalar& threshold,
25     const Scalar& value) {
26   Tensor qy = at::_empty_affine_quantized(
27     qx.sizes(), qx.options(), qx.q_scale(), qx.q_zero_point());
28   qthreshold_stub(qx.device().type(), qx, threshold, value, qy);
29   return qy;
30 }
31 
32 // at::native functions for the native_functions.yaml
threshold_quantized_cpu(const Tensor & qx,const Scalar & threshold,const Scalar & value)33 Tensor threshold_quantized_cpu(
34     const Tensor& qx,
35     const Scalar& threshold,
36     const Scalar& value) {
37   Tensor qy;
38   AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "threshold", [&]() {
39     qy = quantized_threshold_impl(qx, threshold, value);
40   });
41   return qy;
42 }
43 
TORCH_LIBRARY_IMPL(quantized,QuantizedCPU,m)44 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
45   m.impl(TORCH_SELECTIVE_NAME("quantized::threshold"), TORCH_FN(threshold_quantized_cpu));
46 }
47 
48 } // namespace native
49 } // namespace at
50