xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/static/ops.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/runtime/static/ops.h>
2 
3 #include <ATen/CPUFunctions.h>
4 #include <ATen/InferSize.h>
5 #include <ATen/NativeFunctions.h>
6 #include <ATen/Parallel.h>
7 #include <ATen/ScalarOps.h>
8 #include <ATen/TensorUtils.h>
9 #include <ATen/cpu/vec/functional.h>
10 #include <ATen/cpu/vec/vec.h>
11 #include <ATen/native/Fill.h>
12 #include <ATen/native/IndexingUtils.h>
13 #include <ATen/native/NonSymbolicBC.h>
14 #include <ATen/native/Resize.h>
15 #include <ATen/native/SharedReduceOps.h>
16 #include <ATen/native/TensorAdvancedIndexing.h>
17 #include <ATen/native/TensorConversions.h>
18 #include <ATen/native/cpu/SerialStackImpl.h>
19 #include <ATen/native/layer_norm.h>
20 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
21 #include <ATen/native/quantized/cpu/qembeddingbag.h>
22 #include <ATen/native/quantized/cpu/qembeddingbag_prepack.h>
23 #include <ATen/quantized/QTensorImpl.h>
24 #include <ATen/quantized/Quantizer.h>
25 #include <c10/core/ScalarType.h>
26 #include <c10/core/WrapDimMinimal.h>
27 #include <c10/util/irange.h>
28 #include <torch/csrc/jit/ir/constants.h>
29 #include <torch/csrc/jit/ir/ir.h>
30 #include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
31 #include <torch/csrc/jit/runtime/static/impl.h>
32 #include <torch/csrc/jit/runtime/static/processed_node_wrapper.h>
33 #include <torch/csrc/jit/runtime/static/te_wrapper.h>
34 #include <torch/csrc/jit/runtime/vararg_functions.h>
35 #include <torch/csrc/jit/tensorexpr/ir.h>
36 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
37 #include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
38 #include <torch/csrc/jit/tensorexpr/loopnest.h>
39 #include <iterator>
40 #include <mutex>
41 #include <unordered_map>
42 
43 #include <ATen/CompositeExplicitAutogradFunctions.h>
44 
45 C10_DEFINE_bool(
46     static_runtime_enable_fast_math,
47     true,
48     "If on, static runtime may use use optimizations that cause accuracy loss "
49     "vs the jit interpreter");
50 
51 namespace at::native {
52 
repeat_out(at::Tensor & result,const Tensor & self,IntArrayRef repeats)53 static void repeat_out(
54     at::Tensor& result,
55     const Tensor& self,
56     IntArrayRef repeats) {
57   TORCH_CHECK(
58       repeats.size() >= static_cast<size_t>(self.dim()),
59       "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor");
60 
61   // Add new leading dimensions to the tensor if the
62   // number of target dimensions is larger than the
63   // number of source dimensions.
64   int64_t num_new_dimensions = repeats.size() - self.dim();
65   DimVector padded_size(num_new_dimensions, 1);
66   padded_size.insert(
67       padded_size.end(), self.sizes().begin(), self.sizes().end());
68   DimVector target_size(repeats.size());
69   bool zero_tensor = false;
70   for (const auto idx : c10::irange(repeats.size())) {
71     if (repeats[idx] == 0) {
72       zero_tensor = true;
73     }
74     target_size[idx] = padded_size[idx] * repeats[idx];
75   }
76 
77   // return an empty tensor if one of the repeat dimensions is zero
78   at::native::resize_(result, target_size, std::nullopt);
79   if (zero_tensor) {
80     return;
81   }
82 
83   Tensor xtensor = at::compositeexplicitautograd::expand(self, padded_size);
84   Tensor urtensor = at::native::alias(result);
85   for (const auto i : c10::irange(xtensor.dim())) {
86     // can't unfold with step 0, so make sure step is at least 1
87     // (it doesn't matter what it is in that case, because the size is 0).
88     urtensor = urtensor.unfold(
89         i, xtensor.size(i), std::max<int64_t>(xtensor.size(i), 1));
90   }
91 
92   at::native::copy_(urtensor, xtensor.expand_as(urtensor));
93 }
94 
95 // copy version of view ops
reshape_copy_out(at::Tensor & out,const at::Tensor & self,const at::DimVector & proposed_shape,bool infer_size)96 at::Tensor& reshape_copy_out(
97     at::Tensor& out,
98     const at::Tensor& self,
99     const at::DimVector& proposed_shape,
100     bool infer_size) {
101   const auto& shape = infer_size
102       ? at::infer_size_dv(proposed_shape, self.numel())
103       : proposed_shape;
104   at::native::resize_(out, shape, std::nullopt);
105 
106   auto self_contig = self.expect_contiguous();
107 
108   size_t nbytes = self.nbytes();
109   if (nbytes == 0) {
110     return out;
111   }
112 
113   const void* self_data = self_contig->const_data_ptr();
114   void* out_data = out.mutable_data_ptr();
115   memcpy(out_data, self_data, nbytes);
116 
117   return out;
118 }
119 
flatten_copy_out(at::Tensor & out,const at::Tensor & self,int64_t start_dim,int64_t end_dim)120 static at::Tensor& flatten_copy_out(
121     at::Tensor& out,
122     const at::Tensor& self,
123     int64_t start_dim,
124     int64_t end_dim) {
125   start_dim =
126       start_dim < 0 ? c10::maybe_wrap_dim(start_dim, self.dim()) : start_dim;
127   end_dim = end_dim < 0 ? c10::maybe_wrap_dim(end_dim, self.dim()) : end_dim;
128   TORCH_CHECK(
129       start_dim <= end_dim,
130       "flatten() has invalid args: start_dim cannot come after end_dim");
131 
132   if (self.dim() == 0) {
133     return reshape_copy_out(out, self, at::DimVector{1}, false);
134   }
135 
136   if (start_dim == end_dim) {
137     auto shape = at::DimVector{self.sizes()};
138     return reshape_copy_out(out, self, shape, false);
139   }
140 
141   // We don't want to infer_size on the entire shape, because that can give us
142   // an extra degree of freedom we don't want; for example, consider shape [0,
143   // 1, 3, 0], with start_dim=1, end_dim=2. It's clear we want result shape [0,
144   // 3, 0] but passing [0, -1, 0] to infer_size means the -1 can take on any
145   // value and satisfy the constraints.
146   auto iter = self.sizes().data();
147   auto slice_numel = std::accumulate(
148       iter + start_dim,
149       iter + end_dim + 1,
150       static_cast<int64_t>(1),
151       // NOLINTNEXTLINE(modernize-use-transparent-functors)
152       std::multiplies<int64_t>());
153 
154   at::DimVector shape;
155   shape.reserve(self.dim() - end_dim + start_dim);
156   for (const auto i : c10::irange(start_dim)) {
157     shape.push_back(self.sizes()[i]);
158   }
159   shape.push_back(slice_numel);
160   for (int64_t i = end_dim + 1; i < self.dim(); i++) {
161     shape.push_back(self.sizes()[i]);
162   }
163   return reshape_copy_out(out, self, shape, false);
164 }
165 
166 namespace {
167 
168 // This is annoying and sily, but it's solving a real problem: the
169 // _MSC_VER version causes an ICE on our old clang5 builds. The
170 // non-_MSC_VER version is a syntax error according to MSVC. Use the
171 // appropriate version depending on if we're MSVC or not.
172 
173 #define TO_COPY_OUT_FAST_PATH_LOGIC(out, self, self_t)             \
174   do {                                                             \
175     const auto N = self.numel();                                   \
176     const auto self_data = self.const_data_ptr<self_t>();          \
177     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(                        \
178         kHalf,                                                     \
179         kBFloat16,                                                 \
180         kBool,                                                     \
181         out.scalar_type(),                                         \
182         "to_copy_out_inner_loop",                                  \
183         [&]() {                                                    \
184           const auto out_data = out.mutable_data_ptr<scalar_t>();  \
185           for (const auto idx : c10::irange(N)) {                  \
186             /* NOLINTNEXTLINE(bugprone-signed-char-misuse) */      \
187             out_data[idx] = static_cast<scalar_t>(self_data[idx]); \
188           }                                                        \
189         });                                                        \
190   } while (0)
191 
192 #ifdef _MSC_VER
193 template <typename T>
to_copy_out_fast_path(Tensor & out,const Tensor & self)194 void to_copy_out_fast_path(Tensor& out, const Tensor& self) {
195   TO_COPY_OUT_FAST_PATH_LOGIC(out, self, T);
196 }
197 
198 #define TO_COPY_OUT_FAST_PATH_BODY(out, self) \
199   to_copy_out_fast_path<scalar_t>(out, self)
200 #else
201 #define TO_COPY_OUT_FAST_PATH_BODY(out, self) \
202   using self_t = scalar_t;                    \
203   TO_COPY_OUT_FAST_PATH_LOGIC(out, self, self_t)
204 #endif
205 } // namespace
206 
to_copy_out(Tensor & out,const Tensor & self,bool non_blocking,bool copy_strides,std::optional<MemoryFormat> memory_format)207 at::Tensor& to_copy_out(
208     Tensor& out,
209     const Tensor& self,
210     bool non_blocking,
211     bool copy_strides,
212     std::optional<MemoryFormat> memory_format) {
213   if (copy_strides) {
214     at::native::resize_impl_cpu_(
215         out.unsafeGetTensorImpl(), self.sizes(), self.strides());
216   } else {
217     at::native::resize_(out, self.sizes(), std::nullopt);
218   }
219   auto is_unsupported_dtype = [](ScalarType t) {
220 #define TORCH_OPS_UNSUPPORTED_TYPE(_, type) \
221   case k##type:                             \
222     return true;
223     switch (t) {
224       AT_FORALL_QINT_TYPES(TORCH_OPS_UNSUPPORTED_TYPE)
225       AT_FORALL_COMPLEX_TYPES(TORCH_OPS_UNSUPPORTED_TYPE)
226       default:
227         return false;
228     }
229 #undef TORCH_OPS_UNSUPPORTED_TYPE
230   };
231   // Fast path: can we just copy the data ourselves? Avoids creating a
232   // TensorIterator in at::native::copy_, which is relatively
233   // expensive.
234   if (self.is_contiguous() && !non_blocking &&
235       // Did the user request us to make a copy that isn't contiguous?
236       (memory_format == std::nullopt ||
237        memory_format == c10::MemoryFormat::Preserve ||
238        memory_format == c10::MemoryFormat::Contiguous) &&
239       // CopyKernel.cpp handles this case specially, so let's not mess
240       // with it.
241       !self.is_neg() && !is_unsupported_dtype(self.dtype().toScalarType()) &&
242       !is_unsupported_dtype(out.dtype().toScalarType()) &&
243       !(
244           // FBGEMM optimization might kick in, don't interfere with
245           // that.
246           (self.dtype() == kFloat && out.dtype() == kHalf) ||
247           (self.dtype() == kHalf && out.dtype() == kFloat))) {
248     AT_DISPATCH_ALL_TYPES_AND3(
249         kHalf, kBFloat16, kBool, self.scalar_type(), "to_copy_out", [&]() {
250           TO_COPY_OUT_FAST_PATH_BODY(out, self);
251         });
252     return out;
253   }
254   at::native::copy_(out, self, non_blocking);
255   return out;
256 }
257 
linear_out(Tensor & output,const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias_opt)258 static Tensor& linear_out(
259     Tensor& output,
260     const Tensor& input,
261     const Tensor& weight,
262     const std::optional<Tensor>& bias_opt) {
263   TORCH_CHECK(!input.is_mkldnn());
264 
265   auto bias = bias_opt.has_value()
266       ? c10::MaybeOwned<Tensor>::borrowed(*bias_opt)
267       : c10::MaybeOwned<Tensor>::owned(std::in_place);
268 
269   if (input.dim() == 2 && bias->defined()) {
270     // Fused op is marginally faster.
271     return at::cpu::addmm_out(output, *bias, input, weight.t());
272   }
273   at::native::matmul_out(input, weight.t(), output);
274   if (bias->defined()) {
275     at::cpu::add_(output, *bias);
276   }
277   return output;
278 }
279 
c2_argmin_out(Tensor & output,const Tensor & input,const int64_t dim,const bool keepdim)280 static Tensor& c2_argmin_out(
281     Tensor& output,
282     const Tensor& input,
283     const int64_t dim,
284     const bool keepdim) {
285   const auto ndim = input.dim();
286   int64_t dim_ = maybe_wrap_dim(dim, ndim);
287   TORCH_CHECK(dim_ >= 0 && dim_ < ndim);
288 
289   const auto in_dims = input.sizes();
290 
291   c10::SmallVector<int64_t, 5> out_dims;
292   out_dims.reserve(ndim);
293   int prev_size = 1;
294   int next_size = 1;
295   for (int i = 0; i < dim_; ++i) {
296     out_dims.push_back(in_dims[i]);
297     prev_size *= in_dims[i];
298   }
299   if (keepdim) {
300     out_dims.push_back(1);
301   }
302   for (auto i = dim_ + 1; i < ndim; ++i) {
303     out_dims.push_back(in_dims[i]);
304     next_size *= in_dims[i];
305   }
306   at::native::resize_(output, out_dims, std::nullopt);
307 
308   const auto n = in_dims[dim_];
309 
310   if (next_size == 1) {
311     AT_DISPATCH_ALL_TYPES_AND2(
312         kHalf, kBFloat16, input.scalar_type(), "argmin_input", [&]() {
313           const auto in_ptr = input.const_data_ptr<scalar_t>();
314           const auto out_ptr = output.mutable_data_ptr<int64_t>();
315           // input is a [prev_size, n] tensor.
316           // output is a [prev_size,] tensor.
317           // Thus, access is contiguous/coalesced.
318           for (int i = 0; i < prev_size; ++i) {
319             auto v = std::min_element(
320                 in_ptr + i * n,
321                 in_ptr + (i + 1) * n,
322                 [](scalar_t a, scalar_t b) {
323                   // if a is nan, then a is *less* than b with LessOrNan
324                   // semantics
325                   if (at::_isnan(a)) {
326                     return true;
327                   }
328                   // if a is not nan and b is nan, then a is not less than b
329                   // with LessOrNan semantics otherwise, act normally. If `b` is
330                   // NaN then a < b will always return false, so this is
331                   // equivalent to the first snippet.
332                   return a < b;
333                 });
334             out_ptr[i] = std::distance(in_ptr + i * n, v);
335           }
336         });
337   } else {
338     AT_DISPATCH_ALL_TYPES_AND2(
339         kHalf, kBFloat16, input.scalar_type(), "argmin_input", [&]() {
340           const auto less_or_nan = native::detail::LessOrNan<scalar_t>{};
341 
342           const auto in_ptr = input.const_data_ptr<scalar_t>();
343           const auto out_ptr = output.mutable_data_ptr<int64_t>();
344 
345           std::memset(out_ptr, 0, prev_size * next_size * sizeof(int64_t));
346 
347           for (int i = 0; i < prev_size; ++i) {
348             const scalar_t* cur_in_ptr = in_ptr + i * n * next_size + next_size;
349             for (int k = 1; k < n; ++k) {
350               for (int j = 0; j < next_size; ++j) {
351                 int64_t* cur_out_ptr = out_ptr + i * next_size + j;
352                 if (less_or_nan(
353                         *cur_in_ptr,
354                         in_ptr
355                             [i * n * next_size + *cur_out_ptr * next_size + j],
356                         *cur_out_ptr,
357                         k)) {
358                   *cur_out_ptr = k;
359                 }
360                 ++cur_in_ptr;
361               }
362             }
363           }
364         });
365   }
366   return output;
367 }
368 
dequantize_copy_out(Tensor & out,const Tensor & self)369 static at::Tensor& dequantize_copy_out(Tensor& out, const Tensor& self) {
370   if (C10_UNLIKELY(!self.is_quantized())) {
371     // fallback to dequantize_cpu equivalent case: make sure out is at::kFloat
372     DCHECK(out.scalar_type() == kFloat);
373     return at::native::to_copy_out(out, self, false, false, std::nullopt);
374   }
375   return get_qtensorimpl(self)->quantizer()->dequantize_out(out, self);
376 }
377 } // namespace at::native
378 
379 namespace torch::jit {
380 
381 C10_DEFINE_REGISTRY(SROperatorRegistry, SROperatorFunctor);
382 
opIsRegistered(const c10::Symbol & op_name)383 bool opIsRegistered(const c10::Symbol& op_name) {
384   const std::string name(op_name.toQualString());
385   return SROperatorRegistry()->Has(name);
386 }
387 
disableUnsafeMathOp(const char * op_name)388 static bool disableUnsafeMathOp(const char* op_name) {
389   if (FLAGS_static_runtime_enable_fast_math) {
390     return false;
391   }
392   // This list contains ops that use caffe2 math library or use NNC that does
393   // not guarantee bit exactness vs the jit interpreter. Note aten::relu is not
394   // included even though it uses NNC because the results of relu should always
395   // match.
396   static const c10::FastSet<std::string> fast_ops{
397       "aten::add", "aten::tanh", "aten::sigmoid", "aten::logit"};
398   return fast_ops.count(op_name) > 0;
399 }
400 
getOutOfPlaceOperation(Node * n)401 SROperator getOutOfPlaceOperation(Node* n) {
402   auto op_name = n->kind().toQualString();
403   if (SROperatorRegistry()->Has(op_name) && !disableUnsafeMathOp(op_name)) {
404     return SROperatorRegistry()->Create(op_name)->Generate(n);
405   }
406 
407   return nullptr;
408 }
409 
410 // Returns true if the node represents an op with variadic arguments.
hasVarArgs(Node * n)411 bool hasVarArgs(Node* n) {
412   if (n->kind() == prim::VarConcat || n->kind() == prim::VarStack) {
413     return true;
414   }
415   return false;
416 }
417 
canReuseInputsOutputs(Node * n,const c10::FastMap<Node *,bool> & node_has_out_variant)418 bool canReuseInputsOutputs(
419     Node* n,
420     const c10::FastMap<Node*, bool>& node_has_out_variant) {
421   auto it = node_has_out_variant.find(n);
422   if (it != node_has_out_variant.end()) {
423     return it->second;
424   }
425   return getOutOfPlaceOperation(n) != nullptr;
426 }
427 
428 // returns true if the producers of the inputs
429 // to this operations are out of place.
430 // This means the IValues will not change run to run
inputsCanRunOutOfPlace(Node * n,const c10::FastMap<Node *,bool> & node_has_out_variant)431 static bool inputsCanRunOutOfPlace(
432     Node* n,
433     const c10::FastMap<Node*, bool>& node_has_out_variant) {
434   for (auto* input : n->inputs()) {
435     if (!canReuseInputsOutputs(input->node(), node_has_out_variant)) {
436       return false;
437     }
438   }
439   return true;
440 }
441 
isOptimizableContainerType(Node * n,const c10::FastMap<Node *,bool> & node_has_out_variant)442 bool isOptimizableContainerType(
443     Node* n,
444     const c10::FastMap<Node*, bool>& node_has_out_variant) {
445   const auto& type = n->output()->type();
446   bool is_supported_type = false;
447   if (type->kind() == TypeKind::ListType) {
448     const auto& list_type = type->expectRef<ListType>();
449     is_supported_type =
450         list_type.getElementType()->kind() == TypeKind::TensorType;
451   } else if (type->kind() == TypeKind::TupleType) {
452     const auto& tuple_type = type->expectRef<TupleType>();
453     auto types = tuple_type.containedTypes();
454     const auto& iter =
455         std::find_if(types.begin(), types.end(), [](const TypePtr& elem) {
456           return elem->kind() == TypeKind::TensorType;
457         });
458     is_supported_type = iter != types.end();
459   }
460   return is_supported_type && inputsCanRunOutOfPlace(n, node_has_out_variant);
461 }
462 
listConstructSlowPath(const ListType & list_type,const size_t size,ProcessedNode * p_node)463 static inline void listConstructSlowPath(
464     const ListType& list_type,
465     const size_t size,
466     ProcessedNode* p_node) {
467   c10::List<IValue> vals(list_type.getElementType());
468   vals.reserve(size);
469   for (const auto i : c10::irange(size)) {
470     vals.push_back(p_node->Input(i));
471   }
472   p_node->Output(0) = vals;
473 }
474 
sr_schema_check_kind(torch::jit::Node * node,c10::Symbol node_kind)475 bool sr_schema_check_kind(torch::jit::Node* node, c10::Symbol node_kind) {
476   auto is_match = node->kind() == node_kind;
477   if (!is_match) {
478     torch::jit::LogAndDumpSchema(node);
479   }
480   return is_match;
481 }
482 
483 REGISTER_OPERATOR_FUNCTOR(
484     prim::ListConstruct,
485     prim_ListConstruct,
__anon11f46a8b0802(Node* n) 486     [](Node* n) -> SROperator {
487       if (!sr_schema_check_kind(n, prim::ListConstruct)) {
488         return nullptr;
489       }
490       const bool can_optimize =
491           isOptimizableContainerType(n, c10::FastMap<Node*, bool>());
492       const auto& type = n->output()->type()->expectRef<ListType>();
493       const size_t size = n->inputs().size();
494       if (!can_optimize) {
495         return [&type, size](ProcessedNode* p_node) {
496           DCHECK(p_node->num_inputs() == size);
497           listConstructSlowPath(type, size, p_node);
498         };
499       }
500       return [&type, size](ProcessedNode* p_node) {
501         DCHECK(p_node->num_inputs() == size);
502         const auto& out_l = p_node->Output(0);
503         if (!out_l.isNone()) {
504           return;
505         }
506         listConstructSlowPath(type, size, p_node);
507       };
508     });
509 
tupleConstructSlowPath(const size_t size,ProcessedNode * p_node)510 static inline void tupleConstructSlowPath(
511     const size_t size,
512     ProcessedNode* p_node) {
513   // prepare inputs
514   switch (size) {
515     case 1:
516       p_node->Output(0) = c10::ivalue::Tuple::create(p_node->Input(0));
517       break;
518     case 2:
519       p_node->Output(0) =
520           c10::ivalue::Tuple::create(p_node->Input(0), p_node->Input(1));
521       break;
522     case 3:
523       p_node->Output(0) = c10::ivalue::Tuple::create(
524           p_node->Input(0), p_node->Input(1), p_node->Input(2));
525       break;
526     default: {
527       std::vector<IValue> vals;
528       vals.reserve(size);
529       for (const auto i : c10::irange(size)) {
530         vals.push_back(p_node->Input(i));
531       }
532       p_node->Output(0) = c10::ivalue::Tuple::create(std::move(vals));
533       break;
534     }
535   }
536 }
537 
538 REGISTER_OPERATOR_FUNCTOR(
539     prim::TupleConstruct,
540     prim_TupleConstruct,
__anon11f46a8b0b02(Node* n) 541     [](Node* n) -> SROperator {
542       if (!sr_schema_check_kind(n, prim::TupleConstruct)) {
543         return nullptr;
544       }
545       const bool can_optimize =
546           isOptimizableContainerType(n, c10::FastMap<Node*, bool>());
547       const size_t size = n->inputs().size();
548       if (!can_optimize) {
549         return [size](ProcessedNode* p_node) {
550           DCHECK(p_node->num_inputs() == size);
551           tupleConstructSlowPath(size, p_node);
552         };
553       }
554       return [size](ProcessedNode* p_node) {
555         DCHECK(p_node->num_inputs() == size);
556         const auto& out_l = p_node->Output(0);
557         if (!out_l.isNone()) {
558           return;
559         }
560         tupleConstructSlowPath(size, p_node);
561       };
562     });
563 
__anon11f46a8b0e02(Node* n) 564 REGISTER_OPERATOR_FUNCTOR(aten::abs, aten_abs, [](Node* n) -> SROperator {
565   if (!n->matches(torch::schema("aten::abs(Tensor self) -> Tensor"))) {
566     LogAndDumpSchema(n);
567     return nullptr;
568   }
569   return [](ProcessedNode* p_node) {
570     const auto& in0_t = p_node->Input(0).toTensor();
571     if (p_node->Output(0).isNone()) {
572       p_node->Output(0) = at::native::abs(in0_t);
573       return;
574     }
575     auto& out_t = p_node->Output(0).toTensor();
576     fastResizeToZero(out_t);
577     at::native::abs_out(in0_t, out_t);
578   };
579 });
580 
__anon11f46a8b1002(Node* n) 581 REGISTER_OPERATOR_FUNCTOR(aten::mul, aten_mul, [](Node* n) -> SROperator {
582   if (!n->matches(torch::schema(
583           "aten::mul.Tensor(Tensor self, Tensor other) -> Tensor"))) {
584     LogAndDumpSchema(n);
585     return nullptr;
586   }
587 
588   return [](ProcessedNode* p_node) {
589     const auto& in0_t = p_node->Input(0).toTensor();
590     const auto& in1_t = p_node->Input(1).toTensor();
591     if (p_node->Output(0).isNone()) {
592       p_node->Output(0) = at::cpu::mul(in0_t, in1_t);
593       return;
594     }
595     auto& out_t = p_node->Output(0).toTensor();
596     fastResizeToZero(out_t);
597     at::cpu::mul_out(out_t, in0_t, in1_t);
598   };
599 });
600 
__anon11f46a8b1202(Node* n) 601 REGISTER_OPERATOR_FUNCTOR(aten::addmm, aten_addmm, [](Node* n) -> SROperator {
602   if (!n->matches(torch::schema(
603           "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"))) {
604     LogAndDumpSchema(n);
605     return nullptr;
606   }
607   return [](ProcessedNode* p_node) {
608     const auto& in0_t = p_node->Input(0).toTensor();
609     const auto& in1_t = p_node->Input(1).toTensor();
610     const auto& in2_t = p_node->Input(2).toTensor();
611     const auto in3_s = p_node->Input(3).toScalar();
612     const auto in4_s = p_node->Input(4).toScalar();
613     if (p_node->Output(0).isNone()) {
614       p_node->Output(0) = at::cpu::addmm(in0_t, in1_t, in2_t, in3_s, in4_s);
615       return;
616     }
617     auto& out_t = p_node->Output(0).toTensor();
618     fastResizeToZero(out_t);
619     at::cpu::addmm_out(out_t, in0_t, in1_t, in2_t, in3_s, in4_s);
620   };
621 });
622 
623 #ifdef FBCODE_CAFFE2
624 // Disable externally to avoid MSVC errors in open-source CI
625 
626 REGISTER_OPERATOR_FUNCTOR(
627     static_runtime::clamp_nan_to_num,
628     static_runtime_clamp_nan_to_num,
__anon11f46a8b1402(Node* n) 629     [](Node* n) -> SROperator {
630       if (!sr_schema_check(
631               n,
632               "static_runtime::clamp_nan_to_num(Tensor input, Scalar? min, Scalar? max, float? nan, float? posinf, float? posinf) -> Tensor")) {
633         return nullptr;
634       }
635       auto clamp_min_ival_opt = toIValue(n->input(1));
636       auto clamp_max_ival_opt = toIValue(n->input(2));
637       TORCH_CHECK(
638           clamp_min_ival_opt.has_value() && clamp_max_ival_opt.has_value());
639 
640       auto clamp_min_opt = clamp_min_ival_opt->toOptional<at::Scalar>();
641       auto clamp_max_opt = clamp_max_ival_opt->toOptional<at::Scalar>();
642       TORCH_CHECK(clamp_min_opt.has_value() && clamp_max_opt.has_value());
643 
644       return [te = createClampNanToNum(),
645               clamp_min = clamp_min_opt->to<float>(),
646               clamp_max =
647                   clamp_max_opt->to<float>()](ProcessedNode* p_node) mutable {
648         const auto& in0_t = p_node->Input(0).toTensor();
649         if (p_node->Output(0).isNone()) {
650           p_node->Output(0) = create_empty_from(in0_t);
651         }
652         auto& out_t = p_node->Output(0).toTensor();
653         fastResizeToZero(out_t);
654         auto in3_s = p_node->Input(3).toOptional<double>();
655 
656         if (!te || !te->checkInput<float>(in0_t)) {
657           at::cpu::nan_to_num_out(
658               out_t,
659               at::cpu::clamp(in0_t, clamp_min, clamp_max),
660               in3_s,
661               std::nullopt,
662               std::nullopt);
663           return;
664         }
665         at::native::resize_(out_t, in0_t.sizes(), std::nullopt);
666 
667         auto output_size = in0_t.numel();
668 
669         // This might be UB if in3_s is absurdly large, but most implementations
670         // just turn it into `inf` in that case. The PyTorch core nan_to_num
671         // kernel just static_cast()s the limits to the destination type, so
672         // we'll ignore overflow issues here as well.
673         auto nan = in3_s.has_value() ? static_cast<float>(*in3_s) : 0.f;
674 
675         te->call(
676             {out_t.data_ptr(),
677              in0_t.data_ptr(),
678              &clamp_min,
679              &clamp_max,
680              &nan,
681              &output_size});
682       };
683     });
684 
685 #endif
686 
__anon11f46a8b1602(Node* n) 687 REGISTER_OPERATOR_FUNCTOR(aten::clamp, aten_clamp, [](Node* n) -> SROperator {
688   if (n->matches(torch::schema(
689           "aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor"))) {
690     return [te = createClamp()](ProcessedNode* p_node) {
691       const auto& in0_t = p_node->Input(0).toTensor();
692       if (p_node->Output(0).isNone()) {
693         p_node->Output(0) = create_empty_from(in0_t);
694       }
695       auto& out_t = p_node->Output(0).toTensor();
696       fastResizeToZero(out_t);
697       auto in1_s = p_node->Input(1).toOptional<at::Scalar>();
698       auto in2_s = p_node->Input(2).toOptional<at::Scalar>();
699       if (!te->checkInput<float>(in0_t)) {
700         at::cpu::clamp_out(out_t, in0_t, in1_s, in2_s);
701         return;
702       }
703       at::native::resize_(out_t, in0_t.sizes(), std::nullopt);
704       auto output_size = in0_t.numel();
705       auto min = in1_s.has_value() ? in1_s->toFloat()
706                                    : -std::numeric_limits<float>::infinity();
707       auto max = in2_s.has_value() ? in2_s->toFloat()
708                                    : std::numeric_limits<float>::infinity();
709       te->call({out_t.data_ptr(), in0_t.data_ptr(), &min, &max, &output_size});
710     };
711   }
712   if (n->matches(
713           "aten::clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor")) {
714     return [](ProcessedNode* p_node) {
715       const auto& in0_t = p_node->Input(0).toTensor();
716       if (p_node->Output(0).isNone()) {
717         p_node->Output(0) = create_empty_from(in0_t);
718       }
719       auto& out_t = p_node->Output(0).toTensor();
720       fastResizeToZero(out_t);
721       auto in1_t = p_node->Input(1).toOptional<at::Tensor>();
722       auto in2_t = p_node->Input(2).toOptional<at::Tensor>();
723       at::cpu::clamp_out(out_t, in0_t, in1_t, in2_t);
724     };
725   }
726   LogAndDumpSchema(n);
727   return nullptr;
728 });
729 
__anon11f46a8b1902(Node* n) 730 REGISTER_OPERATOR_FUNCTOR(aten::bmm, aten_bmm, [](Node* n) -> SROperator {
731   if (!n->matches(
732           torch::schema("aten::bmm(Tensor self, Tensor mat2) -> Tensor"))) {
733     LogAndDumpSchema(n);
734     return nullptr;
735   }
736   return [](ProcessedNode* p_node) {
737     const auto& in0_t = p_node->Input(0).toTensor();
738     const auto& in1_t = p_node->Input(1).toTensor();
739     if (p_node->Output(0).isNone()) {
740       p_node->Output(0) = create_empty_from(in0_t);
741     }
742     auto& out_t = p_node->Output(0).toTensor();
743     fastResizeToZero(out_t);
744     at::cpu::bmm_out(out_t, in0_t, in1_t);
745   };
746 });
747 
__anon11f46a8b1b02(Node* n) 748 REGISTER_OPERATOR_FUNCTOR(aten::nan_to_num, aten_nan_to_num, [](Node* n) -> SROperator {
749   if (!n->matches(torch::schema(
750           "aten::nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor"))) {
751     LogAndDumpSchema(n);
752     return nullptr;
753   }
754   return [](ProcessedNode* p_node) {
755     const auto& in0_t = p_node->Input(0).toTensor();
756     const auto in1_d = p_node->Input(1).toOptional<double>();
757     const auto in2_d = p_node->Input(2).toOptional<double>();
758     const auto in3_d = p_node->Input(3).toOptional<double>();
759     if (p_node->Output(0).isNone()) {
760       p_node->Output(0) = at::native::nan_to_num(in0_t, in1_d, in2_d, in3_d);
761       return;
762     }
763     auto& out_t = p_node->Output(0).toTensor();
764     fastResizeToZero(out_t);
765     at::native::nan_to_num_out(in0_t, in1_d, in2_d, in3_d, out_t);
766   };
767 });
768 
769 namespace {
770 
varStackSerialOut(at::Tensor & result,int64_t dim,const ProcessedNodeInputWrapper & inputs)771 void varStackSerialOut(
772     at::Tensor& result,
773     int64_t dim,
774     const ProcessedNodeInputWrapper& inputs) {
775   auto result_sizes = inputs[0].sizes().vec();
776   result_sizes.insert(result_sizes.begin() + dim, inputs.size());
777   at::native::resize_(result, result_sizes);
778 
779   AT_DISPATCH_FLOATING_TYPES(
780       result.scalar_type(), "varstack_serial_kernel", [&]() {
781         at::native::detail::
782             stack_serial_kernel_impl<scalar_t, ProcessedNodeInputWrapper>(
783                 result, inputs, dim);
784       });
785 }
786 
unsqueezeVarStackInputs(const ProcessedNodeInputWrapper & inputs,const int64_t dim)787 std::vector<at::Tensor> unsqueezeVarStackInputs(
788     const ProcessedNodeInputWrapper& inputs,
789     const int64_t dim) {
790   std::vector<at::Tensor> result;
791   result.reserve(inputs.size());
792   for (const auto i : c10::irange(inputs.size())) {
793     result.push_back(at::native::unsqueeze(inputs[i], dim));
794   }
795   return result;
796 }
797 
varstackNonserialOut(at::Tensor & result,const int64_t dim,const ProcessedNodeInputWrapper & inputs)798 void varstackNonserialOut(
799     at::Tensor& result,
800     const int64_t dim,
801     const ProcessedNodeInputWrapper& inputs) {
802   std::vector<at::Tensor> inputs_unsqueezed =
803       unsqueezeVarStackInputs(inputs, dim);
804   fastResizeToZero(result);
805   at::cpu::cat_outf(inputs_unsqueezed, dim, result);
806 }
807 
varStackFastOut(at::Tensor & out,int64_t dim,const ProcessedNodeInputWrapper & inputs)808 void varStackFastOut(
809     at::Tensor& out,
810     int64_t dim,
811     const ProcessedNodeInputWrapper& inputs) {
812   DCHECK(out.is_contiguous());
813   const auto num_inputs = static_cast<int64_t>(inputs.size());
814   TORCH_CHECK(num_inputs > 0, "stack expects a non-empty list of tensors");
815 
816   const auto first_tensor_shape = inputs[0].sizes();
817   for (const auto i : c10::irange(1, num_inputs)) {
818     const auto shape = inputs[i].sizes();
819     TORCH_CHECK(
820         shape == first_tensor_shape,
821         "Stack expects each tensor to be the same size, but got ",
822         first_tensor_shape,
823         " at position 0 and ",
824         shape,
825         " at position ",
826         i);
827   }
828 
829   const std::array<int64_t, 2> output_size = (dim == 0 || dim == -2)
830       ? std::array<int64_t, 2>{num_inputs, 1}
831       : std::array<int64_t, 2>{1, num_inputs};
832 
833   at::native::resize_(out, output_size, std::nullopt);
834 
835   AT_DISPATCH_ALL_TYPES(out.scalar_type(), "varStackFastOut", [&]() {
836     auto* out_data = out.mutable_data_ptr<scalar_t>();
837     for (const auto i : c10::irange(num_inputs)) {
838       auto& tensor = inputs[i];
839       auto* input_ptr = tensor.const_data_ptr<scalar_t>();
840       out_data[i] = *input_ptr;
841     }
842   });
843 }
844 
inputsAreScalars(const ProcessedNodeInputWrapper & inputs)845 bool inputsAreScalars(const ProcessedNodeInputWrapper& inputs) {
846   // All stack inputs should have the same size, so we only check
847   // the first one. If this isn't true, an exception will be thrown
848   // in the VarStack implementation
849   const auto& first_tensor = inputs[0];
850   return first_tensor.sizes()[0] == 1 && first_tensor.dim() == 1;
851 }
852 
varStackOut(ProcessedNode & pnode,int64_t dim)853 void varStackOut(ProcessedNode& pnode, int64_t dim) {
854   const auto num_inputs = pnode.num_inputs();
855   TORCH_CHECK(num_inputs > 1, "stack expects a non-empty list of tensors");
856   dim = c10::maybe_wrap_dim(dim, pnode.Input(0).toTensor().dim() + 1);
857 
858   auto inputs = ProcessedNodeInputWrapper(pnode);
859   auto& output = pnode.Output(0).toTensor();
860 
861   if (output.is_contiguous() && inputsAreScalars(inputs)) {
862     varStackFastOut(output, dim, inputs);
863     return;
864   }
865 
866   bool can_use_serial = at::native::detail::CanUseNativeSerialStack<
867       ProcessedNodeInputWrapper,
868       /*skip_overlap_check*/ true>::call(output, inputs, dim);
869 
870   if (can_use_serial) {
871     varStackSerialOut(output, dim, inputs);
872     return;
873   }
874   varstackNonserialOut(output, dim, inputs);
875 }
876 
877 } // namespace
878 
879 // Split out into a function to appease MSVC's pre-processor
aten_stack(Node * n)880 static SROperator aten_stack(Node* n) {
881   if (!n->matches(torch::schema(
882           "aten::stack(Tensor[] tensors, int dim=0) -> Tensor"))) {
883     LogAndDumpSchema(n);
884     return nullptr;
885   }
886   return [](ProcessedNode* p_node) {
887     const auto inputs = p_node->Input(0).toTensorVector();
888     TORCH_CHECK(!inputs.empty(), "stack expects non-empty tensor list");
889     const auto dim = p_node->Input(1).toInt();
890     if (p_node->Output(0).isNone()) {
891       p_node->Output(0) = at::native::_stack_cpu(inputs, dim);
892       return;
893     }
894     auto& out_t = p_node->Output(0).toTensor();
895     fastResizeToZero(out_t);
896     at::native::_stack_out_cpu(inputs, dim, out_t);
897   };
898 }
899 
900 REGISTER_OPERATOR_FUNCTOR(aten::stack, aten_stack, aten_stack);
901 
902 REGISTER_OPERATOR_FUNCTOR(
903     prim::VarStack,
904     prim_VarStack,
__anon11f46a8b2102(Node* n) 905     [](Node* n) -> SROperator {
906       if (!sr_schema_check_kind(n, prim::VarStack)) {
907         return nullptr;
908       }
909       return [](ProcessedNode* p_node) {
910         const size_t num_inputs = p_node->num_inputs();
911         const auto dim = p_node->Input(num_inputs - 1).toInt();
912 
913         if (p_node->Output(0).isNone()) {
914           p_node->Output(0) = create_empty_from(p_node->Input(0).toTensor());
915         }
916         varStackOut(*p_node, dim);
917       };
918     });
919 
__anon11f46a8b2302(Node* n) 920 REGISTER_OPERATOR_FUNCTOR(aten::leaky_relu, aten_leaky_relu, [](Node* n) -> SROperator {
921   if (!n->matches(torch::schema(
922           "aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor"))) {
923     LogAndDumpSchema(n);
924     return nullptr;
925   }
926   return [](ProcessedNode* p_node) {
927     const auto& in0_t = p_node->Input(0).toTensor();
928     const auto in1_s = p_node->Input(1).toScalar();
929     if (p_node->Output(0).isNone()) {
930       p_node->Output(0) = at::cpu::leaky_relu(in0_t, in1_s);
931       return;
932     }
933     auto& out_t = p_node->Output(0).toTensor();
934     at::cpu::leaky_relu_out(out_t, in0_t, in1_s);
935   };
936 });
937 
__anon11f46a8b2502(Node* n) 938 REGISTER_OPERATOR_FUNCTOR(aten::relu, aten_relu, [](Node* n) -> SROperator {
939   if (!n->matches(torch::schema("aten::relu(Tensor self) -> Tensor"))) {
940     LogAndDumpSchema(n);
941     return nullptr;
942   }
943   auto te = createRelu();
944   return [te](ProcessedNode* p_node) {
945     const auto& in0_t = p_node->Input(0).toTensor();
946     if (p_node->Output(0).isNone()) {
947       p_node->Output(0) = create_empty_from(in0_t);
948     }
949     auto& out_t = p_node->Output(0).toTensor();
950     if (!te->checkInput<float>(in0_t)) {
951       fastResizeToZero(out_t);
952       at::cpu::threshold_out(out_t, in0_t, 0, 0);
953       return;
954     }
955     at::native::resize_(out_t, in0_t.sizes(), std::nullopt);
956     int64_t nn = in0_t.numel();
957     te->call({out_t.data_ptr(), in0_t.data_ptr(), &nn});
958   };
959 });
960 
__anon11f46a8b2702(Node* n) 961 REGISTER_OPERATOR_FUNCTOR(aten::tanh, aten_tanh, [](Node* n) -> SROperator {
962   if (!n->matches(torch::schema("aten::tanh(Tensor self) -> Tensor"))) {
963     LogAndDumpSchema(n);
964     return nullptr;
965   }
966   auto te = createTanh();
967   return [te](ProcessedNode* p_node) {
968     const auto& in0_t = p_node->Input(0).toTensor();
969     if (p_node->Output(0).isNone()) {
970       p_node->Output(0) = create_empty_from(in0_t);
971     }
972     auto& out_t = p_node->Output(0).toTensor();
973     if (!te->checkInput<float>(in0_t)) {
974       fastResizeToZero(out_t);
975       at::cpu::tanh_out(out_t, in0_t);
976       return;
977     }
978     at::native::resize_(out_t, in0_t.sizes(), std::nullopt);
979     int64_t nn = in0_t.numel();
980     te->call({out_t.data_ptr(), in0_t.data_ptr(), &nn});
981   };
982 });
983 
984 REGISTER_OPERATOR_FUNCTOR(
985     prim::TensorExprDynamicGroup,
986     prim_TensorExprDynamicGroup,
__anon11f46a8b2902(Node* n) 987     [](Node* n) -> SROperator {
988       if (!sr_schema_check_kind(n, prim::TensorExprDynamicGroup)) {
989         return nullptr;
990       }
991       auto graph = n->g(attr::Subgraph);
992       Code code(graph, "");
993       return [code](ProcessedNode* p_node) {
994         auto num_outputs = p_node->num_outputs();
995         Stack stack;
996         if (p_node->Output(0).isNone()) {
997           stack.reserve(p_node->num_inputs());
998         } else {
999           stack.reserve(p_node->num_inputs() + num_outputs);
1000           for (const auto& o : p_node->outputs()) {
1001             stack.emplace_back(o);
1002           }
1003         }
1004         for (auto i : c10::irange(p_node->num_inputs())) {
1005           stack.emplace_back(p_node->Input(i));
1006         }
1007         runTensorExprDynamicGroup(code, stack);
1008         if (p_node->Output(0).isNone()) {
1009           TORCH_INTERNAL_ASSERT(
1010               stack.size() == num_outputs,
1011               "Unexpected # of outputs on stack after executing TensorExprDynamicGroup");
1012           for (auto i : c10::irange(num_outputs)) {
1013             p_node->Output(i) = std::move(stack[i]);
1014           }
1015         }
1016       };
1017     });
1018 
1019 REGISTER_OPERATOR_FUNCTOR(
1020     aten::sigmoid,
1021     aten_sigmoid,
__anon11f46a8b2b02(Node* n) 1022     [](Node* n) -> SROperator {
1023       if (!n->matches(torch::schema("aten::sigmoid(Tensor self) -> Tensor"))) {
1024         LogAndDumpSchema(n);
1025         return nullptr;
1026       }
1027       auto te = createSigmoid();
1028       return [te](ProcessedNode* p_node) {
1029         const auto& in0_t = p_node->Input(0).toTensor();
1030         if (p_node->Output(0).isNone()) {
1031           p_node->Output(0) = create_empty_from(in0_t);
1032         }
1033         auto& out_t = p_node->Output(0).toTensor();
1034         if (!te->checkInput<float>(in0_t)) {
1035           fastResizeToZero(out_t);
1036           at::cpu::sigmoid_out(out_t, in0_t);
1037           return;
1038         }
1039         at::native::resize_(out_t, in0_t.sizes(), std::nullopt);
1040         int64_t nn = in0_t.numel();
1041         te->call({out_t.data_ptr(), in0_t.data_ptr(), &nn});
1042       };
1043     });
1044 
__anon11f46a8b2d02(Node* n) 1045 REGISTER_OPERATOR_FUNCTOR(aten::logit, aten_logit, [](Node* n) -> SROperator {
1046   if (!n->matches(torch::schema(
1047           "aten::logit(Tensor self, float? eps=None) -> Tensor"))) {
1048     LogAndDumpSchema(n);
1049     return nullptr;
1050   }
1051   std::optional<float> clamp = std::nullopt;
1052   if (n->inputs()[1]->node()->kind() == prim::Constant) {
1053     auto clamp_d = toIValue(n->inputs()[1])->toOptional<double>();
1054     clamp = clamp_d
1055         ? std::make_optional<float>(static_cast<float>(clamp_d.value()))
1056         : std::nullopt;
1057   }
1058   auto te = clamp ? createLogit() : nullptr;
1059   float clamp_value = clamp ? *clamp : 0.0f;
1060   return [te, clamp_value](ProcessedNode* p_node) {
1061     const auto& in0_t = p_node->Input(0).toTensor();
1062     if (p_node->Output(0).isNone()) {
1063       p_node->Output(0) = create_empty_from(in0_t);
1064     }
1065     auto& out_t = p_node->Output(0).toTensor();
1066     if (!te || !te->checkInput<float>(in0_t)) {
1067       const auto& in0_t = p_node->Input(0).toTensor();
1068       const auto in1_d = p_node->Input(1).toOptional<double>();
1069       fastResizeToZero(out_t);
1070       at::native::logit_out(in0_t, in1_d, out_t);
1071       return;
1072     }
1073     at::native::resize_(out_t, in0_t.sizes(), std::nullopt);
1074     int64_t nn = in0_t.numel();
1075     float c = clamp_value;
1076     te->call({out_t.data_ptr(), in0_t.data_ptr(), &nn, &c});
1077   };
1078 });
1079 
__anon11f46a8b2f02(Node* n) 1080 REGISTER_OPERATOR_FUNCTOR(aten::clone, aten_clone, [](Node* n) -> SROperator {
1081   if (!n->matches(torch::schema(
1082           "aten::clone(Tensor self, *, MemoryFormat? memory_format=None) ->Tensor"))) {
1083     LogAndDumpSchema(n);
1084     return nullptr;
1085   }
1086   return [](ProcessedNode* p_node) {
1087     const auto& src = p_node->Input(0).toTensor();
1088     const auto& optional_memory_format =
1089         p_node->Input(1).toOptional<c10::MemoryFormat>();
1090     auto memory_format =
1091         optional_memory_format.value_or(c10::MemoryFormat::Preserve);
1092     /*
1093       disable out_variant of clone for case with stride = 0 and
1094       memory formats other than preserve. Perform dynamic allocation
1095       instead of memory reuse for simpler implementation. We could,
1096       in principle, figure out copy of strides.
1097     */
1098     if ((at::has_internal_overlap(src.unsafeGetTensorImpl()) ==
1099          at::MemOverlap::Yes) ||
1100         (memory_format != c10::MemoryFormat::Preserve)) {
1101       p_node->Output(0) = at::native::clone(src, memory_format);
1102       return;
1103     }
1104     if (p_node->Output(0).isNone()) {
1105       if (src.is_non_overlapping_and_dense()) {
1106         // Copy all strides
1107         p_node->Output(0) =
1108             at::empty_strided(src.sizes(), src.strides(), src.options());
1109       } else {
1110         memory_format = src.suggest_memory_format();
1111         p_node->Output(0) = create_empty_from(src, memory_format);
1112       }
1113     }
1114     auto& out_t = p_node->Output(0).toTensor();
1115     at::native::resize_impl_cpu_(
1116         out_t.unsafeGetTensorImpl(), src.sizes(), src.strides());
1117     at::native::copy_(out_t, src, false);
1118   };
1119 });
1120 
1121 REGISTER_OPERATOR_FUNCTOR(
1122     quantized::embedding_bag_byte_rowwise_offsets,
1123     quantized_embedding_bag_byte_rowwise_offsets,
__anon11f46a8b3102(Node* n) 1124     [](Node* n) -> SROperator {
1125       if (!n->matches(torch::schema(
1126               "quantized::embedding_bag_byte_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"))) {
1127         LogAndDumpSchema(n);
1128         return nullptr;
1129       }
1130       return [](ProcessedNode* p_node) {
1131         const auto& weight = p_node->Input(0).toTensor();
1132         const auto& indices = p_node->Input(1).toTensor();
1133         const auto offsets = p_node->Input(2).toOptional<at::Tensor>();
1134         const auto pruned_weights = p_node->Input(5).toBool();
1135         const auto per_sample_weights =
1136             p_node->Input(6).toOptional<at::Tensor>();
1137         const auto compressed_indices_mapping =
1138             p_node->Input(7).toOptional<at::Tensor>();
1139         const auto include_last_offset = p_node->Input(8).toBool();
1140         if (p_node->Output(0).isNone()) {
1141           p_node->Output(0) = create_empty_from(weight, at::kFloat);
1142         }
1143         auto& out_t = p_node->Output(0).toTensor();
1144         fastResizeToZero(out_t);
1145         return at::native::embedding_bag_byte_rowwise_offsets_out(
1146             out_t,
1147             weight,
1148             indices,
1149             offsets,
1150             false, // unused scale_grad_by_freq
1151             0, // unused mode
1152             pruned_weights,
1153             per_sample_weights,
1154             compressed_indices_mapping,
1155             include_last_offset);
1156       };
1157     });
1158 
1159 REGISTER_OPERATOR_FUNCTOR(
1160     quantized::embedding_bag_4bit_rowwise_offsets,
1161     embedding_bag_4bit_rowwise_offsets,
__anon11f46a8b3302(Node* n) 1162     [](Node* n) -> SROperator {
1163       if (!n->matches(torch::schema(
1164               "quantized::embedding_bag_4bit_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"))) {
1165         LogAndDumpSchema(n);
1166         return nullptr;
1167       }
1168       return [](ProcessedNode* p_node) {
1169         const auto& weight = p_node->Input(0).toTensor();
1170         const auto& indices = p_node->Input(1).toTensor();
1171         const auto offsets = p_node->Input(2).toOptional<at::Tensor>();
1172         const auto pruned_weights = p_node->Input(5).toBool();
1173         const auto per_sample_weights =
1174             p_node->Input(6).toOptional<at::Tensor>();
1175         const auto compressed_indices_mapping =
1176             p_node->Input(7).toOptional<at::Tensor>();
1177         const auto include_last_offset = p_node->Input(8).toBool();
1178         if (p_node->Output(0).isNone()) {
1179           p_node->Output(0) = create_empty_from(weight, at::kFloat);
1180         }
1181         auto& out_t = p_node->Output(0).toTensor();
1182         fastResizeToZero(out_t);
1183         return at::native::embedding_bag_4bit_rowwise_offsets_out(
1184             out_t,
1185             weight,
1186             indices,
1187             offsets,
1188             false, // unused scale_grad_by_freq
1189             0, // unused mode
1190             pruned_weights,
1191             per_sample_weights,
1192             compressed_indices_mapping,
1193             include_last_offset);
1194       };
1195     });
1196 
1197 REGISTER_OPERATOR_FUNCTOR(
1198     quantized::embedding_bag_byte_prepack,
1199     embedding_bag_byte_prepack,
__anon11f46a8b3502(Node* n) 1200     [](Node* n) -> SROperator {
1201       if (!n->matches(torch::schema(
1202               "quantized::embedding_bag_byte_prepack(Tensor weight) -> Tensor"))) {
1203         LogAndDumpSchema(n);
1204         return nullptr;
1205       }
1206       return [](ProcessedNode* p_node) {
1207         const auto& weight = p_node->Input(0).toTensor();
1208         if (p_node->Output(0).isNone()) {
1209           p_node->Output(0) = at::native::qembeddingbag_byte_prepack(weight);
1210           return;
1211         }
1212         auto& out_t = p_node->Output(0).toTensor();
1213         fastResizeToZero(out_t);
1214         at::native::qembeddingbag_byte_prepack_out(out_t, weight);
1215       };
1216     });
1217 
1218 // The out variant takes precedence over native
__anon11f46a8b3702(Node* n) 1219 REGISTER_OPERATOR_FUNCTOR(aten::narrow_copy, aten_narrow_copy, [](Node* n) -> SROperator {
1220   if (!n->matches(torch::schema(
1221           "aten::narrow_copy(Tensor self, int dim, int start, int length) -> Tensor"))) {
1222     LogAndDumpSchema(n);
1223     return nullptr;
1224   }
1225   return [](ProcessedNode* p_node) {
1226     const auto& self = p_node->Input(0).toTensor(); // self
1227     const auto dim = p_node->Input(1).toInt(); // dim
1228     int64_t start = 0;
1229     if (p_node->Input(2).isScalar()) {
1230       start = p_node->Input(2).toInt();
1231     } else {
1232       auto& t = p_node->Input(2).toTensor();
1233       start = t.item<int64_t>();
1234     }
1235     auto length = p_node->Input(3).toInt(); // length
1236 
1237     if (p_node->Output(0).isNone()) {
1238       p_node->Output(0) =
1239           at::native::narrow_copy_dense_cpu(self, dim, start, length);
1240       return;
1241     }
1242     auto& output = p_node->Output(0).toTensor();
1243     fastResizeToZero(output);
1244     at::native::narrow_copy_dense_cpu_out(self, dim, start, length, output);
1245   };
1246 });
__anon11f46a8b3902(Node* n) 1247 REGISTER_OPERATOR_FUNCTOR(aten::index, aten_index, [](Node* n) -> SROperator {
1248   if (!n->matches(torch::schema(
1249           "aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor"))) {
1250     LogAndDumpSchema(n);
1251     return nullptr;
1252   }
1253   return [](ProcessedNode* p_node) {
1254     const auto& in0_t = p_node->Input(0).toTensor();
1255     const auto in1_l =
1256         at::native::toListOfOptionalTensors(p_node->Input(1).toListRef());
1257     if (p_node->Output(0).isNone()) {
1258       p_node->Output(0) = at::cpu::index(in0_t, in1_l);
1259       return;
1260     }
1261     auto& out_t = p_node->Output(0).toTensor();
1262     fastResizeToZero(out_t);
1263     at::cpu::index_out(out_t, in0_t, in1_l);
1264   };
1265 });
1266 
1267 REGISTER_OPERATOR_FUNCTOR(
1268     aten::index_select,
1269     aten_index_select,
__anon11f46a8b3b02(Node* n) 1270     [](Node* n) -> SROperator {
1271       if (!n->matches(torch::schema(
1272               "aten::index_select(Tensor self, int dim, Tensor index) -> Tensor"))) {
1273         LogAndDumpSchema(n);
1274         return nullptr;
1275       }
1276       return [](ProcessedNode* p_node) {
1277         const auto& self = p_node->Input(0).toTensor();
1278         const auto dim = p_node->Input(1).toInt();
1279         const auto& index = p_node->Input(2).toTensor();
1280         if (p_node->Output(0).isNone()) {
1281           p_node->Output(0) = at::native::index_select_cpu_(self, dim, index);
1282           return;
1283         }
1284         auto& out = p_node->Output(0).toTensor();
1285         fastResizeToZero(out);
1286         at::native::index_select_out_cpu_(self, dim, index, out);
1287       };
1288     });
1289 
__anon11f46a8b3d02(Node* n) 1290 REGISTER_OPERATOR_FUNCTOR(aten::pow, aten_pow, [](Node* n) -> SROperator {
1291   if (n->matches(torch::schema(
1292           "aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor"))) {
1293     return [](ProcessedNode* p_node) {
1294       if (p_node->Output(0).isNone()) {
1295         const auto& in0_t = p_node->Input(0).toTensor();
1296         auto dtype =
1297             at::native::result_type(in0_t, p_node->Input(1).toTensor());
1298         p_node->Output(0) = create_empty_from(in0_t, dtype);
1299       }
1300       auto& out_t = p_node->Output(0).toTensor();
1301       fastResizeToZero(out_t);
1302       at::cpu::pow_out(
1303           out_t, p_node->Input(0).toTensor(), p_node->Input(1).toTensor());
1304     };
1305   }
1306   if (n->matches(torch::schema(
1307           "aten::pow.Scalar(Scalar self, Tensor exponent) -> Tensor"))) {
1308     return [](ProcessedNode* p_node) {
1309       if (p_node->Output(0).isNone()) {
1310         const auto& in1_t = p_node->Input(1).toTensor();
1311         auto dtype =
1312             at::native::result_type(p_node->Input(0).toScalar(), in1_t);
1313         p_node->Output(0) = at::native::empty_like(
1314             in1_t,
1315             dtype,
1316             in1_t.options().layout_opt(),
1317             in1_t.options().device_opt(),
1318             in1_t.options().pinned_memory_opt(),
1319             at::MemoryFormat::Preserve);
1320       }
1321       auto& out_t = p_node->Output(0).toTensor();
1322       fastResizeToZero(out_t);
1323       at::cpu::pow_out(
1324           out_t, p_node->Input(0).toScalar(), p_node->Input(1).toTensor());
1325     };
1326   }
1327   if (n->matches(torch::schema(
1328           "aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor"))) {
1329     return [](ProcessedNode* p_node) {
1330       if (p_node->Output(0).isNone()) {
1331         const auto& in0_t = p_node->Input(0).toTensor();
1332         auto dtype =
1333             at::native::result_type(in0_t, p_node->Input(1).toScalar());
1334         p_node->Output(0) = at::native::empty_like(
1335             in0_t,
1336             dtype,
1337             in0_t.options().layout_opt(),
1338             in0_t.options().device_opt(),
1339             in0_t.options().pinned_memory_opt(),
1340             at::MemoryFormat::Preserve);
1341       }
1342       auto& out_t = p_node->Output(0).toTensor();
1343       fastResizeToZero(out_t);
1344       at::cpu::pow_out(
1345           out_t, p_node->Input(0).toTensor(), p_node->Input(1).toScalar());
1346     };
1347   }
1348   LogAndDumpSchema(n);
1349   return nullptr;
1350 });
1351 
1352 namespace {
1353 
1354 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
1355 struct ToArgs {
1356   std::optional<at::ScalarType> dtype;
1357   c10::Layout layout;
1358   bool know_to_will_alias = false;
1359   std::optional<c10::MemoryFormat> memory_format;
1360 };
1361 
1362 template <bool has_constant_non_tensor_dtype_and_flags, bool has_memory_format>
extract_to_args(ProcessedNode * p_node)1363 ToArgs extract_to_args(ProcessedNode* p_node) {
1364   ToArgs result;
1365   if (!has_constant_non_tensor_dtype_and_flags && p_node->Input(1).isTensor()) {
1366     const auto& other = p_node->Input(1).toTensor();
1367     result.dtype = other.scalar_type();
1368     result.layout = other.layout();
1369     TORCH_DCHECK_EQ(other.device().type(), c10::DeviceType::CPU);
1370   } else {
1371     const auto& self = p_node->Input(0).toTensor();
1372     result.dtype = p_node->Input(1).toOptional<at::ScalarType>();
1373     result.layout = self.layout();
1374     // Static runtime only works with CPU tensors; don't need to read this.
1375     TORCH_DCHECK_EQ(self.device().type(), c10::DeviceType::CPU);
1376     result.know_to_will_alias = has_constant_non_tensor_dtype_and_flags &&
1377         (!result.dtype.has_value() ||
1378          result.dtype.value() == self.dtype().toScalarType());
1379   }
1380   if (has_memory_format) {
1381     TORCH_DCHECK_EQ(p_node->num_inputs(), 5);
1382     result.memory_format = p_node->Input(4).toOptional<c10::MemoryFormat>();
1383     result.know_to_will_alias = result.know_to_will_alias &&
1384         (result.memory_format.value_or(c10::MemoryFormat::Preserve) ==
1385          c10::MemoryFormat::Preserve);
1386   }
1387 
1388   return result;
1389 }
1390 
1391 template <bool has_constant_non_tensor_dtype_and_flags, bool has_memory_format>
1392 struct CheckToWillAlias {
calltorch::jit::__anon11f46a8b4111::CheckToWillAlias1393   static bool call(
1394       ProcessedNode* p_node,
1395       const at::Tensor& self,
1396       const ToArgs& to_args) {
1397     // The to_maybe_copy_out operator functor should have detected a
1398     // constant true `copy` argument and used to_copy instead.
1399     bool copy = false;
1400     if (has_constant_non_tensor_dtype_and_flags) {
1401       DCHECK(!p_node->Input(3).toBool());
1402       copy = false;
1403     } else {
1404       copy = p_node->Input(3).toBool();
1405     }
1406     return !copy &&
1407         (to_args.know_to_will_alias ||
1408          at::native::to_will_alias(
1409              self,
1410              to_args.dtype,
1411              to_args.layout,
1412              c10::Device{c10::DeviceType::CPU},
1413              copy,
1414              has_memory_format ? to_args.memory_format
1415                                : c10::MemoryFormat::Preserve));
1416   }
1417 };
1418 
1419 template <>
1420 struct CheckToWillAlias<true, false> {
1421   // Special case! First, there is no memory format to check. Second,
1422   // we know that the layout and device will match self, so we only
1423   // need to check the dtype.
calltorch::jit::__anon11f46a8b4111::CheckToWillAlias1424   static bool call(ProcessedNode* p_node, const at::Tensor& self) {
1425     DCHECK(!p_node->Input(3).toBool()); // !copy
1426     const auto dtype_opt = p_node->Input(1).toOptional<at::ScalarType>();
1427     return !dtype_opt.has_value() || *dtype_opt == self.dtype().toScalarType();
1428   }
1429 };
1430 
1431 // Force inlining so we don't have to branch on whether args is null
1432 // at runtime.
1433 template <bool has_constant_non_tensor_dtype_and_flags, bool has_memory_format>
to_copy_functor_impl(ProcessedNode * p_node,const ToArgs * args)1434 C10_ALWAYS_INLINE void to_copy_functor_impl(
1435     ProcessedNode* p_node,
1436     const ToArgs* args) {
1437   const auto& self = p_node->Input(0).toTensor();
1438   // ignore input 3 (copy)
1439   auto non_blocking = p_node->Input(2).toBool(); // non_blocking
1440   // handle memory format
1441   bool copy_strides = false;
1442 
1443   std::optional<c10::MemoryFormat> memory_format = c10::MemoryFormat::Preserve;
1444   std::optional<ToArgs> my_args;
1445   if (!args) {
1446     my_args = extract_to_args<
1447         has_constant_non_tensor_dtype_and_flags,
1448         has_memory_format>(p_node);
1449     args = &my_args.value();
1450   }
1451   if (has_memory_format) {
1452     memory_format = args->memory_format.value_or(c10::MemoryFormat::Preserve);
1453   }
1454 
1455   if (memory_format == c10::MemoryFormat::Preserve) {
1456     if (self.is_non_overlapping_and_dense()) {
1457       memory_format = std::nullopt;
1458       copy_strides = true;
1459     } else {
1460       memory_format = self.suggest_memory_format();
1461     }
1462   }
1463 
1464   bool need_to_allocate_output = true;
1465   if (p_node->Output(0).isTensor()) {
1466     const auto& existing_output = p_node->Output(0).toTensor();
1467     if ((!has_constant_non_tensor_dtype_and_flags &&
1468          (existing_output.dtype() != args->dtype ||
1469           existing_output.layout() != args->layout ||
1470           existing_output.device() != self.device())) ||
1471         (has_memory_format &&
1472          !existing_output.is_contiguous(
1473              memory_format.value_or(c10::MemoryFormat::Contiguous)))) {
1474       need_to_allocate_output = true;
1475     } else {
1476       need_to_allocate_output = false;
1477     }
1478   }
1479 
1480   // See Note [Explicit nullopt MemoryFormat argument]
1481   // Can't use size {0} if memory_format is ChannelLast
1482   if (need_to_allocate_output) {
1483     p_node->Output(0) = at::detail::empty_cpu(
1484         self.sizes(),
1485         args->dtype,
1486         args->layout,
1487         self.device(),
1488         std::nullopt,
1489         memory_format);
1490   } else {
1491     if (has_memory_format) {
1492       memory_format = p_node->Input(4).toOptional<c10::MemoryFormat>().value_or(
1493           c10::MemoryFormat::Preserve);
1494     } else {
1495       memory_format = c10::MemoryFormat::Preserve;
1496     }
1497   }
1498 
1499   copy_strides = copy_strides ||
1500       (memory_format == c10::MemoryFormat::Preserve &&
1501        self.is_non_overlapping_and_dense());
1502 
1503   auto& out_t = p_node->Output(0).toTensor();
1504   fastResizeToZero(out_t);
1505   at::native::to_copy_out(
1506       out_t, self, non_blocking, copy_strides, memory_format);
1507 }
1508 
1509 template <bool has_constant_non_tensor_dtype_and_flags, bool has_memory_format>
to_copy_functor(ProcessedNode * p_node)1510 void to_copy_functor(ProcessedNode* p_node) {
1511   to_copy_functor_impl<
1512       has_constant_non_tensor_dtype_and_flags,
1513       has_memory_format>(p_node, nullptr);
1514 }
1515 
1516 template <bool has_constant_non_tensor_dtype_and_flags, bool has_memory_format>
to_maybe_copy_out_functor(ProcessedNode * p_node)1517 void to_maybe_copy_out_functor(ProcessedNode* p_node) {
1518   // It would be great if we could avoid checking our arguments every
1519   // time. However, we need to make account for the possibility that
1520   // the dtype (and layout, memory format, etc.) of self changed
1521   // between iterations.
1522   ToArgs args = extract_to_args<
1523       has_constant_non_tensor_dtype_and_flags,
1524       has_memory_format>(p_node);
1525   const auto& self = p_node->Input(0).toTensor();
1526   if (CheckToWillAlias<
1527           has_constant_non_tensor_dtype_and_flags,
1528           has_memory_format>::call(p_node, self, args)) {
1529     // Don't write our Tensor output. This leaves it None if it
1530     // was never allocated (and there is a special case in the
1531     // memory planner to not start managing in this case), but
1532     // if we are oscillating between aliasing and needing to
1533     // copy, we should just leave our output in place so as not
1534     // to confuse the memory planner.
1535     p_node->Output(1) = false;
1536   } else {
1537     p_node->Output(1) = true;
1538     to_copy_functor_impl<
1539         has_constant_non_tensor_dtype_and_flags,
1540         has_memory_format>(p_node, &args);
1541   }
1542 }
1543 
1544 // Take advantage of CheckToWillAlias not requiring the args in this
1545 // case.
1546 template <>
to_maybe_copy_out_functor(ProcessedNode * p_node)1547 void to_maybe_copy_out_functor<true, false>(ProcessedNode* p_node) {
1548   const auto& self = p_node->Input(0).toTensor();
1549   if (CheckToWillAlias<true, false>::call(p_node, self)) {
1550     p_node->Output(1) = false;
1551   } else {
1552     p_node->Output(1) = true;
1553     auto args = extract_to_args<true, false>(p_node);
1554     to_copy_functor_impl<true, false>(p_node, &args);
1555   }
1556 }
1557 
node_has_constant_non_tensor_dtype_and_flags(Node * n)1558 bool node_has_constant_non_tensor_dtype_and_flags(Node* n) {
1559   const auto* input1 = n->inputs()[1];
1560   return input1->type()->kind() != TypeKind::TensorType &&
1561       input1->node()->kind() == prim::Constant &&
1562       n->inputs()[2]->node()->kind() == prim::Constant &&
1563       n->inputs()[3]->node()->kind() == prim::Constant;
1564 }
1565 
get_to_copy_functor(bool has_constant_non_tensor_dtype_and_flags,bool has_memory_format)1566 auto get_to_copy_functor(
1567     bool has_constant_non_tensor_dtype_and_flags,
1568     bool has_memory_format) {
1569   if (has_constant_non_tensor_dtype_and_flags) {
1570     if (has_memory_format) {
1571       return to_copy_functor<true, true>;
1572     } else {
1573       return to_copy_functor<true, false>;
1574     }
1575   } else {
1576     if (has_memory_format) {
1577       return to_copy_functor<false, true>;
1578     } else {
1579       return to_copy_functor<false, false>;
1580     }
1581   }
1582 }
1583 
1584 } // namespace
1585 
1586 REGISTER_OPERATOR_FUNCTOR(
1587     static_runtime::to_maybe_copy_out,
1588     aten_to_maybe_copy,
__anon11f46a8b4202(Node* n) 1589     [](Node* n) -> SROperator {
1590       // support 4- or 5-arg for adindexer/adfinder models
1591       // Keep TORCH_CHECK here because there is no alternative for fallback
1592       if (!sr_schema_check(
1593               n,
1594               "static_runtime::to_maybe_copy_out.prim_dtype(Tensor self, int? dtype=None, bool non_blocking=False, bool copy=False) -> (Tensor, bool)",
1595               "static_runtime::to_maybe_copy_out.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> (Tensor, bool)",
1596               "static_runtime::to_maybe_copy_out.other(Tensor self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> (Tensor, bool)")) {
1597         return nullptr;
1598       }
1599       TORCH_CHECK(n->inputs().size() == 4 || n->inputs().size() == 5);
1600       const bool has_constant_non_tensor_dtype_and_flags =
1601           node_has_constant_non_tensor_dtype_and_flags(n);
1602       const bool has_memory_format = n->inputs().size() == 5;
1603 
1604       // If we are going to copy always, just use the to_copy path so
1605       // that the to_maybe_copy path can assume that won't happen.
1606       if (has_constant_non_tensor_dtype_and_flags) {
1607         const auto copyArg =
1608             torch::jit::constant_as<bool>(n->inputs()[3]->node()->output());
1609         DCHECK(copyArg.has_value());
1610         if (*copyArg) {
1611           return get_to_copy_functor(
1612               has_constant_non_tensor_dtype_and_flags, has_memory_format);
1613         }
1614       }
1615       if (has_constant_non_tensor_dtype_and_flags) {
1616         if (has_memory_format) {
1617           return to_maybe_copy_out_functor<true, true>;
1618         } else {
1619           return to_maybe_copy_out_functor<true, false>;
1620         }
1621       } else {
1622         if (has_memory_format) {
1623           return to_maybe_copy_out_functor<false, true>;
1624         } else {
1625           return to_maybe_copy_out_functor<false, false>;
1626         }
1627       }
1628     });
1629 
1630 // out variant takes precedence over native
1631 // NB: This impl doesn't work for cpu->cuda copy/cast or vice versa.
1632 REGISTER_OPERATOR_FUNCTOR(
1633     static_runtime::to_copy,
1634     static_runtime_to_copy,
__anon11f46a8b4302(Node* n) 1635     [](Node* n) -> SROperator {
1636       // support 4- or 5-arg for adindexer/adfinder models
1637       // Keep TORCH_CHECK here because there is no alternative for fallback
1638       if (!sr_schema_check(
1639               n,
1640               "static_runtime::to_copy.prim_dtype(Tensor self, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor",
1641               "static_runtime::to_copy.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor",
1642               "static_runtime::to_copy.other(Tensor self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor")) {
1643         return nullptr;
1644       }
1645       TORCH_CHECK(n->inputs().size() == 4 || n->inputs().size() == 5);
1646       const bool has_constant_non_tensor_dtype_and_flags =
1647           node_has_constant_non_tensor_dtype_and_flags(n);
1648       const bool has_memory_format = n->inputs().size() == 5;
1649       return get_to_copy_functor(
1650           has_constant_non_tensor_dtype_and_flags, has_memory_format);
1651     });
1652 
1653 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
1654 REGISTER_OPERATOR_FUNCTOR(
1655     static_runtime::dequantize_copy,
1656     aten_dequantize_copy,
__anon11f46a8b4402(Node* n) 1657     [](Node* n) -> SROperator {
1658       if (!n->matches(torch::schema(
1659               "static_runtime::dequantize_copy.self(Tensor self) -> Tensor"))) {
1660         // please implement static runtime support for aten::dequantize with
1661         // TensorList
1662         LogAndDumpSchema(n);
1663         return nullptr;
1664       }
1665       return [](ProcessedNode* p_node) {
1666         const auto& self = p_node->Input(0).toTensor();
1667         if (p_node->Output(0).isNone()) {
1668           p_node->Output(0) =
1669               create_empty_from(self, at::kFloat, self.suggest_memory_format());
1670         }
1671 
1672         auto& out_t = p_node->Output(0).toTensor();
1673         fastResizeToZero(out_t);
1674         at::native::dequantize_copy_out(out_t, self);
1675       };
1676     });
1677 
1678 // Out variants for view ops are registered to a separate registry because
1679 // their outputs (views) can't participate in memory reuse.
1680 REGISTER_OPERATOR_FUNCTOR(
1681     static_runtime::reshape_copy,
1682     aten_reshape,
__anon11f46a8b4602(Node* n) 1683     [](Node* n) -> SROperator {
1684       if (!sr_schema_check(
1685               n,
1686               "static_runtime::reshape_copy(Tensor self, int[] shape) -> Tensor")) {
1687         return nullptr;
1688       }
1689       TORCH_CHECK(n->inputs().size() == 2);
1690       return [](ProcessedNode* p_node) {
1691         const auto& self = p_node->Input(0).toTensor(); // self
1692         const auto proposed_shape = p_node->Input(1).toDimVector(); // shape
1693 
1694         if (p_node->Output(0).isNone()) {
1695           p_node->Output(0) = create_empty_from(self);
1696         }
1697         auto& out = p_node->Output(0).toTensor();
1698         at::native::reshape_copy_out(out, self, proposed_shape, true);
1699       };
1700     });
1701 
1702 REGISTER_OPERATOR_FUNCTOR(
1703     static_runtime::flatten_copy,
1704     aten_flatten,
__anon11f46a8b4802(Node* n) 1705     [](Node* n) -> SROperator {
1706       if (!sr_schema_check(
1707               n,
1708               "static_runtime::flatten_copy.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> Tensor")) {
1709         return nullptr;
1710       }
1711       TORCH_CHECK(n->inputs().size() == 3);
1712       return [](ProcessedNode* p_node) {
1713         const auto& self = p_node->Input(0).toTensor();
1714         const auto start_dim = p_node->Input(1).toInt();
1715         const auto end_dim = p_node->Input(2).toInt();
1716 
1717         if (p_node->Output(0).isNone()) {
1718           p_node->Output(0) = create_empty_from(self);
1719         }
1720         auto& out = p_node->Output(0).toTensor();
1721         at::native::flatten_copy_out(out, self, start_dim, end_dim);
1722       };
1723     });
1724 
__anon11f46a8b4a02(Node* n) 1725 REGISTER_OPERATOR_FUNCTOR(aten::sum, aten_sum, [](Node* n) -> SROperator {
1726   if (n->inputs().size() != 2 && n->inputs().size() != 4) {
1727     return nullptr;
1728   }
1729   if (n->matches(torch::schema(
1730           "aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor"))) {
1731     return [](ProcessedNode* p_node) {
1732       const at::Tensor& self = p_node->Input(0).toTensor();
1733       auto dtype = p_node->Input(1).toOptional<at::ScalarType>();
1734       std::vector<int64_t> dim = {};
1735       bool keepdim = false;
1736       if (p_node->Output(0).isNone()) {
1737         p_node->Output(0) = at::cpu::sum(self, dim, keepdim, dtype);
1738       } else {
1739         auto& output = p_node->Output(0).toTensor();
1740         fastResizeToZero(output);
1741         at::cpu::sum_out(output, self, dim, keepdim, dtype);
1742       }
1743     };
1744   }
1745   if (n->matches(torch::schema(
1746           "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"))) {
1747     return [](ProcessedNode* p_node) {
1748       const at::Tensor& self = p_node->Input(0).toTensor();
1749       auto dim = p_node->Input(1).toDimVector();
1750       auto keepdim = p_node->Input(2).toBool();
1751       auto dtype = p_node->Input(3).toOptional<at::ScalarType>();
1752       if (p_node->Output(0).isNone()) {
1753         p_node->Output(0) = at::cpu::sum(self, dim, keepdim, dtype);
1754       } else {
1755         auto& output = p_node->Output(0).toTensor();
1756         fastResizeToZero(output);
1757         at::cpu::sum_out(output, self, dim, keepdim, dtype);
1758       }
1759     };
1760   }
1761   LogAndDumpSchema(n);
1762   return nullptr;
1763 });
1764 
__anon11f46a8b4d02(Node* n) 1765 REGISTER_OPERATOR_FUNCTOR(aten::mean, aten_mean, [](Node* n) -> SROperator {
1766   if (n->matches(torch::schema(
1767           "aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"))) {
1768     return [](ProcessedNode* p_node) {
1769       const auto& self = p_node->Input(0).toTensor();
1770       const auto dim = p_node->Input(1).toDimVector();
1771       const bool keepdim = p_node->Input(2).toBool();
1772       const auto dtype = p_node->Input(3).toOptional<at::ScalarType>();
1773       if (p_node->Output(0).isNone()) {
1774         p_node->Output(0) = create_empty_from(
1775             self, dtype.value_or(self.dtype().toScalarType()));
1776       }
1777       auto& output = p_node->Output(0).toTensor();
1778       fastResizeToZero(output);
1779       at::cpu::mean_out(output, self, dim, keepdim, dtype);
1780     };
1781   }
1782 
1783   if (n->matches(torch::schema(
1784           "aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor"))) {
1785     return [](ProcessedNode* p_node) {
1786       const auto& self = p_node->Input(0).toTensor();
1787       const auto dtype = p_node->Input(1).toOptional<at::ScalarType>();
1788       if (p_node->Output(0).isNone()) {
1789         p_node->Output(0) = create_empty_from(
1790             self, dtype.value_or(self.dtype().toScalarType()));
1791       }
1792       auto& output = p_node->Output(0).toTensor();
1793       fastResizeToZero(output);
1794       at::cpu::mean_out(output, self, /*dim=*/{}, /*keepdim=*/false, dtype);
1795     };
1796   }
1797 
1798   LogAndDumpSchema(n);
1799   return nullptr;
1800 });
1801 
__anon11f46a8b5002(Node* n) 1802 REGISTER_OPERATOR_FUNCTOR(aten::repeat, aten_repeat, [](Node* n) -> SROperator {
1803   if (!n->matches(torch::schema(
1804           "aten::repeat(Tensor self, int[] repeats) -> Tensor"))) {
1805     LogAndDumpSchema(n);
1806     return nullptr;
1807   }
1808   return [](ProcessedNode* p_node) {
1809     const auto& self = p_node->Input(0).toTensor();
1810     const auto repeats = p_node->Input(1).toDimVector();
1811 
1812     if (p_node->Output(0).isNone()) {
1813       p_node->Output(0) = at::native::repeat(self, repeats);
1814       return;
1815     }
1816     at::Tensor& output = p_node->Output(0).toTensor();
1817     at::native::repeat_out(output, self, repeats);
1818   };
1819 });
1820 
__anon11f46a8b5202(Node* n) 1821 REGISTER_OPERATOR_FUNCTOR(aten::max, aten_max, [](Node* n) -> SROperator {
1822   if (n->matches(torch::schema(
1823           "aten::max.other(Tensor self, Tensor other) -> Tensor"))) {
1824     return [](ProcessedNode* p_node) {
1825       const auto& self = p_node->Input(0).toTensor();
1826       const auto& other = p_node->Input(1).toTensor();
1827       if (p_node->Output(0).isNone()) {
1828         p_node->Output(0) = at::native::max(self, other);
1829         return;
1830       }
1831       auto& out = p_node->Output(0).toTensor();
1832       fastResizeToZero(out);
1833       at::native::max_out(self, other, out);
1834     };
1835   }
1836 
1837   if (n->matches(torch::schema(
1838           "aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)"))) {
1839     return [](ProcessedNode* p_node) {
1840       const auto& self = p_node->Input(0).toTensor();
1841       auto dim = p_node->Input(1).toInt();
1842       const auto keepdim = p_node->Input(2).toBool();
1843 
1844       if (p_node->Output(0).isNone()) {
1845         p_node->Output(0) = create_empty_from(self);
1846       }
1847 
1848       if (p_node->Output(1).isNone()) {
1849         p_node->Output(1) = create_empty_from(self, at::kLong);
1850       }
1851 
1852       auto& values = p_node->Output(0).toTensor();
1853       auto& indices = p_node->Output(1).toTensor();
1854       fastResizeToZero(values);
1855       fastResizeToZero(indices);
1856       at::cpu::max_out(values, indices, self, dim, keepdim);
1857     };
1858   }
1859 
1860   if (n->matches(torch::schema("aten::max(Tensor self) -> Tensor"))) {
1861     return [](ProcessedNode* p_node) {
1862       const auto& self = p_node->Input(0).toTensor();
1863       if (p_node->Output(0).isNone()) {
1864         p_node->Output(0) = create_empty_from(self);
1865       }
1866       auto& value = p_node->Output(0).toTensor();
1867       fastResizeToZero(value);
1868       at::cpu::amax_out(value, self);
1869     };
1870   }
1871 
1872   LogAndDumpSchema(n);
1873   return nullptr;
1874 });
1875 
__anon11f46a8b5602(Node* n) 1876 REGISTER_OPERATOR_FUNCTOR(aten::sign, aten_sign, [](Node* n) -> SROperator {
1877   if (!n->matches(torch::schema("aten::sign.Tensor(Tensor input) -> Tensor"))) {
1878     LogAndDumpSchema(n);
1879     return nullptr;
1880   }
1881   return [](ProcessedNode* p_node) {
1882     const auto& in0_t = p_node->Input(0).toTensor();
1883     if (p_node->Output(0).isNone()) {
1884       p_node->Output(0) = at::cpu::sign(in0_t);
1885       return;
1886     }
1887     auto& out_t = p_node->Output(0).toTensor();
1888     fastResizeToZero(out_t);
1889     at::cpu::sign_out(out_t, in0_t);
1890   };
1891 });
1892 
__anon11f46a8b5802(Node* n) 1893 REGISTER_OPERATOR_FUNCTOR(aten::div, aten_div, [](Node* n) -> SROperator {
1894   if (!n->matches(torch::schema(
1895           "aten::div.Tensor(Tensor self, Tensor other) -> Tensor")) &&
1896       !n->matches(torch::schema(
1897           "aten::div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor")) &&
1898       !n->matches(torch::schema(
1899           "aten::div.Scalar(Tensor self, Scalar other) -> Tensor")) &&
1900       !n->matches(torch::schema(
1901           "aten::div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor"))) {
1902     LogAndDumpSchema(n);
1903     return nullptr;
1904   }
1905 
1906   return [te = createDiv()](ProcessedNode* p_node) {
1907     const auto& in0_t = p_node->Input(0).toTensor();
1908     std::optional<c10::string_view> rounding_mode = std::nullopt;
1909     if (p_node->num_inputs() > 2) {
1910       rounding_mode = p_node->Input(2).toOptional<c10::string_view>();
1911     }
1912     const auto& in1_t = p_node->Input(1).isTensor()
1913         ? p_node->Input(1).toTensor()
1914         : at::native::wrapped_scalar_tensor(p_node->Input(1).toScalar());
1915 
1916     if (p_node->Output(0).isNone()) {
1917       p_node->Output(0) = create_empty_from(in0_t);
1918     }
1919     auto& out_t = p_node->Output(0).toTensor();
1920 
1921     if (in0_t.sizes() == in1_t.sizes() &&
1922         in0_t.scalar_type() == in1_t.scalar_type() &&
1923         in0_t.strides() == in1_t.strides() && in0_t.is_contiguous() &&
1924         in0_t.scalar_type() == at::kFloat) {
1925       int64_t dim = in0_t.numel();
1926       int i_rounding_mode = 0;
1927       if (rounding_mode && !rounding_mode.value().empty()) {
1928         const char peek_rounding_mode = rounding_mode.value().at(0);
1929         if (peek_rounding_mode == 't') {
1930           // trunc after div
1931           i_rounding_mode = 1;
1932         } else if (peek_rounding_mode == 'f') {
1933           // floor after div
1934           i_rounding_mode = 2;
1935         }
1936       }
1937       at::native::resize_(out_t, in0_t.sizes());
1938       te->call(
1939           {out_t.data_ptr(),
1940            in0_t.data_ptr(),
1941            in1_t.data_ptr(),
1942            &i_rounding_mode,
1943            &dim});
1944     } else {
1945       fastResizeToZero(out_t);
1946       at::cpu::div_out(out_t, in0_t, in1_t, rounding_mode);
1947     }
1948   };
1949 });
1950 
__anon11f46a8b5a02(Node* n) 1951 REGISTER_OPERATOR_FUNCTOR(aten::log, aten_log, [](Node* n) -> SROperator {
1952   if (!n->matches(torch::schema("aten::log.Tensor(Tensor input) -> Tensor"))) {
1953     LogAndDumpSchema(n);
1954     return nullptr;
1955   }
1956   return [](ProcessedNode* p_node) {
1957     const auto& in0_t = p_node->Input(0).toTensor();
1958     if (p_node->Output(0).isNone()) {
1959       p_node->Output(0) = at::cpu::log(in0_t);
1960       return;
1961     }
1962     auto& out_t = p_node->Output(0).toTensor();
1963     fastResizeToZero(out_t);
1964     at::cpu::log_out(out_t, in0_t);
1965   };
1966 });
1967 
__anon11f46a8b5c02(Node* n) 1968 REGISTER_OPERATOR_FUNCTOR(aten::sub, aten_sub, [](Node* n) -> SROperator {
1969   if (n->matches(torch::schema(
1970           "aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"))) {
1971     return [](ProcessedNode* p_node) {
1972       const auto& in0_t = p_node->Input(0).toTensor();
1973       const auto& in1_t = p_node->Input(1).toTensor();
1974       const auto alpha = p_node->Input(2).toScalar();
1975       if (p_node->Output(0).isNone()) {
1976         p_node->Output(0) = at::cpu::sub(in0_t, in1_t, alpha);
1977         return;
1978       }
1979       auto& out_t = p_node->Output(0).toTensor();
1980       fastResizeToZero(out_t);
1981       at::cpu::sub_out(out_t, in0_t, in1_t, alpha);
1982     };
1983   }
1984   if (n->matches(torch::schema(
1985           "aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor"))) {
1986     return [](ProcessedNode* p_node) {
1987       const auto& in0_t = p_node->Input(0).toTensor();
1988       const auto& in1_t =
1989           at::native::wrapped_scalar_tensor(p_node->Input(1).toScalar());
1990       const auto alpha = p_node->Input(2).toScalar();
1991       if (p_node->Output(0).isNone()) {
1992         p_node->Output(0) = at::cpu::sub(in0_t, in1_t, alpha);
1993         return;
1994       }
1995       auto& out_t = p_node->Output(0).toTensor();
1996       fastResizeToZero(out_t);
1997       at::cpu::sub_out(out_t, in0_t, in1_t, alpha);
1998     };
1999   }
2000   LogAndDumpSchema(n);
2001   return nullptr;
2002 });
2003 
2004 // TODO: support clamp_min.Tensor(Tensor self, Tensor min) -> Tensor
2005 REGISTER_OPERATOR_FUNCTOR(
2006     aten::clamp_min,
2007     aten_clamp_min,
__anon11f46a8b5f02(Node* n) 2008     [](Node* n) -> SROperator {
2009       if (!n->matches(torch::schema(
2010               "aten::clamp_min(Tensor self, Scalar min) -> Tensor"))) {
2011         LogAndDumpSchema(n);
2012         return nullptr;
2013       }
2014       return [](ProcessedNode* p_node) {
2015         const auto& in0_t = p_node->Input(0).toTensor();
2016         const auto in1_s = p_node->Input(1).toScalar();
2017         if (p_node->Output(0).isNone()) {
2018           p_node->Output(0) = at::cpu::clamp_min(in0_t, in1_s);
2019           return;
2020         }
2021         auto& out_t = p_node->Output(0).toTensor();
2022         fastResizeToZero(out_t);
2023         at::cpu::clamp_min_out(out_t, in0_t, in1_s);
2024       };
2025     });
2026 
__anon11f46a8b6102(Node* n) 2027 REGISTER_OPERATOR_FUNCTOR(aten::argmin, aten_argmin, [](Node* n) -> SROperator {
2028   if (!n->matches(torch::schema(
2029           "aten::argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"))) {
2030     LogAndDumpSchema(n);
2031     return nullptr;
2032   }
2033   return [](ProcessedNode* p_node) {
2034     const auto& in0_t = p_node->Input(0).toTensor();
2035     const auto dim = p_node->Input(1).toOptional<int64_t>();
2036     const auto keepdim = p_node->Input(2).toBool();
2037     if (p_node->Output(0).isNone()) {
2038       p_node->Output(0) = at::cpu::argmin(in0_t, dim, keepdim);
2039       return;
2040     }
2041     auto& out_t = p_node->Output(0).toTensor();
2042     fastResizeToZero(out_t);
2043     if (in0_t.is_contiguous() && dim.has_value()) {
2044       at::native::c2_argmin_out(out_t, in0_t, dim.value(), keepdim);
2045       return;
2046     }
2047     at::cpu::argmin_out(out_t, in0_t, dim, keepdim);
2048   };
2049 });
2050 
__anon11f46a8b6302(Node* n) 2051 REGISTER_OPERATOR_FUNCTOR(aten::softmax, aten_softmax, [](Node* n) -> SROperator {
2052   if (!n->matches(torch::schema(
2053           "aten::softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"))) {
2054     LogAndDumpSchema(n);
2055     return nullptr;
2056   }
2057   return [](ProcessedNode* p_node) {
2058     const auto& in_t = p_node->Input(0).toTensor();
2059     const auto& dim = p_node->Input(1).toInt();
2060     const auto& dtype = p_node->Input(2).toOptional<c10::ScalarType>();
2061     if (p_node->Output(0).isNone()) {
2062       p_node->Output(0) = at::native::softmax(in_t, dim, dtype);
2063       return;
2064     }
2065     auto& out_t = p_node->Output(0).toTensor();
2066     fastResizeToZero(out_t);
2067     auto half_to_float = in_t.scalar_type() == at::ScalarType::Half &&
2068         dtype == at::ScalarType::Float;
2069     at::cpu::_softmax_out(out_t, in_t, dim, half_to_float);
2070   };
2071 });
2072 
2073 namespace {
2074 
borrow_from_optional_tensor_ivalue(const IValue & iv)2075 c10::MaybeOwned<at::Tensor> borrow_from_optional_tensor_ivalue(
2076     const IValue& iv) {
2077   if (iv.isNone()) {
2078     return c10::MaybeOwned<at::Tensor>::owned(std::in_place);
2079   }
2080   return c10::MaybeOwned<at::Tensor>::borrowed(iv.toTensor());
2081 }
2082 
2083 } // namespace
2084 
__anon11f46a8b6602(Node* n) 2085 REGISTER_OPERATOR_FUNCTOR(aten::layer_norm, aten_layer_norm, [](Node* n) -> SROperator {
2086   if (!sr_schema_check(
2087           n,
2088           "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor")) {
2089     return nullptr;
2090   }
2091   return [](ProcessedNode* p_node) {
2092     // ignore Input(5): `bool cudnn_enable=True`
2093     const auto& input = p_node->Input(0).toTensor();
2094     const auto normalized_shape = p_node->Input(1).toDimVector();
2095     float eps = p_node->Input(4).toDouble();
2096 
2097     c10::MaybeOwned<at::Tensor> weight_maybe_owned =
2098         borrow_from_optional_tensor_ivalue(p_node->Input(2));
2099     const at::Tensor& weight = *weight_maybe_owned;
2100     c10::MaybeOwned<at::Tensor> bias_maybe_owned =
2101         borrow_from_optional_tensor_ivalue(p_node->Input(3));
2102     const at::Tensor& bias = *bias_maybe_owned;
2103 
2104     auto M_N = at::native::_check_layer_norm_inputs(
2105         input, normalized_shape, weight, bias);
2106     auto M = M_N.first;
2107     auto N = M_N.second;
2108     auto X = input.expect_contiguous();
2109     auto gamma = weight.expect_contiguous();
2110     auto beta = bias.expect_contiguous();
2111 
2112     if (p_node->Output(0).isNone()) {
2113       p_node->Output(0) = at::native::empty_like(
2114           *X,
2115           std::nullopt /* dtype */,
2116           std::nullopt /* layout */,
2117           std::nullopt /* device */,
2118           std::nullopt /* pin_memory */,
2119           at::MemoryFormat::Contiguous);
2120     } else {
2121       at::native::resize_(
2122           p_node->Output(0).toTensor(), X->sizes(), std::nullopt);
2123     }
2124     at::Tensor& output = p_node->Output(0).toTensor();
2125     at::native::layer_norm_cpu_out(output, *X, *gamma, *beta, eps, M, N);
2126   };
2127 });
2128 
__anon11f46a8b6802(Node* n) 2129 REGISTER_OPERATOR_FUNCTOR(aten::norm, aten_norm, [](Node* n) -> SROperator {
2130   if (n->matches(torch::schema(
2131           "aten::norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor"))) {
2132     return [](ProcessedNode* p_node) {
2133       const auto& in0_t = p_node->Input(0).toTensor();
2134       if (p_node->Output(0).isNone()) {
2135         p_node->Output(0) = create_empty_from(in0_t);
2136       }
2137       auto& out_t = p_node->Output(0).toTensor();
2138       fastResizeToZero(out_t);
2139       const auto in1_s = p_node->Input(1).toOptional<at::Scalar>();
2140       at::cpu::norm_outf(
2141           in0_t,
2142           in1_s,
2143           c10::IntArrayRef{},
2144           false,
2145           p_node->Input(2).toScalarType(),
2146           out_t);
2147     };
2148   }
2149   if (n->matches(torch::schema(
2150           "aten::norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor"))) {
2151     return [](ProcessedNode* p_node) {
2152       const auto& in0_t = p_node->Input(0).toTensor();
2153 
2154       if (p_node->Output(0).isNone()) {
2155         p_node->Output(0) = create_empty_from(in0_t);
2156       }
2157       auto& out_t = p_node->Output(0).toTensor();
2158       fastResizeToZero(out_t);
2159 
2160       const auto in1_s = p_node->Input(1).toOptional<at::Scalar>();
2161       at::cpu::norm_outf(
2162           in0_t,
2163           in1_s,
2164           p_node->Input(2).toDimVector(), // dim
2165           p_node->Input(3).toBool(), // keepdim
2166           p_node->Input(4).toScalarType(), // dtype
2167           out_t);
2168     };
2169   }
2170   if (n->matches(torch::schema(
2171           "aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor"))) {
2172     return [](ProcessedNode* p_node) {
2173       const auto& in0_t = p_node->Input(0).toTensor();
2174 
2175       if (p_node->Output(0).isNone()) {
2176         p_node->Output(0) = create_empty_from(in0_t);
2177       }
2178       auto& out_t = p_node->Output(0).toTensor();
2179       fastResizeToZero(out_t);
2180 
2181       const auto in1_s = p_node->Input(1).toOptional<at::Scalar>();
2182       at::cpu::norm_outf(
2183           in0_t,
2184           in1_s,
2185           p_node->Input(2).toDimVector(), // dim
2186           p_node->Input(3).toBool(), // keepdim
2187           out_t);
2188     };
2189   }
2190   LogAndDumpSchema(n);
2191   return nullptr;
2192 });
2193 
__anon11f46a8b6c02(Node* n) 2194 REGISTER_OPERATOR_FUNCTOR(aten::matmul, aten_matmul, [](Node* n) -> SROperator {
2195   if (!n->matches(
2196           torch::schema("aten::matmul(Tensor self, Tensor other) -> Tensor"))) {
2197     LogAndDumpSchema(n);
2198     return nullptr;
2199   }
2200   return [](ProcessedNode* p_node) {
2201     const auto& in0_t = p_node->Input(0).toTensor();
2202     const auto& in1_t = p_node->Input(1).toTensor();
2203 
2204     if (p_node->Output(0).isNone()) {
2205       p_node->Output(0) = at::native::matmul(in0_t, in1_t);
2206       return;
2207     }
2208     auto& out_t = p_node->Output(0).toTensor();
2209     fastResizeToZero(out_t);
2210     at::native::matmul_out(in0_t, in1_t, out_t);
2211   };
2212 });
2213 
__anon11f46a8b6e02(Node* n) 2214 REGISTER_OPERATOR_FUNCTOR(quantized::linear, quantized_linear, [](Node* n) -> SROperator {
2215   if (!n->matches(torch::schema(
2216           "quantized::linear(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y"))) {
2217     LogAndDumpSchema(n);
2218     return nullptr;
2219   }
2220   const auto w = toIValue(n->inputs()[1]);
2221   c10::intrusive_ptr<LinearPackedParamsBase> packed_weight;
2222   if (w) {
2223     packed_weight = w->toCustomClass<LinearPackedParamsBase>();
2224   }
2225   return [packed_weight](ProcessedNode* p_node) {
2226     const auto& input = p_node->Input(0).toTensor();
2227     const auto output_scale = p_node->Input(2).toDouble();
2228     const auto output_zero_point = p_node->Input(3).toInt();
2229 
2230     if (p_node->Output(0).isNone()) {
2231       p_node->Output(0) = at::native::empty_affine_quantized(
2232           {0},
2233           c10::kQUInt8,
2234           std::nullopt,
2235           c10::kCPU,
2236           false,
2237           output_scale,
2238           output_zero_point,
2239           std::nullopt);
2240     }
2241     auto& out_t = p_node->Output(0).toTensor();
2242     fastResizeToZero(out_t);
2243 
2244     if (packed_weight) {
2245       packed_weight->apply_out(input, output_scale, output_zero_point, out_t);
2246     } else {
2247       // Weights could be quantized on the fly
2248       auto packed_weight_tmp =
2249           p_node->Input(1).toCustomClass<LinearPackedParamsBase>();
2250       packed_weight_tmp->apply_out(
2251           input, output_scale, output_zero_point, out_t);
2252     }
2253   };
2254 });
2255 
2256 REGISTER_OPERATOR_FUNCTOR(
2257     fb::quantized_linear,
2258     fb_quantized_linear,
__anon11f46a8b7002(Node* n) 2259     [](Node* n) -> SROperator {
2260       if (!n->matches(torch::schema(
2261               "fb::quantized_linear(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase w_prepack, Tensor Y_scale_i, Tensor Y_zero_point_i) -> Tensor"))) {
2262         LogAndDumpSchema(n);
2263         return nullptr;
2264       }
2265       const auto w = toIValue(n->inputs()[1]);
2266       c10::intrusive_ptr<LinearPackedParamsBase> packed_weight;
2267       if (w) {
2268         packed_weight = w->toCustomClass<LinearPackedParamsBase>();
2269       }
2270       return [packed_weight](ProcessedNode* p_node) {
2271         const auto& input = p_node->Input(0).toTensor();
2272         const auto output_scale = p_node->Input(2).toTensor().item().toFloat();
2273         const auto output_zero_point =
2274             p_node->Input(3).toTensor().item().toLong();
2275 
2276         if (p_node->Output(0).isNone()) {
2277           p_node->Output(0) = at::native::empty_affine_quantized(
2278               {0},
2279               c10::kQUInt8,
2280               std::nullopt,
2281               c10::kCPU,
2282               false,
2283               output_scale,
2284               output_zero_point,
2285               std::nullopt);
2286         }
2287         auto& out_t = p_node->Output(0).toTensor();
2288         fastResizeToZero(out_t);
2289 
2290         if (packed_weight) {
2291           packed_weight->apply_out(
2292               input, output_scale, output_zero_point, out_t);
2293         } else {
2294           // Weights could be quantized on the fly
2295           auto packed_weight_tmp =
2296               p_node->Input(1).toCustomClass<LinearPackedParamsBase>();
2297           packed_weight_tmp->apply_out(
2298               input, output_scale, output_zero_point, out_t);
2299         }
2300       };
2301     });
2302 
2303 namespace {
2304 
2305 template <bool has_relu>
2306 void apply_dynamic_out_functor(
2307     c10::intrusive_ptr<LinearPackedParamsBase> packed_weight,
2308     const at::Tensor& input,
2309     at::Tensor& out,
2310     bool reduce_range);
2311 
2312 template <>
apply_dynamic_out_functor(c10::intrusive_ptr<LinearPackedParamsBase> packed_weight,const at::Tensor & input,at::Tensor & out,bool reduce_range)2313 void apply_dynamic_out_functor<false>(
2314     c10::intrusive_ptr<LinearPackedParamsBase> packed_weight,
2315     const at::Tensor& input,
2316     at::Tensor& out,
2317     bool reduce_range) {
2318   packed_weight->apply_dynamic_out(input, out, reduce_range);
2319 }
2320 
2321 template <>
apply_dynamic_out_functor(c10::intrusive_ptr<LinearPackedParamsBase> packed_weight,const at::Tensor & input,at::Tensor & out,bool reduce_range)2322 void apply_dynamic_out_functor<true>(
2323     c10::intrusive_ptr<LinearPackedParamsBase> packed_weight,
2324     const at::Tensor& input,
2325     at::Tensor& out,
2326     bool reduce_range) {
2327   // The implementation of PackedLinearWeightFp16::apply_dynamic_impl does not
2328   // handle relu. Currently, it ignores the `ReluFused` template parameter.
2329   // So, we explicitly do the relu here.
2330   packed_weight->apply_dynamic_out(input, out, reduce_range);
2331   out.relu_();
2332 }
2333 
2334 template <bool has_relu>
quantized_linear_dynamic_fp16_impl(Node * n)2335 SROperator quantized_linear_dynamic_fp16_impl(Node* n) {
2336   const auto weight = toIValue(n->inputs()[1]);
2337   c10::intrusive_ptr<LinearPackedParamsBase> packed_weight;
2338   if (weight) {
2339     packed_weight = weight->toCustomClass<LinearPackedParamsBase>();
2340   }
2341   if (packed_weight) {
2342     return [packed_weight](ProcessedNode* p_node) {
2343       const auto& input = p_node->Input(0).toTensor();
2344       if (p_node->Output(0).isNone()) {
2345         p_node->Output(0) = create_empty_from(input, at::kFloat);
2346       }
2347       auto& out_t = p_node->Output(0).toTensor();
2348       fastResizeToZero(out_t);
2349       apply_dynamic_out_functor<has_relu>(packed_weight, input, out_t, false);
2350     };
2351   } else {
2352     return [](ProcessedNode* p_node) {
2353       const auto& input = p_node->Input(0).toTensor();
2354       if (p_node->Output(0).isNone()) {
2355         p_node->Output(0) = create_empty_from(input, at::kFloat);
2356       }
2357       auto& out_t = p_node->Output(0).toTensor();
2358       fastResizeToZero(out_t);
2359       // Weights could be quantized on the fly
2360       auto packed_weight_tmp =
2361           p_node->Input(1).toCustomClass<LinearPackedParamsBase>();
2362       apply_dynamic_out_functor<has_relu>(
2363           packed_weight_tmp, input, out_t, false);
2364     };
2365   }
2366 }
2367 
2368 } // namespace
2369 
2370 REGISTER_OPERATOR_FUNCTOR(
2371     quantized::linear_dynamic_fp16,
2372     quantized_linear_dynamic_fp16,
__anon11f46a8b7502(Node* n) 2373     [](Node* n) -> SROperator {
2374       if (!n->matches(torch::schema(
2375               "quantized::linear_dynamic_fp16(Tensor X, __torch__.torch.classes."
2376               "quantized.LinearPackedParamsBase W_prepack) -> Tensor Y"))) {
2377         LogAndDumpSchema(n);
2378         return nullptr;
2379       }
2380       return quantized_linear_dynamic_fp16_impl<false>(n);
2381     });
2382 
2383 REGISTER_OPERATOR_FUNCTOR(
2384     quantized::linear_relu_dynamic_fp16,
2385     quantized_linear_relu_dynamic_fp16,
__anon11f46a8b7602(Node* n) 2386     [](Node* n) -> SROperator {
2387       if (!n->matches(torch::schema(
2388               "quantized::linear_relu_dynamic_fp16(Tensor X, __torch__.torch.classes."
2389               "quantized.LinearPackedParamsBase W_prepack) -> Tensor Y"))) {
2390         LogAndDumpSchema(n);
2391         return nullptr;
2392       }
2393       return quantized_linear_dynamic_fp16_impl<true>(n);
2394     });
2395 
2396 // device & pin_memory matter only when CUDA is enabled.
hasTensorWithOptions(const IValue & ivalue,std::optional<c10::ScalarType> dtype,std::optional<c10::Layout> layout)2397 static bool hasTensorWithOptions(
2398     const IValue& ivalue,
2399     std::optional<c10::ScalarType> dtype,
2400     std::optional<c10::Layout> layout) {
2401   if (!ivalue.isTensor()) {
2402     return false;
2403   }
2404   const auto& tensor = ivalue.toTensor();
2405   if (dtype == tensor.dtype().toScalarType() &&
2406       layout == tensor.options().layout_opt()) {
2407     return true;
2408   }
2409   VLOG(1) << "tensor exists, but tensor options were different";
2410   return false;
2411 }
2412 
hasTensorWithOptions(const IValue & ivalue,std::optional<c10::ScalarType> dtype,std::optional<c10::Layout> layout,std::optional<c10::MemoryFormat> memory_format)2413 static bool hasTensorWithOptions(
2414     const IValue& ivalue,
2415     std::optional<c10::ScalarType> dtype,
2416     std::optional<c10::Layout> layout,
2417     std::optional<c10::MemoryFormat> memory_format) {
2418   return hasTensorWithOptions(ivalue, dtype, layout) &&
2419       (memory_format == ivalue.toTensor().options().memory_format_opt());
2420 }
2421 
__anon11f46a8b7702(Node* n) 2422 REGISTER_OPERATOR_FUNCTOR(aten::full, aten_full, [](Node* n) -> SROperator {
2423   if (!n->matches(torch::schema(
2424           "aten::full(int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"))) {
2425     LogAndDumpSchema(n);
2426     return nullptr;
2427   }
2428   return [](ProcessedNode* p_node) {
2429     const auto& size = p_node->Input(0).toDimVector();
2430     const auto fill_value = p_node->Input(1).toScalar();
2431     const auto dtype = p_node->Input(2).toOptional<c10::ScalarType>();
2432     const auto layout = p_node->Input(3).toOptional<c10::Layout>();
2433     if (!hasTensorWithOptions(p_node->Output(0), dtype, layout)) {
2434       const auto device = p_node->Input(4).toOptional<c10::Device>();
2435       const auto pin_memory = p_node->Input(5).toOptional<bool>();
2436       p_node->Output(0) =
2437           at::native::full(size, fill_value, dtype, layout, device, pin_memory);
2438       return;
2439     }
2440     p_node->Output(0) =
2441         at::native::full_out(size, fill_value, p_node->Output(0).toTensor());
2442   };
2443 });
2444 
__anon11f46a8b7902(Node* n) 2445 REGISTER_OPERATOR_FUNCTOR(aten::full_like, aten_full_like, [](Node* n) -> SROperator {
2446   if (!n->matches(torch::schema(
2447           "aten::full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"))) {
2448     LogAndDumpSchema(n);
2449     return nullptr;
2450   }
2451   return [](ProcessedNode* p_node) {
2452     const auto in1_s = p_node->Input(1).toScalar();
2453     const auto& in0_t = p_node->Input(0).toTensor();
2454     const auto dtype = p_node->Input(2).toOptional<c10::ScalarType>();
2455     const auto layout = p_node->Input(3).toOptional<c10::Layout>();
2456     if (!hasTensorWithOptions(p_node->Output(0), dtype, layout)) {
2457       const auto device = p_node->Input(4).toOptional<c10::Device>();
2458       const auto pin_memory = p_node->Input(5).toOptional<bool>();
2459       const auto memory_format =
2460           p_node->Input(6).toOptional<c10::MemoryFormat>();
2461 
2462       p_node->Output(0) = at::native::empty_like(
2463           in0_t, dtype, layout, device, pin_memory, memory_format);
2464     }
2465     auto& out_t = p_node->Output(0).toTensor();
2466     at::native::resize_(out_t, in0_t.sizes(), std::nullopt);
2467     at::native::fill_out(out_t, in1_s);
2468   };
2469 });
2470 
__anon11f46a8b7b02(Node* n) 2471 REGISTER_OPERATOR_FUNCTOR(aten::ones, aten_ones, [](Node* n) -> SROperator {
2472   if (!n->matches(torch::schema(
2473           "aten::ones(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"))) {
2474     LogAndDumpSchema(n);
2475     return nullptr;
2476   }
2477   return [](ProcessedNode* p_node) {
2478     const auto size = p_node->Input(0).toDimVector();
2479     if (p_node->Output(0).isNone()) {
2480       const auto dtype = p_node->Input(1).toOptional<c10::ScalarType>();
2481       const auto layout = p_node->Input(2).toOptional<c10::Layout>();
2482       const auto device = p_node->Input(3).toOptional<c10::Device>();
2483       const auto pin_memory = p_node->Input(4).toOptional<bool>();
2484       p_node->Output(0) =
2485           at::native::ones(size, dtype, layout, device, pin_memory);
2486       return;
2487     }
2488     auto& out_t = p_node->Output(0).toTensor();
2489     fastResizeToZero(out_t);
2490     at::native::ones_out(size, out_t);
2491   };
2492 });
2493 
__anon11f46a8b7d02(Node* n) 2494 REGISTER_OPERATOR_FUNCTOR(aten::ones_like, aten_ones_like, [](Node* n) -> SROperator {
2495   if (!n->matches(torch::schema(
2496           "aten::ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"))) {
2497     LogAndDumpSchema(n);
2498     return nullptr;
2499   }
2500   return [](ProcessedNode* p_node) {
2501     const auto& self = p_node->Input(0).toTensor();
2502     const auto dtype = p_node->Input(1).toOptional<c10::ScalarType>();
2503     const auto layout = p_node->Input(2).toOptional<c10::Layout>();
2504     const auto device = p_node->Input(3).toOptional<c10::Device>();
2505     const auto pin_memory = p_node->Input(4).toOptional<bool>();
2506     const auto memory_format = p_node->Input(5).toOptional<c10::MemoryFormat>();
2507     if (!hasTensorWithOptions(
2508             p_node->Output(0), dtype, layout, memory_format)) {
2509       p_node->Output(0) = at::native::ones_like(
2510           self, dtype, layout, device, pin_memory, memory_format);
2511       return;
2512     }
2513     auto& out_t = p_node->Output(0).toTensor();
2514     fastResizeToZero(out_t);
2515     at::native::ones_out(self.sizes(), out_t);
2516   };
2517 });
2518 
__anon11f46a8b7f02(Node* n) 2519 REGISTER_OPERATOR_FUNCTOR(aten::zeros, aten_zeros, [](Node* n) -> SROperator {
2520   if (!n->matches(torch::schema(
2521           "aten::zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"))) {
2522     LogAndDumpSchema(n);
2523     return nullptr;
2524   }
2525   return [](ProcessedNode* p_node) {
2526     const auto size = p_node->Input(0).toDimVector();
2527     const auto dtype = p_node->Input(1).toOptional<c10::ScalarType>();
2528     const auto layout = p_node->Input(2).toOptional<c10::Layout>();
2529     if (!hasTensorWithOptions(p_node->Output(0), dtype, layout)) {
2530       p_node->Output(0) = at::compositeexplicitautograd::zeros(
2531           size, dtype, layout, std::nullopt, std::nullopt);
2532       return;
2533     }
2534     auto& out_t = p_node->Output(0).toTensor();
2535     fastResizeToZero(out_t);
2536     at::compositeexplicitautograd::zeros_out(out_t, size);
2537   };
2538 });
2539 
__anon11f46a8b8102(Node* n) 2540 REGISTER_OPERATOR_FUNCTOR(aten::linear, aten_linear, [](Node* n) -> SROperator {
2541   if (!n->matches(torch::schema(
2542           "aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor"))) {
2543     LogAndDumpSchema(n);
2544     return nullptr;
2545   }
2546 
2547   return [](ProcessedNode* p_node) {
2548     const auto& in0_t = p_node->Input(0).toTensor();
2549     const auto& in1_t = p_node->Input(1).toTensor();
2550     auto in2_t = p_node->Input(2).toOptional<at::Tensor>();
2551 
2552     if (p_node->Output(0).isNone()) {
2553       p_node->Output(0) = at::native::linear(in0_t, in1_t, in2_t);
2554       return;
2555     }
2556     auto& out_t = p_node->Output(0).toTensor();
2557     fastResizeToZero(out_t);
2558     at::native::linear_out(out_t, in0_t, in1_t, in2_t);
2559   };
2560 });
2561 
__anon11f46a8b8302(Node* n) 2562 REGISTER_OPERATOR_FUNCTOR(aten::linalg_norm, aten_linalg_norm, [](Node* n) -> SROperator {
2563   if (n->matches(torch::schema(
2564           "aten::linalg_norm(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"))) {
2565     return [](ProcessedNode* p_node) {
2566       const auto& input = p_node->Input(0).toTensor();
2567       const auto dim = p_node->Input(2).toDimVector();
2568       const auto keepdim = p_node->Input(3).toBool();
2569       const auto dtype = p_node->Input(4).toOptional<c10::ScalarType>();
2570       if (p_node->Output(0).isNone()) {
2571         p_node->Output(0) = at::native::linalg_norm(
2572             input,
2573             p_node->Input(1).toOptional<at::Scalar>(),
2574             dim,
2575             keepdim,
2576             dtype);
2577         return;
2578       }
2579       auto& output = p_node->Output(0).toTensor();
2580       fastResizeToZero(output);
2581       at::native::linalg_norm_out(
2582           input,
2583           p_node->Input(1).toOptional<at::Scalar>(),
2584           dim,
2585           keepdim,
2586           dtype,
2587           output);
2588     };
2589   }
2590   if (n->matches(torch::schema(
2591           "aten::linalg_norm.ord_str(Tensor self, str ord, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"))) {
2592     return [](ProcessedNode* p_node) {
2593       const auto& input = p_node->Input(0).toTensor();
2594       const auto dim = p_node->Input(2).toDimVector();
2595       const auto keepdim = p_node->Input(3).toBool();
2596       const auto dtype = p_node->Input(4).toOptional<c10::ScalarType>();
2597       if (p_node->Output(0).isNone()) {
2598         p_node->Output(0) = at::native::linalg_norm(
2599             input, p_node->Input(1).toStringView(), dim, keepdim, dtype);
2600         return;
2601       }
2602       auto& output = p_node->Output(0).toTensor();
2603       fastResizeToZero(output);
2604       at::native::linalg_norm_out(
2605           input, p_node->Input(1).toStringRef(), dim, keepdim, dtype, output);
2606     };
2607   }
2608   LogAndDumpSchema(n);
2609   return nullptr;
2610 });
2611 
__anon11f46a8b8602(Node* n) 2612 REGISTER_OPERATOR_FUNCTOR(aten::cat, aten_cat, [](Node* n) -> SROperator {
2613   if (!n->matches(
2614           torch::schema("aten::cat(Tensor[] tensors, int dim=0) -> Tensor"))) {
2615     LogAndDumpSchema(n);
2616     return nullptr;
2617   }
2618   return [](ProcessedNode* p_node) {
2619     const auto inputs = p_node->Input(0).toTensorVector();
2620     TORCH_CHECK(!inputs.empty(), "concat expects non-empty tensor list");
2621     const auto dim = p_node->Input(1).toInt();
2622     if (p_node->Output(0).isNone()) {
2623       p_node->Output(0) = at::cpu::cat(inputs, dim);
2624       return;
2625     }
2626     auto& output = p_node->Output(0).toTensor();
2627     fastResizeToZero(output);
2628     at::cpu::cat_outf(inputs, dim, output);
2629   };
2630 });
2631 
__anon11f46a8b8802(Node* n) 2632 REGISTER_OPERATOR_FUNCTOR(aten::cumsum, aten_cumsum, [](Node* n) -> SROperator {
2633   if (!n->matches(torch::schema(
2634           "aten::cumsum(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"))) {
2635     LogAndDumpSchema(n);
2636     return nullptr;
2637   }
2638   return [](ProcessedNode* p_node) {
2639     const auto& input = p_node->Input(0).toTensor();
2640     const auto dim = p_node->Input(1).toInt();
2641     const auto dtype = p_node->Input(2).toOptional<c10::ScalarType>();
2642     if (p_node->Output(0).isNone()) {
2643       p_node->Output(0) = at::cpu::cumsum(input, dim, dtype);
2644       return;
2645     }
2646     auto& output = p_node->Output(0).toTensor();
2647     fastResizeToZero(output);
2648     at::cpu::cumsum_out(output, input, dim, dtype);
2649   };
2650 });
2651 
2652 REGISTER_OPERATOR_FUNCTOR(
2653     aten::nonzero,
2654     aten_nonzero,
__anon11f46a8b8a02(Node* n) 2655     [](Node* n) -> SROperator {
2656       if (!n->matches(torch::schema("aten::nonzero(Tensor self) -> Tensor"))) {
2657         LogAndDumpSchema(n);
2658         return nullptr;
2659       }
2660       return [](ProcessedNode* p_node) {
2661         const auto& input = p_node->Input(0).toTensor();
2662         if (p_node->Output(0).isNone()) {
2663           p_node->Output(0) = at::native::nonzero_cpu(input);
2664           return;
2665         }
2666         auto& output = p_node->Output(0).toTensor();
2667         fastResizeToZero(output);
2668         at::native::nonzero_out_cpu(input, output);
2669       };
2670     });
2671 
2672 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
2673 REGISTER_OPERATOR_FUNCTOR(
2674     prim::VarConcat,
2675     prim_VarConcat,
__anon11f46a8b8c02(Node* n) 2676     [](Node* n) -> SROperator {
2677       if (!sr_schema_check_kind(n, prim::VarConcat)) {
2678         return nullptr;
2679       }
2680       return [](ProcessedNode* p_node) {
2681         const size_t num_inputs = p_node->num_inputs();
2682         std::vector<at::Tensor> inputs(num_inputs - 1);
2683         for (const auto i : c10::irange(num_inputs - 1)) {
2684           inputs[i] = p_node->Input(i).toTensor();
2685         }
2686         auto dim = p_node->Input(num_inputs - 1).toInt();
2687         if (p_node->Output(0).isNone()) {
2688           p_node->Output(0) = at::cpu::cat(inputs, dim);
2689           return;
2690         }
2691         auto& out_t = p_node->Output(0).toTensor();
2692         fastResizeToZero(out_t);
2693         at::cpu::cat_outf(inputs, dim, out_t);
2694       };
2695     });
2696 
2697 namespace {
2698 // This template and its specialization help us avoid compiler warnings
2699 // about taking the absolute value of an unsigned type in signed_log1p
2700 template <class T>
abs_if_signed(T val)2701 T abs_if_signed(T val) {
2702   return std::abs(val);
2703 }
2704 
2705 template <>
abs_if_signed(unsigned char val)2706 unsigned char abs_if_signed<unsigned char>(unsigned char val) {
2707   return val;
2708 }
2709 
2710 // Computes f(x) = sign(x) * ln(|1 + x|) for each x in the input tensor
signed_log1p_out(at::Tensor & out,const at::Tensor & input)2711 void signed_log1p_out(at::Tensor& out, const at::Tensor& input) {
2712   at::native::resize_(out, input.sizes(), std::nullopt);
2713 
2714   const auto input_contig = input.expect_contiguous();
2715   auto output_contig = out.expect_contiguous();
2716 
2717   AT_DISPATCH_ALL_TYPES(input.scalar_type(), "signed_log1p_kernel", [&]() {
2718     const auto input_data = input_contig->const_data_ptr<scalar_t>();
2719     auto output_data = output_contig->mutable_data_ptr<float>();
2720     const auto N = input.numel();
2721 
2722     for (const auto i : c10::irange(N)) {
2723       const int sign = input_data[i] < 0 ? -1 : 1;
2724       output_data[i] = std::log1p(abs_if_signed(input_data[i])) * sign;
2725     }
2726   });
2727 }
2728 
2729 } // namespace
2730 
2731 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
2732 REGISTER_OPERATOR_FUNCTOR(
2733     static_runtime::signed_log1p,
2734     static_runtime_signed_log1p,
__anon11f46a8b9002(Node* n) 2735     [](Node* n) -> SROperator {
2736       if (!n->matches(torch::schema(
2737               "static_runtime::signed_log1p(Tensor x) -> Tensor"))) {
2738         LogAndDumpSchema(n);
2739         return nullptr;
2740       }
2741       auto te = createSignedLog1p();
2742       return [te](ProcessedNode* p_node) {
2743         const auto& input = p_node->Input(0).toTensor();
2744         if (p_node->Output(0).isNone()) {
2745           p_node->Output(0) = create_empty_from(input);
2746         }
2747         auto& out = p_node->Output(0).toTensor();
2748         if (!te || !te->checkInput<float>(input)) {
2749           fastResizeToZero(out);
2750           signed_log1p_out(out, input);
2751           return;
2752         }
2753         at::native::resize_(out, input.sizes(), std::nullopt);
2754         int64_t nn = input.numel();
2755         te->call({out.data_ptr(), input.data_ptr(), &nn});
2756       };
2757     });
2758 
2759 REGISTER_OPERATOR_FUNCTOR(
2760     aten::remainder,
2761     aten_remainder,
__anon11f46a8b9202(Node* n) 2762     [](Node* n) -> SROperator {
2763       if (n->matches(torch::schema(
2764               "aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor"))) {
2765         return [](ProcessedNode* p_node) {
2766           const auto& self = p_node->Input(0).toTensor();
2767           if (p_node->Output(0).isNone()) {
2768             p_node->Output(0) =
2769                 at::cpu::remainder(self, p_node->Input(1).toTensor());
2770             return;
2771           }
2772           auto& out = p_node->Output(0).toTensor();
2773           fastResizeToZero(out);
2774           at::cpu::remainder_out(out, self, p_node->Input(1).toTensor());
2775         };
2776       }
2777       if (n->matches(torch::schema(
2778               "aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor"))) {
2779         return [](ProcessedNode* p_node) {
2780           const auto& self = p_node->Input(0).toTensor();
2781           if (p_node->Output(0).isNone()) {
2782             p_node->Output(0) =
2783                 at::native::remainder(self, p_node->Input(1).toScalar());
2784             return;
2785           }
2786           auto& out = p_node->Output(0).toTensor();
2787           fastResizeToZero(out);
2788           at::native::remainder_out(self, p_node->Input(1).toScalar(), out);
2789         };
2790       }
2791 
2792       // Unrecognized overload
2793       LogAndDumpSchema(n);
2794       return nullptr;
2795     });
2796 
__anon11f46a8b9502(Node* n) 2797 REGISTER_OPERATOR_FUNCTOR(aten::where, aten_where, [](Node* n) -> SROperator {
2798   if (n->matches(torch::schema(
2799           "aten::where.self(Tensor condition, Tensor self, Tensor other) -> Tensor"))) {
2800     return [](ProcessedNode* p_node) {
2801       const auto& cond = p_node->Input(0).toTensor();
2802       const auto& self = p_node->Input(1).toTensor();
2803       const auto& other = p_node->Input(2).toTensor();
2804 
2805       if (p_node->Output(0).isNone()) {
2806         p_node->Output(0) = create_empty_from(self);
2807       }
2808       auto& out = p_node->Output(0).toTensor();
2809       fastResizeToZero(out);
2810       at::native::where_self_out(cond, self, other, out);
2811     };
2812   }
2813 
2814   LogAndDumpSchema(n);
2815   return nullptr;
2816 });
2817 
2818 REGISTER_OPERATOR_FUNCTOR(
2819     prim::NumToTensor,
2820     prim_NumToTensor,
__anon11f46a8b9702(Node* n) 2821     [](Node* n) -> SROperator {
2822       if (n->matches(
2823               torch::schema("prim::NumToTensor.Scalar(Scalar s) -> Tensor")) ||
2824           n->matches(
2825               torch::schema("prim::NumToTensor.bool(bool a) -> Tensor"))) {
2826         return [](ProcessedNode* pnode) {
2827           const auto scalar = pnode->Input(0).toScalar();
2828           if (pnode->Output(0).isNone()) {
2829             pnode->Output(0) = at::scalar_to_tensor(scalar);
2830             return;
2831           }
2832           auto& out = pnode->Output(0).toTensor();
2833           at::detail::scalar_fill(out, scalar);
2834         };
2835       }
2836       LogAndDumpSchema(n);
2837       return nullptr;
2838     });
2839 
2840 REGISTER_OPERATOR_FUNCTOR(
2841     quantized::embedding_bag_byte_unpack,
2842     quantized_embedding_bag_byte_unpack,
__anon11f46a8b9902(Node* n) 2843     [](Node* n) -> SROperator {
2844       if (!sr_schema_check(
2845               n,
2846               "quantized::embedding_bag_byte_unpack(Tensor weight) -> Tensor")) {
2847         return nullptr;
2848       }
2849       return [](ProcessedNode* pnode) {
2850         auto& weight = pnode->Input(0).toTensor();
2851         if (pnode->Output(0).isNone()) {
2852           pnode->Output(0) = at::empty(
2853               {},
2854               weight.options().dtype(at::kFloat),
2855               weight.suggest_memory_format());
2856         }
2857         auto& out = pnode->Output(0).toTensor();
2858         at::native::qembeddingbag_byte_unpack_out(out, weight);
2859       };
2860     });
2861 
2862 } // namespace torch::jit
2863