xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cudnn/utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 /*
3 This file contains some of the auxiliary functions used by both Conv.cpp & Linear.cpp (introduced in a later PR)
4 */
5 
6 #ifdef USE_CUDA
7 #include <ATen/cuda/CUDAConfig.h>  // for the definition of AT_CUDNN_ENABLED
8 
9 #if AT_CUDNN_ENABLED()
10 
11 #include <ATen/cudnn/Types.h>
12 #include <ATen/Tensor.h>
13 #include <ATen/native/quantized/PackedParams.h>
14 #include <c10/core/QScheme.h>
15 #include <c10/util/ArrayRef.h>
16 
17 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wsuggest-override")
18 #include <cudnn_frontend.h>
19 C10_DIAGNOSTIC_POP()
20 
21 #ifndef AT_PER_OPERATOR_HEADERS
22 #include <ATen/Functions.h>
23 #else
24 #include <ATen/ops/empty.h>
25 #endif
26 
27 struct PackedLinearWeightCudnn : public LinearPackedParamsBase {
PackedLinearWeightCudnnPackedLinearWeightCudnn28   PackedLinearWeightCudnn(
29       at::Tensor orig_weight,
30       std::optional<at::Tensor> bias,
31       c10::QScheme q_scheme)
32       : orig_weight(std::move(orig_weight)),
33         bias_(std::move(bias)),
34         q_scheme(std::move(q_scheme)) {}
35 
36   at::Tensor apply(
37       at::Tensor input,
38       double output_scale,
39       int64_t output_zero_point) override;
40   at::Tensor apply_relu(
41       at::Tensor input,
42       double output_scale,
43       int64_t output_zero_point) override;
44 
45   at::Tensor apply_dynamic(at::Tensor input, bool reduce_range = false) override {
46     throw std::runtime_error(
47     "apply_dynamic is not implemented for this packed "
48     "parameter type");
49   }
50   at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range = false) override {
51     throw std::runtime_error(
52     "apply_dynamic_relu is not implemented for this packed "
53     "parameter type");
54   }
55 
56   std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() override;
57 
biasPackedLinearWeightCudnn58   std::optional<at::Tensor> bias() override {
59     return bias_;
60   }
61 
62   static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
63       at::Tensor weight,
64       std::optional<at::Tensor> bias);
65 
66  private:
67   at::Tensor orig_weight;
68   std::optional<at::Tensor> bias_;
69   c10::QScheme q_scheme;
70 
71   template <bool ReluFused>
72   at::Tensor apply_impl(
73       const at::Tensor& input,
74       double output_scale,
75       int64_t output_zero_point);
76 
77   template <bool ReluFused>
78   void apply_impl_helper(
79       const at::Tensor& quantized_output,
80       const at::Tensor& input,
81       double output_scale);
82 };
83 
84 template <int kSpatialDim = 2>
85 struct PackedConvWeightCudnn : public ConvPackedParamsBase<kSpatialDim> {
PackedConvWeightCudnnPackedConvWeightCudnn86   PackedConvWeightCudnn(
87       at::Tensor orig_weight,
88       std::optional<at::Tensor> bias,
89       torch::List<int64_t> stride,
90       torch::List<int64_t> padding,
91       torch::List<int64_t> output_padding,
92       torch::List<int64_t> dilation,
93       int64_t groups,
94       bool transpose,
95       c10::QScheme q_scheme,
96       int64_t output_channels)
97       : maybe_padded_weight_(std::move(orig_weight)),
98         bias_(std::move(bias)),
99         stride_(stride),
100         padding_(padding),
101         output_padding_(output_padding),
102         dilation_(dilation),
103         groups_(groups),
104         transpose_(transpose),
105         q_scheme_(q_scheme),
106         num_unpadded_output_channels_(output_channels) {} // output channels needs to be stored when we have to pad this dimension
107 
108   at::Tensor apply(
109       const at::Tensor& input,
110       double output_scale,
111       int64_t output_zero_point) override;
112 
113   at::Tensor apply_relu(
114       const at::Tensor& input,
115       double output_scale,
116       int64_t output_zero_point) override;
117 
apply_dynamicPackedConvWeightCudnn118   at::Tensor apply_dynamic(
119     const at::Tensor& input,
120     bool reduce_range) override {
121     TORCH_CHECK(false, "apply_dynamic is currently not reported");
122   }
123 
apply_dynamic_reluPackedConvWeightCudnn124   at::Tensor apply_dynamic_relu(
125     const at::Tensor& input,
126     bool reduce_range) {
127     TORCH_CHECK(false, "apply_dynamic_relu is currently not reported");
128   }
129 
130   std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() override;
131 
132   static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> prepack(
133       at::Tensor weight,
134       std::optional<at::Tensor> bias,
135       torch::List<int64_t> stride,
136       torch::List<int64_t> padding,
137       torch::List<int64_t> output_padding,
138       torch::List<int64_t> dilation,
139       int64_t groups,
140       bool transpose);
141 
142   const float* GetBiasData(at::Tensor* bias);
143 
stridePackedConvWeightCudnn144   torch::List<int64_t> stride() const override {
145     return stride_;
146   }
147 
paddingPackedConvWeightCudnn148   torch::List<int64_t> padding() const override {
149     return padding_;
150   }
151 
output_paddingPackedConvWeightCudnn152   torch::List<int64_t> output_padding() const override {
153     return output_padding_;
154   }
155 
dilationPackedConvWeightCudnn156   torch::List<int64_t> dilation() const override {
157     return dilation_;
158   }
159 
groupsPackedConvWeightCudnn160   int64_t groups() const override {
161     return groups_;
162   }
163 
transposePackedConvWeightCudnn164   bool transpose() const override {
165     return transpose_;
166   }
167 
168  private:
169   // cudnn v8.4.0 expects conv2d's int8 weight tensor's input and output channels to be a multiple of 4. if it is not
170   // we need to explicitly pad it to a multiple of 4 ourselves as cudnn does not currently support padding, hence the naming
171   // convention "maybe"_padded_weight.
172   // TODO: when and if cudnn enables padding in their operators, we can remove padding on our end and rename this to orig_weight_
173   at::Tensor maybe_padded_weight_;
174   std::optional<at::Tensor> bias_;
175   torch::List<int64_t> stride_;
176   torch::List<int64_t> padding_;
177   torch::List<int64_t> output_padding_;
178   torch::List<int64_t> dilation_;
179   int64_t groups_;
180   bool transpose_;
181   c10::QScheme q_scheme_;
182   int64_t num_unpadded_output_channels_;
183 
184   template <bool ReluFused>
185   at::Tensor apply_impl(
186       const at::Tensor& input,
187       double output_scale,
188       int64_t output_zero_point);
189 
190   template <bool ReluFused>
191   void apply_impl_helper(
192       const at::Tensor& quantized_output,
193       const at::Tensor& input,
194       double output_scale);
195 };
196 
197 namespace cudnn_utils {
198 
199 // TODO: we can remove this function when cuDNN enables pass by value support for
200 // pointwise multiplication operations. the only reason why we need this right now is
201 // we use broadcasting scalar multiplication in conv, linear, and add ops, and cuDNN requires
202 // the scalar to be a scalar tensor with the same number of dimensions (num_dim) as the tensor we're multiplying to
getRequantMultiplierTensor(double requant_multiplier,uint8_t num_dim)203 inline at::Tensor getRequantMultiplierTensor(double requant_multiplier, uint8_t num_dim) {
204   at::SmallVector<int64_t, 4> requantize_multiplier_tensor_size(num_dim, 1);
205   at::Tensor requantize_multiplier_tensor = at::empty(requantize_multiplier_tensor_size, at::device(at::kCUDA).dtype(at::kFloat));
206   requantize_multiplier_tensor.fill_(requant_multiplier);
207   return requantize_multiplier_tensor;
208 }
209 
getAlignment(const at::Tensor & t)210 inline uint8_t getAlignment(const at::Tensor &t) {
211   // alignment are in bytes
212   uint8_t alignment = 1;
213   uintptr_t address = reinterpret_cast<uintptr_t>(t.data_ptr());
214   for (; alignment < 16; alignment *= 2) {
215     if (address % (alignment * 2)) {
216       return alignment;
217     }
218   }
219   return alignment;
220 }
221 
222 // For the two getTensorDescriptor functions, there is a is_virtual parameter. This parameter is used to set the cudnn
223 // tensor as virtual or not. Setting the tensor as virtual is expected to have some performance benefits as the cudnn
224 // backend cudnn will no longer directly save to the tensor, allowing us to omit this tensor from the variant pack.
225 // See third_party/cudnn_frontend/samples/fusion_sample.cpp for other examples
226 
227 inline cudnn_frontend::Tensor getTensorDescriptor(const at::Tensor &t, int64_t id, uint8_t alignment, bool is_virtual = false) {
228   auto shape = t.sizes();
229   auto strides = t.strides();
230   if (is_virtual) {
231     return cudnn_frontend::TensorBuilder()
232       .setDim(shape.size(), shape.data())
233       .setStrides(strides.size(), strides.data())
234       .setId(id)
235       .setAlignment(alignment)
236       .setVirtual()
237       .setDataType(at::native::getCudnnDataType(t))
238       .build();
239   }
240   return cudnn_frontend::TensorBuilder()
241     .setDim(shape.size(), shape.data())
242     .setStrides(strides.size(), strides.data())
243     .setId(id)
244     .setAlignment(alignment)
245     .setDataType(at::native::getCudnnDataType(t))
246     .build();
247 }
248 
249 inline cudnn_frontend::Tensor getTensorDescriptor(const c10::IntArrayRef& shape, const c10::IntArrayRef& strides, cudnnDataType_t cudnn_dtype, int64_t id, uint8_t alignment, bool is_virtual = false) {
250   if (is_virtual) {
251     return cudnn_frontend::TensorBuilder()
252       .setDim(shape.size(), shape.data())
253       .setStrides(strides.size(), strides.data())
254       .setId(id)
255       .setAlignment(alignment)
256       .setVirtual()
257       .setDataType(cudnn_dtype)
258       .build();
259   }
260   return cudnn_frontend::TensorBuilder()
261     .setDim(shape.size(), shape.data())
262     .setStrides(strides.size(), strides.data())
263     .setId(id)
264     .setAlignment(alignment)
265     .setDataType(cudnn_dtype)
266     .build();
267 }
268 
269 // TODO: there is a table from input dtype to operator dtype, we can derive
270 // the operator dtype based on input dtype
getPointWiseMulDescriptor(cudnnDataType_t dataType)271 inline cudnn_frontend::PointWiseDesc_v8 getPointWiseMulDescriptor(cudnnDataType_t dataType) {
272   return cudnn_frontend::PointWiseDescBuilder()
273     .setMode(cudnnPointwiseMode_t::CUDNN_POINTWISE_MUL)
274     .setMathPrecision(dataType)
275     .build();
276 }
277 
278 // TODO: there is a table from input dtype to operator dtype, we can derive
279 // the operator dtype based on input dtype
getPointWiseAddDescriptor(cudnnDataType_t dataType)280 inline cudnn_frontend::PointWiseDesc_v8 getPointWiseAddDescriptor(cudnnDataType_t dataType) {
281   return cudnn_frontend::PointWiseDescBuilder()
282     .setMode(cudnnPointwiseMode_t::CUDNN_POINTWISE_ADD)
283     .setMathPrecision(dataType)
284     .build();
285 }
286 
287 // TODO: there is a table from input dtype to operator dtype, we can derive
288 // the operator dtype based on input dtype
getPointWiseReluDescriptor(cudnnDataType_t dataType)289 inline cudnn_frontend::PointWiseDesc_v8 getPointWiseReluDescriptor(cudnnDataType_t dataType) {
290   return cudnn_frontend::PointWiseDescBuilder()
291     .setMode(cudnnPointwiseMode_t::CUDNN_POINTWISE_RELU_FWD)
292     .setMathPrecision(dataType)
293     .build();
294 }
295 
296 
filterEngineConfigs(cudnn_frontend::EngineConfigList & from,cudnn_frontend::EngineConfigList & to,bool deterministic,bool allow_tf32,c10::ScalarType scalar_type)297 inline void filterEngineConfigs(
298   cudnn_frontend::EngineConfigList &from,
299   cudnn_frontend::EngineConfigList &to,
300   bool deterministic, bool allow_tf32, c10::ScalarType scalar_type)
301 {
302   auto filter = [=](cudnnBackendDescriptor_t c) {
303     if (deterministic) {
304       if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_NONDETERMINISTIC>(c)) return true;
305     }
306     if (scalar_type == at::kFloat || scalar_type == at::kChar || !allow_tf32) {
307       if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) return true;
308       if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c)) return true;
309     }
310     return false;
311   };
312   cudnn_frontend::filter(from, to, filter);
313 }
314 
315 } // cudnn_utils
316 
317 #endif  // AT_CUDNN_ENABLED
318 #endif  // USE_CUDA
319