xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/operators/conv2d.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/tensorexpr/operators/misc.h>
4 #include <torch/csrc/jit/tensorexpr/tensor.h>
5 
6 namespace torch {
7 namespace jit {
8 namespace tensorexpr {
9 
10 // An API to compute 2D depthwise convolutions with bias.
11 TORCH_API Tensor conv2d_depthwise(
12     BufHandle input,
13     BufHandle weight,
14     BufHandle bias,
15     int stride,
16     int pad,
17     int groups);
18 
19 // An API to compute 2D depthwise convolutions without bias.
20 TORCH_API Tensor conv2d_depthwise(
21     BufHandle input,
22     BufHandle weight,
23     int stride,
24     int pad,
25     int groups);
26 
27 TORCH_API Tensor conv2d_depthwise(
28     BufHandle input,
29     BufHandle weight,
30     BufHandle bias,
31     ExprHandle N,
32     ExprHandle C,
33     ExprHandle H,
34     ExprHandle W,
35     ExprHandle K,
36     ExprHandle CperG,
37     ExprHandle R,
38     ExprHandle S,
39     ExprHandle stride,
40     ExprHandle pad,
41     ExprHandle groups);
42 
43 TORCH_API Tensor conv2d_depthwise(
44     BufHandle input,
45     BufHandle weight,
46     ExprHandle N,
47     ExprHandle C,
48     ExprHandle H,
49     ExprHandle W,
50     ExprHandle K,
51     ExprHandle CperG,
52     ExprHandle R,
53     ExprHandle S,
54     ExprHandle stride,
55     ExprHandle pad,
56     ExprHandle groups);
57 
58 bool conv2dIsSupported(
59     const TensorInfo& input,
60     const TensorInfo& weight,
61     const TensorInfo& bias,
62     const std::vector<int64_t>& stride,
63     const std::vector<int64_t>& pad,
64     const std::vector<int64_t>& dilation,
65     int64_t groups);
66 bool mkldnnPrepackedConvIsSupported(
67     const TensorInfo& input,
68     const TensorInfo& weight,
69     const std::vector<int64_t>& stride,
70     const std::vector<int64_t>& pad,
71     const std::vector<int64_t>& dilation,
72     int64_t groups);
73 Tensor computeConv2d(
74     const std::vector<ArgValue>& inputs,
75     const std::vector<ExprHandle>& outputShape,
76     const std::vector<ExprHandle>& outputStrides,
77     const std::optional<ScalarType>& outputType,
78     at::Device device);
79 Tensor computeConv1d(
80     const std::vector<ArgValue>& inputs,
81     const std::vector<ExprHandle>& outputShape,
82     const std::vector<ExprHandle>& outputStrides,
83     const std::optional<ScalarType>& outputType,
84     at::Device device);
85 Tensor computePrepackedConv2dClampRun(
86     const std::vector<ArgValue>& inputs,
87     const std::vector<ExprHandle>& outputShape,
88     const std::vector<ExprHandle>& outputStrides,
89     const std::optional<ScalarType>& outputType,
90     at::Device device);
91 Tensor computePrepackedLinearClampRun(
92     const std::vector<ArgValue>& inputs,
93     const std::vector<ExprHandle>& outputShape,
94     const std::vector<ExprHandle>& outputStrides,
95     const std::optional<ScalarType>& outputType,
96     at::Device device);
97 Tensor computeMkldnnPrepackedConvRun(
98     const std::vector<ArgValue>& inputs,
99     const std::vector<ExprHandle>& outputShape,
100     const std::vector<ExprHandle>& outputStrides,
101     const std::optional<ScalarType>& outputType,
102     at::Device device);
103 } // namespace tensorexpr
104 } // namespace jit
105 } // namespace torch
106