1 #ifdef USE_CUDA
2 #include <ATen/cuda/CUDAConfig.h> // for the definition of AT_CUDNN_ENABLED
3
4 #if AT_CUDNN_ENABLED()
5
6 #include <c10/util/ArrayRef.h>
7
8 #include <ATen/ATen.h>
9 #include <ATen/cuda/Exceptions.h>
10 #include <ATen/cudnn/Handle.h>
11 #include <ATen/native/cudnn/ConvShared.h>
12 #include <ATen/native/quantized/cudnn/utils.h>
13 #include <ATen/native/quantized/ConvUtils.h>
14 #include <ATen/native/quantized/PackedParams.h>
15 #include <ATen/native/utils/ParamsHash.h>
16 #include <ATen/TensorUtils.h>
17 #include <c10/cuda/CUDACachingAllocator.h>
18 #include <cudnn_frontend.h>
19 #include <torch/library.h>
20
21 #include <iostream>
22 #include <unordered_map>
23 #include <vector>
24
25 template <int kSpatialDim = 2>
26 int register_conv_params();
27
28 extern template int register_conv_params<2>();
29 extern template int register_conv_params<3>();
30
31 // TODO: there is a table from input dtype and weight dtype to operator qdtype,
32 // we can derive the operator dtype based on input dtype
getConvDescriptor(cudnnDataType_t dataType,c10::IntArrayRef padding,c10::IntArrayRef stride,c10::IntArrayRef dilation)33 cudnn_frontend::ConvDesc_v8 getConvDescriptor(cudnnDataType_t dataType, c10::IntArrayRef padding, c10::IntArrayRef stride, c10::IntArrayRef dilation) {
34 int64_t convDim = static_cast<int64_t>(stride.size());
35 return cudnn_frontend::ConvDescBuilder()
36 .setDataType(dataType)
37 .setMathMode(CUDNN_CROSS_CORRELATION)
38 .setNDims(convDim)
39 .setStrides(convDim, stride.data())
40 .setPrePadding(convDim, padding.data())
41 .setPostPadding(convDim, padding.data())
42 .setDilation(convDim, dilation.data())
43 .build();
44 }
45
46 // FIXME: make this thread-safe by reusing the benchmark cache in Conv_v7.cpp
47 namespace {
48 struct CacheKey {
49 at::native::ConvolutionParams params;
50 uint8_t input_alignment;
51 uint8_t weight_alignment;
52 uint8_t output_alignment;
53 // default to -1 when no bias
54 int8_t bias_alignment;
55 bool kReluFused;
56 };
57 std::unordered_map<CacheKey, cudnn_frontend::ExecutionPlan, at::native::ParamsHash<CacheKey>, at::native::ParamsEqual<CacheKey>> execution_plan_cache;
58 } // anonymous namespace
59 // TODO: we can use cudnn_frontend::ExecutionPlanCache when it supports caching
60 // multiple operators
61 // reference: https://github.com/NVIDIA/cudnn-frontend/blob/main/samples/conv_sample.cpp#L293
62 //static cudnn_frontend::ExecutionPlanCache plan_cache("sample_cache");
63
64 // the parameter quantized_output is a quantized tensor
65 template <int kSpatialDim>
66 template <bool kReluFused>
apply_impl_helper(const at::Tensor & quantized_output,const at::Tensor & input,double output_scale)67 void PackedConvWeightCudnn<kSpatialDim>::apply_impl_helper(const at::Tensor& quantized_output, const at::Tensor& input, double output_scale) {
68 auto act_scale = input.q_scale();
69 auto weight_scale = maybe_padded_weight_.q_scale();
70 auto requantize_multiplier = act_scale * weight_scale / output_scale;
71 at::Tensor requantize_multiplier_tensor = cudnn_utils::getRequantMultiplierTensor(requantize_multiplier, kSpatialDim + 2);
72
73 std::optional<at::Tensor> bias_multiplier_tensor;
74 std::optional<at::Tensor> broadcasted_bias;
75 if (bias_.has_value()) {
76 // the input bias is a 1-D tensor whose size is the same as the size of the second dimension of quantized_output.
77 // we need to add trailing dimensions in order to properly broadcast bias, otherwise broadcast_to will fail.
78 // the number of trailing dimensions is quantized_output.dim() - 2, so the new size of the broadcast_bias
79 // becomes quantized_output.dim() - 2 + 1. nothing needs to be done for the leading dimensions
80 std::vector<int64_t> new_size(quantized_output.dim() - 1, 1);
81 new_size[0] = bias_.value().size(0);
82 broadcasted_bias = bias_.value().reshape(new_size);
83 broadcasted_bias.value() = broadcasted_bias.value().broadcast_to(quantized_output.sizes());
84 broadcasted_bias.value() = broadcasted_bias.value().to(c10::MemoryFormat::ChannelsLast);
85 bias_multiplier_tensor = at::empty(quantized_output.sizes(), at::device(at::kCUDA).dtype(at::kFloat), at::MemoryFormat::ChannelsLast);
86 auto bias_multiplier = 1.0 / (act_scale * weight_scale);
87 bias_multiplier_tensor.value().fill_(bias_multiplier);
88 }
89
90 cudnnHandle_t handle = at::native::getCudnnHandle();
91 CacheKey key{};
92 // memset is needed here because there is implicit packing added for CacheKey, and this can result in uninitialized padded values that are
93 // used for hashing (see how at::native::ParamsHash is defined). without memset, we can potentially come across a situation where two
94 // CacheKey objects have the same user defined parameters, but
95 // different padded values, resulting in different hash outputs.
96 memset(&key, 0, sizeof(key));
97 bool deterministic{true};
98 bool allow_tf32{false};
99 auto padding_vec = padding_.vec();
100 auto stride_vec = stride_.vec();
101 auto dilation_vec = dilation_.vec();
102 setConvolutionParams(&key.params, input, maybe_padded_weight_, padding_vec, stride_vec, dilation_vec, groups_, deterministic, allow_tf32, input.suggest_memory_format());
103
104 // operator datatype needs to be int32 for int8 convolution, but we can
105 // set the datatype for output tensor to int32 or fp32
106 key.params.dataType = CUDNN_DATA_INT32;
107 key.input_alignment = cudnn_utils::getAlignment(input);
108 key.output_alignment = cudnn_utils::getAlignment(quantized_output);
109 key.weight_alignment = cudnn_utils::getAlignment(maybe_padded_weight_);
110 if (bias_.has_value()) {
111 key.bias_alignment = static_cast<int8_t>(cudnn_utils::getAlignment(broadcasted_bias.value()));
112 } else {
113 key.bias_alignment = -1;
114 }
115 key.kReluFused = kReluFused;
116
117 auto run = [&](const cudnn_frontend::ExecutionPlan& plan_desc) {
118 auto workspace_size = plan_desc.getWorkspaceSize();
119 auto workspace_ptr = c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
120 at::SmallVector<void *, 7> data_ptrs;
121 at::SmallVector<int64_t, 7> uids;
122 data_ptrs = {input.data_ptr<int8_t>(), maybe_padded_weight_.data_ptr<int8_t>(),
123 requantize_multiplier_tensor.data_ptr(), quantized_output.data_ptr<int8_t>()};
124 uids = {'x', 'w', 's', 'r'};
125 if (bias_.has_value()) {
126 data_ptrs.insert(data_ptrs.end(), {broadcasted_bias.value().data_ptr(), bias_multiplier_tensor.value().data_ptr(),
127 broadcasted_bias.value().data_ptr()});
128 uids.insert(uids.end(), {'b', 'c', 'd'});
129 }
130 auto variantPack = cudnn_frontend::VariantPackBuilder()
131 .setWorkspacePointer(workspace_size ? workspace_ptr.get() : nullptr)
132 .setDataPointers(static_cast<int64_t>(uids.size()), data_ptrs.data())
133 .setUids(static_cast<int64_t>(uids.size()), uids.data())
134 .build();
135 auto variant_pack_desc = variantPack.get_raw_desc();
136 AT_CUDNN_CHECK(cudnnBackendExecute(handle, plan_desc.get_raw_desc(), variant_pack_desc));
137 };
138
139 auto search = execution_plan_cache.find(key);
140 if (search != execution_plan_cache.end()) {
141 cudnn_frontend::ExecutionPlan plan_desc = search->second;
142 run(plan_desc);
143 return;
144 }
145 // conv_op computes act_fp32 * w_fp32 (matrix multiplication)
146 // where act_fp32 and w_fp32 are the input and weight variables, resp.
147 // output is a fp32 tensor
148 auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
149 .setxDesc(cudnn_utils::getTensorDescriptor(input.sizes(), input.strides(), CUDNN_DATA_INT8, 'x', key.input_alignment))
150 // for virtual tensors, the alignment is not used, so we can just put an arbitrary value here, e.g., key.output_alignment
151 .setyDesc(cudnn_utils::getTensorDescriptor(quantized_output.sizes(), quantized_output.strides(), CUDNN_DATA_FLOAT, 'y', key.output_alignment, true))
152 .setwDesc(cudnn_utils::getTensorDescriptor(maybe_padded_weight_.sizes(), maybe_padded_weight_.strides(), CUDNN_DATA_INT8, 'w', key.weight_alignment))
153 .setcDesc(getConvDescriptor(key.params.dataType, padding_vec, stride_vec, dilation_vec))
154 .build();
155 // std::cout << "operator:" << conv_op.describe() << std::endl;
156
157 std::optional<cudnn_frontend::Operation> bias_mult_op;
158 std::optional<cudnn_frontend::Operation> sum_conv_bias_op;
159 if (bias_.has_value()) {
160 // we can't directly assign bias_mult_op because operator= is deleted for cudnn_frontend::Operation;
161 // alternatively, I think we can use std::unique_ptr and dynamically allocate these builder ops
162 // but here, we chose to do it statically. std::optional<T>::emplace() enables this approach
163
164 // bias_mult_op computes bias_fp32 / (act_scale * w_scale) or bias_fp32 * (1 / (act_scale * w_scale))
165 // where bias_multiplier = (1 / (act_scale * w_scale))
166 // output is a fp32 tensor
167 // we use inplace operation here where the output is assigned to the input
168 bias_mult_op.emplace(cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
169 .setxDesc(cudnn_utils::getTensorDescriptor(broadcasted_bias.value(), 'b', cudnn_utils::getAlignment(broadcasted_bias.value())))
170 .setbDesc(cudnn_utils::getTensorDescriptor(bias_multiplier_tensor.value(), 'c', cudnn_utils::getAlignment(bias_multiplier_tensor.value())))
171 .setyDesc(cudnn_utils::getTensorDescriptor(broadcasted_bias.value(), 'd', cudnn_utils::getAlignment(broadcasted_bias.value())))
172 .setpwDesc(cudnn_utils::getPointWiseMulDescriptor(at::native::getCudnnDataType(bias_multiplier_tensor.value())))
173 .build());
174
175 // computes (act_int8 * w_int8 + [bias_fp32/(act_scale * w_scale)])
176 // where the 1st and 2nd summands is output of conv_op and broadcasted_bias, resp.
177 // output is a fp32 tensor
178 // we use inplace operation here where the output is assigned to the input
179 sum_conv_bias_op.emplace(cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
180 .setxDesc(conv_op.getOutputTensor())
181 .setbDesc(cudnn_utils::getTensorDescriptor(broadcasted_bias.value(), 'd', cudnn_utils::getAlignment(broadcasted_bias.value())))
182 // for virtual tensors, the alignment is not used, so we can just put an arbitrary value here, e.g., key.output_alignment
183 .setyDesc(cudnn_utils::getTensorDescriptor(quantized_output.sizes(), quantized_output.strides(), CUDNN_DATA_FLOAT, 'e', key.output_alignment, true))
184 .setpwDesc(cudnn_utils::getPointWiseAddDescriptor(at::native::getCudnnDataType(broadcasted_bias.value())))
185 .build());
186 }
187
188 // relu_op computes relu(act_int8 * w_int8 + [bias_fp32/(act_scale * w_scale)]
189 // or relu(act_int8 * w_int8) if bias is not present.
190 // output is a fp32 tensor
191 std::optional<cudnn_frontend::Operation> relu_op;
192 std::shared_ptr<cudnn_frontend::OpaqueBackendPointer> tensor2requant_ptr = bias_.has_value() ? sum_conv_bias_op.value().getOutputTensor() : conv_op.getOutputTensor();
193 if (kReluFused) {
194 // we use inplace operation here where the output is assigned to the input
195 relu_op.emplace(cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
196 .setxDesc(tensor2requant_ptr)
197 // for virtual tensors, the alignment is not used, so we can just put an arbitrary value here, e.g., key.output_alignment
198 .setyDesc(cudnn_utils::getTensorDescriptor(quantized_output.sizes(), quantized_output.strides(), CUDNN_DATA_FLOAT, 'f', key.output_alignment, true))
199 .setpwDesc(cudnn_utils::getPointWiseReluDescriptor(CUDNN_DATA_FLOAT))
200 .build());
201 }
202
203 // relu_op computes relu(act_int8 * w_int8 + [bias_fp32/(act_scale * w_scale)]) / (out_scale / (act_scale * w_scale))
204 // or relu(act_int8 * w_int8) / (out_scale / (act_scale * w_scale))) if bias is not present.
205 // output is a fp32 tensor
206 auto requant_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
207 .setxDesc(kReluFused ? relu_op.value().getOutputTensor() : tensor2requant_ptr)
208 .setbDesc(cudnn_utils::getTensorDescriptor(requantize_multiplier_tensor, 's', cudnn_utils::getAlignment(requantize_multiplier_tensor)))
209 .setyDesc(cudnn_utils::getTensorDescriptor(quantized_output.sizes(), quantized_output.strides(), CUDNN_DATA_INT8, 'r', key.output_alignment))
210 .setpwDesc(cudnn_utils::getPointWiseMulDescriptor(at::native::getCudnnDataType(requantize_multiplier_tensor)))
211 .build();
212 // std::cout << "operator:" << requant_op.describe() << std::endl;
213
214 std::vector<cudnn_frontend::Operation const *> ops{&conv_op};
215 if (bias_.has_value()) {
216 ops.emplace_back(&(bias_mult_op.value()));
217 ops.emplace_back(&(sum_conv_bias_op.value()));
218 }
219 if (kReluFused) {
220 ops.emplace_back(&(relu_op.value()));
221 }
222 ops.emplace_back(&requant_op);
223
224 auto opGraph = cudnn_frontend::OperationGraphBuilder()
225 .setHandle(handle)
226 .setOperationGraph(static_cast<int64_t>(ops.size()), ops.data())
227 .build();
228 // std::cout << "opGraph: " << opGraph.describe() << std::endl;
229
230 auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
231 .setOperationGraph(opGraph)
232 .setHeurMode(CUDNN_HEUR_MODE_INSTANT)
233 .build();
234 auto fallback = cudnn_frontend::EngineFallbackListBuilder()
235 .setOperationGraph(opGraph)
236 .setOperation(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
237 .build();
238
239 auto& engine_configs = heuristics.getEngineConfig(heuristics.getEngineConfigCount());
240 auto& fallback_list = fallback.getFallbackList();
241
242 cudnn_frontend::EngineConfigList filtered_configs;
243 cudnn_utils::filterEngineConfigs(engine_configs, filtered_configs, deterministic, allow_tf32, at::kChar);
244 cudnn_utils::filterEngineConfigs(fallback_list, filtered_configs, deterministic, allow_tf32, at::kChar);
245
246 for (auto &cfg : engine_configs) {
247 try {
248 auto plan = cudnn_frontend::ExecutionPlanBuilder()
249 .setHandle(handle)
250 .setEngineConfig(cfg)
251 .build();
252 run(plan);
253 execution_plan_cache.emplace(key, plan);
254 return;
255 } catch (cudnn_frontend::cudnnException &e) {std::cout << "cudnn error:" << e.what() << '\n';} catch(c10::CuDNNError &e) { std::cout << "other error" << e.what() << '\n';}
256 }
257
258 TORCH_CHECK(false, "Unable to find an engine to execute this computation in Quantized Conv2D Cudnn");
259 }
260
261 //
262 // output Tensor will be a clampped int8 Tensor
263 // both act and weight will be int8 Tensor
264 /*
265 Numerics:
266 out_fp32 = conv_fp32(act_fp32, w_fp32, …)
267 = act_fp32 * w_fp32 + bias_fp32
268 act_int8 = act_fp32 / act_scale + act_zero_point
269 w_int8 = w_fp32 / w_scale + w_zero_point
270 out_int8 = out_fp32 / out_scale + out_zero_point
271 out_int8 = (act_fp32 * w_fp32 + [bias_fp32]) / out_scale + out_zero_point
272 = (act_int8 - act_zero_point) * act_scale * (w_int8 - w_zero_point) * w_scale / out_scale + out_zero_point + [bias_fp32 / out_scale]
273 = (act_int8 * w_int8 - act_int8 * w_zero_point - act_zero_point * w_int8 + act_zero_point * w_zero_point) * act_scale * w_scale / out_scale + out_zero_point + [bias_fp32 / out_scale]
274 = (if both act and weight are symmetrically quantized, int8, then act_zero_point = w_zero_point = 0)
275 = (act_int8 * w_int8 + [bias_fp32/(act_scale * w_scale)]) * act_scale * w_scale / out_scale
276 = (act_int8 * w_int8 + [bias_fp32/(act_scale * w_scale)]) / (out_scale / (act_scale * w_scale))
277 = requantize((act_int8 * w_int8 + [bias_fp32/(act_scale * w_scale)]), out_scale / (act_scale * w_scale))
278 */
279 template <int kSpatialDim>
280 template <bool kReluFused>
apply_impl(const at::Tensor & act,double output_scale,int64_t output_zero_point)281 at::Tensor PackedConvWeightCudnn<kSpatialDim>::apply_impl(
282 const at::Tensor& act,
283 double output_scale,
284 int64_t output_zero_point) {
285 const auto batch_size = kSpatialDim == 2 ? act.size(0) : 1;
286 const auto num_input_channels = act.size(kSpatialDim - 1);
287 const auto H = act.size(kSpatialDim);
288 const auto W = act.size(kSpatialDim + 1);
289 const auto num_output_channels = maybe_padded_weight_.size(0); // output channels
290 std::vector<int64_t> kernel_size = {maybe_padded_weight_.size(2), maybe_padded_weight_.size(3)};
291 auto output_shape = at::native::quantized::MakeConvOutputShape<kSpatialDim>(batch_size, num_output_channels, {H, W},
292 kernel_size, stride_, padding_, dilation_);
293 at::Tensor quantized_output = at::_empty_affine_quantized(
294 output_shape,
295 at::device(at::kCUDA).dtype(at::ScalarType::QInt8),
296 output_scale,
297 output_zero_point,
298 at::MemoryFormat::ChannelsLast);
299
300 // cudnn v8.4.0 expects conv2d's int8 activation tensor's input channels to be a multiple of 4. if it is not
301 // we need to explicitly pad it to a multiple of 4 ourselves as cudnn does not currently support padding.
302 // TODO: when and if cudnn enables padding in their operators, we can remove padding on our end;
303 // currently, limit padding support to groups=1 (ungrouped conv)
304 // TODO: implement this for groups > 1; should be straightforward since we're only padding a single dimension
305 auto act_maybe_padded = act;
306 if (num_input_channels % 4 != 0) {
307 int8_t num_slices = 4 - num_input_channels % 4; // number of slices we need to pad
308 act_maybe_padded = at::pad(act, {0, 0, 0, 0, 0, num_slices, 0, 0}, "constant", 0);
309 }
310 apply_impl_helper<kReluFused>(
311 quantized_output, act_maybe_padded.to(c10::MemoryFormat::ChannelsLast), output_scale);
312
313 // need to return sliced tensor if output_channels was padded
314 if (num_unpadded_output_channels_ != maybe_padded_weight_.size(0)) {
315 return quantized_output.slice(1, 0, num_unpadded_output_channels_);
316 }
317 return quantized_output;
318 }
319
320 template <int kSpatialDim>
apply(const at::Tensor & input,double output_scale,int64_t output_zero_point)321 at::Tensor PackedConvWeightCudnn<kSpatialDim>::apply(
322 const at::Tensor& input,
323 double output_scale,
324 int64_t output_zero_point) {
325 return apply_impl<false>(input, output_scale, output_zero_point);
326 }
327
328 template <int kSpatialDim>
apply_relu(const at::Tensor & input,double output_scale,int64_t output_zero_point)329 at::Tensor PackedConvWeightCudnn<kSpatialDim>::apply_relu(
330 const at::Tensor& input,
331 double output_scale,
332 int64_t output_zero_point) {
333 return apply_impl<true>(input, output_scale, output_zero_point);
334 }
335
336 template at::Tensor PackedConvWeightCudnn<2>::apply(
337 const at::Tensor& act,
338 double output_scale,
339 int64_t output_zero_point);
340
341 template at::Tensor PackedConvWeightCudnn<2>::apply_relu(
342 const at::Tensor& act,
343 double output_scale,
344 int64_t output_zero_point);
345
346 namespace at:: native {
347 namespace {
348
349 template <bool kReluFused>
350 class QConv1dInt8 final {
351 public:
run(Tensor act,const c10::intrusive_ptr<ConvPackedParamsBase<2>> & packed_weight,double output_scale,int64_t output_zero_point)352 static Tensor run(
353 Tensor act,
354 const c10::intrusive_ptr<ConvPackedParamsBase<2>>& packed_weight,
355 double output_scale,
356 int64_t output_zero_point) {
357 at::Tensor output;
358 // we currently use conv2d kernel for conv1d by making the input and weight tensors
359 // 4D rather than 3D. we add a dummy width dimension of size 1
360 // N, C, L -> N, C, 1, L
361 act = act.unsqueeze(-2);
362 if (kReluFused) {
363 output = packed_weight->apply_relu(act, output_scale, output_zero_point);
364 } else {
365 output = packed_weight->apply(act, output_scale, output_zero_point);
366 }
367 // N, C, 1, L -> N, C, L
368 return output.squeeze_(-2);
369 }
370 };
371
372 template <int kSpatialDim, bool kReluFused>
373 class QConvInt8 final {
374 public:
run(at::Tensor act,const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> & packed_weight,double output_scale,int64_t output_zero_point)375 static at::Tensor run(
376 at::Tensor act,
377 const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>>& packed_weight,
378 double output_scale,
379 int64_t output_zero_point) {
380 TORCH_CHECK(kSpatialDim == 1 || kSpatialDim == 2, "Error in quantized cudnn conv2d operator: "
381 "Expected kSpatialDim == 1 || kSpatialDim == 2; received kSpatialDim=", kSpatialDim);
382 // TODO: check all zero_points are zero/all tensors are symmetrically quantized
383 if (kReluFused) {
384 return packed_weight->apply_relu(act, output_scale, output_zero_point);
385 } else {
386 return packed_weight->apply(act, output_scale, output_zero_point);
387 }
388 }
389 };
390
TORCH_LIBRARY_IMPL(quantized,QuantizedCUDA,m)391 TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) {
392 // the cpu conv1d doesn't use the quantized::conv1d*.new variant for packed weights. instead it just uses
393 // quantized::conv1d for packed weights (see quantized/library.cpp).
394 // this is inconsistent with what has been done for conv2d where new variants use packed weights, and
395 // old variant does not. we adopt this inconsistency for now to be consistent with QuantizedCPU's conv1d
396 // and will eventually deprecate the old variants
397 register_conv_params<2>();
398 register_conv_params<3>();
399 m.impl(TORCH_SELECTIVE_NAME("quantized::conv1d"), QConv1dInt8<false>::run);
400 m.impl(TORCH_SELECTIVE_NAME("quantized::conv1d_relu"), QConv1dInt8<true>::run);
401 m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d.new"), QConvInt8<2, false>::run);
402 m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_relu.new"), QConvInt8<2, true>::run);
403 }
404
405 } // anonymous namespace
406 } // namespace at::native
407
408
409 #endif // AT_CUDNN_ENABLED
410 #endif // USE_CUDA
411