1 #ifdef USE_XNNPACK
2
3 #include <ATen/native/utils/Factory.h>
4 #include <ATen/native/xnnpack/Common.h>
5 #include <ATen/native/xnnpack/Engine.h>
6 #include <ATen/native/xnnpack/Pooling.h>
7
8 namespace at::native::xnnpack {
9
use_global_average_pool(const Tensor & input)10 bool use_global_average_pool(const Tensor& input) {
11 return xnnpack::available() && (1 <= input.ndimension()) &&
12 (input.device().is_cpu()) && (kFloat == input.scalar_type()) &&
13 !input.requires_grad() && true;
14 }
15
global_average_pool(const Tensor & input)16 Tensor global_average_pool(const Tensor& input) {
17 using namespace internal;
18
19 const Tensor input_padded_contig_nhwc =
20 mobile::allocate_padded_contiguous_if_needed(
21 input, MemoryFormat::ChannelsLast);
22
23 Tensor output = mobile::empty_with_tail_padding(
24 {
25 input_padded_contig_nhwc.size(Layout::Activation4D::batch),
26 input_padded_contig_nhwc.size(Layout::Activation4D::channels),
27 1,
28 1,
29 },
30 input_padded_contig_nhwc.options().dtype(),
31 MemoryFormat::ChannelsLast,
32 input_padded_contig_nhwc.opt_names());
33
34 xnn_operator_t global_average_pooling_op{};
35 const xnn_status create_status = xnn_create_global_average_pooling_nwc_f32(
36 -std::numeric_limits<float>::infinity(),
37 std::numeric_limits<float>::infinity(),
38 0 /* flags */,
39 &global_average_pooling_op);
40
41 TORCH_CHECK(
42 xnn_status_success == create_status,
43 "xnn_create_global_average_pooling_nwc_f32 failed!");
44
45 Operator global_avg_pool_scoped_op(global_average_pooling_op);
46
47 size_t workspace_size = 0;
48 size_t workspace_alignment = 0;
49
50 const xnn_status reshape_status = xnn_reshape_global_average_pooling_nwc_f32(
51 global_average_pooling_op,
52 input_padded_contig_nhwc.size(Layout::Activation4D::batch), // batch_size
53 input_padded_contig_nhwc.size(Layout::Activation4D::width) *
54 input_padded_contig_nhwc.size(Layout::Activation4D::height), // width
55 input_padded_contig_nhwc.size(Layout::Activation4D::channels), // channels
56 input_padded_contig_nhwc.size(
57 Layout::Activation4D::channels), // input stride
58 input_padded_contig_nhwc.size(
59 Layout::Activation4D::channels), // output stride
60 &workspace_size, // workspace_size
61 &workspace_alignment, // workspace_alignment
62 caffe2::pthreadpool_());
63
64 TORCH_CHECK(
65 xnn_status_success == reshape_status,
66 "xnn_reshape_global_average_pooling_nwc_f32 failed!");
67
68 // Create Workspace pointer, which we will align and pad with 16 bytes
69 size_t xnnpack_buffer_padding = 16;
70 std::vector<char> workspace_vector(workspace_size + workspace_alignment + xnnpack_buffer_padding);
71 void* maybe_aligned_workspace = workspace_vector.data();
72 void* aligned_workspace =
73 (void*)((intptr_t)maybe_aligned_workspace + workspace_alignment - (intptr_t)maybe_aligned_workspace % workspace_alignment);
74
75 const xnn_status setup_status = xnn_setup_global_average_pooling_nwc_f32(
76 global_average_pooling_op,
77 aligned_workspace,
78 input_padded_contig_nhwc.data_ptr<float>(),
79 output.data_ptr<float>());
80
81 TORCH_CHECK(
82 xnn_status_success == setup_status,
83 "xnn_setup_global_average_pooling_nwc_f32 failed!");
84
85 const xnn_status run_status =
86 xnn_run_operator(global_average_pooling_op, caffe2::pthreadpool_());
87
88 TORCH_CHECK(
89 xnn_status_success == run_status,
90 "xnn_setup_global_average_pooling_nwc_f32 failed!");
91
92 return output.to(input.suggest_memory_format());
93 }
94
95 } // namespace at::native::xnnpack
96
97 #endif /* USE_XNNPACK */
98