xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/q8dwconv.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <stddef.h>
12 #include <stdint.h>
13 
14 #include <qnnpack/common.h>
15 #include <qnnpack/params.h>
16 
17 #ifdef __cplusplus
18 extern "C" {
19 #endif
20 
21 #define DECLARE_PYTORCH_Q8UPDWCONV_UKERNEL_FUNCTION(fn_name) \
22   PYTORCH_QNNP_INTERNAL void fn_name(                \
23       size_t channels,                               \
24       size_t output_width,                           \
25       const uint8_t** input,                         \
26       const void* weights,                           \
27       uint8_t* output,                               \
28       size_t input_stride,                           \
29       size_t output_increment,                       \
30       const union pytorch_qnnp_conv_quantization_params* quantization_params);
31 
32 DECLARE_PYTORCH_Q8UPDWCONV_UKERNEL_FUNCTION(pytorch_q8dwconv_ukernel_up8x9__neon)
33 DECLARE_PYTORCH_Q8UPDWCONV_UKERNEL_FUNCTION(
34     pytorch_q8dwconv_ukernel_up8x9_per_channel__neon)
35 DECLARE_PYTORCH_Q8UPDWCONV_UKERNEL_FUNCTION(pytorch_q8dwconv_ukernel_up8x9__aarch32_neon)
36 DECLARE_PYTORCH_Q8UPDWCONV_UKERNEL_FUNCTION(
37     pytorch_q8dwconv_ukernel_up8x9_per_channel__aarch32_neon)
38 DECLARE_PYTORCH_Q8UPDWCONV_UKERNEL_FUNCTION(pytorch_q8dwconv_ukernel_up8x9__sse2)
39 DECLARE_PYTORCH_Q8UPDWCONV_UKERNEL_FUNCTION(
40     pytorch_q8dwconv_ukernel_up8x9_per_channel__sse2)
41 
42 #define DECLARE_PYTORCH_Q8MPDWCONV_UKERNEL_FUNCTION(fn_name) \
43   PYTORCH_QNNP_INTERNAL void fn_name(                \
44       size_t channels,                               \
45       size_t output_width,                           \
46       const uint8_t** input,                         \
47       const void* weights,                           \
48       int32_t* buffer,                               \
49       uint8_t* output,                               \
50       size_t input_stride,                           \
51       size_t output_increment,                       \
52       const union pytorch_qnnp_conv_quantization_params* quantization_params);
53 
54 DECLARE_PYTORCH_Q8MPDWCONV_UKERNEL_FUNCTION(pytorch_q8dwconv_ukernel_mp8x25__neon)
55 DECLARE_PYTORCH_Q8MPDWCONV_UKERNEL_FUNCTION(
56     pytorch_q8dwconv_ukernel_mp8x25_per_channel__neon)
57 DECLARE_PYTORCH_Q8MPDWCONV_UKERNEL_FUNCTION(pytorch_q8dwconv_ukernel_mp8x25__sse2)
58 DECLARE_PYTORCH_Q8MPDWCONV_UKERNEL_FUNCTION(
59     pytorch_q8dwconv_ukernel_mp8x25_per_channel__sse2)
60 
61 #define DECLARE_PYTORCH_Q8MPDWCONV_3D_UKERNEL_FUNCTION(fn_name) \
62   PYTORCH_QNNP_INTERNAL void fn_name(                           \
63       size_t channels,                                          \
64       size_t output_height,                                     \
65       size_t output_width,                                      \
66       const uint8_t** input,                                    \
67       const void* weights,                                      \
68       int32_t* buffer,                                          \
69       uint8_t* output,                                          \
70       size_t input_row_stride,                                  \
71       size_t input_col_stride,                                  \
72       size_t output_increment,                                  \
73       const union pytorch_qnnp_conv_quantization_params* quantization_params);
74 
75 DECLARE_PYTORCH_Q8MPDWCONV_3D_UKERNEL_FUNCTION(
76     pytorch_q8dwconv_ukernel_mp8x27__neon)
77 DECLARE_PYTORCH_Q8MPDWCONV_3D_UKERNEL_FUNCTION(
78     pytorch_q8dwconv_ukernel_mp8x27__sse2)
79 
80 #ifdef __cplusplus
81 } /* extern "C" */
82 #endif
83