xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/xnnpack/AveragePooling.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifdef USE_XNNPACK
2 
3 #include <ATen/native/utils/Factory.h>
4 #include <ATen/native/xnnpack/Common.h>
5 #include <ATen/native/xnnpack/Engine.h>
6 #include <ATen/native/xnnpack/Pooling.h>
7 
8 namespace at::native::xnnpack {
9 
use_global_average_pool(const Tensor & input)10 bool use_global_average_pool(const Tensor& input) {
11   return xnnpack::available() && (1 <= input.ndimension()) &&
12       (input.device().is_cpu()) && (kFloat == input.scalar_type()) &&
13       !input.requires_grad() && true;
14 }
15 
global_average_pool(const Tensor & input)16 Tensor global_average_pool(const Tensor& input) {
17   using namespace internal;
18 
19   const Tensor input_padded_contig_nhwc =
20       mobile::allocate_padded_contiguous_if_needed(
21           input, MemoryFormat::ChannelsLast);
22 
23   Tensor output = mobile::empty_with_tail_padding(
24       {
25           input_padded_contig_nhwc.size(Layout::Activation4D::batch),
26           input_padded_contig_nhwc.size(Layout::Activation4D::channels),
27           1,
28           1,
29       },
30       input_padded_contig_nhwc.options().dtype(),
31       MemoryFormat::ChannelsLast,
32       input_padded_contig_nhwc.opt_names());
33 
34   xnn_operator_t global_average_pooling_op{};
35   const xnn_status create_status = xnn_create_global_average_pooling_nwc_f32(
36       -std::numeric_limits<float>::infinity(),
37       std::numeric_limits<float>::infinity(),
38       0 /* flags */,
39       &global_average_pooling_op);
40 
41   TORCH_CHECK(
42       xnn_status_success == create_status,
43       "xnn_create_global_average_pooling_nwc_f32 failed!");
44 
45   Operator global_avg_pool_scoped_op(global_average_pooling_op);
46 
47   size_t workspace_size = 0;
48   size_t workspace_alignment = 0;
49 
50   const xnn_status reshape_status = xnn_reshape_global_average_pooling_nwc_f32(
51       global_average_pooling_op,
52       input_padded_contig_nhwc.size(Layout::Activation4D::batch), // batch_size
53       input_padded_contig_nhwc.size(Layout::Activation4D::width) *
54           input_padded_contig_nhwc.size(Layout::Activation4D::height), // width
55       input_padded_contig_nhwc.size(Layout::Activation4D::channels), // channels
56       input_padded_contig_nhwc.size(
57           Layout::Activation4D::channels), // input stride
58       input_padded_contig_nhwc.size(
59           Layout::Activation4D::channels), // output stride
60       &workspace_size, // workspace_size
61       &workspace_alignment, // workspace_alignment
62       caffe2::pthreadpool_());
63 
64   TORCH_CHECK(
65       xnn_status_success == reshape_status,
66       "xnn_reshape_global_average_pooling_nwc_f32 failed!");
67 
68   // Create Workspace pointer, which we will align and pad with 16 bytes
69   size_t xnnpack_buffer_padding = 16;
70   std::vector<char> workspace_vector(workspace_size + workspace_alignment + xnnpack_buffer_padding);
71   void* maybe_aligned_workspace = workspace_vector.data();
72   void* aligned_workspace =
73       (void*)((intptr_t)maybe_aligned_workspace + workspace_alignment - (intptr_t)maybe_aligned_workspace % workspace_alignment);
74 
75   const xnn_status setup_status = xnn_setup_global_average_pooling_nwc_f32(
76       global_average_pooling_op,
77       aligned_workspace,
78       input_padded_contig_nhwc.data_ptr<float>(),
79       output.data_ptr<float>());
80 
81   TORCH_CHECK(
82       xnn_status_success == setup_status,
83       "xnn_setup_global_average_pooling_nwc_f32 failed!");
84 
85   const xnn_status run_status =
86       xnn_run_operator(global_average_pooling_op, caffe2::pthreadpool_());
87 
88   TORCH_CHECK(
89       xnn_status_success == run_status,
90       "xnn_setup_global_average_pooling_nwc_f32 failed!");
91 
92   return output.to(input.suggest_memory_format());
93 }
94 
95 } // namespace at::native::xnnpack
96 
97 #endif /* USE_XNNPACK */
98