xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cudnn/ConvShared.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/Tensor.h>
3 
4 #include <ATen/cudnn/Descriptors.h>
5 #include <ATen/cudnn/Types.h>
6 #include <ATen/cudnn/cudnn-wrapper.h>
7 #include <ATen/native/ConvUtils.h>
8 
9 namespace at {
10 namespace native {
11 
12 // ---------------------------------------------------------------------
13 //
14 // Helper classes
15 //
16 // ---------------------------------------------------------------------
17 
18 // This POD struct is used to let us easily compute hashes of the
19 // parameters
20 struct ConvolutionParams {
21   c10::DeviceIndex device_id;
22   cudnnDataType_t dataType;
23   int input_size[2 + max_dim];
24   uint8_t input_dim;
25   at::MemoryFormat memory_format;
26   int weight_size[2 + max_dim];
27   int padding[max_dim];
28   int stride[max_dim];
29   int dilation[max_dim];
30   int64_t groups;
31   bool deterministic;
32   bool allow_tf32;
33   // NB: transposed purposely omitted: transposed just swaps
34   // forward and backward, so you can reuse the benchmark entry,
35 };
36 
37 std::ostream& operator<<(std::ostream& out, const ConvolutionParams& params);
38 
39 // NB: This can't be a constructor, because then ConvolutionParams
40 // would not be a POD anymore.
41 // TODO: Use TensorGeometry here instead of the entire Tensor, which we
42 // don't actually need.  (OTOH: We can always pass in
43 // grad_input/grad_output, so this is not very pressing)
44 void setConvolutionParams(
45     ConvolutionParams* params,
46     const at::Tensor& input,
47     const at::Tensor& weight,
48     IntArrayRef padding,
49     IntArrayRef stride,
50     IntArrayRef dilation,
51     int64_t groups,
52     bool deterministic,
53     bool allow_tf32,
54     at::MemoryFormat memory_format);
55 
56 std::string repro_from_args(const ConvolutionParams& args);
57 
58 // ---------------------------------------------------------------------
59 //
60 // Raw functions
61 //
62 // ---------------------------------------------------------------------
63 
64 void raw_cudnn_convolution_forward_out(
65     const Tensor& output,
66     const Tensor& input,
67     const Tensor& weight,
68     IntArrayRef padding,
69     IntArrayRef stride,
70     IntArrayRef dilation,
71     int64_t groups,
72     bool benchmark,
73     bool deterministic,
74     bool allow_tf32);
75 
76 void raw_cudnn_convolution_backward_input_out(
77     const at::Tensor& grad_input,
78     const at::Tensor& grad_output,
79     const at::Tensor& weight,
80     IntArrayRef padding,
81     IntArrayRef stride,
82     IntArrayRef dilation,
83     int64_t groups,
84     bool benchmark,
85     bool deterministic,
86     bool allow_tf32);
87 
88 void raw_cudnn_convolution_backward_weight_out(
89     const Tensor& grad_weight,
90     const Tensor& grad_output,
91     const Tensor& input,
92     IntArrayRef padding,
93     IntArrayRef stride,
94     IntArrayRef dilation,
95     int64_t groups,
96     bool benchmark,
97     bool deterministic,
98     bool allow_tf32);
99 
100 void raw_cudnn_convolution_add_relu_out(
101     const Tensor& output,
102     const Tensor& input,
103     const Tensor& weight,
104     const Tensor& z,
105     float alpha,
106     const Tensor& bias,
107     IntArrayRef stride,
108     IntArrayRef padding,
109     IntArrayRef dilation,
110     int64_t groups,
111     bool benchmark,
112     bool deterministic,
113     bool allow_tf32);
114 
115 void raw_cudnn_convolution_add_relu_fallback_out(
116     const Tensor& output,
117     const Tensor& input,
118     const Tensor& weight,
119     const Tensor& z,
120     float alpha,
121     const Tensor& bias,
122     IntArrayRef stride,
123     IntArrayRef padding,
124     IntArrayRef dilation,
125     int64_t groups,
126     bool benchmark,
127     bool deterministic,
128     bool allow_tf32);
129 
130 #if AT_CUDNN_ENABLED()
131 
132 // v7 functions are preserved here to allow for runtime switching to v7
133 // (e.g., TORCH_CUDNN_V8_API_DISABLED=1).
134 // Note that v7 forward/backward out can have different behavior from the v8
135 // versions, as v7 explicitly splits large tensors as a 32-bit indexing
136 // workaround whereas v8 expects cuDNN to handle large tensors.
137 void raw_cudnn_convolution_forward_out_v7(
138     const Tensor& output,
139     const Tensor& input,
140     const Tensor& weight,
141     IntArrayRef padding,
142     IntArrayRef stride,
143     IntArrayRef dilation,
144     int64_t groups,
145     bool benchmark,
146     bool deterministic,
147     bool allow_tf32);
148 
149 void raw_cudnn_convolution_backward_input_out_v7(
150     const at::Tensor& grad_input,
151     const at::Tensor& grad_output,
152     const at::Tensor& weight,
153     IntArrayRef padding,
154     IntArrayRef stride,
155     IntArrayRef dilation,
156     int64_t groups,
157     bool benchmark,
158     bool deterministic,
159     bool allow_tf32);
160 
161 void raw_cudnn_convolution_backward_weight_out_v7(
162     const Tensor& grad_weight,
163     const Tensor& grad_output,
164     const Tensor& input,
165     IntArrayRef padding,
166     IntArrayRef stride,
167     IntArrayRef dilation,
168     int64_t groups,
169     bool benchmark,
170     bool deterministic,
171     bool allow_tf32);
172 
173 void raw_cudnn_convolution_add_relu_out_v7(
174     const Tensor& output,
175     const Tensor& input,
176     const Tensor& weight,
177     const Tensor& z,
178     float alpha,
179     const Tensor& bias,
180     IntArrayRef stride,
181     IntArrayRef padding,
182     IntArrayRef dilation,
183     int64_t groups,
184     bool benchmark,
185     bool deterministic,
186     bool allow_tf32);
187 #endif
188 } // namespace native
189 } // namespace at
190