1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <assert.h>
10 #include <math.h>
11 #include <stddef.h>
12 #include <stdint.h>
13 #include <stdlib.h>
14
15 #include <pytorch_qnnpack.h>
16 #include <qnnpack/log.h>
17 #include <qnnpack/operator.h>
18
pytorch_qnnp_create_clamp_nc_u8(size_t channels,uint8_t output_min,uint8_t output_max,uint32_t flags,pytorch_qnnp_operator_t * clamp_out)19 enum pytorch_qnnp_status pytorch_qnnp_create_clamp_nc_u8(
20 size_t channels,
21 uint8_t output_min,
22 uint8_t output_max,
23 uint32_t flags,
24 pytorch_qnnp_operator_t* clamp_out) {
25 pytorch_qnnp_operator_t clamp_op = NULL;
26 enum pytorch_qnnp_status status = pytorch_qnnp_status_uninitialized;
27
28 if (!pytorch_qnnp_params.initialized) {
29 pytorch_qnnp_log_error(
30 "pytorch_qnnp_create_clamp_nc_u8 failed because QNNPACK is not properly initialized");
31 goto error;
32 }
33
34 status = pytorch_qnnp_status_invalid_parameter;
35
36 if (channels == 0) {
37 pytorch_qnnp_log_error(
38 "failed to create Clamp operator with %zu channels: number of channels must be non-zero",
39 channels);
40 goto error;
41 }
42
43 if (output_min > output_max) {
44 pytorch_qnnp_log_error(
45 "failed to create Clamp operator with [%" PRIu8 ", %" PRIu8
46 "] output range: range min must be below range max",
47 output_min,
48 output_max);
49 goto error;
50 }
51
52 status = pytorch_qnnp_status_out_of_memory;
53
54 clamp_op = calloc(1, sizeof(struct pytorch_qnnp_operator));
55 if (clamp_op == NULL) {
56 pytorch_qnnp_log_error(
57 "failed to allocate %zu bytes for pytorch_qnnp_operator structure",
58 sizeof(struct pytorch_qnnp_operator));
59 goto error;
60 }
61
62 clamp_op->channels = channels;
63 clamp_op->u8_clamping_params =
64 pytorch_qnnp_compute_u8_clamping_params(output_min, output_max);
65
66 clamp_op->ukernel_type = pytorch_qnnp_ukernel_type_clamp;
67 clamp_op->format = pytorch_qnnp_format_quint8;
68
69 *clamp_out = clamp_op;
70 return pytorch_qnnp_status_success;
71
72 error:
73 pytorch_qnnp_delete_operator(clamp_op);
74 return status;
75 }
76
pytorch_qnnp_setup_clamp_nc_u8(pytorch_qnnp_operator_t clamp,size_t batch_size,const uint8_t * input,size_t input_stride,uint8_t * output,size_t output_stride)77 enum pytorch_qnnp_status pytorch_qnnp_setup_clamp_nc_u8(
78 pytorch_qnnp_operator_t clamp,
79 size_t batch_size,
80 const uint8_t* input,
81 size_t input_stride,
82 uint8_t* output,
83 size_t output_stride) {
84 if (!pytorch_qnnp_params.initialized) {
85 pytorch_qnnp_log_error(
86 "pytorch_qnnp_setup_clamp_nc_u8 failed because QNNPACK is not properly initialized");
87 return pytorch_qnnp_status_uninitialized;
88 }
89
90 if (batch_size == 0) {
91 clamp->batch_size = 0;
92 return pytorch_qnnp_status_success;
93 }
94
95 clamp->batch_size = batch_size;
96 clamp->input = input;
97 clamp->input_pixel_stride = input_stride;
98 clamp->output = output;
99 clamp->output_pixel_stride = output_stride;
100
101 return pytorch_qnnp_status_success;
102 }
103