xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/clamp.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <assert.h>
10 #include <math.h>
11 #include <stddef.h>
12 #include <stdint.h>
13 #include <stdlib.h>
14 
15 #include <pytorch_qnnpack.h>
16 #include <qnnpack/log.h>
17 #include <qnnpack/operator.h>
18 
pytorch_qnnp_create_clamp_nc_u8(size_t channels,uint8_t output_min,uint8_t output_max,uint32_t flags,pytorch_qnnp_operator_t * clamp_out)19 enum pytorch_qnnp_status pytorch_qnnp_create_clamp_nc_u8(
20     size_t channels,
21     uint8_t output_min,
22     uint8_t output_max,
23     uint32_t flags,
24     pytorch_qnnp_operator_t* clamp_out) {
25   pytorch_qnnp_operator_t clamp_op = NULL;
26   enum pytorch_qnnp_status status = pytorch_qnnp_status_uninitialized;
27 
28   if (!pytorch_qnnp_params.initialized) {
29     pytorch_qnnp_log_error(
30         "pytorch_qnnp_create_clamp_nc_u8 failed because QNNPACK is not properly initialized");
31     goto error;
32   }
33 
34   status = pytorch_qnnp_status_invalid_parameter;
35 
36   if (channels == 0) {
37     pytorch_qnnp_log_error(
38         "failed to create Clamp operator with %zu channels: number of channels must be non-zero",
39         channels);
40     goto error;
41   }
42 
43   if (output_min > output_max) {
44     pytorch_qnnp_log_error(
45         "failed to create Clamp operator with [%" PRIu8 ", %" PRIu8
46         "] output range: range min must be below range max",
47         output_min,
48         output_max);
49     goto error;
50   }
51 
52   status = pytorch_qnnp_status_out_of_memory;
53 
54   clamp_op = calloc(1, sizeof(struct pytorch_qnnp_operator));
55   if (clamp_op == NULL) {
56     pytorch_qnnp_log_error(
57         "failed to allocate %zu bytes for pytorch_qnnp_operator structure",
58         sizeof(struct pytorch_qnnp_operator));
59     goto error;
60   }
61 
62   clamp_op->channels = channels;
63   clamp_op->u8_clamping_params =
64       pytorch_qnnp_compute_u8_clamping_params(output_min, output_max);
65 
66   clamp_op->ukernel_type = pytorch_qnnp_ukernel_type_clamp;
67   clamp_op->format = pytorch_qnnp_format_quint8;
68 
69   *clamp_out = clamp_op;
70   return pytorch_qnnp_status_success;
71 
72 error:
73   pytorch_qnnp_delete_operator(clamp_op);
74   return status;
75 }
76 
pytorch_qnnp_setup_clamp_nc_u8(pytorch_qnnp_operator_t clamp,size_t batch_size,const uint8_t * input,size_t input_stride,uint8_t * output,size_t output_stride)77 enum pytorch_qnnp_status pytorch_qnnp_setup_clamp_nc_u8(
78     pytorch_qnnp_operator_t clamp,
79     size_t batch_size,
80     const uint8_t* input,
81     size_t input_stride,
82     uint8_t* output,
83     size_t output_stride) {
84   if (!pytorch_qnnp_params.initialized) {
85     pytorch_qnnp_log_error(
86         "pytorch_qnnp_setup_clamp_nc_u8 failed because QNNPACK is not properly initialized");
87     return pytorch_qnnp_status_uninitialized;
88   }
89 
90   if (batch_size == 0) {
91     clamp->batch_size = 0;
92     return pytorch_qnnp_status_success;
93   }
94 
95   clamp->batch_size = batch_size;
96   clamp->input = input;
97   clamp->input_pixel_stride = input_stride;
98   clamp->output = output;
99   clamp->output_pixel_stride = output_stride;
100 
101   return pytorch_qnnp_status_success;
102 }
103