xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/xnnpack/Convolution.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifdef USE_XNNPACK
4 
5 #include <ATen/Tensor.h>
6 #include <ATen/native/xnnpack/Common.h>
7 #include <ATen/native/xnnpack/OpContext.h>
8 
9 namespace at::native::xnnpack {
10 namespace internal::convolution2d {
11 
12 c10::intrusive_ptr<xnnpack::Conv2dOpContext>
13     createConv2dClampPrePackOpContext(
14         Tensor weight,
15         std::optional<Tensor> bias,
16         std::vector<int64_t> stride,
17         std::vector<int64_t> padding,
18         std::vector<int64_t> dilation,
19         int64_t groups,
20         const std::optional<Scalar>& output_min,
21         const std::optional<Scalar>& output_max);
22 
23 c10::intrusive_ptr<xnnpack::TransposeConv2dOpContext>
24     createConv2dTransposeClampPrePackOpContext(
25         Tensor weight,
26         std::optional<Tensor> bias,
27         std::vector<int64_t> stride,
28         std::vector<int64_t> padding,
29         std::vector<int64_t> output_padding,
30         std::vector<int64_t> dilation,
31         int64_t groups,
32         const std::optional<Scalar>& output_min,
33         const std::optional<Scalar>& output_max);
34 
35 Tensor conv2d_clamp_run(
36     const Tensor& input,
37     const c10::intrusive_ptr<xnnpack::Conv2dOpContext>& op_context);
38 
39 IValue
40 unpack_prepacked_sizes_conv2d(const IValue& ivalue);
41 
42 Tensor conv2d_transpose_clamp_run(
43     const Tensor& input,
44     const c10::intrusive_ptr<xnnpack::TransposeConv2dOpContext>& op_context);
45 
46 ContextConv2D create(
47     const Tensor& weight,
48     const std::optional<Tensor>& bias,
49     const IntArrayRef padding,
50     const IntArrayRef output_padding,
51     const IntArrayRef stride,
52     const IntArrayRef dilation,
53     const int64_t groups,
54     const bool transposed,
55     const float output_min,
56     const float output_max);
57 
58 Tensor run(ContextConv2D& context, const Tensor& input);
59 
60 } // namespace internal::convolution2d
61 
62 Tensor convolution2d(
63     const Tensor& input,
64     const Tensor& weight,
65     const Tensor& bias,
66     const IntArrayRef padding,
67     const IntArrayRef stride,
68     const IntArrayRef dilation,
69     const int64_t groups);
70 } // namespace at::native::xnnpack
71 
72 #endif /* USE_XNNPACK */
73