1 #include <torch/csrc/jit/tensorexpr/external_functions.h>
2
3 #include <ATen/ATen.h>
4 #include <ATen/Functions.h>
5 #include <ATen/NativeFunctions.h>
6 #include <ATen/Parallel.h>
7 #include <ATen/core/Tensor.h>
8 #include <ATen/native/mkldnn/OpContext.h>
9 #include <ATen/native/quantized/PackedParams.h>
10 #include <ATen/native/quantized/cpu/BinaryOps.h>
11 #include <ATen/native/quantized/cpu/QuantUtils.h>
12 #include <ATen/native/quantized/cpu/QuantizedOps.h>
13 #include <ATen/native/quantized/cpu/conv_serialization.h>
14 #include <ATen/native/xnnpack/OpContext.h>
15 #include <ATen/quantized/QTensorImpl.h>
16 #include <c10/core/TensorImpl.h>
17 #include <c10/core/TensorOptions.h>
18 #include <c10/util/ArrayRef.h>
19 #include <c10/util/irange.h>
20 #include <torch/csrc/jit/serialization/import_source.h>
21 #include <torch/csrc/jit/serialization/pickle.h>
22 #include <torch/csrc/jit/tensorexpr/exceptions.h>
23 #include <torch/csrc/jit/tensorexpr/external_functions_registry.h>
24 #include <utility>
25
26 namespace torch::jit::tensorexpr {
27
deduce_memory_format(c10::IntArrayRef strides,c10::IntArrayRef dims)28 static c10::MemoryFormat deduce_memory_format(
29 c10::IntArrayRef strides,
30 c10::IntArrayRef dims) {
31 if (strides.size() == 4 && strides[3] == dims[1] && strides[1] == 1l) {
32 return c10::MemoryFormat::ChannelsLast;
33 }
34 return c10::MemoryFormat::Contiguous;
35 }
36
deduce_memory_format(const std::vector<int64_t> & strides,const std::vector<int64_t> & dims)37 static c10::MemoryFormat deduce_memory_format(
38 const std::vector<int64_t>& strides,
39 const std::vector<int64_t>& dims) {
40 return deduce_memory_format(
41 c10::IntArrayRef(strides), c10::IntArrayRef(dims));
42 }
43
from_blob_quantized(void * data,at::IntArrayRef sizes,at::IntArrayRef strides,double qscale,int64_t qzero,at::ScalarType dtype)44 static at::Tensor from_blob_quantized(
45 void* data,
46 at::IntArrayRef sizes,
47 at::IntArrayRef strides,
48 double qscale,
49 int64_t qzero,
50 at::ScalarType dtype) {
51 auto memory_format = deduce_memory_format(strides, sizes);
52 auto qx = at::_empty_affine_quantized(
53 sizes,
54 dtype,
55 c10::kStrided,
56 at::kCPU,
57 false,
58 qscale,
59 qzero,
60 memory_format);
61 auto qtensor_impl = static_cast<at::QTensorImpl*>(qx.unsafeGetTensorImpl());
62 auto typeMeta = c10::scalarTypeToTypeMeta(dtype);
63 std::size_t size = 1;
64 for (std::int64_t s : sizes) {
65 size *= static_cast<std::size_t>(s);
66 }
67 qtensor_impl->ShareExternalPointer(
68 c10::InefficientStdFunctionContext::makeDataPtr(
69 data, [](void*) {}, at::kCPU),
70 typeMeta,
71 size * typeMeta.itemsize());
72 qtensor_impl->set_sizes_and_strides(sizes, strides);
73 return qx;
74 }
75
constructTensors(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,std::optional<std::vector<std::pair<size_t,QIData>>> qdataArg)76 std::vector<at::Tensor> constructTensors(
77 int64_t bufs_num,
78 void** buf_data,
79 int64_t* buf_ranks,
80 int64_t* buf_dims,
81 int64_t* buf_strides,
82 int8_t* buf_dtypes,
83 std::optional<std::vector<std::pair<size_t, QIData>>> qdataArg) {
84 std::vector<void*> buf_data_vec;
85 std::vector<std::vector<int64_t>> buf_dims_vec;
86 std::vector<std::vector<int64_t>> buf_strides_vec;
87 std::vector<c10::ScalarType> buf_dtypes_vec;
88 int64_t buf_dims_idx = 0;
89 int64_t buf_strides_idx = 0;
90 for (const auto i : c10::irange(bufs_num)) {
91 buf_data_vec.push_back(buf_data[i]);
92 buf_dims_vec.emplace_back();
93 buf_strides_vec.emplace_back();
94 for (const auto dim : c10::irange(buf_ranks[i])) {
95 (void)dim;
96 buf_dims_vec[i].push_back(buf_dims[buf_dims_idx++]);
97 buf_strides_vec[i].push_back(buf_strides[buf_strides_idx++]);
98 }
99 buf_dtypes_vec.push_back(static_cast<c10::ScalarType>(buf_dtypes[i]));
100 }
101
102 std::vector<at::Tensor> tensors;
103 if (!qdataArg.has_value()) {
104 for (const auto i : c10::irange(buf_data_vec.size())) {
105 auto options = at::TensorOptions()
106 // NOLINTNEXTLINE
107 .dtype(buf_dtypes_vec[i])
108 .layout(at::kStrided)
109 .device(at::kCPU) // TODO: support GPUs too
110 .memory_format(deduce_memory_format(
111 // NOLINTNEXTLINE
112 buf_strides_vec[i],
113 // NOLINTNEXTLINE
114 buf_dims_vec[i]))
115 .requires_grad(false);
116 auto tensor = at::from_blob(
117 // NOLINTNEXTLINE
118 buf_data_vec[i],
119 buf_dims_vec[i],
120 buf_strides_vec[i],
121 options);
122 tensors.emplace_back(tensor);
123 }
124 } else {
125 // handle quantized
126 std::vector<std::optional<QIData>> qdata(bufs_num, std::nullopt);
127 for (const auto& qd : *qdataArg) {
128 qdata[qd.first] = qd.second;
129 }
130 for (const auto i : c10::irange(buf_data_vec.size())) {
131 auto options = at::TensorOptions()
132 // NOLINTNEXTLINE
133 .dtype(buf_dtypes_vec[i])
134 .layout(at::kStrided)
135 .device(at::kCPU) // TODO: support GPUs too
136 .memory_format(deduce_memory_format(
137 // NOLINTNEXTLINE
138 buf_strides_vec[i],
139 // NOLINTNEXTLINE
140 buf_dims_vec[i]))
141 .requires_grad(false);
142 if (auto qd = qdata[i]) {
143 // inplace tensor
144 auto tensor = from_blob_quantized(
145 // NOLINTNEXTLINE
146 buf_data_vec[i],
147 buf_dims_vec[i],
148 buf_strides_vec[i],
149 qd->scale,
150 qd->zero,
151 qd->scalarType);
152 tensors.emplace_back(tensor);
153 } else {
154 auto tensor = at::from_blob(
155 // NOLINTNEXTLINE
156 buf_data_vec[i],
157 buf_dims_vec[i],
158 buf_strides_vec[i],
159 options);
160 tensors.emplace_back(tensor);
161 }
162 }
163 }
164 return tensors;
165 }
166
constructTensors(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,std::vector<std::pair<size_t,QIData>> qdata)167 static std::vector<at::Tensor> constructTensors(
168 int64_t bufs_num,
169 void** buf_data,
170 int64_t* buf_ranks,
171 int64_t* buf_dims,
172 int64_t* buf_strides,
173 int8_t* buf_dtypes,
174 std::vector<std::pair<size_t, QIData>> qdata) {
175 std::optional<std::vector<std::pair<size_t, QIData>>> opt = std::move(qdata);
176 return constructTensors(
177 bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes, opt);
178 }
179
constructTensors2(int64_t bufs_in_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,std::optional<std::vector<std::pair<size_t,QIData>>> qdataArg,size_t bufs_out_num)180 std::vector<at::Tensor> constructTensors2(
181 int64_t bufs_in_num,
182 void** buf_data,
183 int64_t* buf_ranks,
184 int64_t* buf_dims,
185 int64_t* buf_strides,
186 int8_t* buf_dtypes,
187 std::optional<std::vector<std::pair<size_t, QIData>>> qdataArg,
188 size_t bufs_out_num) {
189 std::vector<void*> buf_data_vec;
190 std::vector<std::vector<int64_t>> buf_dims_vec;
191 std::vector<std::vector<int64_t>> buf_strides_vec;
192 std::vector<c10::ScalarType> buf_dtypes_vec;
193 int64_t buf_dims_idx = 0;
194 int64_t buf_strides_idx = 0;
195 for (const auto i : c10::irange(bufs_in_num)) {
196 buf_data_vec.push_back(buf_data[bufs_out_num + i]);
197 buf_dims_vec.emplace_back();
198 buf_strides_vec.emplace_back();
199 for (const auto dim : c10::irange(buf_ranks[i])) {
200 (void)dim;
201 buf_dims_vec[i].push_back(buf_dims[buf_dims_idx++]);
202 buf_strides_vec[i].push_back(buf_strides[buf_strides_idx++]);
203 }
204 buf_dtypes_vec.push_back(static_cast<c10::ScalarType>(buf_dtypes[i]));
205 }
206
207 std::vector<at::Tensor> tensors;
208 at::Tensor und;
209 for (const auto i : c10::irange(bufs_out_num)) {
210 (void)i;
211 tensors.emplace_back(und);
212 }
213 if (!qdataArg.has_value()) {
214 for (const auto i : c10::irange(buf_data_vec.size())) {
215 auto options = at::TensorOptions()
216 // NOLINTNEXTLINE
217 .dtype(buf_dtypes_vec[i])
218 .layout(at::kStrided)
219 .device(at::kCPU) // TODO: support GPUs too
220 .memory_format(deduce_memory_format(
221 // NOLINTNEXTLINE
222 buf_strides_vec[i],
223 // NOLINTNEXTLINE
224 buf_dims_vec[i]))
225 .requires_grad(false);
226 auto tensor = at::from_blob(
227 // NOLINTNEXTLINE
228 buf_data_vec[i],
229 buf_dims_vec[i],
230 buf_strides_vec[i],
231 options);
232 tensors.emplace_back(tensor);
233 }
234 } else {
235 // handle quantized
236 std::vector<std::optional<QIData>> qdata(bufs_in_num, std::nullopt);
237 for (const auto& qd : *qdataArg) {
238 qdata[qd.first - bufs_out_num] = qd.second;
239 }
240 for (const auto i : c10::irange(buf_data_vec.size())) {
241 auto options = at::TensorOptions()
242 // NOLINTNEXTLINE
243 .dtype(buf_dtypes_vec[i])
244 .layout(at::kStrided)
245 .device(at::kCPU) // TODO: support GPUs too
246 .memory_format(deduce_memory_format(
247 // NOLINTNEXTLINE
248 buf_strides_vec[i],
249 // NOLINTNEXTLINE
250 buf_dims_vec[i]))
251 .requires_grad(false);
252 if (auto qd = qdata[i]) {
253 // inplace tensor
254 auto tensor = from_blob_quantized(
255 // NOLINTNEXTLINE
256 buf_data_vec[i],
257 buf_dims_vec[i],
258 buf_strides_vec[i],
259 qd->scale,
260 qd->zero,
261 qd->scalarType);
262 tensors.emplace_back(tensor);
263 } else {
264 auto tensor = at::from_blob(
265 // NOLINTNEXTLINE
266 buf_data_vec[i],
267 buf_dims_vec[i],
268 buf_strides_vec[i],
269 options);
270 tensors.emplace_back(tensor);
271 }
272 }
273 }
274 return tensors;
275 }
276
constructTensors2(int64_t bufs_in_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,std::vector<std::pair<size_t,QIData>> qdata,size_t bufs_out_num=0u)277 static std::vector<at::Tensor> constructTensors2(
278 int64_t bufs_in_num,
279 void** buf_data,
280 int64_t* buf_ranks,
281 int64_t* buf_dims,
282 int64_t* buf_strides,
283 int8_t* buf_dtypes,
284 std::vector<std::pair<size_t, QIData>> qdata,
285 size_t bufs_out_num = 0u) {
286 std::optional<std::vector<std::pair<size_t, QIData>>> opt = std::move(qdata);
287 return constructTensors2(
288 bufs_in_num,
289 buf_data,
290 buf_ranks,
291 buf_dims,
292 buf_strides,
293 buf_dtypes,
294 opt,
295 bufs_out_num);
296 }
297
298 #ifndef _WIN32
quantized_add(const at::Tensor & x1,const at::Tensor & x2,double scale,int64_t zero)299 static at::Tensor quantized_add(
300 const at::Tensor& x1,
301 const at::Tensor& x2,
302 double scale,
303 int64_t zero) {
304 const auto qadd_op =
305 c10::Dispatcher::singleton()
306 .findSchemaOrThrow("quantized::add", "")
307 .typed<at::Tensor(at::Tensor, at::Tensor, double, int64_t)>();
308 return qadd_op.call(x1, x2, scale, zero);
309 }
310
quantized_mul(const at::Tensor & x1,const at::Tensor & x2,double scale,int64_t zero)311 static at::Tensor quantized_mul(
312 const at::Tensor& x1,
313 const at::Tensor& x2,
314 double scale,
315 int64_t zero) {
316 const auto op =
317 c10::Dispatcher::singleton()
318 .findSchemaOrThrow("quantized::mul", "")
319 .typed<at::Tensor(at::Tensor, at::Tensor, double, int64_t)>();
320 return op.call(x1, x2, scale, zero);
321 }
322
quantized_mul_scalar(const at::Tensor & x,double scalar)323 static at::Tensor quantized_mul_scalar(const at::Tensor& x, double scalar) {
324 const auto op = c10::Dispatcher::singleton()
325 .findSchemaOrThrow("quantized::mul", "Scalar")
326 .typed<at::Tensor(at::Tensor, c10::Scalar const&)>();
327 auto s = c10::Scalar(scalar);
328 return op.call(x, s);
329 }
330
quantized_cat(const c10::List<at::Tensor> & qxs,int64_t dim,std::optional<double> scale,std::optional<int64_t> zero)331 static at::Tensor quantized_cat(
332 const c10::List<at::Tensor>& qxs,
333 int64_t dim,
334 std::optional<double> scale,
335 std::optional<int64_t> zero) {
336 const auto op = c10::Dispatcher::singleton()
337 .findSchemaOrThrow("quantized::cat", "")
338 .typed<at::Tensor(
339 c10::List<at::Tensor> const&,
340 int64_t,
341 std::optional<double>,
342 std::optional<int64_t>)>();
343 return op.redispatch(
344 c10::DispatchKeySet({c10::DispatchKey::QuantizedCPU}),
345 qxs,
346 dim,
347 scale,
348 zero);
349 }
350
351 #endif // _WIN32
352
353 #ifdef C10_MOBILE
354 extern "C" {
355 #endif
356
nnc_aten_conv2d(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t args_num,int64_t * extra_args)357 void nnc_aten_conv2d(
358 int64_t bufs_num,
359 void** buf_data,
360 int64_t* buf_ranks,
361 int64_t* buf_dims,
362 int64_t* buf_strides,
363 int8_t* buf_dtypes,
364 int64_t args_num,
365 int64_t* extra_args) {
366 auto tensors = constructTensors(
367 bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
368
369 at::Tensor& r = tensors[0];
370 const at::Tensor& x = tensors[1];
371 const at::Tensor& w = tensors[2];
372 if (args_num > 0) {
373 // Check that if the extra arguments are provided, then the bias tensor is
374 // also present
375 TORCH_INTERNAL_ASSERT(args_num == 7 && bufs_num == 4);
376 const at::Tensor& b = tensors[3];
377
378 int64_t strideH = extra_args[0];
379 int64_t strideW = extra_args[1];
380 int64_t paddingH = extra_args[2];
381 int64_t paddingW = extra_args[3];
382 int64_t dilationH = extra_args[4];
383 int64_t dilationW = extra_args[5];
384 int64_t groups = extra_args[6];
385
386 try {
387 r = at::conv2d(
388 x,
389 w,
390 b,
391 {strideH, strideW},
392 {paddingH, paddingW},
393 {dilationH, dilationW},
394 groups);
395 } catch (...) {
396 }
397 } else {
398 try {
399 r = at::conv2d(x, w);
400 } catch (...) {
401 }
402 }
403
404 // TODO: can i haz an out version of the conv2d?
405 memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
406 }
407
nnc_aten_quantized_conv1d(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)408 void nnc_aten_quantized_conv1d(
409 int64_t bufs_num,
410 void** buf_data,
411 int64_t* buf_ranks,
412 int64_t* buf_dims,
413 int64_t* buf_strides,
414 int8_t* buf_dtypes,
415 int64_t,
416 int64_t* extra_args) {
417 const double x_qscale = ((double*)extra_args)[0];
418 const int64_t x_qzero = extra_args[1];
419 const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
420 auto tensors = constructTensors(
421 bufs_num,
422 buf_data,
423 buf_ranks,
424 buf_dims,
425 buf_strides,
426 buf_dtypes,
427 {{1u, {x_qscale, x_qzero, toQIntType(x_qdtype)}}});
428 auto convPackedParams =
429 reinterpret_cast<ConvPackedParamsBase<2>*>(buf_data[2]);
430 const double out_qscale = ((double*)extra_args)[3];
431 const int64_t out_qzero = extra_args[4];
432 // NOLINTNEXTLINE
433 auto qx = tensors[1].unsqueeze(quant_utils::kConv1dSqueezeDim + 2);
434 auto r = convPackedParams->apply(qx, out_qscale, out_qzero);
435 r = r.squeeze_(quant_utils::kConv1dSqueezeDim + 2);
436 memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
437 }
438
nnc_aten_quantized_conv1d_out(int64_t bufs_in_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)439 void nnc_aten_quantized_conv1d_out(
440 int64_t bufs_in_num,
441 void** buf_data,
442 int64_t* buf_ranks,
443 int64_t* buf_dims,
444 int64_t* buf_strides,
445 int8_t* buf_dtypes,
446 int64_t,
447 int64_t* extra_args) {
448 const size_t bufs_out_num = 1u;
449 const double x_qscale = ((double*)extra_args)[0];
450 const int64_t x_qzero = extra_args[1];
451 const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
452 auto tensors = constructTensors2(
453 bufs_in_num,
454 buf_data,
455 buf_ranks,
456 buf_dims,
457 buf_strides,
458 buf_dtypes,
459 {{1u, {x_qscale, x_qzero, toQIntType(x_qdtype)}}},
460 bufs_out_num);
461 auto convPackedParams =
462 reinterpret_cast<ConvPackedParamsBase<2>*>(buf_data[2]);
463 const double out_qscale = ((double*)extra_args)[3];
464 const int64_t out_qzero = extra_args[4];
465 // NOLINTNEXTLINE
466 auto qx = tensors[1].unsqueeze(quant_utils::kConv1dSqueezeDim + 2);
467 auto r = convPackedParams->apply(qx, out_qscale, out_qzero);
468 r = r.squeeze_(quant_utils::kConv1dSqueezeDim + 2);
469 buf_data[0] = r.data_ptr();
470 c10::raw::intrusive_ptr::incref(r.getIntrusivePtr().get());
471 buf_data[bufs_in_num + bufs_out_num] = r.getIntrusivePtr().get();
472 }
473
nnc_aten_quantized_conv2d(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)474 void nnc_aten_quantized_conv2d(
475 int64_t bufs_num,
476 void** buf_data,
477 int64_t* buf_ranks,
478 int64_t* buf_dims,
479 int64_t* buf_strides,
480 int8_t* buf_dtypes,
481 int64_t,
482 int64_t* extra_args) {
483 const double x_qscale = ((double*)extra_args)[0];
484 const int64_t x_qzero = extra_args[1];
485 const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
486 auto tensors = constructTensors(
487 bufs_num,
488 buf_data,
489 buf_ranks,
490 buf_dims,
491 buf_strides,
492 buf_dtypes,
493 {{1u, {x_qscale, x_qzero, toQIntType(x_qdtype)}}});
494 auto convPackedParams =
495 reinterpret_cast<ConvPackedParamsBase<2>*>(buf_data[2]);
496 const double out_qscale = ((double*)extra_args)[3];
497 const int64_t out_qzero = extra_args[4];
498 // NOLINTNEXTLINE
499 auto r = convPackedParams->apply(tensors[1], out_qscale, out_qzero);
500 memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
501 }
502
nnc_aten_quantized_conv2d_out(int64_t bufs_in_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)503 void nnc_aten_quantized_conv2d_out(
504 int64_t bufs_in_num,
505 void** buf_data,
506 int64_t* buf_ranks,
507 int64_t* buf_dims,
508 int64_t* buf_strides,
509 int8_t* buf_dtypes,
510 int64_t,
511 int64_t* extra_args) {
512 const size_t bufs_out_num = 1u;
513 const double x_qscale = ((double*)extra_args)[0];
514 const int64_t x_qzero = extra_args[1];
515 const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
516 auto tensors = constructTensors2(
517 bufs_in_num,
518 buf_data,
519 buf_ranks,
520 buf_dims,
521 buf_strides,
522 buf_dtypes,
523 {{1u, {x_qscale, x_qzero, toQIntType(x_qdtype)}}},
524 bufs_out_num);
525 auto convPackedParams =
526 reinterpret_cast<ConvPackedParamsBase<2>*>(buf_data[2]);
527 const double out_qscale = ((double*)extra_args)[3];
528 const int64_t out_qzero = extra_args[4];
529 // NOLINTNEXTLINE
530 auto r = convPackedParams->apply(tensors[1], out_qscale, out_qzero);
531 buf_data[0] = r.data_ptr();
532 c10::raw::intrusive_ptr::incref(r.getIntrusivePtr().get());
533 buf_data[bufs_in_num + bufs_out_num] = r.getIntrusivePtr().get();
534 }
535
nnc_aten_quantized_conv2d_relu(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)536 void nnc_aten_quantized_conv2d_relu(
537 int64_t bufs_num,
538 void** buf_data,
539 int64_t* buf_ranks,
540 int64_t* buf_dims,
541 int64_t* buf_strides,
542 int8_t* buf_dtypes,
543 int64_t,
544 int64_t* extra_args) {
545 const double x_qscale = ((double*)extra_args)[0];
546 const int64_t x_qzero = extra_args[1];
547 const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
548 auto tensors = constructTensors(
549 bufs_num,
550 buf_data,
551 buf_ranks,
552 buf_dims,
553 buf_strides,
554 buf_dtypes,
555 {{1u, {x_qscale, x_qzero, toQIntType(x_qdtype)}}});
556 auto convPackedParams =
557 reinterpret_cast<ConvPackedParamsBase<2>*>(buf_data[2]);
558 const double out_qscale = ((double*)extra_args)[3];
559 const int64_t out_qzero = extra_args[4];
560 // NOLINTNEXTLINE
561 auto r = convPackedParams->apply_relu(tensors[1], out_qscale, out_qzero);
562 memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
563 }
564
nnc_aten_quantized_conv2d_relu_out(int64_t bufs_in_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)565 void nnc_aten_quantized_conv2d_relu_out(
566 int64_t bufs_in_num,
567 void** buf_data,
568 int64_t* buf_ranks,
569 int64_t* buf_dims,
570 int64_t* buf_strides,
571 int8_t* buf_dtypes,
572 int64_t,
573 int64_t* extra_args) {
574 const size_t bufs_out_num = 1u;
575 const double x_qscale = ((double*)extra_args)[0];
576 const int64_t x_qzero = extra_args[1];
577 const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
578 auto tensors = constructTensors2(
579 bufs_in_num,
580 buf_data,
581 buf_ranks,
582 buf_dims,
583 buf_strides,
584 buf_dtypes,
585 {{1u, {x_qscale, x_qzero, toQIntType(x_qdtype)}}},
586 bufs_out_num);
587 auto convPackedParams =
588 reinterpret_cast<ConvPackedParamsBase<2>*>(buf_data[2]);
589 const double out_qscale = ((double*)extra_args)[3];
590 const int64_t out_qzero = extra_args[4];
591 // NOLINTNEXTLINE
592 auto r = convPackedParams->apply_relu(tensors[1], out_qscale, out_qzero);
593 buf_data[0] = r.data_ptr();
594 c10::raw::intrusive_ptr::incref(r.getIntrusivePtr().get());
595 buf_data[bufs_in_num + bufs_out_num] = r.getIntrusivePtr().get();
596 }
597
nnc_aten_quantized_linear(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)598 void nnc_aten_quantized_linear(
599 int64_t bufs_num,
600 void** buf_data,
601 int64_t* buf_ranks,
602 int64_t* buf_dims,
603 int64_t* buf_strides,
604 int8_t* buf_dtypes,
605 int64_t,
606 int64_t* extra_args) {
607 const double x_qscale = ((double*)extra_args)[0];
608 const int64_t x_qzero = extra_args[1];
609 const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
610 auto tensors = constructTensors(
611 bufs_num,
612 buf_data,
613 buf_ranks,
614 buf_dims,
615 buf_strides,
616 buf_dtypes,
617 {{1u, {x_qscale, x_qzero, toQIntType(x_qdtype)}}});
618 auto linearPackedParams =
619 reinterpret_cast<LinearPackedParamsBase*>(buf_data[2]);
620 const double out_qscale = ((double*)extra_args)[3];
621 const int64_t out_qzero = extra_args[4];
622 // NOLINTNEXTLINE
623 auto r = linearPackedParams->apply(tensors[1], out_qscale, out_qzero);
624 memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
625 }
626
nnc_aten_quantized_linear_out(int64_t bufs_in_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)627 void nnc_aten_quantized_linear_out(
628 int64_t bufs_in_num,
629 void** buf_data,
630 int64_t* buf_ranks,
631 int64_t* buf_dims,
632 int64_t* buf_strides,
633 int8_t* buf_dtypes,
634 int64_t,
635 int64_t* extra_args) {
636 const size_t bufs_out_num = 1u;
637 const double x_qscale = ((double*)extra_args)[0];
638 const int64_t x_qzero = extra_args[1];
639 const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
640 auto tensors = constructTensors2(
641 bufs_in_num,
642 buf_data,
643 buf_ranks,
644 buf_dims,
645 buf_strides,
646 buf_dtypes,
647 {{1u, {x_qscale, x_qzero, toQIntType(x_qdtype)}}},
648 bufs_out_num);
649 auto linearPackedParams =
650 reinterpret_cast<LinearPackedParamsBase*>(buf_data[2]);
651 const double out_qscale = ((double*)extra_args)[3];
652 const int64_t out_qzero = extra_args[4];
653 // NOLINTNEXTLINE
654 auto r = linearPackedParams->apply(tensors[1], out_qscale, out_qzero);
655 buf_data[0] = r.data_ptr();
656 c10::raw::intrusive_ptr::incref(r.getIntrusivePtr().get());
657 buf_data[bufs_in_num + bufs_out_num] = r.getIntrusivePtr().get();
658 }
659
nnc_aten_quantized_linear_relu(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)660 void nnc_aten_quantized_linear_relu(
661 int64_t bufs_num,
662 void** buf_data,
663 int64_t* buf_ranks,
664 int64_t* buf_dims,
665 int64_t* buf_strides,
666 int8_t* buf_dtypes,
667 int64_t,
668 int64_t* extra_args) {
669 const double x_qscale = ((double*)extra_args)[0];
670 const int64_t x_qzero = extra_args[1];
671 const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
672 auto tensors = constructTensors(
673 bufs_num,
674 buf_data,
675 buf_ranks,
676 buf_dims,
677 buf_strides,
678 buf_dtypes,
679 {{1u, {x_qscale, x_qzero, toQIntType(x_qdtype)}}});
680 auto linearPackedParams =
681 reinterpret_cast<LinearPackedParamsBase*>(buf_data[2]);
682 const double out_qscale = ((double*)extra_args)[3];
683 const int64_t out_qzero = extra_args[4];
684 // NOLINTNEXTLINE
685 auto r = linearPackedParams->apply_relu(tensors[1], out_qscale, out_qzero);
686 memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
687 }
688
689 #ifndef _WIN32
nnc_aten_quantized_add(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)690 void nnc_aten_quantized_add(
691 int64_t bufs_num,
692 void** buf_data,
693 int64_t* buf_ranks,
694 int64_t* buf_dims,
695 int64_t* buf_strides,
696 int8_t* buf_dtypes,
697 int64_t,
698 int64_t* extra_args) {
699 // TORCH_INTERNAL_ASSERT(tensors.size() == 3);
700
701 const double a_qscale = ((double*)extra_args)[0];
702 const int64_t a_qzero = extra_args[1];
703 const c10::ScalarType a_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
704 const double b_qscale = ((double*)extra_args)[3];
705 const int64_t b_qzero = extra_args[4];
706 const c10::ScalarType b_qdtype = static_cast<c10::ScalarType>(extra_args[5]);
707 auto tensors = constructTensors(
708 bufs_num,
709 buf_data,
710 buf_ranks,
711 buf_dims,
712 buf_strides,
713 buf_dtypes,
714 {{1u, {a_qscale, a_qzero, toQIntType(a_qdtype)}},
715 {2u, {b_qscale, b_qzero, toQIntType(b_qdtype)}}});
716
717 const double out_qscale = ((double*)extra_args)[6];
718 const int64_t out_qzero = extra_args[7];
719 // NOLINTNEXTLINE
720 auto r = quantized_add(tensors[1], tensors[2], out_qscale, out_qzero);
721 memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
722 }
723
nnc_aten_quantized_mul(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)724 void nnc_aten_quantized_mul(
725 int64_t bufs_num,
726 void** buf_data,
727 int64_t* buf_ranks,
728 int64_t* buf_dims,
729 int64_t* buf_strides,
730 int8_t* buf_dtypes,
731 int64_t,
732 int64_t* extra_args) {
733 const double a_qscale = ((double*)extra_args)[0];
734 const int64_t a_qzero = extra_args[1];
735 const c10::ScalarType a_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
736 const double b_qscale = ((double*)extra_args)[3];
737 const int64_t b_qzero = extra_args[4];
738 const c10::ScalarType b_qdtype = static_cast<c10::ScalarType>(extra_args[5]);
739 auto tensors = constructTensors(
740 bufs_num,
741 buf_data,
742 buf_ranks,
743 buf_dims,
744 buf_strides,
745 buf_dtypes,
746 {{1u, {a_qscale, a_qzero, toQIntType(a_qdtype)}},
747 {2u, {b_qscale, b_qzero, toQIntType(b_qdtype)}}});
748 const double out_qscale = ((double*)extra_args)[6];
749 const int64_t out_qzero = extra_args[7];
750 // NOLINTNEXTLINE
751 auto r = quantized_mul(tensors[1], tensors[2], out_qscale, out_qzero);
752 memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
753 }
754
nnc_aten_quantized_mul_out(int64_t bufs_in_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)755 void nnc_aten_quantized_mul_out(
756 int64_t bufs_in_num,
757 void** buf_data,
758 int64_t* buf_ranks,
759 int64_t* buf_dims,
760 int64_t* buf_strides,
761 int8_t* buf_dtypes,
762 int64_t,
763 int64_t* extra_args) {
764 const size_t bufs_out_num = 1u;
765 const double a_qscale = ((double*)extra_args)[0];
766 const int64_t a_qzero = extra_args[1];
767 const c10::ScalarType a_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
768 const double b_qscale = ((double*)extra_args)[3];
769 const int64_t b_qzero = extra_args[4];
770 const c10::ScalarType b_qdtype = static_cast<c10::ScalarType>(extra_args[5]);
771 auto tensors = constructTensors2(
772 bufs_in_num,
773 buf_data,
774 buf_ranks,
775 buf_dims,
776 buf_strides,
777 buf_dtypes,
778 {{1u, {a_qscale, a_qzero, toQIntType(a_qdtype)}},
779 {2u, {b_qscale, b_qzero, toQIntType(b_qdtype)}}},
780 1u);
781 const double out_qscale = ((double*)extra_args)[6];
782 const int64_t out_qzero = extra_args[7];
783 // NOLINTNEXTLINE
784 auto r = quantized_mul(tensors[1], tensors[2], out_qscale, out_qzero);
785 buf_data[0] = r.data_ptr();
786 c10::raw::intrusive_ptr::incref(r.getIntrusivePtr().get());
787 buf_data[bufs_in_num + bufs_out_num] = r.getIntrusivePtr().get();
788 }
789
nnc_aten_quantized_mul_scalar(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)790 void nnc_aten_quantized_mul_scalar(
791 int64_t bufs_num,
792 void** buf_data,
793 int64_t* buf_ranks,
794 int64_t* buf_dims,
795 int64_t* buf_strides,
796 int8_t* buf_dtypes,
797 int64_t,
798 int64_t* extra_args) {
799 const double x_qscale = ((double*)extra_args)[0];
800 const int64_t x_qzero = extra_args[1];
801 const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
802 auto tensors = constructTensors(
803 bufs_num,
804 buf_data,
805 buf_ranks,
806 buf_dims,
807 buf_strides,
808 buf_dtypes,
809 {{1u, {x_qscale, x_qzero, toQIntType(x_qdtype)}}});
810 const double scalar = ((double*)extra_args)[3];
811 // NOLINTNEXTLINE
812 auto r = quantized_mul_scalar(tensors[1], scalar);
813 memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
814 }
815
nnc_aten_quantized_mul_scalar_out(int64_t bufs_in_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)816 void nnc_aten_quantized_mul_scalar_out(
817 int64_t bufs_in_num,
818 void** buf_data,
819 int64_t* buf_ranks,
820 int64_t* buf_dims,
821 int64_t* buf_strides,
822 int8_t* buf_dtypes,
823 int64_t,
824 int64_t* extra_args) {
825 const size_t bufs_out_num = 1u;
826 const double x_qscale = ((double*)extra_args)[0];
827 const int64_t x_qzero = extra_args[1];
828 const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
829 auto tensors = constructTensors2(
830 bufs_in_num,
831 buf_data,
832 buf_ranks,
833 buf_dims,
834 buf_strides,
835 buf_dtypes,
836 {{1u, {x_qscale, x_qzero, toQIntType(x_qdtype)}}},
837 bufs_out_num);
838 const double scalar = ((double*)extra_args)[3];
839 // NOLINTNEXTLINE
840 auto r = quantized_mul_scalar(tensors[1], scalar);
841 buf_data[0] = r.data_ptr();
842 c10::raw::intrusive_ptr::incref(r.getIntrusivePtr().get());
843 buf_data[bufs_in_num + bufs_out_num] = r.getIntrusivePtr().get();
844 }
845
nnc_aten_quantized_relu(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)846 void nnc_aten_quantized_relu(
847 int64_t bufs_num,
848 void** buf_data,
849 int64_t* buf_ranks,
850 int64_t* buf_dims,
851 int64_t* buf_strides,
852 int8_t* buf_dtypes,
853 int64_t,
854 int64_t* extra_args) {
855 const double x_qscale = ((double*)extra_args)[0];
856 const int64_t x_qzero = extra_args[1];
857 const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
858 auto tensors = constructTensors(
859 bufs_num,
860 buf_data,
861 buf_ranks,
862 buf_dims,
863 buf_strides,
864 buf_dtypes,
865 {{1u, {x_qscale, x_qzero, toQIntType(x_qdtype)}}});
866 // NOLINTNEXTLINE
867 auto r = at::relu(tensors[1]);
868 memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
869 }
870
nnc_aten_quantized_sigmoid(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)871 void nnc_aten_quantized_sigmoid(
872 int64_t bufs_num,
873 void** buf_data,
874 int64_t* buf_ranks,
875 int64_t* buf_dims,
876 int64_t* buf_strides,
877 int8_t* buf_dtypes,
878 int64_t,
879 int64_t* extra_args) {
880 const double x_qscale = ((double*)extra_args)[0];
881 const int64_t x_qzero = extra_args[1];
882 const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
883 auto tensors = constructTensors(
884 bufs_num,
885 buf_data,
886 buf_ranks,
887 buf_dims,
888 buf_strides,
889 buf_dtypes,
890 {{1u, {x_qscale, x_qzero, toQIntType(x_qdtype)}}});
891
892 // NOLINTNEXTLINE
893 auto r = at::sigmoid(tensors[1]);
894 memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
895 }
896
nnc_aten_quantized_sigmoid_out(int64_t bufs_in_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)897 void nnc_aten_quantized_sigmoid_out(
898 int64_t bufs_in_num,
899 void** buf_data,
900 int64_t* buf_ranks,
901 int64_t* buf_dims,
902 int64_t* buf_strides,
903 int8_t* buf_dtypes,
904 int64_t,
905 int64_t* extra_args) {
906 const double x_qscale = ((double*)extra_args)[0];
907 const int64_t x_qzero = extra_args[1];
908 const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
909 const size_t bufs_out_num = 1u;
910 auto tensors = constructTensors2(
911 bufs_in_num,
912 buf_data,
913 buf_ranks,
914 buf_dims,
915 buf_strides,
916 buf_dtypes,
917 {{1u, {x_qscale, x_qzero, toQIntType(x_qdtype)}}},
918 bufs_out_num);
919
920 // NOLINTNEXTLINE
921 auto r = at::sigmoid(tensors[1]);
922 buf_data[0] = r.data_ptr();
923 c10::raw::intrusive_ptr::incref(r.getIntrusivePtr().get());
924 buf_data[bufs_in_num + bufs_out_num] = r.getIntrusivePtr().get();
925 }
926
nnc_aten_quantized_cat(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)927 void nnc_aten_quantized_cat(
928 int64_t bufs_num,
929 void** buf_data,
930 int64_t* buf_ranks,
931 int64_t* buf_dims,
932 int64_t* buf_strides,
933 int8_t* buf_dtypes,
934 int64_t,
935 int64_t* extra_args) {
936 std::vector<std::pair<size_t, QIData>> qdata;
937 const auto in_bufs_num = bufs_num - 1;
938 const double out_qscale = ((double*)extra_args)[3 * in_bufs_num + 1];
939 const int64_t out_qzero = extra_args[3 * in_bufs_num + 2];
940 qdata.emplace_back(
941 0u,
942 QIData{
943 out_qscale, out_qzero, static_cast<c10::ScalarType>(extra_args[2])});
944 for (const size_t i : c10::irange(in_bufs_num)) {
945 const double qscale = ((double*)extra_args)[3 * i + 0];
946 const int64_t qzero = extra_args[3 * i + 1];
947 const c10::ScalarType qdtype =
948 static_cast<c10::ScalarType>(extra_args[3 * i + 2]);
949 qdata.emplace_back(i + 1u, QIData{qscale, qzero, qdtype});
950 }
951 auto tensors = constructTensors(
952 bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes, qdata);
953 const int64_t dim = extra_args[3 * in_bufs_num + 0];
954 auto qxs = c10::List<at::Tensor>(
955 std::vector<at::Tensor>(tensors.begin() + 1, tensors.end()));
956 auto r = quantized_cat(qxs, dim, out_qscale, out_qzero);
957 memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
958 }
959 #endif // _WIN32
960
nnc_aten_upsample_nearest2d(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)961 void nnc_aten_upsample_nearest2d(
962 int64_t bufs_num,
963 void** buf_data,
964 int64_t* buf_ranks,
965 int64_t* buf_dims,
966 int64_t* buf_strides,
967 int8_t* buf_dtypes,
968 int64_t,
969 int64_t* extra_args) {
970 // NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
971 const double x_qscale = ((double*)extra_args)[0];
972 const int64_t x_qzero = extra_args[1];
973 const int64_t x_qdtype = extra_args[2];
974 const auto is_quantized = x_qdtype != -1;
975 std::optional<std::vector<std::pair<size_t, QIData>>> qdata;
976 if (is_quantized) {
977 qdata = {
978 {1u,
979 {x_qscale,
980 x_qzero,
981 at::toQIntType(static_cast<c10::ScalarType>(x_qdtype))}}};
982 }
983 auto tensors = constructTensors(
984 bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes, qdata);
985 auto x = tensors[1];
986
987 int64_t output_size_h = extra_args[3];
988 int64_t output_size_w = extra_args[4];
989 double scale_factor_h = ((double*)extra_args)[5];
990 double scale_factor_w = ((double*)extra_args)[6];
991
992 auto r = at::upsample_nearest2d(
993 x,
994 (output_size_h != -1)
995 ? std::optional<at::IntArrayRef>({output_size_h, output_size_w})
996 : std::nullopt,
997 (scale_factor_h != -1.f) ? std::optional<at::ArrayRef<double>>(
998 {scale_factor_h, scale_factor_w})
999 : std::nullopt);
1000 memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
1001 }
1002
nnc_aten_upsample_nearest2d_out(int64_t bufs_in_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)1003 void nnc_aten_upsample_nearest2d_out(
1004 int64_t bufs_in_num,
1005 void** buf_data,
1006 int64_t* buf_ranks,
1007 int64_t* buf_dims,
1008 int64_t* buf_strides,
1009 int8_t* buf_dtypes,
1010 int64_t,
1011 int64_t* extra_args) {
1012 const size_t bufs_out_num = 1u;
1013 // NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
1014 const double x_qscale = ((double*)extra_args)[0];
1015 const int64_t x_qzero = extra_args[1];
1016 const int64_t x_qdtype = extra_args[2];
1017 const auto is_quantized = x_qdtype != -1;
1018 std::optional<std::vector<std::pair<size_t, QIData>>> qdata;
1019 if (is_quantized) {
1020 qdata = {
1021 {1u,
1022 {x_qscale,
1023 x_qzero,
1024 at::toQIntType(static_cast<c10::ScalarType>(x_qdtype))}}};
1025 }
1026 auto tensors = constructTensors2(
1027 bufs_in_num,
1028 buf_data,
1029 buf_ranks,
1030 buf_dims,
1031 buf_strides,
1032 buf_dtypes,
1033 qdata,
1034 bufs_out_num);
1035 auto x = tensors[1];
1036
1037 int64_t output_size_h = extra_args[3];
1038 int64_t output_size_w = extra_args[4];
1039 double scale_factor_h = ((double*)extra_args)[5];
1040 double scale_factor_w = ((double*)extra_args)[6];
1041
1042 auto r = at::upsample_nearest2d(
1043 x,
1044 (output_size_h != -1)
1045 ? std::optional<at::IntArrayRef>({output_size_h, output_size_w})
1046 : std::nullopt,
1047 (scale_factor_h != -1.f) ? std::optional<at::ArrayRef<double>>(
1048 {scale_factor_h, scale_factor_w})
1049 : std::nullopt);
1050 buf_data[0] = r.data_ptr();
1051 c10::raw::intrusive_ptr::incref(r.getIntrusivePtr().get());
1052 buf_data[bufs_in_num + bufs_out_num] = r.getIntrusivePtr().get();
1053 }
1054
nnc_aten_quantize_per_tensor(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)1055 void nnc_aten_quantize_per_tensor(
1056 int64_t bufs_num,
1057 void** buf_data,
1058 int64_t* buf_ranks,
1059 int64_t* buf_dims,
1060 int64_t* buf_strides,
1061 int8_t* buf_dtypes,
1062 int64_t,
1063 int64_t* extra_args) {
1064 auto tensors = constructTensors(
1065 bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
1066 // NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
1067 at::Tensor x = tensors[1];
1068 const double qscale = ((double*)extra_args)[0];
1069 const int64_t qzero = extra_args[1];
1070 const c10::ScalarType qdtype = static_cast<c10::ScalarType>(extra_args[2]);
1071 auto r = at::quantize_per_tensor(x, qscale, qzero, qdtype);
1072 memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
1073 }
1074
nnc_aten_quantize_per_tensor_out(int64_t bufs_in_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)1075 void nnc_aten_quantize_per_tensor_out(
1076 int64_t bufs_in_num,
1077 void** buf_data,
1078 int64_t* buf_ranks,
1079 int64_t* buf_dims,
1080 int64_t* buf_strides,
1081 int8_t* buf_dtypes,
1082 int64_t,
1083 int64_t* extra_args) {
1084 const size_t bufs_out_num = 1u;
1085 auto tensors = constructTensors2(
1086 bufs_in_num,
1087 buf_data,
1088 buf_ranks,
1089 buf_dims,
1090 buf_strides,
1091 buf_dtypes,
1092 std::nullopt,
1093 bufs_out_num);
1094 // NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
1095 at::Tensor x = tensors[1];
1096 const double qscale = ((double*)extra_args)[0];
1097 const int64_t qzero = extra_args[1];
1098 const c10::ScalarType qdtype = static_cast<c10::ScalarType>(extra_args[2]);
1099 auto r = at::quantize_per_tensor(x, qscale, qzero, qdtype);
1100 buf_data[0] = r.data_ptr();
1101 c10::raw::intrusive_ptr::incref(r.getIntrusivePtr().get());
1102 buf_data[bufs_in_num + bufs_out_num] = r.getIntrusivePtr().get();
1103 }
1104
nnc_aten_dequantize(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)1105 void nnc_aten_dequantize(
1106 int64_t bufs_num,
1107 void** buf_data,
1108 int64_t* buf_ranks,
1109 int64_t* buf_dims,
1110 int64_t* buf_strides,
1111 int8_t* buf_dtypes,
1112 int64_t,
1113 int64_t* extra_args) {
1114 const double qscale = ((double*)extra_args)[0];
1115 const int64_t qzero = extra_args[1];
1116 const int64_t qdtype = extra_args[2];
1117 auto tensors = constructTensors(
1118 bufs_num,
1119 buf_data,
1120 buf_ranks,
1121 buf_dims,
1122 buf_strides,
1123 buf_dtypes,
1124 {{1u,
1125 {qscale, qzero, toQIntType(static_cast<c10::ScalarType>(qdtype))}}});
1126 // NOLINTNEXTLINE
1127 auto r = at::dequantize(tensors[1]);
1128 memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
1129 }
1130
nnc_aten_dequantize_out(int64_t bufs_in_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)1131 void nnc_aten_dequantize_out(
1132 int64_t bufs_in_num,
1133 void** buf_data,
1134 int64_t* buf_ranks,
1135 int64_t* buf_dims,
1136 int64_t* buf_strides,
1137 int8_t* buf_dtypes,
1138 int64_t,
1139 int64_t* extra_args) {
1140 const size_t bufs_out_num = 1u;
1141 const double qscale = ((double*)extra_args)[0];
1142 const int64_t qzero = extra_args[1];
1143 const int64_t qdtype = extra_args[2];
1144 auto tensors = constructTensors2(
1145 bufs_in_num,
1146 buf_data,
1147 buf_ranks,
1148 buf_dims,
1149 buf_strides,
1150 buf_dtypes,
1151 {{1u, {qscale, qzero, toQIntType(static_cast<c10::ScalarType>(qdtype))}}},
1152 bufs_out_num);
1153 // NOLINTNEXTLINE
1154 auto r = at::dequantize(tensors[1]);
1155 buf_data[0] = r.data_ptr();
1156 c10::raw::intrusive_ptr::incref(r.getIntrusivePtr().get());
1157 buf_data[bufs_in_num + bufs_out_num] = r.getIntrusivePtr().get();
1158 }
1159
nnc_aten_conv1d(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t args_num,int64_t * extra_args)1160 void nnc_aten_conv1d(
1161 int64_t bufs_num,
1162 void** buf_data,
1163 int64_t* buf_ranks,
1164 int64_t* buf_dims,
1165 int64_t* buf_strides,
1166 int8_t* buf_dtypes,
1167 int64_t args_num,
1168 int64_t* extra_args) {
1169 auto tensors = constructTensors(
1170 bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
1171
1172 at::Tensor& r = tensors[0];
1173 const at::Tensor& x = tensors[1];
1174 const at::Tensor& w = tensors[2];
1175 if (args_num > 0) {
1176 // Check that if the extra arguments are provided, then the bias tensor is
1177 // also present
1178 TORCH_INTERNAL_ASSERT(args_num == 4 && bufs_num == 4);
1179 const at::Tensor& b = tensors[3];
1180
1181 int64_t stride = extra_args[0];
1182 int64_t padding = extra_args[1];
1183 int64_t dilation = extra_args[2];
1184 int64_t groups = extra_args[3];
1185
1186 try {
1187 r = at::conv1d(x, w, b, {stride}, {padding}, {dilation}, groups);
1188 } catch (...) {
1189 }
1190 } else {
1191 try {
1192 r = at::conv1d(x, w);
1193 } catch (...) {
1194 }
1195 }
1196
1197 memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
1198 }
1199
nnc_aten_conv1d_out(int64_t bufs_in_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t args_num,int64_t * extra_args)1200 void nnc_aten_conv1d_out(
1201 int64_t bufs_in_num,
1202 void** buf_data,
1203 int64_t* buf_ranks,
1204 int64_t* buf_dims,
1205 int64_t* buf_strides,
1206 int8_t* buf_dtypes,
1207 int64_t args_num,
1208 int64_t* extra_args) {
1209 const size_t bufs_out_num = 1u;
1210 auto tensors = constructTensors2(
1211 bufs_in_num,
1212 buf_data,
1213 buf_ranks,
1214 buf_dims,
1215 buf_strides,
1216 buf_dtypes,
1217 std::nullopt,
1218 bufs_out_num);
1219
1220 at::Tensor r;
1221 const at::Tensor& x = tensors[1];
1222 const at::Tensor& w = tensors[2];
1223 if (args_num > 0) {
1224 // Check that if the extra arguments are provided, then the bias tensor is
1225 // also present
1226 TORCH_INTERNAL_ASSERT(args_num == 4 && bufs_in_num == 3);
1227 const at::Tensor& b = tensors[3];
1228
1229 int64_t stride = extra_args[0];
1230 int64_t padding = extra_args[1];
1231 int64_t dilation = extra_args[2];
1232 int64_t groups = extra_args[3];
1233
1234 try {
1235 r = at::conv1d(x, w, b, {stride}, {padding}, {dilation}, groups);
1236 } catch (...) {
1237 }
1238 } else {
1239 try {
1240 r = at::conv1d(x, w);
1241 } catch (...) {
1242 }
1243 }
1244
1245 buf_data[0] = r.data_ptr();
1246 c10::raw::intrusive_ptr::incref(r.getIntrusivePtr().get());
1247 buf_data[bufs_in_num + bufs_out_num] = r.getIntrusivePtr().get();
1248 }
1249
nnc_aten_adaptive_avg_pool2d(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t args_num,int64_t * extra_args)1250 void nnc_aten_adaptive_avg_pool2d(
1251 int64_t bufs_num,
1252 void** buf_data,
1253 int64_t* buf_ranks,
1254 int64_t* buf_dims,
1255 int64_t* buf_strides,
1256 int8_t* buf_dtypes,
1257 int64_t args_num,
1258 int64_t* extra_args) {
1259 auto tensors = constructTensors(
1260 bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
1261
1262 at::Tensor& r = tensors[0];
1263 const at::Tensor& x = tensors[1];
1264 int64_t H = extra_args[0];
1265 int64_t W = H;
1266 if (args_num > 1) {
1267 W = extra_args[1];
1268 }
1269 try {
1270 r = at::adaptive_avg_pool2d(x, {H, W});
1271 } catch (...) {
1272 }
1273 memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
1274 }
1275
nnc_aten_mean(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t args_num,int64_t * extra_args)1276 void nnc_aten_mean(
1277 int64_t bufs_num,
1278 void** buf_data,
1279 int64_t* buf_ranks,
1280 int64_t* buf_dims,
1281 int64_t* buf_strides,
1282 int8_t* buf_dtypes,
1283 int64_t args_num,
1284 int64_t* extra_args) {
1285 auto tensors = constructTensors(
1286 bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
1287
1288 at::Tensor& r = tensors[0];
1289 const at::Tensor& x = tensors[1];
1290 std::vector<int64_t> mean_dims(args_num - 1);
1291 bool keepdim = (bool)extra_args[args_num - 1];
1292 if (args_num > 1) {
1293 memcpy(mean_dims.data(), extra_args, sizeof(int64_t) * (args_num - 1));
1294 }
1295 try {
1296 at::mean_out(r, x, mean_dims, keepdim);
1297 } catch (...) {
1298 }
1299 }
1300
nnc_aten_max_red(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t args_num,int64_t * extra_args)1301 void nnc_aten_max_red(
1302 int64_t bufs_num,
1303 void** buf_data,
1304 int64_t* buf_ranks,
1305 int64_t* buf_dims,
1306 int64_t* buf_strides,
1307 int8_t* buf_dtypes,
1308 int64_t args_num,
1309 int64_t* extra_args) {
1310 auto tensors = constructTensors(
1311 bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
1312
1313 at::Tensor& r = tensors[0];
1314 const at::Tensor& x = tensors[1];
1315 int64_t max_dim = extra_args[0];
1316 bool keep_dim = extra_args[1];
1317 try {
1318 r = std::get<0>(at::max(x, max_dim, keep_dim));
1319 } catch (...) {
1320 }
1321 memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
1322 }
1323
nnc_aten_max_red_out(int64_t bufs_in_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)1324 void nnc_aten_max_red_out(
1325 int64_t bufs_in_num,
1326 void** buf_data,
1327 int64_t* buf_ranks,
1328 int64_t* buf_dims,
1329 int64_t* buf_strides,
1330 int8_t* buf_dtypes,
1331 int64_t,
1332 int64_t* extra_args) {
1333 size_t bufs_out_num = 1u;
1334 auto tensors = constructTensors2(
1335 bufs_in_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
1336
1337 at::Tensor r;
1338 // @lint-ignore CLANGTIDY
1339 const at::Tensor& x = tensors[1];
1340 int64_t max_dim = extra_args[0];
1341 bool keep_dim = extra_args[1];
1342 try {
1343 r = std::get<0>(at::max(x, max_dim, keep_dim));
1344 } catch (...) {
1345 }
1346 buf_data[0] = r.data_ptr();
1347 c10::raw::intrusive_ptr::incref(r.getIntrusivePtr().get());
1348 buf_data[bufs_in_num + bufs_out_num] = r.getIntrusivePtr().get();
1349 }
1350
nnc_aten_addmm(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t args_num,int64_t * extra_args)1351 void nnc_aten_addmm(
1352 int64_t bufs_num,
1353 void** buf_data,
1354 int64_t* buf_ranks,
1355 int64_t* buf_dims,
1356 int64_t* buf_strides,
1357 int8_t* buf_dtypes,
1358 int64_t args_num,
1359 int64_t* extra_args) {
1360 auto tensors = constructTensors(
1361 bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
1362
1363 at::Tensor& r = tensors[0];
1364 const at::Tensor& x = tensors[1];
1365 const at::Tensor& y = tensors[2];
1366 const at::Tensor& z = tensors[3];
1367 // TODO: handle other alpha and beta dtypes, e.g. alpha=0.6, beta=0.2
1368 int64_t beta = extra_args[0], alpha = extra_args[1];
1369
1370 try {
1371 at::addmm_out(r, x, y, z, beta, alpha);
1372 } catch (...) {
1373 }
1374 }
1375
1376 // Only provides first output, the second output is just a copy of one of the
1377 // inputs
nnc_aten_triangular_solve(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t args_num,int64_t * extra_args)1378 void nnc_aten_triangular_solve(
1379 int64_t bufs_num,
1380 void** buf_data,
1381 int64_t* buf_ranks,
1382 int64_t* buf_dims,
1383 int64_t* buf_strides,
1384 int8_t* buf_dtypes,
1385 int64_t args_num,
1386 int64_t* extra_args) {
1387 auto tensors = constructTensors(
1388 bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
1389 at::Tensor& r = tensors[0];
1390 at::Tensor r2 = tensors[2].clone();
1391 const at::Tensor& input = tensors[1];
1392 const at::Tensor& A = tensors[2];
1393 try {
1394 at::triangular_solve_out(
1395 r, r2, input, A, extra_args[0], extra_args[2], extra_args[3]);
1396 } catch (...) {
1397 }
1398 }
1399
1400 #if AT_MKLDNN_ENABLED()
1401
nnc_mkldnn_prepacked_conv_run(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t args_num,int64_t * extra_args)1402 void nnc_mkldnn_prepacked_conv_run(
1403 int64_t bufs_num,
1404 void** buf_data,
1405 int64_t* buf_ranks,
1406 int64_t* buf_dims,
1407 int64_t* buf_strides,
1408 int8_t* buf_dtypes,
1409 int64_t args_num,
1410 int64_t* extra_args) {
1411 using namespace at::native::mkldnn;
1412
1413 auto tensors = constructTensors(
1414 bufs_num - 1, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
1415
1416 const at::Tensor& x = tensors[1];
1417 auto context = reinterpret_cast<ConvOpContext*>(buf_data[2]);
1418
1419 context->run(x, buf_data[0]);
1420 }
1421
1422 #endif // AT_MKLDNN_ENABLED()
1423
1424 #ifdef USE_XNNPACK
1425
nnc_prepacked_linear_clamp_run(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t args_num,int64_t * extra_args)1426 void nnc_prepacked_linear_clamp_run(
1427 int64_t bufs_num,
1428 void** buf_data,
1429 int64_t* buf_ranks,
1430 int64_t* buf_dims,
1431 int64_t* buf_strides,
1432 int8_t* buf_dtypes,
1433 int64_t args_num,
1434 int64_t* extra_args) {
1435 using namespace at::native::xnnpack;
1436
1437 auto tensors = constructTensors(
1438 bufs_num - 1, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
1439
1440 const at::Tensor& x = tensors[1];
1441 auto context = reinterpret_cast<LinearOpContext*>(buf_data[2]);
1442 at::Tensor output = context->run(x);
1443 memcpy(
1444 buf_data[0],
1445 output.const_data_ptr(),
1446 output.element_size() * output.numel());
1447 }
1448
nnc_prepacked_conv2d_clamp_run(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t args_num,int64_t * extra_args)1449 void nnc_prepacked_conv2d_clamp_run(
1450 int64_t bufs_num,
1451 void** buf_data,
1452 int64_t* buf_ranks,
1453 int64_t* buf_dims,
1454 int64_t* buf_strides,
1455 int8_t* buf_dtypes,
1456 int64_t args_num,
1457 int64_t* extra_args) {
1458 using namespace at::native::xnnpack;
1459
1460 auto tensors = constructTensors(
1461 bufs_num - 1, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
1462
1463 const at::Tensor& x = tensors[1];
1464 auto context = reinterpret_cast<Conv2dOpContext*>(buf_data[2]);
1465 at::Tensor output = context->run(x);
1466 memcpy(
1467 buf_data[0],
1468 output.const_data_ptr(),
1469 output.element_size() * output.numel());
1470 }
1471
1472 #endif // USE_XNNPACK
1473
nnc_aten_embedding(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t args_num,int64_t * extra_args)1474 void nnc_aten_embedding(
1475 int64_t bufs_num,
1476 void** buf_data,
1477 int64_t* buf_ranks,
1478 int64_t* buf_dims,
1479 int64_t* buf_strides,
1480 int8_t* buf_dtypes,
1481 int64_t args_num,
1482 int64_t* extra_args) {
1483 auto tensors = constructTensors(
1484 bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
1485
1486 at::Tensor& r = tensors[0];
1487 const at::Tensor& weight = tensors[1];
1488 const at::Tensor& indices = tensors[2];
1489 try {
1490 r = at::embedding(weight, indices);
1491 } catch (...) {
1492 }
1493 // TODO: have to copy output because at::embedding doesnt have an out
1494 // variant and NNC's external calls don't support allocations
1495 memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
1496 }
1497
1498 #ifndef C10_MOBILE
1499
1500 const static RegisterNNCExternalFunction nnc_conv2d(
1501 "nnc_aten_conv2d",
1502 nnc_aten_conv2d);
1503
1504 const static RegisterNNCExternalFunction nnc_quantized_conv1d(
1505 "nnc_aten_quantized_conv1d",
1506 nnc_aten_quantized_conv1d);
1507 const static RegisterNNCExternalFunction nnc_quantized_conv1d_out(
1508 "nnc_aten_quantized_conv1d_out",
1509 nnc_aten_quantized_conv1d_out);
1510 const static RegisterNNCExternalFunction nnc_quantized_conv2d(
1511 "nnc_aten_quantized_conv2d",
1512 nnc_aten_quantized_conv2d);
1513 const static RegisterNNCExternalFunction nnc_quantized_conv2d_out(
1514 "nnc_aten_quantized_conv2d_out",
1515 nnc_aten_quantized_conv2d_out);
1516 const static RegisterNNCExternalFunction nnc_quantized_conv2d_relu(
1517 "nnc_aten_quantized_conv2d_relu",
1518 nnc_aten_quantized_conv2d_relu);
1519 const static RegisterNNCExternalFunction nnc_quantized_conv2d_relu_out(
1520 "nnc_aten_quantized_conv2d_relu_out",
1521 nnc_aten_quantized_conv2d_relu_out);
1522 const static RegisterNNCExternalFunction nnc_quantized_linear(
1523 "nnc_aten_quantized_linear",
1524 nnc_aten_quantized_linear);
1525 const static RegisterNNCExternalFunction nnc_quantized_linear_out(
1526 "nnc_aten_quantized_linear_out",
1527 nnc_aten_quantized_linear_out);
1528 #ifndef _WIN32
1529 const static RegisterNNCExternalFunction nnc_quantized_add(
1530 "nnc_aten_quantized_add",
1531 nnc_aten_quantized_add);
1532 const static RegisterNNCExternalFunction nnc_quantized_mul(
1533 "nnc_aten_quantized_mul",
1534 nnc_aten_quantized_mul);
1535 const static RegisterNNCExternalFunction nnc_quantized_mul_out(
1536 "nnc_aten_quantized_mul_out",
1537 nnc_aten_quantized_mul_out);
1538 const static RegisterNNCExternalFunction nnc_quantized_mul_scalar(
1539 "nnc_aten_quantized_mul_scalar",
1540 nnc_aten_quantized_mul_scalar);
1541 const static RegisterNNCExternalFunction nnc_quantized_mul_scalar_out(
1542 "nnc_aten_quantized_mul_scalar_out",
1543 nnc_aten_quantized_mul_scalar_out);
1544 const static RegisterNNCExternalFunction nnc_quantized_sigmoid(
1545 "nnc_aten_quantized_sigmoid",
1546 nnc_aten_quantized_sigmoid);
1547 const static RegisterNNCExternalFunction nnc_quantized_sigmoid_out(
1548 "nnc_aten_quantized_sigmoid_out",
1549 nnc_aten_quantized_sigmoid_out);
1550 const static RegisterNNCExternalFunction nnc_quantized_cat(
1551 "nnc_aten_quantized_cat",
1552 nnc_aten_quantized_cat);
1553 const static RegisterNNCExternalFunction nnc_quantized_relu(
1554 "nnc_aten_quantized_relu",
1555 nnc_aten_quantized_relu);
1556 #endif // _WIN32
1557 const static RegisterNNCExternalFunction nnc_quantize_per_tensor(
1558 "nnc_aten_quantize_per_tensor",
1559 nnc_aten_quantize_per_tensor);
1560 const static RegisterNNCExternalFunction nnc_quantize_per_tensor_out(
1561 "nnc_aten_quantize_per_tensor_out",
1562 nnc_aten_quantize_per_tensor_out);
1563 const static RegisterNNCExternalFunction nnc_dequantize(
1564 "nnc_aten_dequantize",
1565 nnc_aten_dequantize);
1566 const static RegisterNNCExternalFunction nnc_dequantize_out(
1567 "nnc_aten_dequantize_out",
1568 nnc_aten_dequantize_out);
1569
1570 const static RegisterNNCExternalFunction nnc_upsample_nearest2d(
1571 "nnc_aten_upsample_nearest2d",
1572 nnc_aten_upsample_nearest2d);
1573 const static RegisterNNCExternalFunction nnc_upsample_nearest2d_out(
1574 "nnc_aten_upsample_nearest2d_out",
1575 nnc_aten_upsample_nearest2d_out);
1576 const static RegisterNNCExternalFunction nnc_conv1d(
1577 "nnc_aten_conv1d",
1578 nnc_aten_conv1d);
1579 const static RegisterNNCExternalFunction nnc_conv1d_out(
1580 "nnc_aten_conv1d_out",
1581 nnc_aten_conv1d_out);
1582 const static RegisterNNCExternalFunction nnc_adaptive_avg_pool2d(
1583 "nnc_aten_adaptive_avg_pool2d",
1584 nnc_aten_adaptive_avg_pool2d);
1585 const static RegisterNNCExternalFunction nnc_mean(
1586 "nnc_aten_mean",
1587 nnc_aten_mean);
1588 const static RegisterNNCExternalFunction nnc_max_red(
1589 "nnc_aten_max_red",
1590 nnc_aten_max_red);
1591 const static RegisterNNCExternalFunction nnc_max_red_out(
1592 "nnc_aten_max_red_out",
1593 nnc_aten_max_red_out);
1594 const static RegisterNNCExternalFunction nnc_addmm(
1595 "nnc_aten_addmm",
1596 nnc_aten_addmm);
1597
1598 const static RegisterNNCExternalFunction nnc_triangular_solve(
1599 "nnc_aten_triangular_solve",
1600 nnc_aten_triangular_solve);
1601
1602 const static RegisterNNCExternalFunction nnc_embedding(
1603 "nnc_aten_embedding",
1604 nnc_aten_embedding);
1605
1606 #if AT_MKLDNN_ENABLED()
1607 const static RegisterNNCExternalFunction reg_nnc_mkldnn_prepacked_conv_run(
1608 "nnc_mkldnn_prepacked_conv_run",
1609 nnc_mkldnn_prepacked_conv_run);
1610 #endif // AT_MKLDNN_ENABLED()
1611
1612 #ifdef USE_XNNPACK
1613 const static RegisterNNCExternalFunction reg_nnc_prepacked_linear_clamp_run(
1614 "nnc_prepacked_linear_clamp_run",
1615 nnc_prepacked_linear_clamp_run);
1616 const static RegisterNNCExternalFunction reg_nnc_prepacked_conv2d_clamp_run(
1617 "nnc_prepacked_conv2d_clamp_run",
1618 nnc_prepacked_conv2d_clamp_run);
1619 #endif // USE_XNNPACK
1620
1621 #endif // C10_MOBILE
1622
1623 #ifdef C10_MOBILE
1624 } // extern "C"
1625 #endif
1626
1627 } // namespace torch::jit::tensorexpr
1628