1 #pragma once
2 /*
3 This file contains some of the auxiliary functions used by both Conv.cpp & Linear.cpp (introduced in a later PR)
4 */
5
6 #ifdef USE_CUDA
7 #include <ATen/cuda/CUDAConfig.h> // for the definition of AT_CUDNN_ENABLED
8
9 #if AT_CUDNN_ENABLED()
10
11 #include <ATen/cudnn/Types.h>
12 #include <ATen/Tensor.h>
13 #include <ATen/native/quantized/PackedParams.h>
14 #include <c10/core/QScheme.h>
15 #include <c10/util/ArrayRef.h>
16
17 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wsuggest-override")
18 #include <cudnn_frontend.h>
19 C10_DIAGNOSTIC_POP()
20
21 #ifndef AT_PER_OPERATOR_HEADERS
22 #include <ATen/Functions.h>
23 #else
24 #include <ATen/ops/empty.h>
25 #endif
26
27 struct PackedLinearWeightCudnn : public LinearPackedParamsBase {
PackedLinearWeightCudnnPackedLinearWeightCudnn28 PackedLinearWeightCudnn(
29 at::Tensor orig_weight,
30 std::optional<at::Tensor> bias,
31 c10::QScheme q_scheme)
32 : orig_weight(std::move(orig_weight)),
33 bias_(std::move(bias)),
34 q_scheme(std::move(q_scheme)) {}
35
36 at::Tensor apply(
37 at::Tensor input,
38 double output_scale,
39 int64_t output_zero_point) override;
40 at::Tensor apply_relu(
41 at::Tensor input,
42 double output_scale,
43 int64_t output_zero_point) override;
44
45 at::Tensor apply_dynamic(at::Tensor input, bool reduce_range = false) override {
46 throw std::runtime_error(
47 "apply_dynamic is not implemented for this packed "
48 "parameter type");
49 }
50 at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range = false) override {
51 throw std::runtime_error(
52 "apply_dynamic_relu is not implemented for this packed "
53 "parameter type");
54 }
55
56 std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() override;
57
biasPackedLinearWeightCudnn58 std::optional<at::Tensor> bias() override {
59 return bias_;
60 }
61
62 static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
63 at::Tensor weight,
64 std::optional<at::Tensor> bias);
65
66 private:
67 at::Tensor orig_weight;
68 std::optional<at::Tensor> bias_;
69 c10::QScheme q_scheme;
70
71 template <bool ReluFused>
72 at::Tensor apply_impl(
73 const at::Tensor& input,
74 double output_scale,
75 int64_t output_zero_point);
76
77 template <bool ReluFused>
78 void apply_impl_helper(
79 const at::Tensor& quantized_output,
80 const at::Tensor& input,
81 double output_scale);
82 };
83
84 template <int kSpatialDim = 2>
85 struct PackedConvWeightCudnn : public ConvPackedParamsBase<kSpatialDim> {
PackedConvWeightCudnnPackedConvWeightCudnn86 PackedConvWeightCudnn(
87 at::Tensor orig_weight,
88 std::optional<at::Tensor> bias,
89 torch::List<int64_t> stride,
90 torch::List<int64_t> padding,
91 torch::List<int64_t> output_padding,
92 torch::List<int64_t> dilation,
93 int64_t groups,
94 bool transpose,
95 c10::QScheme q_scheme,
96 int64_t output_channels)
97 : maybe_padded_weight_(std::move(orig_weight)),
98 bias_(std::move(bias)),
99 stride_(stride),
100 padding_(padding),
101 output_padding_(output_padding),
102 dilation_(dilation),
103 groups_(groups),
104 transpose_(transpose),
105 q_scheme_(q_scheme),
106 num_unpadded_output_channels_(output_channels) {} // output channels needs to be stored when we have to pad this dimension
107
108 at::Tensor apply(
109 const at::Tensor& input,
110 double output_scale,
111 int64_t output_zero_point) override;
112
113 at::Tensor apply_relu(
114 const at::Tensor& input,
115 double output_scale,
116 int64_t output_zero_point) override;
117
apply_dynamicPackedConvWeightCudnn118 at::Tensor apply_dynamic(
119 const at::Tensor& input,
120 bool reduce_range) override {
121 TORCH_CHECK(false, "apply_dynamic is currently not reported");
122 }
123
apply_dynamic_reluPackedConvWeightCudnn124 at::Tensor apply_dynamic_relu(
125 const at::Tensor& input,
126 bool reduce_range) {
127 TORCH_CHECK(false, "apply_dynamic_relu is currently not reported");
128 }
129
130 std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() override;
131
132 static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> prepack(
133 at::Tensor weight,
134 std::optional<at::Tensor> bias,
135 torch::List<int64_t> stride,
136 torch::List<int64_t> padding,
137 torch::List<int64_t> output_padding,
138 torch::List<int64_t> dilation,
139 int64_t groups,
140 bool transpose);
141
142 const float* GetBiasData(at::Tensor* bias);
143
stridePackedConvWeightCudnn144 torch::List<int64_t> stride() const override {
145 return stride_;
146 }
147
paddingPackedConvWeightCudnn148 torch::List<int64_t> padding() const override {
149 return padding_;
150 }
151
output_paddingPackedConvWeightCudnn152 torch::List<int64_t> output_padding() const override {
153 return output_padding_;
154 }
155
dilationPackedConvWeightCudnn156 torch::List<int64_t> dilation() const override {
157 return dilation_;
158 }
159
groupsPackedConvWeightCudnn160 int64_t groups() const override {
161 return groups_;
162 }
163
transposePackedConvWeightCudnn164 bool transpose() const override {
165 return transpose_;
166 }
167
168 private:
169 // cudnn v8.4.0 expects conv2d's int8 weight tensor's input and output channels to be a multiple of 4. if it is not
170 // we need to explicitly pad it to a multiple of 4 ourselves as cudnn does not currently support padding, hence the naming
171 // convention "maybe"_padded_weight.
172 // TODO: when and if cudnn enables padding in their operators, we can remove padding on our end and rename this to orig_weight_
173 at::Tensor maybe_padded_weight_;
174 std::optional<at::Tensor> bias_;
175 torch::List<int64_t> stride_;
176 torch::List<int64_t> padding_;
177 torch::List<int64_t> output_padding_;
178 torch::List<int64_t> dilation_;
179 int64_t groups_;
180 bool transpose_;
181 c10::QScheme q_scheme_;
182 int64_t num_unpadded_output_channels_;
183
184 template <bool ReluFused>
185 at::Tensor apply_impl(
186 const at::Tensor& input,
187 double output_scale,
188 int64_t output_zero_point);
189
190 template <bool ReluFused>
191 void apply_impl_helper(
192 const at::Tensor& quantized_output,
193 const at::Tensor& input,
194 double output_scale);
195 };
196
197 namespace cudnn_utils {
198
199 // TODO: we can remove this function when cuDNN enables pass by value support for
200 // pointwise multiplication operations. the only reason why we need this right now is
201 // we use broadcasting scalar multiplication in conv, linear, and add ops, and cuDNN requires
202 // the scalar to be a scalar tensor with the same number of dimensions (num_dim) as the tensor we're multiplying to
getRequantMultiplierTensor(double requant_multiplier,uint8_t num_dim)203 inline at::Tensor getRequantMultiplierTensor(double requant_multiplier, uint8_t num_dim) {
204 at::SmallVector<int64_t, 4> requantize_multiplier_tensor_size(num_dim, 1);
205 at::Tensor requantize_multiplier_tensor = at::empty(requantize_multiplier_tensor_size, at::device(at::kCUDA).dtype(at::kFloat));
206 requantize_multiplier_tensor.fill_(requant_multiplier);
207 return requantize_multiplier_tensor;
208 }
209
getAlignment(const at::Tensor & t)210 inline uint8_t getAlignment(const at::Tensor &t) {
211 // alignment are in bytes
212 uint8_t alignment = 1;
213 uintptr_t address = reinterpret_cast<uintptr_t>(t.data_ptr());
214 for (; alignment < 16; alignment *= 2) {
215 if (address % (alignment * 2)) {
216 return alignment;
217 }
218 }
219 return alignment;
220 }
221
222 // For the two getTensorDescriptor functions, there is a is_virtual parameter. This parameter is used to set the cudnn
223 // tensor as virtual or not. Setting the tensor as virtual is expected to have some performance benefits as the cudnn
224 // backend cudnn will no longer directly save to the tensor, allowing us to omit this tensor from the variant pack.
225 // See third_party/cudnn_frontend/samples/fusion_sample.cpp for other examples
226
227 inline cudnn_frontend::Tensor getTensorDescriptor(const at::Tensor &t, int64_t id, uint8_t alignment, bool is_virtual = false) {
228 auto shape = t.sizes();
229 auto strides = t.strides();
230 if (is_virtual) {
231 return cudnn_frontend::TensorBuilder()
232 .setDim(shape.size(), shape.data())
233 .setStrides(strides.size(), strides.data())
234 .setId(id)
235 .setAlignment(alignment)
236 .setVirtual()
237 .setDataType(at::native::getCudnnDataType(t))
238 .build();
239 }
240 return cudnn_frontend::TensorBuilder()
241 .setDim(shape.size(), shape.data())
242 .setStrides(strides.size(), strides.data())
243 .setId(id)
244 .setAlignment(alignment)
245 .setDataType(at::native::getCudnnDataType(t))
246 .build();
247 }
248
249 inline cudnn_frontend::Tensor getTensorDescriptor(const c10::IntArrayRef& shape, const c10::IntArrayRef& strides, cudnnDataType_t cudnn_dtype, int64_t id, uint8_t alignment, bool is_virtual = false) {
250 if (is_virtual) {
251 return cudnn_frontend::TensorBuilder()
252 .setDim(shape.size(), shape.data())
253 .setStrides(strides.size(), strides.data())
254 .setId(id)
255 .setAlignment(alignment)
256 .setVirtual()
257 .setDataType(cudnn_dtype)
258 .build();
259 }
260 return cudnn_frontend::TensorBuilder()
261 .setDim(shape.size(), shape.data())
262 .setStrides(strides.size(), strides.data())
263 .setId(id)
264 .setAlignment(alignment)
265 .setDataType(cudnn_dtype)
266 .build();
267 }
268
269 // TODO: there is a table from input dtype to operator dtype, we can derive
270 // the operator dtype based on input dtype
getPointWiseMulDescriptor(cudnnDataType_t dataType)271 inline cudnn_frontend::PointWiseDesc_v8 getPointWiseMulDescriptor(cudnnDataType_t dataType) {
272 return cudnn_frontend::PointWiseDescBuilder()
273 .setMode(cudnnPointwiseMode_t::CUDNN_POINTWISE_MUL)
274 .setMathPrecision(dataType)
275 .build();
276 }
277
278 // TODO: there is a table from input dtype to operator dtype, we can derive
279 // the operator dtype based on input dtype
getPointWiseAddDescriptor(cudnnDataType_t dataType)280 inline cudnn_frontend::PointWiseDesc_v8 getPointWiseAddDescriptor(cudnnDataType_t dataType) {
281 return cudnn_frontend::PointWiseDescBuilder()
282 .setMode(cudnnPointwiseMode_t::CUDNN_POINTWISE_ADD)
283 .setMathPrecision(dataType)
284 .build();
285 }
286
287 // TODO: there is a table from input dtype to operator dtype, we can derive
288 // the operator dtype based on input dtype
getPointWiseReluDescriptor(cudnnDataType_t dataType)289 inline cudnn_frontend::PointWiseDesc_v8 getPointWiseReluDescriptor(cudnnDataType_t dataType) {
290 return cudnn_frontend::PointWiseDescBuilder()
291 .setMode(cudnnPointwiseMode_t::CUDNN_POINTWISE_RELU_FWD)
292 .setMathPrecision(dataType)
293 .build();
294 }
295
296
filterEngineConfigs(cudnn_frontend::EngineConfigList & from,cudnn_frontend::EngineConfigList & to,bool deterministic,bool allow_tf32,c10::ScalarType scalar_type)297 inline void filterEngineConfigs(
298 cudnn_frontend::EngineConfigList &from,
299 cudnn_frontend::EngineConfigList &to,
300 bool deterministic, bool allow_tf32, c10::ScalarType scalar_type)
301 {
302 auto filter = [=](cudnnBackendDescriptor_t c) {
303 if (deterministic) {
304 if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_NONDETERMINISTIC>(c)) return true;
305 }
306 if (scalar_type == at::kFloat || scalar_type == at::kChar || !allow_tf32) {
307 if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) return true;
308 if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c)) return true;
309 }
310 return false;
311 };
312 cudnn_frontend::filter(from, to, filter);
313 }
314
315 } // cudnn_utils
316
317 #endif // AT_CUDNN_ENABLED
318 #endif // USE_CUDA
319