xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/xnnpack/Pooling.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifdef USE_XNNPACK
4 
5 #include <ATen/Tensor.h>
6 
7 namespace at::native::xnnpack::internal::pooling {
8 
9 struct Parameters final {
10 
11   std::array<int64_t, 2> kernel;
12   std::array<int64_t, 2> padding;
13   std::array<int64_t, 2> stride;
14   std::array<int64_t, 2> dilation;
15 
Parametersfinal16   explicit Parameters(
17       const IntArrayRef kernel_,
18       const IntArrayRef padding_,
19       const IntArrayRef stride_,
20       const IntArrayRef dilation_)
21   : kernel(normalize(kernel_)),
22     padding(normalize(padding_)),
23     stride(normalize(stride_)),
24     dilation(normalize(dilation_)) {
25   }
26 
27 private:
normalizefinal28   static std::array<int64_t, 2> normalize(const IntArrayRef parameter) {
29     TORCH_INTERNAL_ASSERT(
30         !parameter.empty(),
31         "Invalid usage!  Reason: normalize() was called on an empty parameter.");
32 
33     return std::array<int64_t, 2>{
34       parameter[0],
35       (2 == parameter.size()) ? parameter[1] : parameter[0],
36     };
37   }
38 };
39 
40 } // namespace at::native::xnnpack::internal::pooling
41 
42 #endif /* USE_XNNPACK */
43