xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/external_functions.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/tensorexpr/external_functions.h>
2 
3 #include <ATen/ATen.h>
4 #include <ATen/Functions.h>
5 #include <ATen/NativeFunctions.h>
6 #include <ATen/Parallel.h>
7 #include <ATen/core/Tensor.h>
8 #include <ATen/native/mkldnn/OpContext.h>
9 #include <ATen/native/quantized/PackedParams.h>
10 #include <ATen/native/quantized/cpu/BinaryOps.h>
11 #include <ATen/native/quantized/cpu/QuantUtils.h>
12 #include <ATen/native/quantized/cpu/QuantizedOps.h>
13 #include <ATen/native/quantized/cpu/conv_serialization.h>
14 #include <ATen/native/xnnpack/OpContext.h>
15 #include <ATen/quantized/QTensorImpl.h>
16 #include <c10/core/TensorImpl.h>
17 #include <c10/core/TensorOptions.h>
18 #include <c10/util/ArrayRef.h>
19 #include <c10/util/irange.h>
20 #include <torch/csrc/jit/serialization/import_source.h>
21 #include <torch/csrc/jit/serialization/pickle.h>
22 #include <torch/csrc/jit/tensorexpr/exceptions.h>
23 #include <torch/csrc/jit/tensorexpr/external_functions_registry.h>
24 #include <utility>
25 
26 namespace torch::jit::tensorexpr {
27 
deduce_memory_format(c10::IntArrayRef strides,c10::IntArrayRef dims)28 static c10::MemoryFormat deduce_memory_format(
29     c10::IntArrayRef strides,
30     c10::IntArrayRef dims) {
31   if (strides.size() == 4 && strides[3] == dims[1] && strides[1] == 1l) {
32     return c10::MemoryFormat::ChannelsLast;
33   }
34   return c10::MemoryFormat::Contiguous;
35 }
36 
deduce_memory_format(const std::vector<int64_t> & strides,const std::vector<int64_t> & dims)37 static c10::MemoryFormat deduce_memory_format(
38     const std::vector<int64_t>& strides,
39     const std::vector<int64_t>& dims) {
40   return deduce_memory_format(
41       c10::IntArrayRef(strides), c10::IntArrayRef(dims));
42 }
43 
from_blob_quantized(void * data,at::IntArrayRef sizes,at::IntArrayRef strides,double qscale,int64_t qzero,at::ScalarType dtype)44 static at::Tensor from_blob_quantized(
45     void* data,
46     at::IntArrayRef sizes,
47     at::IntArrayRef strides,
48     double qscale,
49     int64_t qzero,
50     at::ScalarType dtype) {
51   auto memory_format = deduce_memory_format(strides, sizes);
52   auto qx = at::_empty_affine_quantized(
53       sizes,
54       dtype,
55       c10::kStrided,
56       at::kCPU,
57       false,
58       qscale,
59       qzero,
60       memory_format);
61   auto qtensor_impl = static_cast<at::QTensorImpl*>(qx.unsafeGetTensorImpl());
62   auto typeMeta = c10::scalarTypeToTypeMeta(dtype);
63   std::size_t size = 1;
64   for (std::int64_t s : sizes) {
65     size *= static_cast<std::size_t>(s);
66   }
67   qtensor_impl->ShareExternalPointer(
68       c10::InefficientStdFunctionContext::makeDataPtr(
69           data, [](void*) {}, at::kCPU),
70       typeMeta,
71       size * typeMeta.itemsize());
72   qtensor_impl->set_sizes_and_strides(sizes, strides);
73   return qx;
74 }
75 
constructTensors(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,std::optional<std::vector<std::pair<size_t,QIData>>> qdataArg)76 std::vector<at::Tensor> constructTensors(
77     int64_t bufs_num,
78     void** buf_data,
79     int64_t* buf_ranks,
80     int64_t* buf_dims,
81     int64_t* buf_strides,
82     int8_t* buf_dtypes,
83     std::optional<std::vector<std::pair<size_t, QIData>>> qdataArg) {
84   std::vector<void*> buf_data_vec;
85   std::vector<std::vector<int64_t>> buf_dims_vec;
86   std::vector<std::vector<int64_t>> buf_strides_vec;
87   std::vector<c10::ScalarType> buf_dtypes_vec;
88   int64_t buf_dims_idx = 0;
89   int64_t buf_strides_idx = 0;
90   for (const auto i : c10::irange(bufs_num)) {
91     buf_data_vec.push_back(buf_data[i]);
92     buf_dims_vec.emplace_back();
93     buf_strides_vec.emplace_back();
94     for (const auto dim : c10::irange(buf_ranks[i])) {
95       (void)dim;
96       buf_dims_vec[i].push_back(buf_dims[buf_dims_idx++]);
97       buf_strides_vec[i].push_back(buf_strides[buf_strides_idx++]);
98     }
99     buf_dtypes_vec.push_back(static_cast<c10::ScalarType>(buf_dtypes[i]));
100   }
101 
102   std::vector<at::Tensor> tensors;
103   if (!qdataArg.has_value()) {
104     for (const auto i : c10::irange(buf_data_vec.size())) {
105       auto options = at::TensorOptions()
106                          // NOLINTNEXTLINE
107                          .dtype(buf_dtypes_vec[i])
108                          .layout(at::kStrided)
109                          .device(at::kCPU) // TODO: support GPUs too
110                          .memory_format(deduce_memory_format(
111                              // NOLINTNEXTLINE
112                              buf_strides_vec[i],
113                              // NOLINTNEXTLINE
114                              buf_dims_vec[i]))
115                          .requires_grad(false);
116       auto tensor = at::from_blob(
117           // NOLINTNEXTLINE
118           buf_data_vec[i],
119           buf_dims_vec[i],
120           buf_strides_vec[i],
121           options);
122       tensors.emplace_back(tensor);
123     }
124   } else {
125     // handle quantized
126     std::vector<std::optional<QIData>> qdata(bufs_num, std::nullopt);
127     for (const auto& qd : *qdataArg) {
128       qdata[qd.first] = qd.second;
129     }
130     for (const auto i : c10::irange(buf_data_vec.size())) {
131       auto options = at::TensorOptions()
132                          // NOLINTNEXTLINE
133                          .dtype(buf_dtypes_vec[i])
134                          .layout(at::kStrided)
135                          .device(at::kCPU) // TODO: support GPUs too
136                          .memory_format(deduce_memory_format(
137                              // NOLINTNEXTLINE
138                              buf_strides_vec[i],
139                              // NOLINTNEXTLINE
140                              buf_dims_vec[i]))
141                          .requires_grad(false);
142       if (auto qd = qdata[i]) {
143         // inplace tensor
144         auto tensor = from_blob_quantized(
145             // NOLINTNEXTLINE
146             buf_data_vec[i],
147             buf_dims_vec[i],
148             buf_strides_vec[i],
149             qd->scale,
150             qd->zero,
151             qd->scalarType);
152         tensors.emplace_back(tensor);
153       } else {
154         auto tensor = at::from_blob(
155             // NOLINTNEXTLINE
156             buf_data_vec[i],
157             buf_dims_vec[i],
158             buf_strides_vec[i],
159             options);
160         tensors.emplace_back(tensor);
161       }
162     }
163   }
164   return tensors;
165 }
166 
constructTensors(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,std::vector<std::pair<size_t,QIData>> qdata)167 static std::vector<at::Tensor> constructTensors(
168     int64_t bufs_num,
169     void** buf_data,
170     int64_t* buf_ranks,
171     int64_t* buf_dims,
172     int64_t* buf_strides,
173     int8_t* buf_dtypes,
174     std::vector<std::pair<size_t, QIData>> qdata) {
175   std::optional<std::vector<std::pair<size_t, QIData>>> opt = std::move(qdata);
176   return constructTensors(
177       bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes, opt);
178 }
179 
constructTensors2(int64_t bufs_in_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,std::optional<std::vector<std::pair<size_t,QIData>>> qdataArg,size_t bufs_out_num)180 std::vector<at::Tensor> constructTensors2(
181     int64_t bufs_in_num,
182     void** buf_data,
183     int64_t* buf_ranks,
184     int64_t* buf_dims,
185     int64_t* buf_strides,
186     int8_t* buf_dtypes,
187     std::optional<std::vector<std::pair<size_t, QIData>>> qdataArg,
188     size_t bufs_out_num) {
189   std::vector<void*> buf_data_vec;
190   std::vector<std::vector<int64_t>> buf_dims_vec;
191   std::vector<std::vector<int64_t>> buf_strides_vec;
192   std::vector<c10::ScalarType> buf_dtypes_vec;
193   int64_t buf_dims_idx = 0;
194   int64_t buf_strides_idx = 0;
195   for (const auto i : c10::irange(bufs_in_num)) {
196     buf_data_vec.push_back(buf_data[bufs_out_num + i]);
197     buf_dims_vec.emplace_back();
198     buf_strides_vec.emplace_back();
199     for (const auto dim : c10::irange(buf_ranks[i])) {
200       (void)dim;
201       buf_dims_vec[i].push_back(buf_dims[buf_dims_idx++]);
202       buf_strides_vec[i].push_back(buf_strides[buf_strides_idx++]);
203     }
204     buf_dtypes_vec.push_back(static_cast<c10::ScalarType>(buf_dtypes[i]));
205   }
206 
207   std::vector<at::Tensor> tensors;
208   at::Tensor und;
209   for (const auto i : c10::irange(bufs_out_num)) {
210     (void)i;
211     tensors.emplace_back(und);
212   }
213   if (!qdataArg.has_value()) {
214     for (const auto i : c10::irange(buf_data_vec.size())) {
215       auto options = at::TensorOptions()
216                          // NOLINTNEXTLINE
217                          .dtype(buf_dtypes_vec[i])
218                          .layout(at::kStrided)
219                          .device(at::kCPU) // TODO: support GPUs too
220                          .memory_format(deduce_memory_format(
221                              // NOLINTNEXTLINE
222                              buf_strides_vec[i],
223                              // NOLINTNEXTLINE
224                              buf_dims_vec[i]))
225                          .requires_grad(false);
226       auto tensor = at::from_blob(
227           // NOLINTNEXTLINE
228           buf_data_vec[i],
229           buf_dims_vec[i],
230           buf_strides_vec[i],
231           options);
232       tensors.emplace_back(tensor);
233     }
234   } else {
235     // handle quantized
236     std::vector<std::optional<QIData>> qdata(bufs_in_num, std::nullopt);
237     for (const auto& qd : *qdataArg) {
238       qdata[qd.first - bufs_out_num] = qd.second;
239     }
240     for (const auto i : c10::irange(buf_data_vec.size())) {
241       auto options = at::TensorOptions()
242                          // NOLINTNEXTLINE
243                          .dtype(buf_dtypes_vec[i])
244                          .layout(at::kStrided)
245                          .device(at::kCPU) // TODO: support GPUs too
246                          .memory_format(deduce_memory_format(
247                              // NOLINTNEXTLINE
248                              buf_strides_vec[i],
249                              // NOLINTNEXTLINE
250                              buf_dims_vec[i]))
251                          .requires_grad(false);
252       if (auto qd = qdata[i]) {
253         // inplace tensor
254         auto tensor = from_blob_quantized(
255             // NOLINTNEXTLINE
256             buf_data_vec[i],
257             buf_dims_vec[i],
258             buf_strides_vec[i],
259             qd->scale,
260             qd->zero,
261             qd->scalarType);
262         tensors.emplace_back(tensor);
263       } else {
264         auto tensor = at::from_blob(
265             // NOLINTNEXTLINE
266             buf_data_vec[i],
267             buf_dims_vec[i],
268             buf_strides_vec[i],
269             options);
270         tensors.emplace_back(tensor);
271       }
272     }
273   }
274   return tensors;
275 }
276 
constructTensors2(int64_t bufs_in_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,std::vector<std::pair<size_t,QIData>> qdata,size_t bufs_out_num=0u)277 static std::vector<at::Tensor> constructTensors2(
278     int64_t bufs_in_num,
279     void** buf_data,
280     int64_t* buf_ranks,
281     int64_t* buf_dims,
282     int64_t* buf_strides,
283     int8_t* buf_dtypes,
284     std::vector<std::pair<size_t, QIData>> qdata,
285     size_t bufs_out_num = 0u) {
286   std::optional<std::vector<std::pair<size_t, QIData>>> opt = std::move(qdata);
287   return constructTensors2(
288       bufs_in_num,
289       buf_data,
290       buf_ranks,
291       buf_dims,
292       buf_strides,
293       buf_dtypes,
294       opt,
295       bufs_out_num);
296 }
297 
298 #ifndef _WIN32
quantized_add(const at::Tensor & x1,const at::Tensor & x2,double scale,int64_t zero)299 static at::Tensor quantized_add(
300     const at::Tensor& x1,
301     const at::Tensor& x2,
302     double scale,
303     int64_t zero) {
304   const auto qadd_op =
305       c10::Dispatcher::singleton()
306           .findSchemaOrThrow("quantized::add", "")
307           .typed<at::Tensor(at::Tensor, at::Tensor, double, int64_t)>();
308   return qadd_op.call(x1, x2, scale, zero);
309 }
310 
quantized_mul(const at::Tensor & x1,const at::Tensor & x2,double scale,int64_t zero)311 static at::Tensor quantized_mul(
312     const at::Tensor& x1,
313     const at::Tensor& x2,
314     double scale,
315     int64_t zero) {
316   const auto op =
317       c10::Dispatcher::singleton()
318           .findSchemaOrThrow("quantized::mul", "")
319           .typed<at::Tensor(at::Tensor, at::Tensor, double, int64_t)>();
320   return op.call(x1, x2, scale, zero);
321 }
322 
quantized_mul_scalar(const at::Tensor & x,double scalar)323 static at::Tensor quantized_mul_scalar(const at::Tensor& x, double scalar) {
324   const auto op = c10::Dispatcher::singleton()
325                       .findSchemaOrThrow("quantized::mul", "Scalar")
326                       .typed<at::Tensor(at::Tensor, c10::Scalar const&)>();
327   auto s = c10::Scalar(scalar);
328   return op.call(x, s);
329 }
330 
quantized_cat(const c10::List<at::Tensor> & qxs,int64_t dim,std::optional<double> scale,std::optional<int64_t> zero)331 static at::Tensor quantized_cat(
332     const c10::List<at::Tensor>& qxs,
333     int64_t dim,
334     std::optional<double> scale,
335     std::optional<int64_t> zero) {
336   const auto op = c10::Dispatcher::singleton()
337                       .findSchemaOrThrow("quantized::cat", "")
338                       .typed<at::Tensor(
339                           c10::List<at::Tensor> const&,
340                           int64_t,
341                           std::optional<double>,
342                           std::optional<int64_t>)>();
343   return op.redispatch(
344       c10::DispatchKeySet({c10::DispatchKey::QuantizedCPU}),
345       qxs,
346       dim,
347       scale,
348       zero);
349 }
350 
351 #endif // _WIN32
352 
353 #ifdef C10_MOBILE
354 extern "C" {
355 #endif
356 
nnc_aten_conv2d(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t args_num,int64_t * extra_args)357 void nnc_aten_conv2d(
358     int64_t bufs_num,
359     void** buf_data,
360     int64_t* buf_ranks,
361     int64_t* buf_dims,
362     int64_t* buf_strides,
363     int8_t* buf_dtypes,
364     int64_t args_num,
365     int64_t* extra_args) {
366   auto tensors = constructTensors(
367       bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
368 
369   at::Tensor& r = tensors[0];
370   const at::Tensor& x = tensors[1];
371   const at::Tensor& w = tensors[2];
372   if (args_num > 0) {
373     // Check that if the extra arguments are provided, then the bias tensor is
374     // also present
375     TORCH_INTERNAL_ASSERT(args_num == 7 && bufs_num == 4);
376     const at::Tensor& b = tensors[3];
377 
378     int64_t strideH = extra_args[0];
379     int64_t strideW = extra_args[1];
380     int64_t paddingH = extra_args[2];
381     int64_t paddingW = extra_args[3];
382     int64_t dilationH = extra_args[4];
383     int64_t dilationW = extra_args[5];
384     int64_t groups = extra_args[6];
385 
386     try {
387       r = at::conv2d(
388           x,
389           w,
390           b,
391           {strideH, strideW},
392           {paddingH, paddingW},
393           {dilationH, dilationW},
394           groups);
395     } catch (...) {
396     }
397   } else {
398     try {
399       r = at::conv2d(x, w);
400     } catch (...) {
401     }
402   }
403 
404   // TODO: can i haz an out version of the conv2d?
405   memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
406 }
407 
nnc_aten_quantized_conv1d(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)408 void nnc_aten_quantized_conv1d(
409     int64_t bufs_num,
410     void** buf_data,
411     int64_t* buf_ranks,
412     int64_t* buf_dims,
413     int64_t* buf_strides,
414     int8_t* buf_dtypes,
415     int64_t,
416     int64_t* extra_args) {
417   const double x_qscale = ((double*)extra_args)[0];
418   const int64_t x_qzero = extra_args[1];
419   const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
420   auto tensors = constructTensors(
421       bufs_num,
422       buf_data,
423       buf_ranks,
424       buf_dims,
425       buf_strides,
426       buf_dtypes,
427       {{1u, {x_qscale, x_qzero, toQIntType(x_qdtype)}}});
428   auto convPackedParams =
429       reinterpret_cast<ConvPackedParamsBase<2>*>(buf_data[2]);
430   const double out_qscale = ((double*)extra_args)[3];
431   const int64_t out_qzero = extra_args[4];
432   // NOLINTNEXTLINE
433   auto qx = tensors[1].unsqueeze(quant_utils::kConv1dSqueezeDim + 2);
434   auto r = convPackedParams->apply(qx, out_qscale, out_qzero);
435   r = r.squeeze_(quant_utils::kConv1dSqueezeDim + 2);
436   memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
437 }
438 
nnc_aten_quantized_conv1d_out(int64_t bufs_in_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)439 void nnc_aten_quantized_conv1d_out(
440     int64_t bufs_in_num,
441     void** buf_data,
442     int64_t* buf_ranks,
443     int64_t* buf_dims,
444     int64_t* buf_strides,
445     int8_t* buf_dtypes,
446     int64_t,
447     int64_t* extra_args) {
448   const size_t bufs_out_num = 1u;
449   const double x_qscale = ((double*)extra_args)[0];
450   const int64_t x_qzero = extra_args[1];
451   const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
452   auto tensors = constructTensors2(
453       bufs_in_num,
454       buf_data,
455       buf_ranks,
456       buf_dims,
457       buf_strides,
458       buf_dtypes,
459       {{1u, {x_qscale, x_qzero, toQIntType(x_qdtype)}}},
460       bufs_out_num);
461   auto convPackedParams =
462       reinterpret_cast<ConvPackedParamsBase<2>*>(buf_data[2]);
463   const double out_qscale = ((double*)extra_args)[3];
464   const int64_t out_qzero = extra_args[4];
465   // NOLINTNEXTLINE
466   auto qx = tensors[1].unsqueeze(quant_utils::kConv1dSqueezeDim + 2);
467   auto r = convPackedParams->apply(qx, out_qscale, out_qzero);
468   r = r.squeeze_(quant_utils::kConv1dSqueezeDim + 2);
469   buf_data[0] = r.data_ptr();
470   c10::raw::intrusive_ptr::incref(r.getIntrusivePtr().get());
471   buf_data[bufs_in_num + bufs_out_num] = r.getIntrusivePtr().get();
472 }
473 
nnc_aten_quantized_conv2d(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)474 void nnc_aten_quantized_conv2d(
475     int64_t bufs_num,
476     void** buf_data,
477     int64_t* buf_ranks,
478     int64_t* buf_dims,
479     int64_t* buf_strides,
480     int8_t* buf_dtypes,
481     int64_t,
482     int64_t* extra_args) {
483   const double x_qscale = ((double*)extra_args)[0];
484   const int64_t x_qzero = extra_args[1];
485   const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
486   auto tensors = constructTensors(
487       bufs_num,
488       buf_data,
489       buf_ranks,
490       buf_dims,
491       buf_strides,
492       buf_dtypes,
493       {{1u, {x_qscale, x_qzero, toQIntType(x_qdtype)}}});
494   auto convPackedParams =
495       reinterpret_cast<ConvPackedParamsBase<2>*>(buf_data[2]);
496   const double out_qscale = ((double*)extra_args)[3];
497   const int64_t out_qzero = extra_args[4];
498   // NOLINTNEXTLINE
499   auto r = convPackedParams->apply(tensors[1], out_qscale, out_qzero);
500   memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
501 }
502 
nnc_aten_quantized_conv2d_out(int64_t bufs_in_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)503 void nnc_aten_quantized_conv2d_out(
504     int64_t bufs_in_num,
505     void** buf_data,
506     int64_t* buf_ranks,
507     int64_t* buf_dims,
508     int64_t* buf_strides,
509     int8_t* buf_dtypes,
510     int64_t,
511     int64_t* extra_args) {
512   const size_t bufs_out_num = 1u;
513   const double x_qscale = ((double*)extra_args)[0];
514   const int64_t x_qzero = extra_args[1];
515   const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
516   auto tensors = constructTensors2(
517       bufs_in_num,
518       buf_data,
519       buf_ranks,
520       buf_dims,
521       buf_strides,
522       buf_dtypes,
523       {{1u, {x_qscale, x_qzero, toQIntType(x_qdtype)}}},
524       bufs_out_num);
525   auto convPackedParams =
526       reinterpret_cast<ConvPackedParamsBase<2>*>(buf_data[2]);
527   const double out_qscale = ((double*)extra_args)[3];
528   const int64_t out_qzero = extra_args[4];
529   // NOLINTNEXTLINE
530   auto r = convPackedParams->apply(tensors[1], out_qscale, out_qzero);
531   buf_data[0] = r.data_ptr();
532   c10::raw::intrusive_ptr::incref(r.getIntrusivePtr().get());
533   buf_data[bufs_in_num + bufs_out_num] = r.getIntrusivePtr().get();
534 }
535 
nnc_aten_quantized_conv2d_relu(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)536 void nnc_aten_quantized_conv2d_relu(
537     int64_t bufs_num,
538     void** buf_data,
539     int64_t* buf_ranks,
540     int64_t* buf_dims,
541     int64_t* buf_strides,
542     int8_t* buf_dtypes,
543     int64_t,
544     int64_t* extra_args) {
545   const double x_qscale = ((double*)extra_args)[0];
546   const int64_t x_qzero = extra_args[1];
547   const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
548   auto tensors = constructTensors(
549       bufs_num,
550       buf_data,
551       buf_ranks,
552       buf_dims,
553       buf_strides,
554       buf_dtypes,
555       {{1u, {x_qscale, x_qzero, toQIntType(x_qdtype)}}});
556   auto convPackedParams =
557       reinterpret_cast<ConvPackedParamsBase<2>*>(buf_data[2]);
558   const double out_qscale = ((double*)extra_args)[3];
559   const int64_t out_qzero = extra_args[4];
560   // NOLINTNEXTLINE
561   auto r = convPackedParams->apply_relu(tensors[1], out_qscale, out_qzero);
562   memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
563 }
564 
nnc_aten_quantized_conv2d_relu_out(int64_t bufs_in_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)565 void nnc_aten_quantized_conv2d_relu_out(
566     int64_t bufs_in_num,
567     void** buf_data,
568     int64_t* buf_ranks,
569     int64_t* buf_dims,
570     int64_t* buf_strides,
571     int8_t* buf_dtypes,
572     int64_t,
573     int64_t* extra_args) {
574   const size_t bufs_out_num = 1u;
575   const double x_qscale = ((double*)extra_args)[0];
576   const int64_t x_qzero = extra_args[1];
577   const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
578   auto tensors = constructTensors2(
579       bufs_in_num,
580       buf_data,
581       buf_ranks,
582       buf_dims,
583       buf_strides,
584       buf_dtypes,
585       {{1u, {x_qscale, x_qzero, toQIntType(x_qdtype)}}},
586       bufs_out_num);
587   auto convPackedParams =
588       reinterpret_cast<ConvPackedParamsBase<2>*>(buf_data[2]);
589   const double out_qscale = ((double*)extra_args)[3];
590   const int64_t out_qzero = extra_args[4];
591   // NOLINTNEXTLINE
592   auto r = convPackedParams->apply_relu(tensors[1], out_qscale, out_qzero);
593   buf_data[0] = r.data_ptr();
594   c10::raw::intrusive_ptr::incref(r.getIntrusivePtr().get());
595   buf_data[bufs_in_num + bufs_out_num] = r.getIntrusivePtr().get();
596 }
597 
nnc_aten_quantized_linear(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)598 void nnc_aten_quantized_linear(
599     int64_t bufs_num,
600     void** buf_data,
601     int64_t* buf_ranks,
602     int64_t* buf_dims,
603     int64_t* buf_strides,
604     int8_t* buf_dtypes,
605     int64_t,
606     int64_t* extra_args) {
607   const double x_qscale = ((double*)extra_args)[0];
608   const int64_t x_qzero = extra_args[1];
609   const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
610   auto tensors = constructTensors(
611       bufs_num,
612       buf_data,
613       buf_ranks,
614       buf_dims,
615       buf_strides,
616       buf_dtypes,
617       {{1u, {x_qscale, x_qzero, toQIntType(x_qdtype)}}});
618   auto linearPackedParams =
619       reinterpret_cast<LinearPackedParamsBase*>(buf_data[2]);
620   const double out_qscale = ((double*)extra_args)[3];
621   const int64_t out_qzero = extra_args[4];
622   // NOLINTNEXTLINE
623   auto r = linearPackedParams->apply(tensors[1], out_qscale, out_qzero);
624   memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
625 }
626 
nnc_aten_quantized_linear_out(int64_t bufs_in_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)627 void nnc_aten_quantized_linear_out(
628     int64_t bufs_in_num,
629     void** buf_data,
630     int64_t* buf_ranks,
631     int64_t* buf_dims,
632     int64_t* buf_strides,
633     int8_t* buf_dtypes,
634     int64_t,
635     int64_t* extra_args) {
636   const size_t bufs_out_num = 1u;
637   const double x_qscale = ((double*)extra_args)[0];
638   const int64_t x_qzero = extra_args[1];
639   const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
640   auto tensors = constructTensors2(
641       bufs_in_num,
642       buf_data,
643       buf_ranks,
644       buf_dims,
645       buf_strides,
646       buf_dtypes,
647       {{1u, {x_qscale, x_qzero, toQIntType(x_qdtype)}}},
648       bufs_out_num);
649   auto linearPackedParams =
650       reinterpret_cast<LinearPackedParamsBase*>(buf_data[2]);
651   const double out_qscale = ((double*)extra_args)[3];
652   const int64_t out_qzero = extra_args[4];
653   // NOLINTNEXTLINE
654   auto r = linearPackedParams->apply(tensors[1], out_qscale, out_qzero);
655   buf_data[0] = r.data_ptr();
656   c10::raw::intrusive_ptr::incref(r.getIntrusivePtr().get());
657   buf_data[bufs_in_num + bufs_out_num] = r.getIntrusivePtr().get();
658 }
659 
nnc_aten_quantized_linear_relu(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)660 void nnc_aten_quantized_linear_relu(
661     int64_t bufs_num,
662     void** buf_data,
663     int64_t* buf_ranks,
664     int64_t* buf_dims,
665     int64_t* buf_strides,
666     int8_t* buf_dtypes,
667     int64_t,
668     int64_t* extra_args) {
669   const double x_qscale = ((double*)extra_args)[0];
670   const int64_t x_qzero = extra_args[1];
671   const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
672   auto tensors = constructTensors(
673       bufs_num,
674       buf_data,
675       buf_ranks,
676       buf_dims,
677       buf_strides,
678       buf_dtypes,
679       {{1u, {x_qscale, x_qzero, toQIntType(x_qdtype)}}});
680   auto linearPackedParams =
681       reinterpret_cast<LinearPackedParamsBase*>(buf_data[2]);
682   const double out_qscale = ((double*)extra_args)[3];
683   const int64_t out_qzero = extra_args[4];
684   // NOLINTNEXTLINE
685   auto r = linearPackedParams->apply_relu(tensors[1], out_qscale, out_qzero);
686   memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
687 }
688 
689 #ifndef _WIN32
nnc_aten_quantized_add(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)690 void nnc_aten_quantized_add(
691     int64_t bufs_num,
692     void** buf_data,
693     int64_t* buf_ranks,
694     int64_t* buf_dims,
695     int64_t* buf_strides,
696     int8_t* buf_dtypes,
697     int64_t,
698     int64_t* extra_args) {
699   // TORCH_INTERNAL_ASSERT(tensors.size() == 3);
700 
701   const double a_qscale = ((double*)extra_args)[0];
702   const int64_t a_qzero = extra_args[1];
703   const c10::ScalarType a_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
704   const double b_qscale = ((double*)extra_args)[3];
705   const int64_t b_qzero = extra_args[4];
706   const c10::ScalarType b_qdtype = static_cast<c10::ScalarType>(extra_args[5]);
707   auto tensors = constructTensors(
708       bufs_num,
709       buf_data,
710       buf_ranks,
711       buf_dims,
712       buf_strides,
713       buf_dtypes,
714       {{1u, {a_qscale, a_qzero, toQIntType(a_qdtype)}},
715        {2u, {b_qscale, b_qzero, toQIntType(b_qdtype)}}});
716 
717   const double out_qscale = ((double*)extra_args)[6];
718   const int64_t out_qzero = extra_args[7];
719   // NOLINTNEXTLINE
720   auto r = quantized_add(tensors[1], tensors[2], out_qscale, out_qzero);
721   memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
722 }
723 
nnc_aten_quantized_mul(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)724 void nnc_aten_quantized_mul(
725     int64_t bufs_num,
726     void** buf_data,
727     int64_t* buf_ranks,
728     int64_t* buf_dims,
729     int64_t* buf_strides,
730     int8_t* buf_dtypes,
731     int64_t,
732     int64_t* extra_args) {
733   const double a_qscale = ((double*)extra_args)[0];
734   const int64_t a_qzero = extra_args[1];
735   const c10::ScalarType a_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
736   const double b_qscale = ((double*)extra_args)[3];
737   const int64_t b_qzero = extra_args[4];
738   const c10::ScalarType b_qdtype = static_cast<c10::ScalarType>(extra_args[5]);
739   auto tensors = constructTensors(
740       bufs_num,
741       buf_data,
742       buf_ranks,
743       buf_dims,
744       buf_strides,
745       buf_dtypes,
746       {{1u, {a_qscale, a_qzero, toQIntType(a_qdtype)}},
747        {2u, {b_qscale, b_qzero, toQIntType(b_qdtype)}}});
748   const double out_qscale = ((double*)extra_args)[6];
749   const int64_t out_qzero = extra_args[7];
750   // NOLINTNEXTLINE
751   auto r = quantized_mul(tensors[1], tensors[2], out_qscale, out_qzero);
752   memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
753 }
754 
nnc_aten_quantized_mul_out(int64_t bufs_in_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)755 void nnc_aten_quantized_mul_out(
756     int64_t bufs_in_num,
757     void** buf_data,
758     int64_t* buf_ranks,
759     int64_t* buf_dims,
760     int64_t* buf_strides,
761     int8_t* buf_dtypes,
762     int64_t,
763     int64_t* extra_args) {
764   const size_t bufs_out_num = 1u;
765   const double a_qscale = ((double*)extra_args)[0];
766   const int64_t a_qzero = extra_args[1];
767   const c10::ScalarType a_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
768   const double b_qscale = ((double*)extra_args)[3];
769   const int64_t b_qzero = extra_args[4];
770   const c10::ScalarType b_qdtype = static_cast<c10::ScalarType>(extra_args[5]);
771   auto tensors = constructTensors2(
772       bufs_in_num,
773       buf_data,
774       buf_ranks,
775       buf_dims,
776       buf_strides,
777       buf_dtypes,
778       {{1u, {a_qscale, a_qzero, toQIntType(a_qdtype)}},
779        {2u, {b_qscale, b_qzero, toQIntType(b_qdtype)}}},
780       1u);
781   const double out_qscale = ((double*)extra_args)[6];
782   const int64_t out_qzero = extra_args[7];
783   // NOLINTNEXTLINE
784   auto r = quantized_mul(tensors[1], tensors[2], out_qscale, out_qzero);
785   buf_data[0] = r.data_ptr();
786   c10::raw::intrusive_ptr::incref(r.getIntrusivePtr().get());
787   buf_data[bufs_in_num + bufs_out_num] = r.getIntrusivePtr().get();
788 }
789 
nnc_aten_quantized_mul_scalar(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)790 void nnc_aten_quantized_mul_scalar(
791     int64_t bufs_num,
792     void** buf_data,
793     int64_t* buf_ranks,
794     int64_t* buf_dims,
795     int64_t* buf_strides,
796     int8_t* buf_dtypes,
797     int64_t,
798     int64_t* extra_args) {
799   const double x_qscale = ((double*)extra_args)[0];
800   const int64_t x_qzero = extra_args[1];
801   const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
802   auto tensors = constructTensors(
803       bufs_num,
804       buf_data,
805       buf_ranks,
806       buf_dims,
807       buf_strides,
808       buf_dtypes,
809       {{1u, {x_qscale, x_qzero, toQIntType(x_qdtype)}}});
810   const double scalar = ((double*)extra_args)[3];
811   // NOLINTNEXTLINE
812   auto r = quantized_mul_scalar(tensors[1], scalar);
813   memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
814 }
815 
nnc_aten_quantized_mul_scalar_out(int64_t bufs_in_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)816 void nnc_aten_quantized_mul_scalar_out(
817     int64_t bufs_in_num,
818     void** buf_data,
819     int64_t* buf_ranks,
820     int64_t* buf_dims,
821     int64_t* buf_strides,
822     int8_t* buf_dtypes,
823     int64_t,
824     int64_t* extra_args) {
825   const size_t bufs_out_num = 1u;
826   const double x_qscale = ((double*)extra_args)[0];
827   const int64_t x_qzero = extra_args[1];
828   const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
829   auto tensors = constructTensors2(
830       bufs_in_num,
831       buf_data,
832       buf_ranks,
833       buf_dims,
834       buf_strides,
835       buf_dtypes,
836       {{1u, {x_qscale, x_qzero, toQIntType(x_qdtype)}}},
837       bufs_out_num);
838   const double scalar = ((double*)extra_args)[3];
839   // NOLINTNEXTLINE
840   auto r = quantized_mul_scalar(tensors[1], scalar);
841   buf_data[0] = r.data_ptr();
842   c10::raw::intrusive_ptr::incref(r.getIntrusivePtr().get());
843   buf_data[bufs_in_num + bufs_out_num] = r.getIntrusivePtr().get();
844 }
845 
nnc_aten_quantized_relu(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)846 void nnc_aten_quantized_relu(
847     int64_t bufs_num,
848     void** buf_data,
849     int64_t* buf_ranks,
850     int64_t* buf_dims,
851     int64_t* buf_strides,
852     int8_t* buf_dtypes,
853     int64_t,
854     int64_t* extra_args) {
855   const double x_qscale = ((double*)extra_args)[0];
856   const int64_t x_qzero = extra_args[1];
857   const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
858   auto tensors = constructTensors(
859       bufs_num,
860       buf_data,
861       buf_ranks,
862       buf_dims,
863       buf_strides,
864       buf_dtypes,
865       {{1u, {x_qscale, x_qzero, toQIntType(x_qdtype)}}});
866   // NOLINTNEXTLINE
867   auto r = at::relu(tensors[1]);
868   memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
869 }
870 
nnc_aten_quantized_sigmoid(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)871 void nnc_aten_quantized_sigmoid(
872     int64_t bufs_num,
873     void** buf_data,
874     int64_t* buf_ranks,
875     int64_t* buf_dims,
876     int64_t* buf_strides,
877     int8_t* buf_dtypes,
878     int64_t,
879     int64_t* extra_args) {
880   const double x_qscale = ((double*)extra_args)[0];
881   const int64_t x_qzero = extra_args[1];
882   const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
883   auto tensors = constructTensors(
884       bufs_num,
885       buf_data,
886       buf_ranks,
887       buf_dims,
888       buf_strides,
889       buf_dtypes,
890       {{1u, {x_qscale, x_qzero, toQIntType(x_qdtype)}}});
891 
892   // NOLINTNEXTLINE
893   auto r = at::sigmoid(tensors[1]);
894   memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
895 }
896 
nnc_aten_quantized_sigmoid_out(int64_t bufs_in_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)897 void nnc_aten_quantized_sigmoid_out(
898     int64_t bufs_in_num,
899     void** buf_data,
900     int64_t* buf_ranks,
901     int64_t* buf_dims,
902     int64_t* buf_strides,
903     int8_t* buf_dtypes,
904     int64_t,
905     int64_t* extra_args) {
906   const double x_qscale = ((double*)extra_args)[0];
907   const int64_t x_qzero = extra_args[1];
908   const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
909   const size_t bufs_out_num = 1u;
910   auto tensors = constructTensors2(
911       bufs_in_num,
912       buf_data,
913       buf_ranks,
914       buf_dims,
915       buf_strides,
916       buf_dtypes,
917       {{1u, {x_qscale, x_qzero, toQIntType(x_qdtype)}}},
918       bufs_out_num);
919 
920   // NOLINTNEXTLINE
921   auto r = at::sigmoid(tensors[1]);
922   buf_data[0] = r.data_ptr();
923   c10::raw::intrusive_ptr::incref(r.getIntrusivePtr().get());
924   buf_data[bufs_in_num + bufs_out_num] = r.getIntrusivePtr().get();
925 }
926 
nnc_aten_quantized_cat(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)927 void nnc_aten_quantized_cat(
928     int64_t bufs_num,
929     void** buf_data,
930     int64_t* buf_ranks,
931     int64_t* buf_dims,
932     int64_t* buf_strides,
933     int8_t* buf_dtypes,
934     int64_t,
935     int64_t* extra_args) {
936   std::vector<std::pair<size_t, QIData>> qdata;
937   const auto in_bufs_num = bufs_num - 1;
938   const double out_qscale = ((double*)extra_args)[3 * in_bufs_num + 1];
939   const int64_t out_qzero = extra_args[3 * in_bufs_num + 2];
940   qdata.emplace_back(
941       0u,
942       QIData{
943           out_qscale, out_qzero, static_cast<c10::ScalarType>(extra_args[2])});
944   for (const size_t i : c10::irange(in_bufs_num)) {
945     const double qscale = ((double*)extra_args)[3 * i + 0];
946     const int64_t qzero = extra_args[3 * i + 1];
947     const c10::ScalarType qdtype =
948         static_cast<c10::ScalarType>(extra_args[3 * i + 2]);
949     qdata.emplace_back(i + 1u, QIData{qscale, qzero, qdtype});
950   }
951   auto tensors = constructTensors(
952       bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes, qdata);
953   const int64_t dim = extra_args[3 * in_bufs_num + 0];
954   auto qxs = c10::List<at::Tensor>(
955       std::vector<at::Tensor>(tensors.begin() + 1, tensors.end()));
956   auto r = quantized_cat(qxs, dim, out_qscale, out_qzero);
957   memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
958 }
959 #endif // _WIN32
960 
nnc_aten_upsample_nearest2d(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)961 void nnc_aten_upsample_nearest2d(
962     int64_t bufs_num,
963     void** buf_data,
964     int64_t* buf_ranks,
965     int64_t* buf_dims,
966     int64_t* buf_strides,
967     int8_t* buf_dtypes,
968     int64_t,
969     int64_t* extra_args) {
970   // NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
971   const double x_qscale = ((double*)extra_args)[0];
972   const int64_t x_qzero = extra_args[1];
973   const int64_t x_qdtype = extra_args[2];
974   const auto is_quantized = x_qdtype != -1;
975   std::optional<std::vector<std::pair<size_t, QIData>>> qdata;
976   if (is_quantized) {
977     qdata = {
978         {1u,
979          {x_qscale,
980           x_qzero,
981           at::toQIntType(static_cast<c10::ScalarType>(x_qdtype))}}};
982   }
983   auto tensors = constructTensors(
984       bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes, qdata);
985   auto x = tensors[1];
986 
987   int64_t output_size_h = extra_args[3];
988   int64_t output_size_w = extra_args[4];
989   double scale_factor_h = ((double*)extra_args)[5];
990   double scale_factor_w = ((double*)extra_args)[6];
991 
992   auto r = at::upsample_nearest2d(
993       x,
994       (output_size_h != -1)
995           ? std::optional<at::IntArrayRef>({output_size_h, output_size_w})
996           : std::nullopt,
997       (scale_factor_h != -1.f) ? std::optional<at::ArrayRef<double>>(
998                                      {scale_factor_h, scale_factor_w})
999                                : std::nullopt);
1000   memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
1001 }
1002 
nnc_aten_upsample_nearest2d_out(int64_t bufs_in_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)1003 void nnc_aten_upsample_nearest2d_out(
1004     int64_t bufs_in_num,
1005     void** buf_data,
1006     int64_t* buf_ranks,
1007     int64_t* buf_dims,
1008     int64_t* buf_strides,
1009     int8_t* buf_dtypes,
1010     int64_t,
1011     int64_t* extra_args) {
1012   const size_t bufs_out_num = 1u;
1013   // NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
1014   const double x_qscale = ((double*)extra_args)[0];
1015   const int64_t x_qzero = extra_args[1];
1016   const int64_t x_qdtype = extra_args[2];
1017   const auto is_quantized = x_qdtype != -1;
1018   std::optional<std::vector<std::pair<size_t, QIData>>> qdata;
1019   if (is_quantized) {
1020     qdata = {
1021         {1u,
1022          {x_qscale,
1023           x_qzero,
1024           at::toQIntType(static_cast<c10::ScalarType>(x_qdtype))}}};
1025   }
1026   auto tensors = constructTensors2(
1027       bufs_in_num,
1028       buf_data,
1029       buf_ranks,
1030       buf_dims,
1031       buf_strides,
1032       buf_dtypes,
1033       qdata,
1034       bufs_out_num);
1035   auto x = tensors[1];
1036 
1037   int64_t output_size_h = extra_args[3];
1038   int64_t output_size_w = extra_args[4];
1039   double scale_factor_h = ((double*)extra_args)[5];
1040   double scale_factor_w = ((double*)extra_args)[6];
1041 
1042   auto r = at::upsample_nearest2d(
1043       x,
1044       (output_size_h != -1)
1045           ? std::optional<at::IntArrayRef>({output_size_h, output_size_w})
1046           : std::nullopt,
1047       (scale_factor_h != -1.f) ? std::optional<at::ArrayRef<double>>(
1048                                      {scale_factor_h, scale_factor_w})
1049                                : std::nullopt);
1050   buf_data[0] = r.data_ptr();
1051   c10::raw::intrusive_ptr::incref(r.getIntrusivePtr().get());
1052   buf_data[bufs_in_num + bufs_out_num] = r.getIntrusivePtr().get();
1053 }
1054 
nnc_aten_quantize_per_tensor(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)1055 void nnc_aten_quantize_per_tensor(
1056     int64_t bufs_num,
1057     void** buf_data,
1058     int64_t* buf_ranks,
1059     int64_t* buf_dims,
1060     int64_t* buf_strides,
1061     int8_t* buf_dtypes,
1062     int64_t,
1063     int64_t* extra_args) {
1064   auto tensors = constructTensors(
1065       bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
1066   // NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
1067   at::Tensor x = tensors[1];
1068   const double qscale = ((double*)extra_args)[0];
1069   const int64_t qzero = extra_args[1];
1070   const c10::ScalarType qdtype = static_cast<c10::ScalarType>(extra_args[2]);
1071   auto r = at::quantize_per_tensor(x, qscale, qzero, qdtype);
1072   memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
1073 }
1074 
nnc_aten_quantize_per_tensor_out(int64_t bufs_in_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)1075 void nnc_aten_quantize_per_tensor_out(
1076     int64_t bufs_in_num,
1077     void** buf_data,
1078     int64_t* buf_ranks,
1079     int64_t* buf_dims,
1080     int64_t* buf_strides,
1081     int8_t* buf_dtypes,
1082     int64_t,
1083     int64_t* extra_args) {
1084   const size_t bufs_out_num = 1u;
1085   auto tensors = constructTensors2(
1086       bufs_in_num,
1087       buf_data,
1088       buf_ranks,
1089       buf_dims,
1090       buf_strides,
1091       buf_dtypes,
1092       std::nullopt,
1093       bufs_out_num);
1094   // NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
1095   at::Tensor x = tensors[1];
1096   const double qscale = ((double*)extra_args)[0];
1097   const int64_t qzero = extra_args[1];
1098   const c10::ScalarType qdtype = static_cast<c10::ScalarType>(extra_args[2]);
1099   auto r = at::quantize_per_tensor(x, qscale, qzero, qdtype);
1100   buf_data[0] = r.data_ptr();
1101   c10::raw::intrusive_ptr::incref(r.getIntrusivePtr().get());
1102   buf_data[bufs_in_num + bufs_out_num] = r.getIntrusivePtr().get();
1103 }
1104 
nnc_aten_dequantize(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)1105 void nnc_aten_dequantize(
1106     int64_t bufs_num,
1107     void** buf_data,
1108     int64_t* buf_ranks,
1109     int64_t* buf_dims,
1110     int64_t* buf_strides,
1111     int8_t* buf_dtypes,
1112     int64_t,
1113     int64_t* extra_args) {
1114   const double qscale = ((double*)extra_args)[0];
1115   const int64_t qzero = extra_args[1];
1116   const int64_t qdtype = extra_args[2];
1117   auto tensors = constructTensors(
1118       bufs_num,
1119       buf_data,
1120       buf_ranks,
1121       buf_dims,
1122       buf_strides,
1123       buf_dtypes,
1124       {{1u,
1125         {qscale, qzero, toQIntType(static_cast<c10::ScalarType>(qdtype))}}});
1126   // NOLINTNEXTLINE
1127   auto r = at::dequantize(tensors[1]);
1128   memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
1129 }
1130 
nnc_aten_dequantize_out(int64_t bufs_in_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)1131 void nnc_aten_dequantize_out(
1132     int64_t bufs_in_num,
1133     void** buf_data,
1134     int64_t* buf_ranks,
1135     int64_t* buf_dims,
1136     int64_t* buf_strides,
1137     int8_t* buf_dtypes,
1138     int64_t,
1139     int64_t* extra_args) {
1140   const size_t bufs_out_num = 1u;
1141   const double qscale = ((double*)extra_args)[0];
1142   const int64_t qzero = extra_args[1];
1143   const int64_t qdtype = extra_args[2];
1144   auto tensors = constructTensors2(
1145       bufs_in_num,
1146       buf_data,
1147       buf_ranks,
1148       buf_dims,
1149       buf_strides,
1150       buf_dtypes,
1151       {{1u, {qscale, qzero, toQIntType(static_cast<c10::ScalarType>(qdtype))}}},
1152       bufs_out_num);
1153   // NOLINTNEXTLINE
1154   auto r = at::dequantize(tensors[1]);
1155   buf_data[0] = r.data_ptr();
1156   c10::raw::intrusive_ptr::incref(r.getIntrusivePtr().get());
1157   buf_data[bufs_in_num + bufs_out_num] = r.getIntrusivePtr().get();
1158 }
1159 
nnc_aten_conv1d(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t args_num,int64_t * extra_args)1160 void nnc_aten_conv1d(
1161     int64_t bufs_num,
1162     void** buf_data,
1163     int64_t* buf_ranks,
1164     int64_t* buf_dims,
1165     int64_t* buf_strides,
1166     int8_t* buf_dtypes,
1167     int64_t args_num,
1168     int64_t* extra_args) {
1169   auto tensors = constructTensors(
1170       bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
1171 
1172   at::Tensor& r = tensors[0];
1173   const at::Tensor& x = tensors[1];
1174   const at::Tensor& w = tensors[2];
1175   if (args_num > 0) {
1176     // Check that if the extra arguments are provided, then the bias tensor is
1177     // also present
1178     TORCH_INTERNAL_ASSERT(args_num == 4 && bufs_num == 4);
1179     const at::Tensor& b = tensors[3];
1180 
1181     int64_t stride = extra_args[0];
1182     int64_t padding = extra_args[1];
1183     int64_t dilation = extra_args[2];
1184     int64_t groups = extra_args[3];
1185 
1186     try {
1187       r = at::conv1d(x, w, b, {stride}, {padding}, {dilation}, groups);
1188     } catch (...) {
1189     }
1190   } else {
1191     try {
1192       r = at::conv1d(x, w);
1193     } catch (...) {
1194     }
1195   }
1196 
1197   memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
1198 }
1199 
nnc_aten_conv1d_out(int64_t bufs_in_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t args_num,int64_t * extra_args)1200 void nnc_aten_conv1d_out(
1201     int64_t bufs_in_num,
1202     void** buf_data,
1203     int64_t* buf_ranks,
1204     int64_t* buf_dims,
1205     int64_t* buf_strides,
1206     int8_t* buf_dtypes,
1207     int64_t args_num,
1208     int64_t* extra_args) {
1209   const size_t bufs_out_num = 1u;
1210   auto tensors = constructTensors2(
1211       bufs_in_num,
1212       buf_data,
1213       buf_ranks,
1214       buf_dims,
1215       buf_strides,
1216       buf_dtypes,
1217       std::nullopt,
1218       bufs_out_num);
1219 
1220   at::Tensor r;
1221   const at::Tensor& x = tensors[1];
1222   const at::Tensor& w = tensors[2];
1223   if (args_num > 0) {
1224     // Check that if the extra arguments are provided, then the bias tensor is
1225     // also present
1226     TORCH_INTERNAL_ASSERT(args_num == 4 && bufs_in_num == 3);
1227     const at::Tensor& b = tensors[3];
1228 
1229     int64_t stride = extra_args[0];
1230     int64_t padding = extra_args[1];
1231     int64_t dilation = extra_args[2];
1232     int64_t groups = extra_args[3];
1233 
1234     try {
1235       r = at::conv1d(x, w, b, {stride}, {padding}, {dilation}, groups);
1236     } catch (...) {
1237     }
1238   } else {
1239     try {
1240       r = at::conv1d(x, w);
1241     } catch (...) {
1242     }
1243   }
1244 
1245   buf_data[0] = r.data_ptr();
1246   c10::raw::intrusive_ptr::incref(r.getIntrusivePtr().get());
1247   buf_data[bufs_in_num + bufs_out_num] = r.getIntrusivePtr().get();
1248 }
1249 
nnc_aten_adaptive_avg_pool2d(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t args_num,int64_t * extra_args)1250 void nnc_aten_adaptive_avg_pool2d(
1251     int64_t bufs_num,
1252     void** buf_data,
1253     int64_t* buf_ranks,
1254     int64_t* buf_dims,
1255     int64_t* buf_strides,
1256     int8_t* buf_dtypes,
1257     int64_t args_num,
1258     int64_t* extra_args) {
1259   auto tensors = constructTensors(
1260       bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
1261 
1262   at::Tensor& r = tensors[0];
1263   const at::Tensor& x = tensors[1];
1264   int64_t H = extra_args[0];
1265   int64_t W = H;
1266   if (args_num > 1) {
1267     W = extra_args[1];
1268   }
1269   try {
1270     r = at::adaptive_avg_pool2d(x, {H, W});
1271   } catch (...) {
1272   }
1273   memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
1274 }
1275 
nnc_aten_mean(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t args_num,int64_t * extra_args)1276 void nnc_aten_mean(
1277     int64_t bufs_num,
1278     void** buf_data,
1279     int64_t* buf_ranks,
1280     int64_t* buf_dims,
1281     int64_t* buf_strides,
1282     int8_t* buf_dtypes,
1283     int64_t args_num,
1284     int64_t* extra_args) {
1285   auto tensors = constructTensors(
1286       bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
1287 
1288   at::Tensor& r = tensors[0];
1289   const at::Tensor& x = tensors[1];
1290   std::vector<int64_t> mean_dims(args_num - 1);
1291   bool keepdim = (bool)extra_args[args_num - 1];
1292   if (args_num > 1) {
1293     memcpy(mean_dims.data(), extra_args, sizeof(int64_t) * (args_num - 1));
1294   }
1295   try {
1296     at::mean_out(r, x, mean_dims, keepdim);
1297   } catch (...) {
1298   }
1299 }
1300 
nnc_aten_max_red(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t args_num,int64_t * extra_args)1301 void nnc_aten_max_red(
1302     int64_t bufs_num,
1303     void** buf_data,
1304     int64_t* buf_ranks,
1305     int64_t* buf_dims,
1306     int64_t* buf_strides,
1307     int8_t* buf_dtypes,
1308     int64_t args_num,
1309     int64_t* extra_args) {
1310   auto tensors = constructTensors(
1311       bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
1312 
1313   at::Tensor& r = tensors[0];
1314   const at::Tensor& x = tensors[1];
1315   int64_t max_dim = extra_args[0];
1316   bool keep_dim = extra_args[1];
1317   try {
1318     r = std::get<0>(at::max(x, max_dim, keep_dim));
1319   } catch (...) {
1320   }
1321   memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
1322 }
1323 
nnc_aten_max_red_out(int64_t bufs_in_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t,int64_t * extra_args)1324 void nnc_aten_max_red_out(
1325     int64_t bufs_in_num,
1326     void** buf_data,
1327     int64_t* buf_ranks,
1328     int64_t* buf_dims,
1329     int64_t* buf_strides,
1330     int8_t* buf_dtypes,
1331     int64_t,
1332     int64_t* extra_args) {
1333   size_t bufs_out_num = 1u;
1334   auto tensors = constructTensors2(
1335       bufs_in_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
1336 
1337   at::Tensor r;
1338   // @lint-ignore CLANGTIDY
1339   const at::Tensor& x = tensors[1];
1340   int64_t max_dim = extra_args[0];
1341   bool keep_dim = extra_args[1];
1342   try {
1343     r = std::get<0>(at::max(x, max_dim, keep_dim));
1344   } catch (...) {
1345   }
1346   buf_data[0] = r.data_ptr();
1347   c10::raw::intrusive_ptr::incref(r.getIntrusivePtr().get());
1348   buf_data[bufs_in_num + bufs_out_num] = r.getIntrusivePtr().get();
1349 }
1350 
nnc_aten_addmm(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t args_num,int64_t * extra_args)1351 void nnc_aten_addmm(
1352     int64_t bufs_num,
1353     void** buf_data,
1354     int64_t* buf_ranks,
1355     int64_t* buf_dims,
1356     int64_t* buf_strides,
1357     int8_t* buf_dtypes,
1358     int64_t args_num,
1359     int64_t* extra_args) {
1360   auto tensors = constructTensors(
1361       bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
1362 
1363   at::Tensor& r = tensors[0];
1364   const at::Tensor& x = tensors[1];
1365   const at::Tensor& y = tensors[2];
1366   const at::Tensor& z = tensors[3];
1367   // TODO: handle other alpha and beta dtypes, e.g. alpha=0.6, beta=0.2
1368   int64_t beta = extra_args[0], alpha = extra_args[1];
1369 
1370   try {
1371     at::addmm_out(r, x, y, z, beta, alpha);
1372   } catch (...) {
1373   }
1374 }
1375 
1376 // Only provides first output, the second output is just a copy of one of the
1377 // inputs
nnc_aten_triangular_solve(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t args_num,int64_t * extra_args)1378 void nnc_aten_triangular_solve(
1379     int64_t bufs_num,
1380     void** buf_data,
1381     int64_t* buf_ranks,
1382     int64_t* buf_dims,
1383     int64_t* buf_strides,
1384     int8_t* buf_dtypes,
1385     int64_t args_num,
1386     int64_t* extra_args) {
1387   auto tensors = constructTensors(
1388       bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
1389   at::Tensor& r = tensors[0];
1390   at::Tensor r2 = tensors[2].clone();
1391   const at::Tensor& input = tensors[1];
1392   const at::Tensor& A = tensors[2];
1393   try {
1394     at::triangular_solve_out(
1395         r, r2, input, A, extra_args[0], extra_args[2], extra_args[3]);
1396   } catch (...) {
1397   }
1398 }
1399 
1400 #if AT_MKLDNN_ENABLED()
1401 
nnc_mkldnn_prepacked_conv_run(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t args_num,int64_t * extra_args)1402 void nnc_mkldnn_prepacked_conv_run(
1403     int64_t bufs_num,
1404     void** buf_data,
1405     int64_t* buf_ranks,
1406     int64_t* buf_dims,
1407     int64_t* buf_strides,
1408     int8_t* buf_dtypes,
1409     int64_t args_num,
1410     int64_t* extra_args) {
1411   using namespace at::native::mkldnn;
1412 
1413   auto tensors = constructTensors(
1414       bufs_num - 1, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
1415 
1416   const at::Tensor& x = tensors[1];
1417   auto context = reinterpret_cast<ConvOpContext*>(buf_data[2]);
1418 
1419   context->run(x, buf_data[0]);
1420 }
1421 
1422 #endif // AT_MKLDNN_ENABLED()
1423 
1424 #ifdef USE_XNNPACK
1425 
nnc_prepacked_linear_clamp_run(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t args_num,int64_t * extra_args)1426 void nnc_prepacked_linear_clamp_run(
1427     int64_t bufs_num,
1428     void** buf_data,
1429     int64_t* buf_ranks,
1430     int64_t* buf_dims,
1431     int64_t* buf_strides,
1432     int8_t* buf_dtypes,
1433     int64_t args_num,
1434     int64_t* extra_args) {
1435   using namespace at::native::xnnpack;
1436 
1437   auto tensors = constructTensors(
1438       bufs_num - 1, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
1439 
1440   const at::Tensor& x = tensors[1];
1441   auto context = reinterpret_cast<LinearOpContext*>(buf_data[2]);
1442   at::Tensor output = context->run(x);
1443   memcpy(
1444       buf_data[0],
1445       output.const_data_ptr(),
1446       output.element_size() * output.numel());
1447 }
1448 
nnc_prepacked_conv2d_clamp_run(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t args_num,int64_t * extra_args)1449 void nnc_prepacked_conv2d_clamp_run(
1450     int64_t bufs_num,
1451     void** buf_data,
1452     int64_t* buf_ranks,
1453     int64_t* buf_dims,
1454     int64_t* buf_strides,
1455     int8_t* buf_dtypes,
1456     int64_t args_num,
1457     int64_t* extra_args) {
1458   using namespace at::native::xnnpack;
1459 
1460   auto tensors = constructTensors(
1461       bufs_num - 1, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
1462 
1463   const at::Tensor& x = tensors[1];
1464   auto context = reinterpret_cast<Conv2dOpContext*>(buf_data[2]);
1465   at::Tensor output = context->run(x);
1466   memcpy(
1467       buf_data[0],
1468       output.const_data_ptr(),
1469       output.element_size() * output.numel());
1470 }
1471 
1472 #endif // USE_XNNPACK
1473 
nnc_aten_embedding(int64_t bufs_num,void ** buf_data,int64_t * buf_ranks,int64_t * buf_dims,int64_t * buf_strides,int8_t * buf_dtypes,int64_t args_num,int64_t * extra_args)1474 void nnc_aten_embedding(
1475     int64_t bufs_num,
1476     void** buf_data,
1477     int64_t* buf_ranks,
1478     int64_t* buf_dims,
1479     int64_t* buf_strides,
1480     int8_t* buf_dtypes,
1481     int64_t args_num,
1482     int64_t* extra_args) {
1483   auto tensors = constructTensors(
1484       bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
1485 
1486   at::Tensor& r = tensors[0];
1487   const at::Tensor& weight = tensors[1];
1488   const at::Tensor& indices = tensors[2];
1489   try {
1490     r = at::embedding(weight, indices);
1491   } catch (...) {
1492   }
1493   // TODO: have to copy output because at::embedding doesnt have an out
1494   // variant and NNC's external calls don't support allocations
1495   memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel());
1496 }
1497 
1498 #ifndef C10_MOBILE
1499 
1500 const static RegisterNNCExternalFunction nnc_conv2d(
1501     "nnc_aten_conv2d",
1502     nnc_aten_conv2d);
1503 
1504 const static RegisterNNCExternalFunction nnc_quantized_conv1d(
1505     "nnc_aten_quantized_conv1d",
1506     nnc_aten_quantized_conv1d);
1507 const static RegisterNNCExternalFunction nnc_quantized_conv1d_out(
1508     "nnc_aten_quantized_conv1d_out",
1509     nnc_aten_quantized_conv1d_out);
1510 const static RegisterNNCExternalFunction nnc_quantized_conv2d(
1511     "nnc_aten_quantized_conv2d",
1512     nnc_aten_quantized_conv2d);
1513 const static RegisterNNCExternalFunction nnc_quantized_conv2d_out(
1514     "nnc_aten_quantized_conv2d_out",
1515     nnc_aten_quantized_conv2d_out);
1516 const static RegisterNNCExternalFunction nnc_quantized_conv2d_relu(
1517     "nnc_aten_quantized_conv2d_relu",
1518     nnc_aten_quantized_conv2d_relu);
1519 const static RegisterNNCExternalFunction nnc_quantized_conv2d_relu_out(
1520     "nnc_aten_quantized_conv2d_relu_out",
1521     nnc_aten_quantized_conv2d_relu_out);
1522 const static RegisterNNCExternalFunction nnc_quantized_linear(
1523     "nnc_aten_quantized_linear",
1524     nnc_aten_quantized_linear);
1525 const static RegisterNNCExternalFunction nnc_quantized_linear_out(
1526     "nnc_aten_quantized_linear_out",
1527     nnc_aten_quantized_linear_out);
1528 #ifndef _WIN32
1529 const static RegisterNNCExternalFunction nnc_quantized_add(
1530     "nnc_aten_quantized_add",
1531     nnc_aten_quantized_add);
1532 const static RegisterNNCExternalFunction nnc_quantized_mul(
1533     "nnc_aten_quantized_mul",
1534     nnc_aten_quantized_mul);
1535 const static RegisterNNCExternalFunction nnc_quantized_mul_out(
1536     "nnc_aten_quantized_mul_out",
1537     nnc_aten_quantized_mul_out);
1538 const static RegisterNNCExternalFunction nnc_quantized_mul_scalar(
1539     "nnc_aten_quantized_mul_scalar",
1540     nnc_aten_quantized_mul_scalar);
1541 const static RegisterNNCExternalFunction nnc_quantized_mul_scalar_out(
1542     "nnc_aten_quantized_mul_scalar_out",
1543     nnc_aten_quantized_mul_scalar_out);
1544 const static RegisterNNCExternalFunction nnc_quantized_sigmoid(
1545     "nnc_aten_quantized_sigmoid",
1546     nnc_aten_quantized_sigmoid);
1547 const static RegisterNNCExternalFunction nnc_quantized_sigmoid_out(
1548     "nnc_aten_quantized_sigmoid_out",
1549     nnc_aten_quantized_sigmoid_out);
1550 const static RegisterNNCExternalFunction nnc_quantized_cat(
1551     "nnc_aten_quantized_cat",
1552     nnc_aten_quantized_cat);
1553 const static RegisterNNCExternalFunction nnc_quantized_relu(
1554     "nnc_aten_quantized_relu",
1555     nnc_aten_quantized_relu);
1556 #endif // _WIN32
1557 const static RegisterNNCExternalFunction nnc_quantize_per_tensor(
1558     "nnc_aten_quantize_per_tensor",
1559     nnc_aten_quantize_per_tensor);
1560 const static RegisterNNCExternalFunction nnc_quantize_per_tensor_out(
1561     "nnc_aten_quantize_per_tensor_out",
1562     nnc_aten_quantize_per_tensor_out);
1563 const static RegisterNNCExternalFunction nnc_dequantize(
1564     "nnc_aten_dequantize",
1565     nnc_aten_dequantize);
1566 const static RegisterNNCExternalFunction nnc_dequantize_out(
1567     "nnc_aten_dequantize_out",
1568     nnc_aten_dequantize_out);
1569 
1570 const static RegisterNNCExternalFunction nnc_upsample_nearest2d(
1571     "nnc_aten_upsample_nearest2d",
1572     nnc_aten_upsample_nearest2d);
1573 const static RegisterNNCExternalFunction nnc_upsample_nearest2d_out(
1574     "nnc_aten_upsample_nearest2d_out",
1575     nnc_aten_upsample_nearest2d_out);
1576 const static RegisterNNCExternalFunction nnc_conv1d(
1577     "nnc_aten_conv1d",
1578     nnc_aten_conv1d);
1579 const static RegisterNNCExternalFunction nnc_conv1d_out(
1580     "nnc_aten_conv1d_out",
1581     nnc_aten_conv1d_out);
1582 const static RegisterNNCExternalFunction nnc_adaptive_avg_pool2d(
1583     "nnc_aten_adaptive_avg_pool2d",
1584     nnc_aten_adaptive_avg_pool2d);
1585 const static RegisterNNCExternalFunction nnc_mean(
1586     "nnc_aten_mean",
1587     nnc_aten_mean);
1588 const static RegisterNNCExternalFunction nnc_max_red(
1589     "nnc_aten_max_red",
1590     nnc_aten_max_red);
1591 const static RegisterNNCExternalFunction nnc_max_red_out(
1592     "nnc_aten_max_red_out",
1593     nnc_aten_max_red_out);
1594 const static RegisterNNCExternalFunction nnc_addmm(
1595     "nnc_aten_addmm",
1596     nnc_aten_addmm);
1597 
1598 const static RegisterNNCExternalFunction nnc_triangular_solve(
1599     "nnc_aten_triangular_solve",
1600     nnc_aten_triangular_solve);
1601 
1602 const static RegisterNNCExternalFunction nnc_embedding(
1603     "nnc_aten_embedding",
1604     nnc_aten_embedding);
1605 
1606 #if AT_MKLDNN_ENABLED()
1607 const static RegisterNNCExternalFunction reg_nnc_mkldnn_prepacked_conv_run(
1608     "nnc_mkldnn_prepacked_conv_run",
1609     nnc_mkldnn_prepacked_conv_run);
1610 #endif // AT_MKLDNN_ENABLED()
1611 
1612 #ifdef USE_XNNPACK
1613 const static RegisterNNCExternalFunction reg_nnc_prepacked_linear_clamp_run(
1614     "nnc_prepacked_linear_clamp_run",
1615     nnc_prepacked_linear_clamp_run);
1616 const static RegisterNNCExternalFunction reg_nnc_prepacked_conv2d_clamp_run(
1617     "nnc_prepacked_conv2d_clamp_run",
1618     nnc_prepacked_conv2d_clamp_run);
1619 #endif // USE_XNNPACK
1620 
1621 #endif // C10_MOBILE
1622 
1623 #ifdef C10_MOBILE
1624 } // extern "C"
1625 #endif
1626 
1627 } // namespace torch::jit::tensorexpr
1628