xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/xnnpack/Engine.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <limits>
5 
6 namespace at::native::xnnpack {
7 
8 //
9 // Convolution
10 //
11 
12 bool use_convolution2d(
13     const Tensor& input,
14     const Tensor& weight,
15     const at::OptionalIntArrayRef bias_sizes_opt,
16     const IntArrayRef padding,
17     const IntArrayRef stride,
18     const IntArrayRef dilation,
19     const int64_t groups,
20     const bool transposed);
21 
22 Tensor convolution2d(
23     const Tensor& input,
24     const Tensor& weight,
25     const Tensor& bias,
26     const IntArrayRef padding,
27     const IntArrayRef stride,
28     const IntArrayRef dilation,
29     const int64_t groups);
30 
31 //
32 // Linear
33 //
34 
35 bool use_linear(
36   const Tensor& input,
37   const Tensor& weight,
38   const Tensor& bias);
39 
40 Tensor linear(
41   const Tensor& input,
42   const Tensor& weight,
43   const Tensor& bias);
44 
45 //
46 // Max Pooling
47 //
48 
49 bool use_max_pool2d(
50     const Tensor& input,
51     const IntArrayRef kernel,
52     const IntArrayRef padding,
53     IntArrayRef stride,
54     const IntArrayRef dilation,
55     const bool ceil_mode,
56     const float output_min = -std::numeric_limits<float>::infinity(),
57     const float output_max = +std::numeric_limits<float>::infinity());
58 
59 Tensor max_pool2d(
60     const Tensor& input,
61     const IntArrayRef kernel,
62     const IntArrayRef padding,
63     IntArrayRef stride,
64     const IntArrayRef dilation,
65     const bool ceil_mode,
66     const float output_min = -std::numeric_limits<float>::infinity(),
67     const float output_max = +std::numeric_limits<float>::infinity());
68 
69 //
70 // Global Average Pooling
71 //
72 
73 bool use_global_average_pool(const Tensor& input);
74 Tensor global_average_pool(const Tensor& input);
75 
76 //
77 // Channel Shuffle
78 //
79 
80 bool use_channel_shuffle(
81     const Tensor& input,
82     const int64_t groups);
83 
84 Tensor channel_shuffle(
85     const Tensor& input,
86     const int64_t groups);
87 
88 //
89 // Activations
90 //
91 bool use_hardswish(const Tensor& input);
92 Tensor hardswish(const Tensor& input);
93 Tensor& hardswish_(Tensor& input);
94 
95 } // namespace at::native::xnnpack
96