xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cudnn/Conv.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifdef USE_CUDA
2 #include <ATen/cuda/CUDAConfig.h>  // for the definition of AT_CUDNN_ENABLED
3 
4 #if AT_CUDNN_ENABLED()
5 
6 #include <c10/util/ArrayRef.h>
7 
8 #include <ATen/ATen.h>
9 #include <ATen/cuda/Exceptions.h>
10 #include <ATen/cudnn/Handle.h>
11 #include <ATen/native/cudnn/ConvShared.h>
12 #include <ATen/native/quantized/cudnn/utils.h>
13 #include <ATen/native/quantized/ConvUtils.h>
14 #include <ATen/native/quantized/PackedParams.h>
15 #include <ATen/native/utils/ParamsHash.h>
16 #include <ATen/TensorUtils.h>
17 #include <c10/cuda/CUDACachingAllocator.h>
18 #include <cudnn_frontend.h>
19 #include <torch/library.h>
20 
21 #include <iostream>
22 #include <unordered_map>
23 #include <vector>
24 
25 template <int kSpatialDim = 2>
26 int register_conv_params();
27 
28 extern template int register_conv_params<2>();
29 extern template int register_conv_params<3>();
30 
31 // TODO: there is a table from input dtype and weight dtype to operator qdtype,
32 // we can derive the operator dtype based on input dtype
getConvDescriptor(cudnnDataType_t dataType,c10::IntArrayRef padding,c10::IntArrayRef stride,c10::IntArrayRef dilation)33 cudnn_frontend::ConvDesc_v8 getConvDescriptor(cudnnDataType_t dataType, c10::IntArrayRef padding, c10::IntArrayRef stride, c10::IntArrayRef dilation) {
34   int64_t convDim = static_cast<int64_t>(stride.size());
35   return cudnn_frontend::ConvDescBuilder()
36     .setDataType(dataType)
37     .setMathMode(CUDNN_CROSS_CORRELATION)
38     .setNDims(convDim)
39     .setStrides(convDim, stride.data())
40     .setPrePadding(convDim, padding.data())
41     .setPostPadding(convDim, padding.data())
42     .setDilation(convDim, dilation.data())
43     .build();
44 }
45 
46 // FIXME: make this thread-safe by reusing the benchmark cache in Conv_v7.cpp
47 namespace {
48 struct CacheKey {
49   at::native::ConvolutionParams params;
50   uint8_t input_alignment;
51   uint8_t weight_alignment;
52   uint8_t output_alignment;
53   // default to -1 when no bias
54   int8_t bias_alignment;
55   bool kReluFused;
56 };
57 std::unordered_map<CacheKey, cudnn_frontend::ExecutionPlan, at::native::ParamsHash<CacheKey>, at::native::ParamsEqual<CacheKey>> execution_plan_cache;
58 } // anonymous namespace
59 // TODO: we can use cudnn_frontend::ExecutionPlanCache when it supports caching
60 // multiple operators
61 // reference: https://github.com/NVIDIA/cudnn-frontend/blob/main/samples/conv_sample.cpp#L293
62 //static cudnn_frontend::ExecutionPlanCache plan_cache("sample_cache");
63 
64 // the parameter quantized_output is a quantized tensor
65 template <int kSpatialDim>
66 template <bool kReluFused>
apply_impl_helper(const at::Tensor & quantized_output,const at::Tensor & input,double output_scale)67 void PackedConvWeightCudnn<kSpatialDim>::apply_impl_helper(const at::Tensor& quantized_output, const at::Tensor& input, double output_scale) {
68   auto act_scale = input.q_scale();
69   auto weight_scale = maybe_padded_weight_.q_scale();
70   auto requantize_multiplier = act_scale * weight_scale / output_scale;
71   at::Tensor requantize_multiplier_tensor = cudnn_utils::getRequantMultiplierTensor(requantize_multiplier, kSpatialDim + 2);
72 
73   std::optional<at::Tensor> bias_multiplier_tensor;
74   std::optional<at::Tensor> broadcasted_bias;
75   if (bias_.has_value()) {
76     // the input bias is a 1-D tensor whose size is the same as the size of the second dimension of quantized_output.
77     // we need to add trailing dimensions in order to properly broadcast bias, otherwise broadcast_to will fail.
78     // the number of trailing dimensions is quantized_output.dim() - 2, so the new size of the broadcast_bias
79     // becomes quantized_output.dim() - 2 + 1. nothing needs to be done for the leading dimensions
80     std::vector<int64_t> new_size(quantized_output.dim() - 1, 1);
81     new_size[0] = bias_.value().size(0);
82     broadcasted_bias = bias_.value().reshape(new_size);
83     broadcasted_bias.value() = broadcasted_bias.value().broadcast_to(quantized_output.sizes());
84     broadcasted_bias.value() = broadcasted_bias.value().to(c10::MemoryFormat::ChannelsLast);
85     bias_multiplier_tensor = at::empty(quantized_output.sizes(), at::device(at::kCUDA).dtype(at::kFloat), at::MemoryFormat::ChannelsLast);
86     auto bias_multiplier = 1.0 / (act_scale * weight_scale);
87     bias_multiplier_tensor.value().fill_(bias_multiplier);
88   }
89 
90   cudnnHandle_t handle = at::native::getCudnnHandle();
91   CacheKey key{};
92   // memset is needed here because there is implicit packing added for CacheKey, and this can result in uninitialized padded values that are
93   // used for hashing (see how at::native::ParamsHash is defined). without memset, we can potentially come across a situation where two
94   // CacheKey objects have the same user defined parameters, but
95   // different padded values, resulting in different hash outputs.
96   memset(&key, 0, sizeof(key));
97   bool deterministic{true};
98   bool allow_tf32{false};
99   auto padding_vec = padding_.vec();
100   auto stride_vec = stride_.vec();
101   auto dilation_vec = dilation_.vec();
102   setConvolutionParams(&key.params, input, maybe_padded_weight_, padding_vec, stride_vec, dilation_vec, groups_, deterministic, allow_tf32, input.suggest_memory_format());
103 
104   // operator datatype needs to be int32 for int8 convolution, but we can
105   // set the datatype for output tensor to int32 or fp32
106   key.params.dataType = CUDNN_DATA_INT32;
107   key.input_alignment = cudnn_utils::getAlignment(input);
108   key.output_alignment = cudnn_utils::getAlignment(quantized_output);
109   key.weight_alignment = cudnn_utils::getAlignment(maybe_padded_weight_);
110   if (bias_.has_value()) {
111     key.bias_alignment = static_cast<int8_t>(cudnn_utils::getAlignment(broadcasted_bias.value()));
112   } else {
113     key.bias_alignment = -1;
114   }
115   key.kReluFused = kReluFused;
116 
117   auto run = [&](const cudnn_frontend::ExecutionPlan& plan_desc) {
118     auto workspace_size = plan_desc.getWorkspaceSize();
119     auto workspace_ptr = c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
120     at::SmallVector<void *, 7> data_ptrs;
121     at::SmallVector<int64_t, 7> uids;
122     data_ptrs = {input.data_ptr<int8_t>(), maybe_padded_weight_.data_ptr<int8_t>(),
123                  requantize_multiplier_tensor.data_ptr(), quantized_output.data_ptr<int8_t>()};
124     uids = {'x', 'w', 's', 'r'};
125     if (bias_.has_value()) {
126       data_ptrs.insert(data_ptrs.end(), {broadcasted_bias.value().data_ptr(), bias_multiplier_tensor.value().data_ptr(),
127                                          broadcasted_bias.value().data_ptr()});
128       uids.insert(uids.end(), {'b', 'c', 'd'});
129     }
130     auto variantPack = cudnn_frontend::VariantPackBuilder()
131       .setWorkspacePointer(workspace_size ? workspace_ptr.get() : nullptr)
132       .setDataPointers(static_cast<int64_t>(uids.size()), data_ptrs.data())
133       .setUids(static_cast<int64_t>(uids.size()), uids.data())
134       .build();
135     auto variant_pack_desc = variantPack.get_raw_desc();
136     AT_CUDNN_CHECK(cudnnBackendExecute(handle, plan_desc.get_raw_desc(), variant_pack_desc));
137   };
138 
139   auto search = execution_plan_cache.find(key);
140   if (search != execution_plan_cache.end()) {
141     cudnn_frontend::ExecutionPlan plan_desc = search->second;
142     run(plan_desc);
143     return;
144   }
145   // conv_op computes act_fp32 * w_fp32 (matrix multiplication)
146   // where act_fp32 and w_fp32 are the input and weight variables, resp.
147   // output is a fp32 tensor
148   auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
149       .setxDesc(cudnn_utils::getTensorDescriptor(input.sizes(), input.strides(), CUDNN_DATA_INT8, 'x', key.input_alignment))
150       // for virtual tensors, the alignment is not used, so we can just put an arbitrary value here, e.g., key.output_alignment
151       .setyDesc(cudnn_utils::getTensorDescriptor(quantized_output.sizes(), quantized_output.strides(), CUDNN_DATA_FLOAT, 'y', key.output_alignment, true))
152       .setwDesc(cudnn_utils::getTensorDescriptor(maybe_padded_weight_.sizes(), maybe_padded_weight_.strides(), CUDNN_DATA_INT8, 'w', key.weight_alignment))
153       .setcDesc(getConvDescriptor(key.params.dataType, padding_vec, stride_vec, dilation_vec))
154       .build();
155   // std::cout << "operator:" << conv_op.describe() << std::endl;
156 
157   std::optional<cudnn_frontend::Operation> bias_mult_op;
158   std::optional<cudnn_frontend::Operation> sum_conv_bias_op;
159   if (bias_.has_value()) {
160     // we can't directly assign bias_mult_op because operator= is deleted for cudnn_frontend::Operation;
161     // alternatively, I think we can use std::unique_ptr and dynamically allocate these builder ops
162     // but here, we chose to do it statically. std::optional<T>::emplace() enables this approach
163 
164     // bias_mult_op computes bias_fp32 / (act_scale * w_scale) or bias_fp32 * (1 / (act_scale * w_scale))
165     // where bias_multiplier = (1 / (act_scale * w_scale))
166     // output is a fp32 tensor
167     // we use inplace operation here where the output is assigned to the input
168     bias_mult_op.emplace(cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
169       .setxDesc(cudnn_utils::getTensorDescriptor(broadcasted_bias.value(), 'b', cudnn_utils::getAlignment(broadcasted_bias.value())))
170       .setbDesc(cudnn_utils::getTensorDescriptor(bias_multiplier_tensor.value(), 'c', cudnn_utils::getAlignment(bias_multiplier_tensor.value())))
171       .setyDesc(cudnn_utils::getTensorDescriptor(broadcasted_bias.value(), 'd', cudnn_utils::getAlignment(broadcasted_bias.value())))
172       .setpwDesc(cudnn_utils::getPointWiseMulDescriptor(at::native::getCudnnDataType(bias_multiplier_tensor.value())))
173       .build());
174 
175     // computes (act_int8 * w_int8 + [bias_fp32/(act_scale * w_scale)])
176     // where the 1st and 2nd summands is output of conv_op and broadcasted_bias, resp.
177     // output is a fp32 tensor
178     // we use inplace operation here where the output is assigned to the input
179     sum_conv_bias_op.emplace(cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
180       .setxDesc(conv_op.getOutputTensor())
181       .setbDesc(cudnn_utils::getTensorDescriptor(broadcasted_bias.value(), 'd', cudnn_utils::getAlignment(broadcasted_bias.value())))
182       // for virtual tensors, the alignment is not used, so we can just put an arbitrary value here, e.g., key.output_alignment
183       .setyDesc(cudnn_utils::getTensorDescriptor(quantized_output.sizes(), quantized_output.strides(), CUDNN_DATA_FLOAT, 'e', key.output_alignment, true))
184       .setpwDesc(cudnn_utils::getPointWiseAddDescriptor(at::native::getCudnnDataType(broadcasted_bias.value())))
185       .build());
186   }
187 
188   // relu_op computes relu(act_int8 * w_int8 + [bias_fp32/(act_scale * w_scale)]
189   // or relu(act_int8 * w_int8) if bias is not present.
190   // output is a fp32 tensor
191   std::optional<cudnn_frontend::Operation> relu_op;
192   std::shared_ptr<cudnn_frontend::OpaqueBackendPointer> tensor2requant_ptr = bias_.has_value() ? sum_conv_bias_op.value().getOutputTensor() : conv_op.getOutputTensor();
193   if (kReluFused) {
194     // we use inplace operation here where the output is assigned to the input
195     relu_op.emplace(cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
196       .setxDesc(tensor2requant_ptr)
197       // for virtual tensors, the alignment is not used, so we can just put an arbitrary value here, e.g., key.output_alignment
198       .setyDesc(cudnn_utils::getTensorDescriptor(quantized_output.sizes(), quantized_output.strides(), CUDNN_DATA_FLOAT, 'f', key.output_alignment, true))
199       .setpwDesc(cudnn_utils::getPointWiseReluDescriptor(CUDNN_DATA_FLOAT))
200       .build());
201   }
202 
203   // relu_op computes relu(act_int8 * w_int8 + [bias_fp32/(act_scale * w_scale)]) / (out_scale / (act_scale * w_scale))
204   // or relu(act_int8 * w_int8) / (out_scale / (act_scale * w_scale))) if bias is not present.
205   // output is a fp32 tensor
206   auto requant_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
207     .setxDesc(kReluFused ? relu_op.value().getOutputTensor() : tensor2requant_ptr)
208     .setbDesc(cudnn_utils::getTensorDescriptor(requantize_multiplier_tensor, 's', cudnn_utils::getAlignment(requantize_multiplier_tensor)))
209     .setyDesc(cudnn_utils::getTensorDescriptor(quantized_output.sizes(), quantized_output.strides(), CUDNN_DATA_INT8, 'r', key.output_alignment))
210     .setpwDesc(cudnn_utils::getPointWiseMulDescriptor(at::native::getCudnnDataType(requantize_multiplier_tensor)))
211     .build();
212   // std::cout << "operator:" << requant_op.describe() << std::endl;
213 
214   std::vector<cudnn_frontend::Operation const *> ops{&conv_op};
215   if (bias_.has_value()) {
216     ops.emplace_back(&(bias_mult_op.value()));
217     ops.emplace_back(&(sum_conv_bias_op.value()));
218   }
219   if (kReluFused) {
220     ops.emplace_back(&(relu_op.value()));
221   }
222   ops.emplace_back(&requant_op);
223 
224   auto opGraph = cudnn_frontend::OperationGraphBuilder()
225       .setHandle(handle)
226       .setOperationGraph(static_cast<int64_t>(ops.size()), ops.data())
227       .build();
228   // std::cout << "opGraph: " << opGraph.describe() << std::endl;
229 
230   auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
231       .setOperationGraph(opGraph)
232       .setHeurMode(CUDNN_HEUR_MODE_INSTANT)
233       .build();
234   auto fallback = cudnn_frontend::EngineFallbackListBuilder()
235                     .setOperationGraph(opGraph)
236                     .setOperation(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
237                     .build();
238 
239   auto& engine_configs = heuristics.getEngineConfig(heuristics.getEngineConfigCount());
240   auto& fallback_list = fallback.getFallbackList();
241 
242   cudnn_frontend::EngineConfigList filtered_configs;
243   cudnn_utils::filterEngineConfigs(engine_configs, filtered_configs, deterministic, allow_tf32, at::kChar);
244   cudnn_utils::filterEngineConfigs(fallback_list, filtered_configs, deterministic, allow_tf32, at::kChar);
245 
246   for (auto &cfg : engine_configs) {
247     try {
248       auto plan = cudnn_frontend::ExecutionPlanBuilder()
249         .setHandle(handle)
250         .setEngineConfig(cfg)
251         .build();
252       run(plan);
253       execution_plan_cache.emplace(key, plan);
254       return;
255     } catch (cudnn_frontend::cudnnException &e) {std::cout << "cudnn error:" << e.what() << '\n';} catch(c10::CuDNNError &e) { std::cout << "other error" << e.what() << '\n';}
256   }
257 
258   TORCH_CHECK(false, "Unable to find an engine to execute this computation in Quantized Conv2D Cudnn");
259 }
260 
261 //
262 // output Tensor will be a clampped int8 Tensor
263 // both act and weight will be int8 Tensor
264 /*
265 Numerics:
266 out_fp32 = conv_fp32(act_fp32, w_fp32, …)
267                     = act_fp32 * w_fp32 + bias_fp32
268 act_int8 = act_fp32 / act_scale + act_zero_point
269 w_int8 = w_fp32 / w_scale + w_zero_point
270 out_int8 = out_fp32 / out_scale + out_zero_point
271 out_int8 = (act_fp32 * w_fp32 + [bias_fp32]) / out_scale + out_zero_point
272               = (act_int8 - act_zero_point) * act_scale * (w_int8 - w_zero_point) * w_scale / out_scale + out_zero_point + [bias_fp32 / out_scale]
273              = (act_int8 * w_int8 - act_int8 * w_zero_point - act_zero_point * w_int8 + act_zero_point * w_zero_point) * act_scale * w_scale / out_scale + out_zero_point + [bias_fp32 / out_scale]
274              = (if both act and weight are symmetrically quantized, int8, then act_zero_point = w_zero_point = 0)
275              = (act_int8 * w_int8 + [bias_fp32/(act_scale * w_scale)]) * act_scale * w_scale / out_scale
276              = (act_int8 * w_int8 + [bias_fp32/(act_scale * w_scale)]) / (out_scale / (act_scale * w_scale))
277              = requantize((act_int8 * w_int8 + [bias_fp32/(act_scale * w_scale)]), out_scale / (act_scale * w_scale))
278 */
279 template <int kSpatialDim>
280 template <bool kReluFused>
apply_impl(const at::Tensor & act,double output_scale,int64_t output_zero_point)281 at::Tensor PackedConvWeightCudnn<kSpatialDim>::apply_impl(
282     const at::Tensor& act,
283     double output_scale,
284     int64_t output_zero_point) {
285   const auto batch_size = kSpatialDim == 2 ? act.size(0) : 1;
286   const auto num_input_channels = act.size(kSpatialDim - 1);
287   const auto H = act.size(kSpatialDim);
288   const auto W = act.size(kSpatialDim + 1);
289   const auto num_output_channels = maybe_padded_weight_.size(0); // output channels
290   std::vector<int64_t> kernel_size = {maybe_padded_weight_.size(2), maybe_padded_weight_.size(3)};
291   auto output_shape = at::native::quantized::MakeConvOutputShape<kSpatialDim>(batch_size, num_output_channels, {H, W},
292   kernel_size, stride_, padding_, dilation_);
293   at::Tensor quantized_output = at::_empty_affine_quantized(
294       output_shape,
295       at::device(at::kCUDA).dtype(at::ScalarType::QInt8),
296       output_scale,
297       output_zero_point,
298       at::MemoryFormat::ChannelsLast);
299 
300   // cudnn v8.4.0 expects conv2d's int8 activation tensor's input channels to be a multiple of 4. if it is not
301   // we need to explicitly pad it to a multiple of 4 ourselves as cudnn does not currently support padding.
302   // TODO: when and if cudnn enables padding in their operators, we can remove padding on our end;
303   // currently, limit padding support to groups=1 (ungrouped conv)
304   // TODO: implement this for groups > 1; should be straightforward since we're only padding a single dimension
305   auto act_maybe_padded = act;
306   if (num_input_channels % 4 != 0) {
307     int8_t num_slices = 4 - num_input_channels % 4; // number of slices we need to pad
308     act_maybe_padded = at::pad(act, {0, 0, 0, 0, 0, num_slices, 0, 0}, "constant", 0);
309   }
310   apply_impl_helper<kReluFused>(
311       quantized_output, act_maybe_padded.to(c10::MemoryFormat::ChannelsLast), output_scale);
312 
313   // need to return sliced tensor if output_channels was padded
314   if (num_unpadded_output_channels_ != maybe_padded_weight_.size(0)) {
315     return quantized_output.slice(1, 0, num_unpadded_output_channels_);
316   }
317   return quantized_output;
318 }
319 
320 template <int kSpatialDim>
apply(const at::Tensor & input,double output_scale,int64_t output_zero_point)321 at::Tensor PackedConvWeightCudnn<kSpatialDim>::apply(
322     const at::Tensor& input,
323     double output_scale,
324     int64_t output_zero_point) {
325   return apply_impl<false>(input, output_scale, output_zero_point);
326 }
327 
328 template <int kSpatialDim>
apply_relu(const at::Tensor & input,double output_scale,int64_t output_zero_point)329 at::Tensor PackedConvWeightCudnn<kSpatialDim>::apply_relu(
330     const at::Tensor& input,
331     double output_scale,
332     int64_t output_zero_point) {
333   return apply_impl<true>(input, output_scale, output_zero_point);
334 }
335 
336 template at::Tensor PackedConvWeightCudnn<2>::apply(
337     const at::Tensor& act,
338     double output_scale,
339     int64_t output_zero_point);
340 
341 template at::Tensor PackedConvWeightCudnn<2>::apply_relu(
342     const at::Tensor& act,
343     double output_scale,
344     int64_t output_zero_point);
345 
346 namespace at:: native {
347 namespace {
348 
349 template <bool kReluFused>
350 class QConv1dInt8 final {
351  public:
run(Tensor act,const c10::intrusive_ptr<ConvPackedParamsBase<2>> & packed_weight,double output_scale,int64_t output_zero_point)352   static Tensor run(
353       Tensor act,
354       const c10::intrusive_ptr<ConvPackedParamsBase<2>>& packed_weight,
355       double output_scale,
356       int64_t output_zero_point) {
357     at::Tensor output;
358     // we currently use conv2d kernel for conv1d by making the input and weight tensors
359     // 4D rather than 3D. we add a dummy width dimension of size 1
360     // N, C, L -> N, C, 1, L
361     act = act.unsqueeze(-2);
362     if (kReluFused) {
363       output = packed_weight->apply_relu(act, output_scale, output_zero_point);
364     } else {
365       output = packed_weight->apply(act, output_scale, output_zero_point);
366     }
367     // N, C, 1, L -> N, C, L
368     return output.squeeze_(-2);
369   }
370 };
371 
372 template <int kSpatialDim, bool kReluFused>
373 class QConvInt8 final {
374  public:
run(at::Tensor act,const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> & packed_weight,double output_scale,int64_t output_zero_point)375   static at::Tensor run(
376       at::Tensor act,
377       const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>>& packed_weight,
378       double output_scale,
379       int64_t output_zero_point) {
380     TORCH_CHECK(kSpatialDim == 1 || kSpatialDim == 2, "Error in quantized cudnn conv2d operator: "
381                 "Expected kSpatialDim == 1 || kSpatialDim == 2; received kSpatialDim=", kSpatialDim);
382     // TODO: check all zero_points are zero/all tensors are symmetrically quantized
383     if (kReluFused) {
384       return packed_weight->apply_relu(act, output_scale, output_zero_point);
385     } else {
386       return packed_weight->apply(act, output_scale, output_zero_point);
387     }
388   }
389 };
390 
TORCH_LIBRARY_IMPL(quantized,QuantizedCUDA,m)391 TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) {
392   // the cpu conv1d doesn't use the quantized::conv1d*.new variant for packed weights. instead it just uses
393   // quantized::conv1d for packed weights (see quantized/library.cpp).
394   // this is inconsistent with what has been done for conv2d where new variants use packed weights, and
395   // old variant does not. we adopt this inconsistency for now to be consistent with QuantizedCPU's conv1d
396   // and will eventually deprecate the old variants
397   register_conv_params<2>();
398   register_conv_params<3>();
399   m.impl(TORCH_SELECTIVE_NAME("quantized::conv1d"), QConv1dInt8<false>::run);
400   m.impl(TORCH_SELECTIVE_NAME("quantized::conv1d_relu"), QConv1dInt8<true>::run);
401   m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d.new"), QConvInt8<2, false>::run);
402   m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_relu.new"), QConvInt8<2, true>::run);
403 }
404 
405 } // anonymous namespace
406 } // namespace at::native
407 
408 
409 #endif  // AT_CUDNN_ENABLED
410 #endif  // USE_CUDA
411