1 #include <torch/csrc/jit/runtime/static/ops.h>
2
3 #include <ATen/CPUFunctions.h>
4 #include <ATen/InferSize.h>
5 #include <ATen/NativeFunctions.h>
6 #include <ATen/Parallel.h>
7 #include <ATen/ScalarOps.h>
8 #include <ATen/TensorUtils.h>
9 #include <ATen/cpu/vec/functional.h>
10 #include <ATen/cpu/vec/vec.h>
11 #include <ATen/native/Fill.h>
12 #include <ATen/native/IndexingUtils.h>
13 #include <ATen/native/NonSymbolicBC.h>
14 #include <ATen/native/Resize.h>
15 #include <ATen/native/SharedReduceOps.h>
16 #include <ATen/native/TensorAdvancedIndexing.h>
17 #include <ATen/native/TensorConversions.h>
18 #include <ATen/native/cpu/SerialStackImpl.h>
19 #include <ATen/native/layer_norm.h>
20 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
21 #include <ATen/native/quantized/cpu/qembeddingbag.h>
22 #include <ATen/native/quantized/cpu/qembeddingbag_prepack.h>
23 #include <ATen/quantized/QTensorImpl.h>
24 #include <ATen/quantized/Quantizer.h>
25 #include <c10/core/ScalarType.h>
26 #include <c10/core/WrapDimMinimal.h>
27 #include <c10/util/irange.h>
28 #include <torch/csrc/jit/ir/constants.h>
29 #include <torch/csrc/jit/ir/ir.h>
30 #include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
31 #include <torch/csrc/jit/runtime/static/impl.h>
32 #include <torch/csrc/jit/runtime/static/processed_node_wrapper.h>
33 #include <torch/csrc/jit/runtime/static/te_wrapper.h>
34 #include <torch/csrc/jit/runtime/vararg_functions.h>
35 #include <torch/csrc/jit/tensorexpr/ir.h>
36 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
37 #include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
38 #include <torch/csrc/jit/tensorexpr/loopnest.h>
39 #include <iterator>
40 #include <mutex>
41 #include <unordered_map>
42
43 #include <ATen/CompositeExplicitAutogradFunctions.h>
44
45 C10_DEFINE_bool(
46 static_runtime_enable_fast_math,
47 true,
48 "If on, static runtime may use use optimizations that cause accuracy loss "
49 "vs the jit interpreter");
50
51 namespace at::native {
52
repeat_out(at::Tensor & result,const Tensor & self,IntArrayRef repeats)53 static void repeat_out(
54 at::Tensor& result,
55 const Tensor& self,
56 IntArrayRef repeats) {
57 TORCH_CHECK(
58 repeats.size() >= static_cast<size_t>(self.dim()),
59 "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor");
60
61 // Add new leading dimensions to the tensor if the
62 // number of target dimensions is larger than the
63 // number of source dimensions.
64 int64_t num_new_dimensions = repeats.size() - self.dim();
65 DimVector padded_size(num_new_dimensions, 1);
66 padded_size.insert(
67 padded_size.end(), self.sizes().begin(), self.sizes().end());
68 DimVector target_size(repeats.size());
69 bool zero_tensor = false;
70 for (const auto idx : c10::irange(repeats.size())) {
71 if (repeats[idx] == 0) {
72 zero_tensor = true;
73 }
74 target_size[idx] = padded_size[idx] * repeats[idx];
75 }
76
77 // return an empty tensor if one of the repeat dimensions is zero
78 at::native::resize_(result, target_size, std::nullopt);
79 if (zero_tensor) {
80 return;
81 }
82
83 Tensor xtensor = at::compositeexplicitautograd::expand(self, padded_size);
84 Tensor urtensor = at::native::alias(result);
85 for (const auto i : c10::irange(xtensor.dim())) {
86 // can't unfold with step 0, so make sure step is at least 1
87 // (it doesn't matter what it is in that case, because the size is 0).
88 urtensor = urtensor.unfold(
89 i, xtensor.size(i), std::max<int64_t>(xtensor.size(i), 1));
90 }
91
92 at::native::copy_(urtensor, xtensor.expand_as(urtensor));
93 }
94
95 // copy version of view ops
reshape_copy_out(at::Tensor & out,const at::Tensor & self,const at::DimVector & proposed_shape,bool infer_size)96 at::Tensor& reshape_copy_out(
97 at::Tensor& out,
98 const at::Tensor& self,
99 const at::DimVector& proposed_shape,
100 bool infer_size) {
101 const auto& shape = infer_size
102 ? at::infer_size_dv(proposed_shape, self.numel())
103 : proposed_shape;
104 at::native::resize_(out, shape, std::nullopt);
105
106 auto self_contig = self.expect_contiguous();
107
108 size_t nbytes = self.nbytes();
109 if (nbytes == 0) {
110 return out;
111 }
112
113 const void* self_data = self_contig->const_data_ptr();
114 void* out_data = out.mutable_data_ptr();
115 memcpy(out_data, self_data, nbytes);
116
117 return out;
118 }
119
flatten_copy_out(at::Tensor & out,const at::Tensor & self,int64_t start_dim,int64_t end_dim)120 static at::Tensor& flatten_copy_out(
121 at::Tensor& out,
122 const at::Tensor& self,
123 int64_t start_dim,
124 int64_t end_dim) {
125 start_dim =
126 start_dim < 0 ? c10::maybe_wrap_dim(start_dim, self.dim()) : start_dim;
127 end_dim = end_dim < 0 ? c10::maybe_wrap_dim(end_dim, self.dim()) : end_dim;
128 TORCH_CHECK(
129 start_dim <= end_dim,
130 "flatten() has invalid args: start_dim cannot come after end_dim");
131
132 if (self.dim() == 0) {
133 return reshape_copy_out(out, self, at::DimVector{1}, false);
134 }
135
136 if (start_dim == end_dim) {
137 auto shape = at::DimVector{self.sizes()};
138 return reshape_copy_out(out, self, shape, false);
139 }
140
141 // We don't want to infer_size on the entire shape, because that can give us
142 // an extra degree of freedom we don't want; for example, consider shape [0,
143 // 1, 3, 0], with start_dim=1, end_dim=2. It's clear we want result shape [0,
144 // 3, 0] but passing [0, -1, 0] to infer_size means the -1 can take on any
145 // value and satisfy the constraints.
146 auto iter = self.sizes().data();
147 auto slice_numel = std::accumulate(
148 iter + start_dim,
149 iter + end_dim + 1,
150 static_cast<int64_t>(1),
151 // NOLINTNEXTLINE(modernize-use-transparent-functors)
152 std::multiplies<int64_t>());
153
154 at::DimVector shape;
155 shape.reserve(self.dim() - end_dim + start_dim);
156 for (const auto i : c10::irange(start_dim)) {
157 shape.push_back(self.sizes()[i]);
158 }
159 shape.push_back(slice_numel);
160 for (int64_t i = end_dim + 1; i < self.dim(); i++) {
161 shape.push_back(self.sizes()[i]);
162 }
163 return reshape_copy_out(out, self, shape, false);
164 }
165
166 namespace {
167
168 // This is annoying and sily, but it's solving a real problem: the
169 // _MSC_VER version causes an ICE on our old clang5 builds. The
170 // non-_MSC_VER version is a syntax error according to MSVC. Use the
171 // appropriate version depending on if we're MSVC or not.
172
173 #define TO_COPY_OUT_FAST_PATH_LOGIC(out, self, self_t) \
174 do { \
175 const auto N = self.numel(); \
176 const auto self_data = self.const_data_ptr<self_t>(); \
177 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \
178 kHalf, \
179 kBFloat16, \
180 kBool, \
181 out.scalar_type(), \
182 "to_copy_out_inner_loop", \
183 [&]() { \
184 const auto out_data = out.mutable_data_ptr<scalar_t>(); \
185 for (const auto idx : c10::irange(N)) { \
186 /* NOLINTNEXTLINE(bugprone-signed-char-misuse) */ \
187 out_data[idx] = static_cast<scalar_t>(self_data[idx]); \
188 } \
189 }); \
190 } while (0)
191
192 #ifdef _MSC_VER
193 template <typename T>
to_copy_out_fast_path(Tensor & out,const Tensor & self)194 void to_copy_out_fast_path(Tensor& out, const Tensor& self) {
195 TO_COPY_OUT_FAST_PATH_LOGIC(out, self, T);
196 }
197
198 #define TO_COPY_OUT_FAST_PATH_BODY(out, self) \
199 to_copy_out_fast_path<scalar_t>(out, self)
200 #else
201 #define TO_COPY_OUT_FAST_PATH_BODY(out, self) \
202 using self_t = scalar_t; \
203 TO_COPY_OUT_FAST_PATH_LOGIC(out, self, self_t)
204 #endif
205 } // namespace
206
to_copy_out(Tensor & out,const Tensor & self,bool non_blocking,bool copy_strides,std::optional<MemoryFormat> memory_format)207 at::Tensor& to_copy_out(
208 Tensor& out,
209 const Tensor& self,
210 bool non_blocking,
211 bool copy_strides,
212 std::optional<MemoryFormat> memory_format) {
213 if (copy_strides) {
214 at::native::resize_impl_cpu_(
215 out.unsafeGetTensorImpl(), self.sizes(), self.strides());
216 } else {
217 at::native::resize_(out, self.sizes(), std::nullopt);
218 }
219 auto is_unsupported_dtype = [](ScalarType t) {
220 #define TORCH_OPS_UNSUPPORTED_TYPE(_, type) \
221 case k##type: \
222 return true;
223 switch (t) {
224 AT_FORALL_QINT_TYPES(TORCH_OPS_UNSUPPORTED_TYPE)
225 AT_FORALL_COMPLEX_TYPES(TORCH_OPS_UNSUPPORTED_TYPE)
226 default:
227 return false;
228 }
229 #undef TORCH_OPS_UNSUPPORTED_TYPE
230 };
231 // Fast path: can we just copy the data ourselves? Avoids creating a
232 // TensorIterator in at::native::copy_, which is relatively
233 // expensive.
234 if (self.is_contiguous() && !non_blocking &&
235 // Did the user request us to make a copy that isn't contiguous?
236 (memory_format == std::nullopt ||
237 memory_format == c10::MemoryFormat::Preserve ||
238 memory_format == c10::MemoryFormat::Contiguous) &&
239 // CopyKernel.cpp handles this case specially, so let's not mess
240 // with it.
241 !self.is_neg() && !is_unsupported_dtype(self.dtype().toScalarType()) &&
242 !is_unsupported_dtype(out.dtype().toScalarType()) &&
243 !(
244 // FBGEMM optimization might kick in, don't interfere with
245 // that.
246 (self.dtype() == kFloat && out.dtype() == kHalf) ||
247 (self.dtype() == kHalf && out.dtype() == kFloat))) {
248 AT_DISPATCH_ALL_TYPES_AND3(
249 kHalf, kBFloat16, kBool, self.scalar_type(), "to_copy_out", [&]() {
250 TO_COPY_OUT_FAST_PATH_BODY(out, self);
251 });
252 return out;
253 }
254 at::native::copy_(out, self, non_blocking);
255 return out;
256 }
257
linear_out(Tensor & output,const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias_opt)258 static Tensor& linear_out(
259 Tensor& output,
260 const Tensor& input,
261 const Tensor& weight,
262 const std::optional<Tensor>& bias_opt) {
263 TORCH_CHECK(!input.is_mkldnn());
264
265 auto bias = bias_opt.has_value()
266 ? c10::MaybeOwned<Tensor>::borrowed(*bias_opt)
267 : c10::MaybeOwned<Tensor>::owned(std::in_place);
268
269 if (input.dim() == 2 && bias->defined()) {
270 // Fused op is marginally faster.
271 return at::cpu::addmm_out(output, *bias, input, weight.t());
272 }
273 at::native::matmul_out(input, weight.t(), output);
274 if (bias->defined()) {
275 at::cpu::add_(output, *bias);
276 }
277 return output;
278 }
279
c2_argmin_out(Tensor & output,const Tensor & input,const int64_t dim,const bool keepdim)280 static Tensor& c2_argmin_out(
281 Tensor& output,
282 const Tensor& input,
283 const int64_t dim,
284 const bool keepdim) {
285 const auto ndim = input.dim();
286 int64_t dim_ = maybe_wrap_dim(dim, ndim);
287 TORCH_CHECK(dim_ >= 0 && dim_ < ndim);
288
289 const auto in_dims = input.sizes();
290
291 c10::SmallVector<int64_t, 5> out_dims;
292 out_dims.reserve(ndim);
293 int prev_size = 1;
294 int next_size = 1;
295 for (int i = 0; i < dim_; ++i) {
296 out_dims.push_back(in_dims[i]);
297 prev_size *= in_dims[i];
298 }
299 if (keepdim) {
300 out_dims.push_back(1);
301 }
302 for (auto i = dim_ + 1; i < ndim; ++i) {
303 out_dims.push_back(in_dims[i]);
304 next_size *= in_dims[i];
305 }
306 at::native::resize_(output, out_dims, std::nullopt);
307
308 const auto n = in_dims[dim_];
309
310 if (next_size == 1) {
311 AT_DISPATCH_ALL_TYPES_AND2(
312 kHalf, kBFloat16, input.scalar_type(), "argmin_input", [&]() {
313 const auto in_ptr = input.const_data_ptr<scalar_t>();
314 const auto out_ptr = output.mutable_data_ptr<int64_t>();
315 // input is a [prev_size, n] tensor.
316 // output is a [prev_size,] tensor.
317 // Thus, access is contiguous/coalesced.
318 for (int i = 0; i < prev_size; ++i) {
319 auto v = std::min_element(
320 in_ptr + i * n,
321 in_ptr + (i + 1) * n,
322 [](scalar_t a, scalar_t b) {
323 // if a is nan, then a is *less* than b with LessOrNan
324 // semantics
325 if (at::_isnan(a)) {
326 return true;
327 }
328 // if a is not nan and b is nan, then a is not less than b
329 // with LessOrNan semantics otherwise, act normally. If `b` is
330 // NaN then a < b will always return false, so this is
331 // equivalent to the first snippet.
332 return a < b;
333 });
334 out_ptr[i] = std::distance(in_ptr + i * n, v);
335 }
336 });
337 } else {
338 AT_DISPATCH_ALL_TYPES_AND2(
339 kHalf, kBFloat16, input.scalar_type(), "argmin_input", [&]() {
340 const auto less_or_nan = native::detail::LessOrNan<scalar_t>{};
341
342 const auto in_ptr = input.const_data_ptr<scalar_t>();
343 const auto out_ptr = output.mutable_data_ptr<int64_t>();
344
345 std::memset(out_ptr, 0, prev_size * next_size * sizeof(int64_t));
346
347 for (int i = 0; i < prev_size; ++i) {
348 const scalar_t* cur_in_ptr = in_ptr + i * n * next_size + next_size;
349 for (int k = 1; k < n; ++k) {
350 for (int j = 0; j < next_size; ++j) {
351 int64_t* cur_out_ptr = out_ptr + i * next_size + j;
352 if (less_or_nan(
353 *cur_in_ptr,
354 in_ptr
355 [i * n * next_size + *cur_out_ptr * next_size + j],
356 *cur_out_ptr,
357 k)) {
358 *cur_out_ptr = k;
359 }
360 ++cur_in_ptr;
361 }
362 }
363 }
364 });
365 }
366 return output;
367 }
368
dequantize_copy_out(Tensor & out,const Tensor & self)369 static at::Tensor& dequantize_copy_out(Tensor& out, const Tensor& self) {
370 if (C10_UNLIKELY(!self.is_quantized())) {
371 // fallback to dequantize_cpu equivalent case: make sure out is at::kFloat
372 DCHECK(out.scalar_type() == kFloat);
373 return at::native::to_copy_out(out, self, false, false, std::nullopt);
374 }
375 return get_qtensorimpl(self)->quantizer()->dequantize_out(out, self);
376 }
377 } // namespace at::native
378
379 namespace torch::jit {
380
381 C10_DEFINE_REGISTRY(SROperatorRegistry, SROperatorFunctor);
382
opIsRegistered(const c10::Symbol & op_name)383 bool opIsRegistered(const c10::Symbol& op_name) {
384 const std::string name(op_name.toQualString());
385 return SROperatorRegistry()->Has(name);
386 }
387
disableUnsafeMathOp(const char * op_name)388 static bool disableUnsafeMathOp(const char* op_name) {
389 if (FLAGS_static_runtime_enable_fast_math) {
390 return false;
391 }
392 // This list contains ops that use caffe2 math library or use NNC that does
393 // not guarantee bit exactness vs the jit interpreter. Note aten::relu is not
394 // included even though it uses NNC because the results of relu should always
395 // match.
396 static const c10::FastSet<std::string> fast_ops{
397 "aten::add", "aten::tanh", "aten::sigmoid", "aten::logit"};
398 return fast_ops.count(op_name) > 0;
399 }
400
getOutOfPlaceOperation(Node * n)401 SROperator getOutOfPlaceOperation(Node* n) {
402 auto op_name = n->kind().toQualString();
403 if (SROperatorRegistry()->Has(op_name) && !disableUnsafeMathOp(op_name)) {
404 return SROperatorRegistry()->Create(op_name)->Generate(n);
405 }
406
407 return nullptr;
408 }
409
410 // Returns true if the node represents an op with variadic arguments.
hasVarArgs(Node * n)411 bool hasVarArgs(Node* n) {
412 if (n->kind() == prim::VarConcat || n->kind() == prim::VarStack) {
413 return true;
414 }
415 return false;
416 }
417
canReuseInputsOutputs(Node * n,const c10::FastMap<Node *,bool> & node_has_out_variant)418 bool canReuseInputsOutputs(
419 Node* n,
420 const c10::FastMap<Node*, bool>& node_has_out_variant) {
421 auto it = node_has_out_variant.find(n);
422 if (it != node_has_out_variant.end()) {
423 return it->second;
424 }
425 return getOutOfPlaceOperation(n) != nullptr;
426 }
427
428 // returns true if the producers of the inputs
429 // to this operations are out of place.
430 // This means the IValues will not change run to run
inputsCanRunOutOfPlace(Node * n,const c10::FastMap<Node *,bool> & node_has_out_variant)431 static bool inputsCanRunOutOfPlace(
432 Node* n,
433 const c10::FastMap<Node*, bool>& node_has_out_variant) {
434 for (auto* input : n->inputs()) {
435 if (!canReuseInputsOutputs(input->node(), node_has_out_variant)) {
436 return false;
437 }
438 }
439 return true;
440 }
441
isOptimizableContainerType(Node * n,const c10::FastMap<Node *,bool> & node_has_out_variant)442 bool isOptimizableContainerType(
443 Node* n,
444 const c10::FastMap<Node*, bool>& node_has_out_variant) {
445 const auto& type = n->output()->type();
446 bool is_supported_type = false;
447 if (type->kind() == TypeKind::ListType) {
448 const auto& list_type = type->expectRef<ListType>();
449 is_supported_type =
450 list_type.getElementType()->kind() == TypeKind::TensorType;
451 } else if (type->kind() == TypeKind::TupleType) {
452 const auto& tuple_type = type->expectRef<TupleType>();
453 auto types = tuple_type.containedTypes();
454 const auto& iter =
455 std::find_if(types.begin(), types.end(), [](const TypePtr& elem) {
456 return elem->kind() == TypeKind::TensorType;
457 });
458 is_supported_type = iter != types.end();
459 }
460 return is_supported_type && inputsCanRunOutOfPlace(n, node_has_out_variant);
461 }
462
listConstructSlowPath(const ListType & list_type,const size_t size,ProcessedNode * p_node)463 static inline void listConstructSlowPath(
464 const ListType& list_type,
465 const size_t size,
466 ProcessedNode* p_node) {
467 c10::List<IValue> vals(list_type.getElementType());
468 vals.reserve(size);
469 for (const auto i : c10::irange(size)) {
470 vals.push_back(p_node->Input(i));
471 }
472 p_node->Output(0) = vals;
473 }
474
sr_schema_check_kind(torch::jit::Node * node,c10::Symbol node_kind)475 bool sr_schema_check_kind(torch::jit::Node* node, c10::Symbol node_kind) {
476 auto is_match = node->kind() == node_kind;
477 if (!is_match) {
478 torch::jit::LogAndDumpSchema(node);
479 }
480 return is_match;
481 }
482
483 REGISTER_OPERATOR_FUNCTOR(
484 prim::ListConstruct,
485 prim_ListConstruct,
__anon11f46a8b0802(Node* n) 486 [](Node* n) -> SROperator {
487 if (!sr_schema_check_kind(n, prim::ListConstruct)) {
488 return nullptr;
489 }
490 const bool can_optimize =
491 isOptimizableContainerType(n, c10::FastMap<Node*, bool>());
492 const auto& type = n->output()->type()->expectRef<ListType>();
493 const size_t size = n->inputs().size();
494 if (!can_optimize) {
495 return [&type, size](ProcessedNode* p_node) {
496 DCHECK(p_node->num_inputs() == size);
497 listConstructSlowPath(type, size, p_node);
498 };
499 }
500 return [&type, size](ProcessedNode* p_node) {
501 DCHECK(p_node->num_inputs() == size);
502 const auto& out_l = p_node->Output(0);
503 if (!out_l.isNone()) {
504 return;
505 }
506 listConstructSlowPath(type, size, p_node);
507 };
508 });
509
tupleConstructSlowPath(const size_t size,ProcessedNode * p_node)510 static inline void tupleConstructSlowPath(
511 const size_t size,
512 ProcessedNode* p_node) {
513 // prepare inputs
514 switch (size) {
515 case 1:
516 p_node->Output(0) = c10::ivalue::Tuple::create(p_node->Input(0));
517 break;
518 case 2:
519 p_node->Output(0) =
520 c10::ivalue::Tuple::create(p_node->Input(0), p_node->Input(1));
521 break;
522 case 3:
523 p_node->Output(0) = c10::ivalue::Tuple::create(
524 p_node->Input(0), p_node->Input(1), p_node->Input(2));
525 break;
526 default: {
527 std::vector<IValue> vals;
528 vals.reserve(size);
529 for (const auto i : c10::irange(size)) {
530 vals.push_back(p_node->Input(i));
531 }
532 p_node->Output(0) = c10::ivalue::Tuple::create(std::move(vals));
533 break;
534 }
535 }
536 }
537
538 REGISTER_OPERATOR_FUNCTOR(
539 prim::TupleConstruct,
540 prim_TupleConstruct,
__anon11f46a8b0b02(Node* n) 541 [](Node* n) -> SROperator {
542 if (!sr_schema_check_kind(n, prim::TupleConstruct)) {
543 return nullptr;
544 }
545 const bool can_optimize =
546 isOptimizableContainerType(n, c10::FastMap<Node*, bool>());
547 const size_t size = n->inputs().size();
548 if (!can_optimize) {
549 return [size](ProcessedNode* p_node) {
550 DCHECK(p_node->num_inputs() == size);
551 tupleConstructSlowPath(size, p_node);
552 };
553 }
554 return [size](ProcessedNode* p_node) {
555 DCHECK(p_node->num_inputs() == size);
556 const auto& out_l = p_node->Output(0);
557 if (!out_l.isNone()) {
558 return;
559 }
560 tupleConstructSlowPath(size, p_node);
561 };
562 });
563
__anon11f46a8b0e02(Node* n) 564 REGISTER_OPERATOR_FUNCTOR(aten::abs, aten_abs, [](Node* n) -> SROperator {
565 if (!n->matches(torch::schema("aten::abs(Tensor self) -> Tensor"))) {
566 LogAndDumpSchema(n);
567 return nullptr;
568 }
569 return [](ProcessedNode* p_node) {
570 const auto& in0_t = p_node->Input(0).toTensor();
571 if (p_node->Output(0).isNone()) {
572 p_node->Output(0) = at::native::abs(in0_t);
573 return;
574 }
575 auto& out_t = p_node->Output(0).toTensor();
576 fastResizeToZero(out_t);
577 at::native::abs_out(in0_t, out_t);
578 };
579 });
580
__anon11f46a8b1002(Node* n) 581 REGISTER_OPERATOR_FUNCTOR(aten::mul, aten_mul, [](Node* n) -> SROperator {
582 if (!n->matches(torch::schema(
583 "aten::mul.Tensor(Tensor self, Tensor other) -> Tensor"))) {
584 LogAndDumpSchema(n);
585 return nullptr;
586 }
587
588 return [](ProcessedNode* p_node) {
589 const auto& in0_t = p_node->Input(0).toTensor();
590 const auto& in1_t = p_node->Input(1).toTensor();
591 if (p_node->Output(0).isNone()) {
592 p_node->Output(0) = at::cpu::mul(in0_t, in1_t);
593 return;
594 }
595 auto& out_t = p_node->Output(0).toTensor();
596 fastResizeToZero(out_t);
597 at::cpu::mul_out(out_t, in0_t, in1_t);
598 };
599 });
600
__anon11f46a8b1202(Node* n) 601 REGISTER_OPERATOR_FUNCTOR(aten::addmm, aten_addmm, [](Node* n) -> SROperator {
602 if (!n->matches(torch::schema(
603 "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"))) {
604 LogAndDumpSchema(n);
605 return nullptr;
606 }
607 return [](ProcessedNode* p_node) {
608 const auto& in0_t = p_node->Input(0).toTensor();
609 const auto& in1_t = p_node->Input(1).toTensor();
610 const auto& in2_t = p_node->Input(2).toTensor();
611 const auto in3_s = p_node->Input(3).toScalar();
612 const auto in4_s = p_node->Input(4).toScalar();
613 if (p_node->Output(0).isNone()) {
614 p_node->Output(0) = at::cpu::addmm(in0_t, in1_t, in2_t, in3_s, in4_s);
615 return;
616 }
617 auto& out_t = p_node->Output(0).toTensor();
618 fastResizeToZero(out_t);
619 at::cpu::addmm_out(out_t, in0_t, in1_t, in2_t, in3_s, in4_s);
620 };
621 });
622
623 #ifdef FBCODE_CAFFE2
624 // Disable externally to avoid MSVC errors in open-source CI
625
626 REGISTER_OPERATOR_FUNCTOR(
627 static_runtime::clamp_nan_to_num,
628 static_runtime_clamp_nan_to_num,
__anon11f46a8b1402(Node* n) 629 [](Node* n) -> SROperator {
630 if (!sr_schema_check(
631 n,
632 "static_runtime::clamp_nan_to_num(Tensor input, Scalar? min, Scalar? max, float? nan, float? posinf, float? posinf) -> Tensor")) {
633 return nullptr;
634 }
635 auto clamp_min_ival_opt = toIValue(n->input(1));
636 auto clamp_max_ival_opt = toIValue(n->input(2));
637 TORCH_CHECK(
638 clamp_min_ival_opt.has_value() && clamp_max_ival_opt.has_value());
639
640 auto clamp_min_opt = clamp_min_ival_opt->toOptional<at::Scalar>();
641 auto clamp_max_opt = clamp_max_ival_opt->toOptional<at::Scalar>();
642 TORCH_CHECK(clamp_min_opt.has_value() && clamp_max_opt.has_value());
643
644 return [te = createClampNanToNum(),
645 clamp_min = clamp_min_opt->to<float>(),
646 clamp_max =
647 clamp_max_opt->to<float>()](ProcessedNode* p_node) mutable {
648 const auto& in0_t = p_node->Input(0).toTensor();
649 if (p_node->Output(0).isNone()) {
650 p_node->Output(0) = create_empty_from(in0_t);
651 }
652 auto& out_t = p_node->Output(0).toTensor();
653 fastResizeToZero(out_t);
654 auto in3_s = p_node->Input(3).toOptional<double>();
655
656 if (!te || !te->checkInput<float>(in0_t)) {
657 at::cpu::nan_to_num_out(
658 out_t,
659 at::cpu::clamp(in0_t, clamp_min, clamp_max),
660 in3_s,
661 std::nullopt,
662 std::nullopt);
663 return;
664 }
665 at::native::resize_(out_t, in0_t.sizes(), std::nullopt);
666
667 auto output_size = in0_t.numel();
668
669 // This might be UB if in3_s is absurdly large, but most implementations
670 // just turn it into `inf` in that case. The PyTorch core nan_to_num
671 // kernel just static_cast()s the limits to the destination type, so
672 // we'll ignore overflow issues here as well.
673 auto nan = in3_s.has_value() ? static_cast<float>(*in3_s) : 0.f;
674
675 te->call(
676 {out_t.data_ptr(),
677 in0_t.data_ptr(),
678 &clamp_min,
679 &clamp_max,
680 &nan,
681 &output_size});
682 };
683 });
684
685 #endif
686
__anon11f46a8b1602(Node* n) 687 REGISTER_OPERATOR_FUNCTOR(aten::clamp, aten_clamp, [](Node* n) -> SROperator {
688 if (n->matches(torch::schema(
689 "aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor"))) {
690 return [te = createClamp()](ProcessedNode* p_node) {
691 const auto& in0_t = p_node->Input(0).toTensor();
692 if (p_node->Output(0).isNone()) {
693 p_node->Output(0) = create_empty_from(in0_t);
694 }
695 auto& out_t = p_node->Output(0).toTensor();
696 fastResizeToZero(out_t);
697 auto in1_s = p_node->Input(1).toOptional<at::Scalar>();
698 auto in2_s = p_node->Input(2).toOptional<at::Scalar>();
699 if (!te->checkInput<float>(in0_t)) {
700 at::cpu::clamp_out(out_t, in0_t, in1_s, in2_s);
701 return;
702 }
703 at::native::resize_(out_t, in0_t.sizes(), std::nullopt);
704 auto output_size = in0_t.numel();
705 auto min = in1_s.has_value() ? in1_s->toFloat()
706 : -std::numeric_limits<float>::infinity();
707 auto max = in2_s.has_value() ? in2_s->toFloat()
708 : std::numeric_limits<float>::infinity();
709 te->call({out_t.data_ptr(), in0_t.data_ptr(), &min, &max, &output_size});
710 };
711 }
712 if (n->matches(
713 "aten::clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor")) {
714 return [](ProcessedNode* p_node) {
715 const auto& in0_t = p_node->Input(0).toTensor();
716 if (p_node->Output(0).isNone()) {
717 p_node->Output(0) = create_empty_from(in0_t);
718 }
719 auto& out_t = p_node->Output(0).toTensor();
720 fastResizeToZero(out_t);
721 auto in1_t = p_node->Input(1).toOptional<at::Tensor>();
722 auto in2_t = p_node->Input(2).toOptional<at::Tensor>();
723 at::cpu::clamp_out(out_t, in0_t, in1_t, in2_t);
724 };
725 }
726 LogAndDumpSchema(n);
727 return nullptr;
728 });
729
__anon11f46a8b1902(Node* n) 730 REGISTER_OPERATOR_FUNCTOR(aten::bmm, aten_bmm, [](Node* n) -> SROperator {
731 if (!n->matches(
732 torch::schema("aten::bmm(Tensor self, Tensor mat2) -> Tensor"))) {
733 LogAndDumpSchema(n);
734 return nullptr;
735 }
736 return [](ProcessedNode* p_node) {
737 const auto& in0_t = p_node->Input(0).toTensor();
738 const auto& in1_t = p_node->Input(1).toTensor();
739 if (p_node->Output(0).isNone()) {
740 p_node->Output(0) = create_empty_from(in0_t);
741 }
742 auto& out_t = p_node->Output(0).toTensor();
743 fastResizeToZero(out_t);
744 at::cpu::bmm_out(out_t, in0_t, in1_t);
745 };
746 });
747
__anon11f46a8b1b02(Node* n) 748 REGISTER_OPERATOR_FUNCTOR(aten::nan_to_num, aten_nan_to_num, [](Node* n) -> SROperator {
749 if (!n->matches(torch::schema(
750 "aten::nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor"))) {
751 LogAndDumpSchema(n);
752 return nullptr;
753 }
754 return [](ProcessedNode* p_node) {
755 const auto& in0_t = p_node->Input(0).toTensor();
756 const auto in1_d = p_node->Input(1).toOptional<double>();
757 const auto in2_d = p_node->Input(2).toOptional<double>();
758 const auto in3_d = p_node->Input(3).toOptional<double>();
759 if (p_node->Output(0).isNone()) {
760 p_node->Output(0) = at::native::nan_to_num(in0_t, in1_d, in2_d, in3_d);
761 return;
762 }
763 auto& out_t = p_node->Output(0).toTensor();
764 fastResizeToZero(out_t);
765 at::native::nan_to_num_out(in0_t, in1_d, in2_d, in3_d, out_t);
766 };
767 });
768
769 namespace {
770
varStackSerialOut(at::Tensor & result,int64_t dim,const ProcessedNodeInputWrapper & inputs)771 void varStackSerialOut(
772 at::Tensor& result,
773 int64_t dim,
774 const ProcessedNodeInputWrapper& inputs) {
775 auto result_sizes = inputs[0].sizes().vec();
776 result_sizes.insert(result_sizes.begin() + dim, inputs.size());
777 at::native::resize_(result, result_sizes);
778
779 AT_DISPATCH_FLOATING_TYPES(
780 result.scalar_type(), "varstack_serial_kernel", [&]() {
781 at::native::detail::
782 stack_serial_kernel_impl<scalar_t, ProcessedNodeInputWrapper>(
783 result, inputs, dim);
784 });
785 }
786
unsqueezeVarStackInputs(const ProcessedNodeInputWrapper & inputs,const int64_t dim)787 std::vector<at::Tensor> unsqueezeVarStackInputs(
788 const ProcessedNodeInputWrapper& inputs,
789 const int64_t dim) {
790 std::vector<at::Tensor> result;
791 result.reserve(inputs.size());
792 for (const auto i : c10::irange(inputs.size())) {
793 result.push_back(at::native::unsqueeze(inputs[i], dim));
794 }
795 return result;
796 }
797
varstackNonserialOut(at::Tensor & result,const int64_t dim,const ProcessedNodeInputWrapper & inputs)798 void varstackNonserialOut(
799 at::Tensor& result,
800 const int64_t dim,
801 const ProcessedNodeInputWrapper& inputs) {
802 std::vector<at::Tensor> inputs_unsqueezed =
803 unsqueezeVarStackInputs(inputs, dim);
804 fastResizeToZero(result);
805 at::cpu::cat_outf(inputs_unsqueezed, dim, result);
806 }
807
varStackFastOut(at::Tensor & out,int64_t dim,const ProcessedNodeInputWrapper & inputs)808 void varStackFastOut(
809 at::Tensor& out,
810 int64_t dim,
811 const ProcessedNodeInputWrapper& inputs) {
812 DCHECK(out.is_contiguous());
813 const auto num_inputs = static_cast<int64_t>(inputs.size());
814 TORCH_CHECK(num_inputs > 0, "stack expects a non-empty list of tensors");
815
816 const auto first_tensor_shape = inputs[0].sizes();
817 for (const auto i : c10::irange(1, num_inputs)) {
818 const auto shape = inputs[i].sizes();
819 TORCH_CHECK(
820 shape == first_tensor_shape,
821 "Stack expects each tensor to be the same size, but got ",
822 first_tensor_shape,
823 " at position 0 and ",
824 shape,
825 " at position ",
826 i);
827 }
828
829 const std::array<int64_t, 2> output_size = (dim == 0 || dim == -2)
830 ? std::array<int64_t, 2>{num_inputs, 1}
831 : std::array<int64_t, 2>{1, num_inputs};
832
833 at::native::resize_(out, output_size, std::nullopt);
834
835 AT_DISPATCH_ALL_TYPES(out.scalar_type(), "varStackFastOut", [&]() {
836 auto* out_data = out.mutable_data_ptr<scalar_t>();
837 for (const auto i : c10::irange(num_inputs)) {
838 auto& tensor = inputs[i];
839 auto* input_ptr = tensor.const_data_ptr<scalar_t>();
840 out_data[i] = *input_ptr;
841 }
842 });
843 }
844
inputsAreScalars(const ProcessedNodeInputWrapper & inputs)845 bool inputsAreScalars(const ProcessedNodeInputWrapper& inputs) {
846 // All stack inputs should have the same size, so we only check
847 // the first one. If this isn't true, an exception will be thrown
848 // in the VarStack implementation
849 const auto& first_tensor = inputs[0];
850 return first_tensor.sizes()[0] == 1 && first_tensor.dim() == 1;
851 }
852
varStackOut(ProcessedNode & pnode,int64_t dim)853 void varStackOut(ProcessedNode& pnode, int64_t dim) {
854 const auto num_inputs = pnode.num_inputs();
855 TORCH_CHECK(num_inputs > 1, "stack expects a non-empty list of tensors");
856 dim = c10::maybe_wrap_dim(dim, pnode.Input(0).toTensor().dim() + 1);
857
858 auto inputs = ProcessedNodeInputWrapper(pnode);
859 auto& output = pnode.Output(0).toTensor();
860
861 if (output.is_contiguous() && inputsAreScalars(inputs)) {
862 varStackFastOut(output, dim, inputs);
863 return;
864 }
865
866 bool can_use_serial = at::native::detail::CanUseNativeSerialStack<
867 ProcessedNodeInputWrapper,
868 /*skip_overlap_check*/ true>::call(output, inputs, dim);
869
870 if (can_use_serial) {
871 varStackSerialOut(output, dim, inputs);
872 return;
873 }
874 varstackNonserialOut(output, dim, inputs);
875 }
876
877 } // namespace
878
879 // Split out into a function to appease MSVC's pre-processor
aten_stack(Node * n)880 static SROperator aten_stack(Node* n) {
881 if (!n->matches(torch::schema(
882 "aten::stack(Tensor[] tensors, int dim=0) -> Tensor"))) {
883 LogAndDumpSchema(n);
884 return nullptr;
885 }
886 return [](ProcessedNode* p_node) {
887 const auto inputs = p_node->Input(0).toTensorVector();
888 TORCH_CHECK(!inputs.empty(), "stack expects non-empty tensor list");
889 const auto dim = p_node->Input(1).toInt();
890 if (p_node->Output(0).isNone()) {
891 p_node->Output(0) = at::native::_stack_cpu(inputs, dim);
892 return;
893 }
894 auto& out_t = p_node->Output(0).toTensor();
895 fastResizeToZero(out_t);
896 at::native::_stack_out_cpu(inputs, dim, out_t);
897 };
898 }
899
900 REGISTER_OPERATOR_FUNCTOR(aten::stack, aten_stack, aten_stack);
901
902 REGISTER_OPERATOR_FUNCTOR(
903 prim::VarStack,
904 prim_VarStack,
__anon11f46a8b2102(Node* n) 905 [](Node* n) -> SROperator {
906 if (!sr_schema_check_kind(n, prim::VarStack)) {
907 return nullptr;
908 }
909 return [](ProcessedNode* p_node) {
910 const size_t num_inputs = p_node->num_inputs();
911 const auto dim = p_node->Input(num_inputs - 1).toInt();
912
913 if (p_node->Output(0).isNone()) {
914 p_node->Output(0) = create_empty_from(p_node->Input(0).toTensor());
915 }
916 varStackOut(*p_node, dim);
917 };
918 });
919
__anon11f46a8b2302(Node* n) 920 REGISTER_OPERATOR_FUNCTOR(aten::leaky_relu, aten_leaky_relu, [](Node* n) -> SROperator {
921 if (!n->matches(torch::schema(
922 "aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor"))) {
923 LogAndDumpSchema(n);
924 return nullptr;
925 }
926 return [](ProcessedNode* p_node) {
927 const auto& in0_t = p_node->Input(0).toTensor();
928 const auto in1_s = p_node->Input(1).toScalar();
929 if (p_node->Output(0).isNone()) {
930 p_node->Output(0) = at::cpu::leaky_relu(in0_t, in1_s);
931 return;
932 }
933 auto& out_t = p_node->Output(0).toTensor();
934 at::cpu::leaky_relu_out(out_t, in0_t, in1_s);
935 };
936 });
937
__anon11f46a8b2502(Node* n) 938 REGISTER_OPERATOR_FUNCTOR(aten::relu, aten_relu, [](Node* n) -> SROperator {
939 if (!n->matches(torch::schema("aten::relu(Tensor self) -> Tensor"))) {
940 LogAndDumpSchema(n);
941 return nullptr;
942 }
943 auto te = createRelu();
944 return [te](ProcessedNode* p_node) {
945 const auto& in0_t = p_node->Input(0).toTensor();
946 if (p_node->Output(0).isNone()) {
947 p_node->Output(0) = create_empty_from(in0_t);
948 }
949 auto& out_t = p_node->Output(0).toTensor();
950 if (!te->checkInput<float>(in0_t)) {
951 fastResizeToZero(out_t);
952 at::cpu::threshold_out(out_t, in0_t, 0, 0);
953 return;
954 }
955 at::native::resize_(out_t, in0_t.sizes(), std::nullopt);
956 int64_t nn = in0_t.numel();
957 te->call({out_t.data_ptr(), in0_t.data_ptr(), &nn});
958 };
959 });
960
__anon11f46a8b2702(Node* n) 961 REGISTER_OPERATOR_FUNCTOR(aten::tanh, aten_tanh, [](Node* n) -> SROperator {
962 if (!n->matches(torch::schema("aten::tanh(Tensor self) -> Tensor"))) {
963 LogAndDumpSchema(n);
964 return nullptr;
965 }
966 auto te = createTanh();
967 return [te](ProcessedNode* p_node) {
968 const auto& in0_t = p_node->Input(0).toTensor();
969 if (p_node->Output(0).isNone()) {
970 p_node->Output(0) = create_empty_from(in0_t);
971 }
972 auto& out_t = p_node->Output(0).toTensor();
973 if (!te->checkInput<float>(in0_t)) {
974 fastResizeToZero(out_t);
975 at::cpu::tanh_out(out_t, in0_t);
976 return;
977 }
978 at::native::resize_(out_t, in0_t.sizes(), std::nullopt);
979 int64_t nn = in0_t.numel();
980 te->call({out_t.data_ptr(), in0_t.data_ptr(), &nn});
981 };
982 });
983
984 REGISTER_OPERATOR_FUNCTOR(
985 prim::TensorExprDynamicGroup,
986 prim_TensorExprDynamicGroup,
__anon11f46a8b2902(Node* n) 987 [](Node* n) -> SROperator {
988 if (!sr_schema_check_kind(n, prim::TensorExprDynamicGroup)) {
989 return nullptr;
990 }
991 auto graph = n->g(attr::Subgraph);
992 Code code(graph, "");
993 return [code](ProcessedNode* p_node) {
994 auto num_outputs = p_node->num_outputs();
995 Stack stack;
996 if (p_node->Output(0).isNone()) {
997 stack.reserve(p_node->num_inputs());
998 } else {
999 stack.reserve(p_node->num_inputs() + num_outputs);
1000 for (const auto& o : p_node->outputs()) {
1001 stack.emplace_back(o);
1002 }
1003 }
1004 for (auto i : c10::irange(p_node->num_inputs())) {
1005 stack.emplace_back(p_node->Input(i));
1006 }
1007 runTensorExprDynamicGroup(code, stack);
1008 if (p_node->Output(0).isNone()) {
1009 TORCH_INTERNAL_ASSERT(
1010 stack.size() == num_outputs,
1011 "Unexpected # of outputs on stack after executing TensorExprDynamicGroup");
1012 for (auto i : c10::irange(num_outputs)) {
1013 p_node->Output(i) = std::move(stack[i]);
1014 }
1015 }
1016 };
1017 });
1018
1019 REGISTER_OPERATOR_FUNCTOR(
1020 aten::sigmoid,
1021 aten_sigmoid,
__anon11f46a8b2b02(Node* n) 1022 [](Node* n) -> SROperator {
1023 if (!n->matches(torch::schema("aten::sigmoid(Tensor self) -> Tensor"))) {
1024 LogAndDumpSchema(n);
1025 return nullptr;
1026 }
1027 auto te = createSigmoid();
1028 return [te](ProcessedNode* p_node) {
1029 const auto& in0_t = p_node->Input(0).toTensor();
1030 if (p_node->Output(0).isNone()) {
1031 p_node->Output(0) = create_empty_from(in0_t);
1032 }
1033 auto& out_t = p_node->Output(0).toTensor();
1034 if (!te->checkInput<float>(in0_t)) {
1035 fastResizeToZero(out_t);
1036 at::cpu::sigmoid_out(out_t, in0_t);
1037 return;
1038 }
1039 at::native::resize_(out_t, in0_t.sizes(), std::nullopt);
1040 int64_t nn = in0_t.numel();
1041 te->call({out_t.data_ptr(), in0_t.data_ptr(), &nn});
1042 };
1043 });
1044
__anon11f46a8b2d02(Node* n) 1045 REGISTER_OPERATOR_FUNCTOR(aten::logit, aten_logit, [](Node* n) -> SROperator {
1046 if (!n->matches(torch::schema(
1047 "aten::logit(Tensor self, float? eps=None) -> Tensor"))) {
1048 LogAndDumpSchema(n);
1049 return nullptr;
1050 }
1051 std::optional<float> clamp = std::nullopt;
1052 if (n->inputs()[1]->node()->kind() == prim::Constant) {
1053 auto clamp_d = toIValue(n->inputs()[1])->toOptional<double>();
1054 clamp = clamp_d
1055 ? std::make_optional<float>(static_cast<float>(clamp_d.value()))
1056 : std::nullopt;
1057 }
1058 auto te = clamp ? createLogit() : nullptr;
1059 float clamp_value = clamp ? *clamp : 0.0f;
1060 return [te, clamp_value](ProcessedNode* p_node) {
1061 const auto& in0_t = p_node->Input(0).toTensor();
1062 if (p_node->Output(0).isNone()) {
1063 p_node->Output(0) = create_empty_from(in0_t);
1064 }
1065 auto& out_t = p_node->Output(0).toTensor();
1066 if (!te || !te->checkInput<float>(in0_t)) {
1067 const auto& in0_t = p_node->Input(0).toTensor();
1068 const auto in1_d = p_node->Input(1).toOptional<double>();
1069 fastResizeToZero(out_t);
1070 at::native::logit_out(in0_t, in1_d, out_t);
1071 return;
1072 }
1073 at::native::resize_(out_t, in0_t.sizes(), std::nullopt);
1074 int64_t nn = in0_t.numel();
1075 float c = clamp_value;
1076 te->call({out_t.data_ptr(), in0_t.data_ptr(), &nn, &c});
1077 };
1078 });
1079
__anon11f46a8b2f02(Node* n) 1080 REGISTER_OPERATOR_FUNCTOR(aten::clone, aten_clone, [](Node* n) -> SROperator {
1081 if (!n->matches(torch::schema(
1082 "aten::clone(Tensor self, *, MemoryFormat? memory_format=None) ->Tensor"))) {
1083 LogAndDumpSchema(n);
1084 return nullptr;
1085 }
1086 return [](ProcessedNode* p_node) {
1087 const auto& src = p_node->Input(0).toTensor();
1088 const auto& optional_memory_format =
1089 p_node->Input(1).toOptional<c10::MemoryFormat>();
1090 auto memory_format =
1091 optional_memory_format.value_or(c10::MemoryFormat::Preserve);
1092 /*
1093 disable out_variant of clone for case with stride = 0 and
1094 memory formats other than preserve. Perform dynamic allocation
1095 instead of memory reuse for simpler implementation. We could,
1096 in principle, figure out copy of strides.
1097 */
1098 if ((at::has_internal_overlap(src.unsafeGetTensorImpl()) ==
1099 at::MemOverlap::Yes) ||
1100 (memory_format != c10::MemoryFormat::Preserve)) {
1101 p_node->Output(0) = at::native::clone(src, memory_format);
1102 return;
1103 }
1104 if (p_node->Output(0).isNone()) {
1105 if (src.is_non_overlapping_and_dense()) {
1106 // Copy all strides
1107 p_node->Output(0) =
1108 at::empty_strided(src.sizes(), src.strides(), src.options());
1109 } else {
1110 memory_format = src.suggest_memory_format();
1111 p_node->Output(0) = create_empty_from(src, memory_format);
1112 }
1113 }
1114 auto& out_t = p_node->Output(0).toTensor();
1115 at::native::resize_impl_cpu_(
1116 out_t.unsafeGetTensorImpl(), src.sizes(), src.strides());
1117 at::native::copy_(out_t, src, false);
1118 };
1119 });
1120
1121 REGISTER_OPERATOR_FUNCTOR(
1122 quantized::embedding_bag_byte_rowwise_offsets,
1123 quantized_embedding_bag_byte_rowwise_offsets,
__anon11f46a8b3102(Node* n) 1124 [](Node* n) -> SROperator {
1125 if (!n->matches(torch::schema(
1126 "quantized::embedding_bag_byte_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"))) {
1127 LogAndDumpSchema(n);
1128 return nullptr;
1129 }
1130 return [](ProcessedNode* p_node) {
1131 const auto& weight = p_node->Input(0).toTensor();
1132 const auto& indices = p_node->Input(1).toTensor();
1133 const auto offsets = p_node->Input(2).toOptional<at::Tensor>();
1134 const auto pruned_weights = p_node->Input(5).toBool();
1135 const auto per_sample_weights =
1136 p_node->Input(6).toOptional<at::Tensor>();
1137 const auto compressed_indices_mapping =
1138 p_node->Input(7).toOptional<at::Tensor>();
1139 const auto include_last_offset = p_node->Input(8).toBool();
1140 if (p_node->Output(0).isNone()) {
1141 p_node->Output(0) = create_empty_from(weight, at::kFloat);
1142 }
1143 auto& out_t = p_node->Output(0).toTensor();
1144 fastResizeToZero(out_t);
1145 return at::native::embedding_bag_byte_rowwise_offsets_out(
1146 out_t,
1147 weight,
1148 indices,
1149 offsets,
1150 false, // unused scale_grad_by_freq
1151 0, // unused mode
1152 pruned_weights,
1153 per_sample_weights,
1154 compressed_indices_mapping,
1155 include_last_offset);
1156 };
1157 });
1158
1159 REGISTER_OPERATOR_FUNCTOR(
1160 quantized::embedding_bag_4bit_rowwise_offsets,
1161 embedding_bag_4bit_rowwise_offsets,
__anon11f46a8b3302(Node* n) 1162 [](Node* n) -> SROperator {
1163 if (!n->matches(torch::schema(
1164 "quantized::embedding_bag_4bit_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"))) {
1165 LogAndDumpSchema(n);
1166 return nullptr;
1167 }
1168 return [](ProcessedNode* p_node) {
1169 const auto& weight = p_node->Input(0).toTensor();
1170 const auto& indices = p_node->Input(1).toTensor();
1171 const auto offsets = p_node->Input(2).toOptional<at::Tensor>();
1172 const auto pruned_weights = p_node->Input(5).toBool();
1173 const auto per_sample_weights =
1174 p_node->Input(6).toOptional<at::Tensor>();
1175 const auto compressed_indices_mapping =
1176 p_node->Input(7).toOptional<at::Tensor>();
1177 const auto include_last_offset = p_node->Input(8).toBool();
1178 if (p_node->Output(0).isNone()) {
1179 p_node->Output(0) = create_empty_from(weight, at::kFloat);
1180 }
1181 auto& out_t = p_node->Output(0).toTensor();
1182 fastResizeToZero(out_t);
1183 return at::native::embedding_bag_4bit_rowwise_offsets_out(
1184 out_t,
1185 weight,
1186 indices,
1187 offsets,
1188 false, // unused scale_grad_by_freq
1189 0, // unused mode
1190 pruned_weights,
1191 per_sample_weights,
1192 compressed_indices_mapping,
1193 include_last_offset);
1194 };
1195 });
1196
1197 REGISTER_OPERATOR_FUNCTOR(
1198 quantized::embedding_bag_byte_prepack,
1199 embedding_bag_byte_prepack,
__anon11f46a8b3502(Node* n) 1200 [](Node* n) -> SROperator {
1201 if (!n->matches(torch::schema(
1202 "quantized::embedding_bag_byte_prepack(Tensor weight) -> Tensor"))) {
1203 LogAndDumpSchema(n);
1204 return nullptr;
1205 }
1206 return [](ProcessedNode* p_node) {
1207 const auto& weight = p_node->Input(0).toTensor();
1208 if (p_node->Output(0).isNone()) {
1209 p_node->Output(0) = at::native::qembeddingbag_byte_prepack(weight);
1210 return;
1211 }
1212 auto& out_t = p_node->Output(0).toTensor();
1213 fastResizeToZero(out_t);
1214 at::native::qembeddingbag_byte_prepack_out(out_t, weight);
1215 };
1216 });
1217
1218 // The out variant takes precedence over native
__anon11f46a8b3702(Node* n) 1219 REGISTER_OPERATOR_FUNCTOR(aten::narrow_copy, aten_narrow_copy, [](Node* n) -> SROperator {
1220 if (!n->matches(torch::schema(
1221 "aten::narrow_copy(Tensor self, int dim, int start, int length) -> Tensor"))) {
1222 LogAndDumpSchema(n);
1223 return nullptr;
1224 }
1225 return [](ProcessedNode* p_node) {
1226 const auto& self = p_node->Input(0).toTensor(); // self
1227 const auto dim = p_node->Input(1).toInt(); // dim
1228 int64_t start = 0;
1229 if (p_node->Input(2).isScalar()) {
1230 start = p_node->Input(2).toInt();
1231 } else {
1232 auto& t = p_node->Input(2).toTensor();
1233 start = t.item<int64_t>();
1234 }
1235 auto length = p_node->Input(3).toInt(); // length
1236
1237 if (p_node->Output(0).isNone()) {
1238 p_node->Output(0) =
1239 at::native::narrow_copy_dense_cpu(self, dim, start, length);
1240 return;
1241 }
1242 auto& output = p_node->Output(0).toTensor();
1243 fastResizeToZero(output);
1244 at::native::narrow_copy_dense_cpu_out(self, dim, start, length, output);
1245 };
1246 });
__anon11f46a8b3902(Node* n) 1247 REGISTER_OPERATOR_FUNCTOR(aten::index, aten_index, [](Node* n) -> SROperator {
1248 if (!n->matches(torch::schema(
1249 "aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor"))) {
1250 LogAndDumpSchema(n);
1251 return nullptr;
1252 }
1253 return [](ProcessedNode* p_node) {
1254 const auto& in0_t = p_node->Input(0).toTensor();
1255 const auto in1_l =
1256 at::native::toListOfOptionalTensors(p_node->Input(1).toListRef());
1257 if (p_node->Output(0).isNone()) {
1258 p_node->Output(0) = at::cpu::index(in0_t, in1_l);
1259 return;
1260 }
1261 auto& out_t = p_node->Output(0).toTensor();
1262 fastResizeToZero(out_t);
1263 at::cpu::index_out(out_t, in0_t, in1_l);
1264 };
1265 });
1266
1267 REGISTER_OPERATOR_FUNCTOR(
1268 aten::index_select,
1269 aten_index_select,
__anon11f46a8b3b02(Node* n) 1270 [](Node* n) -> SROperator {
1271 if (!n->matches(torch::schema(
1272 "aten::index_select(Tensor self, int dim, Tensor index) -> Tensor"))) {
1273 LogAndDumpSchema(n);
1274 return nullptr;
1275 }
1276 return [](ProcessedNode* p_node) {
1277 const auto& self = p_node->Input(0).toTensor();
1278 const auto dim = p_node->Input(1).toInt();
1279 const auto& index = p_node->Input(2).toTensor();
1280 if (p_node->Output(0).isNone()) {
1281 p_node->Output(0) = at::native::index_select_cpu_(self, dim, index);
1282 return;
1283 }
1284 auto& out = p_node->Output(0).toTensor();
1285 fastResizeToZero(out);
1286 at::native::index_select_out_cpu_(self, dim, index, out);
1287 };
1288 });
1289
__anon11f46a8b3d02(Node* n) 1290 REGISTER_OPERATOR_FUNCTOR(aten::pow, aten_pow, [](Node* n) -> SROperator {
1291 if (n->matches(torch::schema(
1292 "aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor"))) {
1293 return [](ProcessedNode* p_node) {
1294 if (p_node->Output(0).isNone()) {
1295 const auto& in0_t = p_node->Input(0).toTensor();
1296 auto dtype =
1297 at::native::result_type(in0_t, p_node->Input(1).toTensor());
1298 p_node->Output(0) = create_empty_from(in0_t, dtype);
1299 }
1300 auto& out_t = p_node->Output(0).toTensor();
1301 fastResizeToZero(out_t);
1302 at::cpu::pow_out(
1303 out_t, p_node->Input(0).toTensor(), p_node->Input(1).toTensor());
1304 };
1305 }
1306 if (n->matches(torch::schema(
1307 "aten::pow.Scalar(Scalar self, Tensor exponent) -> Tensor"))) {
1308 return [](ProcessedNode* p_node) {
1309 if (p_node->Output(0).isNone()) {
1310 const auto& in1_t = p_node->Input(1).toTensor();
1311 auto dtype =
1312 at::native::result_type(p_node->Input(0).toScalar(), in1_t);
1313 p_node->Output(0) = at::native::empty_like(
1314 in1_t,
1315 dtype,
1316 in1_t.options().layout_opt(),
1317 in1_t.options().device_opt(),
1318 in1_t.options().pinned_memory_opt(),
1319 at::MemoryFormat::Preserve);
1320 }
1321 auto& out_t = p_node->Output(0).toTensor();
1322 fastResizeToZero(out_t);
1323 at::cpu::pow_out(
1324 out_t, p_node->Input(0).toScalar(), p_node->Input(1).toTensor());
1325 };
1326 }
1327 if (n->matches(torch::schema(
1328 "aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor"))) {
1329 return [](ProcessedNode* p_node) {
1330 if (p_node->Output(0).isNone()) {
1331 const auto& in0_t = p_node->Input(0).toTensor();
1332 auto dtype =
1333 at::native::result_type(in0_t, p_node->Input(1).toScalar());
1334 p_node->Output(0) = at::native::empty_like(
1335 in0_t,
1336 dtype,
1337 in0_t.options().layout_opt(),
1338 in0_t.options().device_opt(),
1339 in0_t.options().pinned_memory_opt(),
1340 at::MemoryFormat::Preserve);
1341 }
1342 auto& out_t = p_node->Output(0).toTensor();
1343 fastResizeToZero(out_t);
1344 at::cpu::pow_out(
1345 out_t, p_node->Input(0).toTensor(), p_node->Input(1).toScalar());
1346 };
1347 }
1348 LogAndDumpSchema(n);
1349 return nullptr;
1350 });
1351
1352 namespace {
1353
1354 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
1355 struct ToArgs {
1356 std::optional<at::ScalarType> dtype;
1357 c10::Layout layout;
1358 bool know_to_will_alias = false;
1359 std::optional<c10::MemoryFormat> memory_format;
1360 };
1361
1362 template <bool has_constant_non_tensor_dtype_and_flags, bool has_memory_format>
extract_to_args(ProcessedNode * p_node)1363 ToArgs extract_to_args(ProcessedNode* p_node) {
1364 ToArgs result;
1365 if (!has_constant_non_tensor_dtype_and_flags && p_node->Input(1).isTensor()) {
1366 const auto& other = p_node->Input(1).toTensor();
1367 result.dtype = other.scalar_type();
1368 result.layout = other.layout();
1369 TORCH_DCHECK_EQ(other.device().type(), c10::DeviceType::CPU);
1370 } else {
1371 const auto& self = p_node->Input(0).toTensor();
1372 result.dtype = p_node->Input(1).toOptional<at::ScalarType>();
1373 result.layout = self.layout();
1374 // Static runtime only works with CPU tensors; don't need to read this.
1375 TORCH_DCHECK_EQ(self.device().type(), c10::DeviceType::CPU);
1376 result.know_to_will_alias = has_constant_non_tensor_dtype_and_flags &&
1377 (!result.dtype.has_value() ||
1378 result.dtype.value() == self.dtype().toScalarType());
1379 }
1380 if (has_memory_format) {
1381 TORCH_DCHECK_EQ(p_node->num_inputs(), 5);
1382 result.memory_format = p_node->Input(4).toOptional<c10::MemoryFormat>();
1383 result.know_to_will_alias = result.know_to_will_alias &&
1384 (result.memory_format.value_or(c10::MemoryFormat::Preserve) ==
1385 c10::MemoryFormat::Preserve);
1386 }
1387
1388 return result;
1389 }
1390
1391 template <bool has_constant_non_tensor_dtype_and_flags, bool has_memory_format>
1392 struct CheckToWillAlias {
calltorch::jit::__anon11f46a8b4111::CheckToWillAlias1393 static bool call(
1394 ProcessedNode* p_node,
1395 const at::Tensor& self,
1396 const ToArgs& to_args) {
1397 // The to_maybe_copy_out operator functor should have detected a
1398 // constant true `copy` argument and used to_copy instead.
1399 bool copy = false;
1400 if (has_constant_non_tensor_dtype_and_flags) {
1401 DCHECK(!p_node->Input(3).toBool());
1402 copy = false;
1403 } else {
1404 copy = p_node->Input(3).toBool();
1405 }
1406 return !copy &&
1407 (to_args.know_to_will_alias ||
1408 at::native::to_will_alias(
1409 self,
1410 to_args.dtype,
1411 to_args.layout,
1412 c10::Device{c10::DeviceType::CPU},
1413 copy,
1414 has_memory_format ? to_args.memory_format
1415 : c10::MemoryFormat::Preserve));
1416 }
1417 };
1418
1419 template <>
1420 struct CheckToWillAlias<true, false> {
1421 // Special case! First, there is no memory format to check. Second,
1422 // we know that the layout and device will match self, so we only
1423 // need to check the dtype.
calltorch::jit::__anon11f46a8b4111::CheckToWillAlias1424 static bool call(ProcessedNode* p_node, const at::Tensor& self) {
1425 DCHECK(!p_node->Input(3).toBool()); // !copy
1426 const auto dtype_opt = p_node->Input(1).toOptional<at::ScalarType>();
1427 return !dtype_opt.has_value() || *dtype_opt == self.dtype().toScalarType();
1428 }
1429 };
1430
1431 // Force inlining so we don't have to branch on whether args is null
1432 // at runtime.
1433 template <bool has_constant_non_tensor_dtype_and_flags, bool has_memory_format>
to_copy_functor_impl(ProcessedNode * p_node,const ToArgs * args)1434 C10_ALWAYS_INLINE void to_copy_functor_impl(
1435 ProcessedNode* p_node,
1436 const ToArgs* args) {
1437 const auto& self = p_node->Input(0).toTensor();
1438 // ignore input 3 (copy)
1439 auto non_blocking = p_node->Input(2).toBool(); // non_blocking
1440 // handle memory format
1441 bool copy_strides = false;
1442
1443 std::optional<c10::MemoryFormat> memory_format = c10::MemoryFormat::Preserve;
1444 std::optional<ToArgs> my_args;
1445 if (!args) {
1446 my_args = extract_to_args<
1447 has_constant_non_tensor_dtype_and_flags,
1448 has_memory_format>(p_node);
1449 args = &my_args.value();
1450 }
1451 if (has_memory_format) {
1452 memory_format = args->memory_format.value_or(c10::MemoryFormat::Preserve);
1453 }
1454
1455 if (memory_format == c10::MemoryFormat::Preserve) {
1456 if (self.is_non_overlapping_and_dense()) {
1457 memory_format = std::nullopt;
1458 copy_strides = true;
1459 } else {
1460 memory_format = self.suggest_memory_format();
1461 }
1462 }
1463
1464 bool need_to_allocate_output = true;
1465 if (p_node->Output(0).isTensor()) {
1466 const auto& existing_output = p_node->Output(0).toTensor();
1467 if ((!has_constant_non_tensor_dtype_and_flags &&
1468 (existing_output.dtype() != args->dtype ||
1469 existing_output.layout() != args->layout ||
1470 existing_output.device() != self.device())) ||
1471 (has_memory_format &&
1472 !existing_output.is_contiguous(
1473 memory_format.value_or(c10::MemoryFormat::Contiguous)))) {
1474 need_to_allocate_output = true;
1475 } else {
1476 need_to_allocate_output = false;
1477 }
1478 }
1479
1480 // See Note [Explicit nullopt MemoryFormat argument]
1481 // Can't use size {0} if memory_format is ChannelLast
1482 if (need_to_allocate_output) {
1483 p_node->Output(0) = at::detail::empty_cpu(
1484 self.sizes(),
1485 args->dtype,
1486 args->layout,
1487 self.device(),
1488 std::nullopt,
1489 memory_format);
1490 } else {
1491 if (has_memory_format) {
1492 memory_format = p_node->Input(4).toOptional<c10::MemoryFormat>().value_or(
1493 c10::MemoryFormat::Preserve);
1494 } else {
1495 memory_format = c10::MemoryFormat::Preserve;
1496 }
1497 }
1498
1499 copy_strides = copy_strides ||
1500 (memory_format == c10::MemoryFormat::Preserve &&
1501 self.is_non_overlapping_and_dense());
1502
1503 auto& out_t = p_node->Output(0).toTensor();
1504 fastResizeToZero(out_t);
1505 at::native::to_copy_out(
1506 out_t, self, non_blocking, copy_strides, memory_format);
1507 }
1508
1509 template <bool has_constant_non_tensor_dtype_and_flags, bool has_memory_format>
to_copy_functor(ProcessedNode * p_node)1510 void to_copy_functor(ProcessedNode* p_node) {
1511 to_copy_functor_impl<
1512 has_constant_non_tensor_dtype_and_flags,
1513 has_memory_format>(p_node, nullptr);
1514 }
1515
1516 template <bool has_constant_non_tensor_dtype_and_flags, bool has_memory_format>
to_maybe_copy_out_functor(ProcessedNode * p_node)1517 void to_maybe_copy_out_functor(ProcessedNode* p_node) {
1518 // It would be great if we could avoid checking our arguments every
1519 // time. However, we need to make account for the possibility that
1520 // the dtype (and layout, memory format, etc.) of self changed
1521 // between iterations.
1522 ToArgs args = extract_to_args<
1523 has_constant_non_tensor_dtype_and_flags,
1524 has_memory_format>(p_node);
1525 const auto& self = p_node->Input(0).toTensor();
1526 if (CheckToWillAlias<
1527 has_constant_non_tensor_dtype_and_flags,
1528 has_memory_format>::call(p_node, self, args)) {
1529 // Don't write our Tensor output. This leaves it None if it
1530 // was never allocated (and there is a special case in the
1531 // memory planner to not start managing in this case), but
1532 // if we are oscillating between aliasing and needing to
1533 // copy, we should just leave our output in place so as not
1534 // to confuse the memory planner.
1535 p_node->Output(1) = false;
1536 } else {
1537 p_node->Output(1) = true;
1538 to_copy_functor_impl<
1539 has_constant_non_tensor_dtype_and_flags,
1540 has_memory_format>(p_node, &args);
1541 }
1542 }
1543
1544 // Take advantage of CheckToWillAlias not requiring the args in this
1545 // case.
1546 template <>
to_maybe_copy_out_functor(ProcessedNode * p_node)1547 void to_maybe_copy_out_functor<true, false>(ProcessedNode* p_node) {
1548 const auto& self = p_node->Input(0).toTensor();
1549 if (CheckToWillAlias<true, false>::call(p_node, self)) {
1550 p_node->Output(1) = false;
1551 } else {
1552 p_node->Output(1) = true;
1553 auto args = extract_to_args<true, false>(p_node);
1554 to_copy_functor_impl<true, false>(p_node, &args);
1555 }
1556 }
1557
node_has_constant_non_tensor_dtype_and_flags(Node * n)1558 bool node_has_constant_non_tensor_dtype_and_flags(Node* n) {
1559 const auto* input1 = n->inputs()[1];
1560 return input1->type()->kind() != TypeKind::TensorType &&
1561 input1->node()->kind() == prim::Constant &&
1562 n->inputs()[2]->node()->kind() == prim::Constant &&
1563 n->inputs()[3]->node()->kind() == prim::Constant;
1564 }
1565
get_to_copy_functor(bool has_constant_non_tensor_dtype_and_flags,bool has_memory_format)1566 auto get_to_copy_functor(
1567 bool has_constant_non_tensor_dtype_and_flags,
1568 bool has_memory_format) {
1569 if (has_constant_non_tensor_dtype_and_flags) {
1570 if (has_memory_format) {
1571 return to_copy_functor<true, true>;
1572 } else {
1573 return to_copy_functor<true, false>;
1574 }
1575 } else {
1576 if (has_memory_format) {
1577 return to_copy_functor<false, true>;
1578 } else {
1579 return to_copy_functor<false, false>;
1580 }
1581 }
1582 }
1583
1584 } // namespace
1585
1586 REGISTER_OPERATOR_FUNCTOR(
1587 static_runtime::to_maybe_copy_out,
1588 aten_to_maybe_copy,
__anon11f46a8b4202(Node* n) 1589 [](Node* n) -> SROperator {
1590 // support 4- or 5-arg for adindexer/adfinder models
1591 // Keep TORCH_CHECK here because there is no alternative for fallback
1592 if (!sr_schema_check(
1593 n,
1594 "static_runtime::to_maybe_copy_out.prim_dtype(Tensor self, int? dtype=None, bool non_blocking=False, bool copy=False) -> (Tensor, bool)",
1595 "static_runtime::to_maybe_copy_out.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> (Tensor, bool)",
1596 "static_runtime::to_maybe_copy_out.other(Tensor self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> (Tensor, bool)")) {
1597 return nullptr;
1598 }
1599 TORCH_CHECK(n->inputs().size() == 4 || n->inputs().size() == 5);
1600 const bool has_constant_non_tensor_dtype_and_flags =
1601 node_has_constant_non_tensor_dtype_and_flags(n);
1602 const bool has_memory_format = n->inputs().size() == 5;
1603
1604 // If we are going to copy always, just use the to_copy path so
1605 // that the to_maybe_copy path can assume that won't happen.
1606 if (has_constant_non_tensor_dtype_and_flags) {
1607 const auto copyArg =
1608 torch::jit::constant_as<bool>(n->inputs()[3]->node()->output());
1609 DCHECK(copyArg.has_value());
1610 if (*copyArg) {
1611 return get_to_copy_functor(
1612 has_constant_non_tensor_dtype_and_flags, has_memory_format);
1613 }
1614 }
1615 if (has_constant_non_tensor_dtype_and_flags) {
1616 if (has_memory_format) {
1617 return to_maybe_copy_out_functor<true, true>;
1618 } else {
1619 return to_maybe_copy_out_functor<true, false>;
1620 }
1621 } else {
1622 if (has_memory_format) {
1623 return to_maybe_copy_out_functor<false, true>;
1624 } else {
1625 return to_maybe_copy_out_functor<false, false>;
1626 }
1627 }
1628 });
1629
1630 // out variant takes precedence over native
1631 // NB: This impl doesn't work for cpu->cuda copy/cast or vice versa.
1632 REGISTER_OPERATOR_FUNCTOR(
1633 static_runtime::to_copy,
1634 static_runtime_to_copy,
__anon11f46a8b4302(Node* n) 1635 [](Node* n) -> SROperator {
1636 // support 4- or 5-arg for adindexer/adfinder models
1637 // Keep TORCH_CHECK here because there is no alternative for fallback
1638 if (!sr_schema_check(
1639 n,
1640 "static_runtime::to_copy.prim_dtype(Tensor self, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor",
1641 "static_runtime::to_copy.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor",
1642 "static_runtime::to_copy.other(Tensor self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor")) {
1643 return nullptr;
1644 }
1645 TORCH_CHECK(n->inputs().size() == 4 || n->inputs().size() == 5);
1646 const bool has_constant_non_tensor_dtype_and_flags =
1647 node_has_constant_non_tensor_dtype_and_flags(n);
1648 const bool has_memory_format = n->inputs().size() == 5;
1649 return get_to_copy_functor(
1650 has_constant_non_tensor_dtype_and_flags, has_memory_format);
1651 });
1652
1653 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
1654 REGISTER_OPERATOR_FUNCTOR(
1655 static_runtime::dequantize_copy,
1656 aten_dequantize_copy,
__anon11f46a8b4402(Node* n) 1657 [](Node* n) -> SROperator {
1658 if (!n->matches(torch::schema(
1659 "static_runtime::dequantize_copy.self(Tensor self) -> Tensor"))) {
1660 // please implement static runtime support for aten::dequantize with
1661 // TensorList
1662 LogAndDumpSchema(n);
1663 return nullptr;
1664 }
1665 return [](ProcessedNode* p_node) {
1666 const auto& self = p_node->Input(0).toTensor();
1667 if (p_node->Output(0).isNone()) {
1668 p_node->Output(0) =
1669 create_empty_from(self, at::kFloat, self.suggest_memory_format());
1670 }
1671
1672 auto& out_t = p_node->Output(0).toTensor();
1673 fastResizeToZero(out_t);
1674 at::native::dequantize_copy_out(out_t, self);
1675 };
1676 });
1677
1678 // Out variants for view ops are registered to a separate registry because
1679 // their outputs (views) can't participate in memory reuse.
1680 REGISTER_OPERATOR_FUNCTOR(
1681 static_runtime::reshape_copy,
1682 aten_reshape,
__anon11f46a8b4602(Node* n) 1683 [](Node* n) -> SROperator {
1684 if (!sr_schema_check(
1685 n,
1686 "static_runtime::reshape_copy(Tensor self, int[] shape) -> Tensor")) {
1687 return nullptr;
1688 }
1689 TORCH_CHECK(n->inputs().size() == 2);
1690 return [](ProcessedNode* p_node) {
1691 const auto& self = p_node->Input(0).toTensor(); // self
1692 const auto proposed_shape = p_node->Input(1).toDimVector(); // shape
1693
1694 if (p_node->Output(0).isNone()) {
1695 p_node->Output(0) = create_empty_from(self);
1696 }
1697 auto& out = p_node->Output(0).toTensor();
1698 at::native::reshape_copy_out(out, self, proposed_shape, true);
1699 };
1700 });
1701
1702 REGISTER_OPERATOR_FUNCTOR(
1703 static_runtime::flatten_copy,
1704 aten_flatten,
__anon11f46a8b4802(Node* n) 1705 [](Node* n) -> SROperator {
1706 if (!sr_schema_check(
1707 n,
1708 "static_runtime::flatten_copy.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> Tensor")) {
1709 return nullptr;
1710 }
1711 TORCH_CHECK(n->inputs().size() == 3);
1712 return [](ProcessedNode* p_node) {
1713 const auto& self = p_node->Input(0).toTensor();
1714 const auto start_dim = p_node->Input(1).toInt();
1715 const auto end_dim = p_node->Input(2).toInt();
1716
1717 if (p_node->Output(0).isNone()) {
1718 p_node->Output(0) = create_empty_from(self);
1719 }
1720 auto& out = p_node->Output(0).toTensor();
1721 at::native::flatten_copy_out(out, self, start_dim, end_dim);
1722 };
1723 });
1724
__anon11f46a8b4a02(Node* n) 1725 REGISTER_OPERATOR_FUNCTOR(aten::sum, aten_sum, [](Node* n) -> SROperator {
1726 if (n->inputs().size() != 2 && n->inputs().size() != 4) {
1727 return nullptr;
1728 }
1729 if (n->matches(torch::schema(
1730 "aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor"))) {
1731 return [](ProcessedNode* p_node) {
1732 const at::Tensor& self = p_node->Input(0).toTensor();
1733 auto dtype = p_node->Input(1).toOptional<at::ScalarType>();
1734 std::vector<int64_t> dim = {};
1735 bool keepdim = false;
1736 if (p_node->Output(0).isNone()) {
1737 p_node->Output(0) = at::cpu::sum(self, dim, keepdim, dtype);
1738 } else {
1739 auto& output = p_node->Output(0).toTensor();
1740 fastResizeToZero(output);
1741 at::cpu::sum_out(output, self, dim, keepdim, dtype);
1742 }
1743 };
1744 }
1745 if (n->matches(torch::schema(
1746 "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"))) {
1747 return [](ProcessedNode* p_node) {
1748 const at::Tensor& self = p_node->Input(0).toTensor();
1749 auto dim = p_node->Input(1).toDimVector();
1750 auto keepdim = p_node->Input(2).toBool();
1751 auto dtype = p_node->Input(3).toOptional<at::ScalarType>();
1752 if (p_node->Output(0).isNone()) {
1753 p_node->Output(0) = at::cpu::sum(self, dim, keepdim, dtype);
1754 } else {
1755 auto& output = p_node->Output(0).toTensor();
1756 fastResizeToZero(output);
1757 at::cpu::sum_out(output, self, dim, keepdim, dtype);
1758 }
1759 };
1760 }
1761 LogAndDumpSchema(n);
1762 return nullptr;
1763 });
1764
__anon11f46a8b4d02(Node* n) 1765 REGISTER_OPERATOR_FUNCTOR(aten::mean, aten_mean, [](Node* n) -> SROperator {
1766 if (n->matches(torch::schema(
1767 "aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"))) {
1768 return [](ProcessedNode* p_node) {
1769 const auto& self = p_node->Input(0).toTensor();
1770 const auto dim = p_node->Input(1).toDimVector();
1771 const bool keepdim = p_node->Input(2).toBool();
1772 const auto dtype = p_node->Input(3).toOptional<at::ScalarType>();
1773 if (p_node->Output(0).isNone()) {
1774 p_node->Output(0) = create_empty_from(
1775 self, dtype.value_or(self.dtype().toScalarType()));
1776 }
1777 auto& output = p_node->Output(0).toTensor();
1778 fastResizeToZero(output);
1779 at::cpu::mean_out(output, self, dim, keepdim, dtype);
1780 };
1781 }
1782
1783 if (n->matches(torch::schema(
1784 "aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor"))) {
1785 return [](ProcessedNode* p_node) {
1786 const auto& self = p_node->Input(0).toTensor();
1787 const auto dtype = p_node->Input(1).toOptional<at::ScalarType>();
1788 if (p_node->Output(0).isNone()) {
1789 p_node->Output(0) = create_empty_from(
1790 self, dtype.value_or(self.dtype().toScalarType()));
1791 }
1792 auto& output = p_node->Output(0).toTensor();
1793 fastResizeToZero(output);
1794 at::cpu::mean_out(output, self, /*dim=*/{}, /*keepdim=*/false, dtype);
1795 };
1796 }
1797
1798 LogAndDumpSchema(n);
1799 return nullptr;
1800 });
1801
__anon11f46a8b5002(Node* n) 1802 REGISTER_OPERATOR_FUNCTOR(aten::repeat, aten_repeat, [](Node* n) -> SROperator {
1803 if (!n->matches(torch::schema(
1804 "aten::repeat(Tensor self, int[] repeats) -> Tensor"))) {
1805 LogAndDumpSchema(n);
1806 return nullptr;
1807 }
1808 return [](ProcessedNode* p_node) {
1809 const auto& self = p_node->Input(0).toTensor();
1810 const auto repeats = p_node->Input(1).toDimVector();
1811
1812 if (p_node->Output(0).isNone()) {
1813 p_node->Output(0) = at::native::repeat(self, repeats);
1814 return;
1815 }
1816 at::Tensor& output = p_node->Output(0).toTensor();
1817 at::native::repeat_out(output, self, repeats);
1818 };
1819 });
1820
__anon11f46a8b5202(Node* n) 1821 REGISTER_OPERATOR_FUNCTOR(aten::max, aten_max, [](Node* n) -> SROperator {
1822 if (n->matches(torch::schema(
1823 "aten::max.other(Tensor self, Tensor other) -> Tensor"))) {
1824 return [](ProcessedNode* p_node) {
1825 const auto& self = p_node->Input(0).toTensor();
1826 const auto& other = p_node->Input(1).toTensor();
1827 if (p_node->Output(0).isNone()) {
1828 p_node->Output(0) = at::native::max(self, other);
1829 return;
1830 }
1831 auto& out = p_node->Output(0).toTensor();
1832 fastResizeToZero(out);
1833 at::native::max_out(self, other, out);
1834 };
1835 }
1836
1837 if (n->matches(torch::schema(
1838 "aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)"))) {
1839 return [](ProcessedNode* p_node) {
1840 const auto& self = p_node->Input(0).toTensor();
1841 auto dim = p_node->Input(1).toInt();
1842 const auto keepdim = p_node->Input(2).toBool();
1843
1844 if (p_node->Output(0).isNone()) {
1845 p_node->Output(0) = create_empty_from(self);
1846 }
1847
1848 if (p_node->Output(1).isNone()) {
1849 p_node->Output(1) = create_empty_from(self, at::kLong);
1850 }
1851
1852 auto& values = p_node->Output(0).toTensor();
1853 auto& indices = p_node->Output(1).toTensor();
1854 fastResizeToZero(values);
1855 fastResizeToZero(indices);
1856 at::cpu::max_out(values, indices, self, dim, keepdim);
1857 };
1858 }
1859
1860 if (n->matches(torch::schema("aten::max(Tensor self) -> Tensor"))) {
1861 return [](ProcessedNode* p_node) {
1862 const auto& self = p_node->Input(0).toTensor();
1863 if (p_node->Output(0).isNone()) {
1864 p_node->Output(0) = create_empty_from(self);
1865 }
1866 auto& value = p_node->Output(0).toTensor();
1867 fastResizeToZero(value);
1868 at::cpu::amax_out(value, self);
1869 };
1870 }
1871
1872 LogAndDumpSchema(n);
1873 return nullptr;
1874 });
1875
__anon11f46a8b5602(Node* n) 1876 REGISTER_OPERATOR_FUNCTOR(aten::sign, aten_sign, [](Node* n) -> SROperator {
1877 if (!n->matches(torch::schema("aten::sign.Tensor(Tensor input) -> Tensor"))) {
1878 LogAndDumpSchema(n);
1879 return nullptr;
1880 }
1881 return [](ProcessedNode* p_node) {
1882 const auto& in0_t = p_node->Input(0).toTensor();
1883 if (p_node->Output(0).isNone()) {
1884 p_node->Output(0) = at::cpu::sign(in0_t);
1885 return;
1886 }
1887 auto& out_t = p_node->Output(0).toTensor();
1888 fastResizeToZero(out_t);
1889 at::cpu::sign_out(out_t, in0_t);
1890 };
1891 });
1892
__anon11f46a8b5802(Node* n) 1893 REGISTER_OPERATOR_FUNCTOR(aten::div, aten_div, [](Node* n) -> SROperator {
1894 if (!n->matches(torch::schema(
1895 "aten::div.Tensor(Tensor self, Tensor other) -> Tensor")) &&
1896 !n->matches(torch::schema(
1897 "aten::div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor")) &&
1898 !n->matches(torch::schema(
1899 "aten::div.Scalar(Tensor self, Scalar other) -> Tensor")) &&
1900 !n->matches(torch::schema(
1901 "aten::div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor"))) {
1902 LogAndDumpSchema(n);
1903 return nullptr;
1904 }
1905
1906 return [te = createDiv()](ProcessedNode* p_node) {
1907 const auto& in0_t = p_node->Input(0).toTensor();
1908 std::optional<c10::string_view> rounding_mode = std::nullopt;
1909 if (p_node->num_inputs() > 2) {
1910 rounding_mode = p_node->Input(2).toOptional<c10::string_view>();
1911 }
1912 const auto& in1_t = p_node->Input(1).isTensor()
1913 ? p_node->Input(1).toTensor()
1914 : at::native::wrapped_scalar_tensor(p_node->Input(1).toScalar());
1915
1916 if (p_node->Output(0).isNone()) {
1917 p_node->Output(0) = create_empty_from(in0_t);
1918 }
1919 auto& out_t = p_node->Output(0).toTensor();
1920
1921 if (in0_t.sizes() == in1_t.sizes() &&
1922 in0_t.scalar_type() == in1_t.scalar_type() &&
1923 in0_t.strides() == in1_t.strides() && in0_t.is_contiguous() &&
1924 in0_t.scalar_type() == at::kFloat) {
1925 int64_t dim = in0_t.numel();
1926 int i_rounding_mode = 0;
1927 if (rounding_mode && !rounding_mode.value().empty()) {
1928 const char peek_rounding_mode = rounding_mode.value().at(0);
1929 if (peek_rounding_mode == 't') {
1930 // trunc after div
1931 i_rounding_mode = 1;
1932 } else if (peek_rounding_mode == 'f') {
1933 // floor after div
1934 i_rounding_mode = 2;
1935 }
1936 }
1937 at::native::resize_(out_t, in0_t.sizes());
1938 te->call(
1939 {out_t.data_ptr(),
1940 in0_t.data_ptr(),
1941 in1_t.data_ptr(),
1942 &i_rounding_mode,
1943 &dim});
1944 } else {
1945 fastResizeToZero(out_t);
1946 at::cpu::div_out(out_t, in0_t, in1_t, rounding_mode);
1947 }
1948 };
1949 });
1950
__anon11f46a8b5a02(Node* n) 1951 REGISTER_OPERATOR_FUNCTOR(aten::log, aten_log, [](Node* n) -> SROperator {
1952 if (!n->matches(torch::schema("aten::log.Tensor(Tensor input) -> Tensor"))) {
1953 LogAndDumpSchema(n);
1954 return nullptr;
1955 }
1956 return [](ProcessedNode* p_node) {
1957 const auto& in0_t = p_node->Input(0).toTensor();
1958 if (p_node->Output(0).isNone()) {
1959 p_node->Output(0) = at::cpu::log(in0_t);
1960 return;
1961 }
1962 auto& out_t = p_node->Output(0).toTensor();
1963 fastResizeToZero(out_t);
1964 at::cpu::log_out(out_t, in0_t);
1965 };
1966 });
1967
__anon11f46a8b5c02(Node* n) 1968 REGISTER_OPERATOR_FUNCTOR(aten::sub, aten_sub, [](Node* n) -> SROperator {
1969 if (n->matches(torch::schema(
1970 "aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"))) {
1971 return [](ProcessedNode* p_node) {
1972 const auto& in0_t = p_node->Input(0).toTensor();
1973 const auto& in1_t = p_node->Input(1).toTensor();
1974 const auto alpha = p_node->Input(2).toScalar();
1975 if (p_node->Output(0).isNone()) {
1976 p_node->Output(0) = at::cpu::sub(in0_t, in1_t, alpha);
1977 return;
1978 }
1979 auto& out_t = p_node->Output(0).toTensor();
1980 fastResizeToZero(out_t);
1981 at::cpu::sub_out(out_t, in0_t, in1_t, alpha);
1982 };
1983 }
1984 if (n->matches(torch::schema(
1985 "aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor"))) {
1986 return [](ProcessedNode* p_node) {
1987 const auto& in0_t = p_node->Input(0).toTensor();
1988 const auto& in1_t =
1989 at::native::wrapped_scalar_tensor(p_node->Input(1).toScalar());
1990 const auto alpha = p_node->Input(2).toScalar();
1991 if (p_node->Output(0).isNone()) {
1992 p_node->Output(0) = at::cpu::sub(in0_t, in1_t, alpha);
1993 return;
1994 }
1995 auto& out_t = p_node->Output(0).toTensor();
1996 fastResizeToZero(out_t);
1997 at::cpu::sub_out(out_t, in0_t, in1_t, alpha);
1998 };
1999 }
2000 LogAndDumpSchema(n);
2001 return nullptr;
2002 });
2003
2004 // TODO: support clamp_min.Tensor(Tensor self, Tensor min) -> Tensor
2005 REGISTER_OPERATOR_FUNCTOR(
2006 aten::clamp_min,
2007 aten_clamp_min,
__anon11f46a8b5f02(Node* n) 2008 [](Node* n) -> SROperator {
2009 if (!n->matches(torch::schema(
2010 "aten::clamp_min(Tensor self, Scalar min) -> Tensor"))) {
2011 LogAndDumpSchema(n);
2012 return nullptr;
2013 }
2014 return [](ProcessedNode* p_node) {
2015 const auto& in0_t = p_node->Input(0).toTensor();
2016 const auto in1_s = p_node->Input(1).toScalar();
2017 if (p_node->Output(0).isNone()) {
2018 p_node->Output(0) = at::cpu::clamp_min(in0_t, in1_s);
2019 return;
2020 }
2021 auto& out_t = p_node->Output(0).toTensor();
2022 fastResizeToZero(out_t);
2023 at::cpu::clamp_min_out(out_t, in0_t, in1_s);
2024 };
2025 });
2026
__anon11f46a8b6102(Node* n) 2027 REGISTER_OPERATOR_FUNCTOR(aten::argmin, aten_argmin, [](Node* n) -> SROperator {
2028 if (!n->matches(torch::schema(
2029 "aten::argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"))) {
2030 LogAndDumpSchema(n);
2031 return nullptr;
2032 }
2033 return [](ProcessedNode* p_node) {
2034 const auto& in0_t = p_node->Input(0).toTensor();
2035 const auto dim = p_node->Input(1).toOptional<int64_t>();
2036 const auto keepdim = p_node->Input(2).toBool();
2037 if (p_node->Output(0).isNone()) {
2038 p_node->Output(0) = at::cpu::argmin(in0_t, dim, keepdim);
2039 return;
2040 }
2041 auto& out_t = p_node->Output(0).toTensor();
2042 fastResizeToZero(out_t);
2043 if (in0_t.is_contiguous() && dim.has_value()) {
2044 at::native::c2_argmin_out(out_t, in0_t, dim.value(), keepdim);
2045 return;
2046 }
2047 at::cpu::argmin_out(out_t, in0_t, dim, keepdim);
2048 };
2049 });
2050
__anon11f46a8b6302(Node* n) 2051 REGISTER_OPERATOR_FUNCTOR(aten::softmax, aten_softmax, [](Node* n) -> SROperator {
2052 if (!n->matches(torch::schema(
2053 "aten::softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"))) {
2054 LogAndDumpSchema(n);
2055 return nullptr;
2056 }
2057 return [](ProcessedNode* p_node) {
2058 const auto& in_t = p_node->Input(0).toTensor();
2059 const auto& dim = p_node->Input(1).toInt();
2060 const auto& dtype = p_node->Input(2).toOptional<c10::ScalarType>();
2061 if (p_node->Output(0).isNone()) {
2062 p_node->Output(0) = at::native::softmax(in_t, dim, dtype);
2063 return;
2064 }
2065 auto& out_t = p_node->Output(0).toTensor();
2066 fastResizeToZero(out_t);
2067 auto half_to_float = in_t.scalar_type() == at::ScalarType::Half &&
2068 dtype == at::ScalarType::Float;
2069 at::cpu::_softmax_out(out_t, in_t, dim, half_to_float);
2070 };
2071 });
2072
2073 namespace {
2074
borrow_from_optional_tensor_ivalue(const IValue & iv)2075 c10::MaybeOwned<at::Tensor> borrow_from_optional_tensor_ivalue(
2076 const IValue& iv) {
2077 if (iv.isNone()) {
2078 return c10::MaybeOwned<at::Tensor>::owned(std::in_place);
2079 }
2080 return c10::MaybeOwned<at::Tensor>::borrowed(iv.toTensor());
2081 }
2082
2083 } // namespace
2084
__anon11f46a8b6602(Node* n) 2085 REGISTER_OPERATOR_FUNCTOR(aten::layer_norm, aten_layer_norm, [](Node* n) -> SROperator {
2086 if (!sr_schema_check(
2087 n,
2088 "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor")) {
2089 return nullptr;
2090 }
2091 return [](ProcessedNode* p_node) {
2092 // ignore Input(5): `bool cudnn_enable=True`
2093 const auto& input = p_node->Input(0).toTensor();
2094 const auto normalized_shape = p_node->Input(1).toDimVector();
2095 float eps = p_node->Input(4).toDouble();
2096
2097 c10::MaybeOwned<at::Tensor> weight_maybe_owned =
2098 borrow_from_optional_tensor_ivalue(p_node->Input(2));
2099 const at::Tensor& weight = *weight_maybe_owned;
2100 c10::MaybeOwned<at::Tensor> bias_maybe_owned =
2101 borrow_from_optional_tensor_ivalue(p_node->Input(3));
2102 const at::Tensor& bias = *bias_maybe_owned;
2103
2104 auto M_N = at::native::_check_layer_norm_inputs(
2105 input, normalized_shape, weight, bias);
2106 auto M = M_N.first;
2107 auto N = M_N.second;
2108 auto X = input.expect_contiguous();
2109 auto gamma = weight.expect_contiguous();
2110 auto beta = bias.expect_contiguous();
2111
2112 if (p_node->Output(0).isNone()) {
2113 p_node->Output(0) = at::native::empty_like(
2114 *X,
2115 std::nullopt /* dtype */,
2116 std::nullopt /* layout */,
2117 std::nullopt /* device */,
2118 std::nullopt /* pin_memory */,
2119 at::MemoryFormat::Contiguous);
2120 } else {
2121 at::native::resize_(
2122 p_node->Output(0).toTensor(), X->sizes(), std::nullopt);
2123 }
2124 at::Tensor& output = p_node->Output(0).toTensor();
2125 at::native::layer_norm_cpu_out(output, *X, *gamma, *beta, eps, M, N);
2126 };
2127 });
2128
__anon11f46a8b6802(Node* n) 2129 REGISTER_OPERATOR_FUNCTOR(aten::norm, aten_norm, [](Node* n) -> SROperator {
2130 if (n->matches(torch::schema(
2131 "aten::norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor"))) {
2132 return [](ProcessedNode* p_node) {
2133 const auto& in0_t = p_node->Input(0).toTensor();
2134 if (p_node->Output(0).isNone()) {
2135 p_node->Output(0) = create_empty_from(in0_t);
2136 }
2137 auto& out_t = p_node->Output(0).toTensor();
2138 fastResizeToZero(out_t);
2139 const auto in1_s = p_node->Input(1).toOptional<at::Scalar>();
2140 at::cpu::norm_outf(
2141 in0_t,
2142 in1_s,
2143 c10::IntArrayRef{},
2144 false,
2145 p_node->Input(2).toScalarType(),
2146 out_t);
2147 };
2148 }
2149 if (n->matches(torch::schema(
2150 "aten::norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor"))) {
2151 return [](ProcessedNode* p_node) {
2152 const auto& in0_t = p_node->Input(0).toTensor();
2153
2154 if (p_node->Output(0).isNone()) {
2155 p_node->Output(0) = create_empty_from(in0_t);
2156 }
2157 auto& out_t = p_node->Output(0).toTensor();
2158 fastResizeToZero(out_t);
2159
2160 const auto in1_s = p_node->Input(1).toOptional<at::Scalar>();
2161 at::cpu::norm_outf(
2162 in0_t,
2163 in1_s,
2164 p_node->Input(2).toDimVector(), // dim
2165 p_node->Input(3).toBool(), // keepdim
2166 p_node->Input(4).toScalarType(), // dtype
2167 out_t);
2168 };
2169 }
2170 if (n->matches(torch::schema(
2171 "aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor"))) {
2172 return [](ProcessedNode* p_node) {
2173 const auto& in0_t = p_node->Input(0).toTensor();
2174
2175 if (p_node->Output(0).isNone()) {
2176 p_node->Output(0) = create_empty_from(in0_t);
2177 }
2178 auto& out_t = p_node->Output(0).toTensor();
2179 fastResizeToZero(out_t);
2180
2181 const auto in1_s = p_node->Input(1).toOptional<at::Scalar>();
2182 at::cpu::norm_outf(
2183 in0_t,
2184 in1_s,
2185 p_node->Input(2).toDimVector(), // dim
2186 p_node->Input(3).toBool(), // keepdim
2187 out_t);
2188 };
2189 }
2190 LogAndDumpSchema(n);
2191 return nullptr;
2192 });
2193
__anon11f46a8b6c02(Node* n) 2194 REGISTER_OPERATOR_FUNCTOR(aten::matmul, aten_matmul, [](Node* n) -> SROperator {
2195 if (!n->matches(
2196 torch::schema("aten::matmul(Tensor self, Tensor other) -> Tensor"))) {
2197 LogAndDumpSchema(n);
2198 return nullptr;
2199 }
2200 return [](ProcessedNode* p_node) {
2201 const auto& in0_t = p_node->Input(0).toTensor();
2202 const auto& in1_t = p_node->Input(1).toTensor();
2203
2204 if (p_node->Output(0).isNone()) {
2205 p_node->Output(0) = at::native::matmul(in0_t, in1_t);
2206 return;
2207 }
2208 auto& out_t = p_node->Output(0).toTensor();
2209 fastResizeToZero(out_t);
2210 at::native::matmul_out(in0_t, in1_t, out_t);
2211 };
2212 });
2213
__anon11f46a8b6e02(Node* n) 2214 REGISTER_OPERATOR_FUNCTOR(quantized::linear, quantized_linear, [](Node* n) -> SROperator {
2215 if (!n->matches(torch::schema(
2216 "quantized::linear(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y"))) {
2217 LogAndDumpSchema(n);
2218 return nullptr;
2219 }
2220 const auto w = toIValue(n->inputs()[1]);
2221 c10::intrusive_ptr<LinearPackedParamsBase> packed_weight;
2222 if (w) {
2223 packed_weight = w->toCustomClass<LinearPackedParamsBase>();
2224 }
2225 return [packed_weight](ProcessedNode* p_node) {
2226 const auto& input = p_node->Input(0).toTensor();
2227 const auto output_scale = p_node->Input(2).toDouble();
2228 const auto output_zero_point = p_node->Input(3).toInt();
2229
2230 if (p_node->Output(0).isNone()) {
2231 p_node->Output(0) = at::native::empty_affine_quantized(
2232 {0},
2233 c10::kQUInt8,
2234 std::nullopt,
2235 c10::kCPU,
2236 false,
2237 output_scale,
2238 output_zero_point,
2239 std::nullopt);
2240 }
2241 auto& out_t = p_node->Output(0).toTensor();
2242 fastResizeToZero(out_t);
2243
2244 if (packed_weight) {
2245 packed_weight->apply_out(input, output_scale, output_zero_point, out_t);
2246 } else {
2247 // Weights could be quantized on the fly
2248 auto packed_weight_tmp =
2249 p_node->Input(1).toCustomClass<LinearPackedParamsBase>();
2250 packed_weight_tmp->apply_out(
2251 input, output_scale, output_zero_point, out_t);
2252 }
2253 };
2254 });
2255
2256 REGISTER_OPERATOR_FUNCTOR(
2257 fb::quantized_linear,
2258 fb_quantized_linear,
__anon11f46a8b7002(Node* n) 2259 [](Node* n) -> SROperator {
2260 if (!n->matches(torch::schema(
2261 "fb::quantized_linear(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase w_prepack, Tensor Y_scale_i, Tensor Y_zero_point_i) -> Tensor"))) {
2262 LogAndDumpSchema(n);
2263 return nullptr;
2264 }
2265 const auto w = toIValue(n->inputs()[1]);
2266 c10::intrusive_ptr<LinearPackedParamsBase> packed_weight;
2267 if (w) {
2268 packed_weight = w->toCustomClass<LinearPackedParamsBase>();
2269 }
2270 return [packed_weight](ProcessedNode* p_node) {
2271 const auto& input = p_node->Input(0).toTensor();
2272 const auto output_scale = p_node->Input(2).toTensor().item().toFloat();
2273 const auto output_zero_point =
2274 p_node->Input(3).toTensor().item().toLong();
2275
2276 if (p_node->Output(0).isNone()) {
2277 p_node->Output(0) = at::native::empty_affine_quantized(
2278 {0},
2279 c10::kQUInt8,
2280 std::nullopt,
2281 c10::kCPU,
2282 false,
2283 output_scale,
2284 output_zero_point,
2285 std::nullopt);
2286 }
2287 auto& out_t = p_node->Output(0).toTensor();
2288 fastResizeToZero(out_t);
2289
2290 if (packed_weight) {
2291 packed_weight->apply_out(
2292 input, output_scale, output_zero_point, out_t);
2293 } else {
2294 // Weights could be quantized on the fly
2295 auto packed_weight_tmp =
2296 p_node->Input(1).toCustomClass<LinearPackedParamsBase>();
2297 packed_weight_tmp->apply_out(
2298 input, output_scale, output_zero_point, out_t);
2299 }
2300 };
2301 });
2302
2303 namespace {
2304
2305 template <bool has_relu>
2306 void apply_dynamic_out_functor(
2307 c10::intrusive_ptr<LinearPackedParamsBase> packed_weight,
2308 const at::Tensor& input,
2309 at::Tensor& out,
2310 bool reduce_range);
2311
2312 template <>
apply_dynamic_out_functor(c10::intrusive_ptr<LinearPackedParamsBase> packed_weight,const at::Tensor & input,at::Tensor & out,bool reduce_range)2313 void apply_dynamic_out_functor<false>(
2314 c10::intrusive_ptr<LinearPackedParamsBase> packed_weight,
2315 const at::Tensor& input,
2316 at::Tensor& out,
2317 bool reduce_range) {
2318 packed_weight->apply_dynamic_out(input, out, reduce_range);
2319 }
2320
2321 template <>
apply_dynamic_out_functor(c10::intrusive_ptr<LinearPackedParamsBase> packed_weight,const at::Tensor & input,at::Tensor & out,bool reduce_range)2322 void apply_dynamic_out_functor<true>(
2323 c10::intrusive_ptr<LinearPackedParamsBase> packed_weight,
2324 const at::Tensor& input,
2325 at::Tensor& out,
2326 bool reduce_range) {
2327 // The implementation of PackedLinearWeightFp16::apply_dynamic_impl does not
2328 // handle relu. Currently, it ignores the `ReluFused` template parameter.
2329 // So, we explicitly do the relu here.
2330 packed_weight->apply_dynamic_out(input, out, reduce_range);
2331 out.relu_();
2332 }
2333
2334 template <bool has_relu>
quantized_linear_dynamic_fp16_impl(Node * n)2335 SROperator quantized_linear_dynamic_fp16_impl(Node* n) {
2336 const auto weight = toIValue(n->inputs()[1]);
2337 c10::intrusive_ptr<LinearPackedParamsBase> packed_weight;
2338 if (weight) {
2339 packed_weight = weight->toCustomClass<LinearPackedParamsBase>();
2340 }
2341 if (packed_weight) {
2342 return [packed_weight](ProcessedNode* p_node) {
2343 const auto& input = p_node->Input(0).toTensor();
2344 if (p_node->Output(0).isNone()) {
2345 p_node->Output(0) = create_empty_from(input, at::kFloat);
2346 }
2347 auto& out_t = p_node->Output(0).toTensor();
2348 fastResizeToZero(out_t);
2349 apply_dynamic_out_functor<has_relu>(packed_weight, input, out_t, false);
2350 };
2351 } else {
2352 return [](ProcessedNode* p_node) {
2353 const auto& input = p_node->Input(0).toTensor();
2354 if (p_node->Output(0).isNone()) {
2355 p_node->Output(0) = create_empty_from(input, at::kFloat);
2356 }
2357 auto& out_t = p_node->Output(0).toTensor();
2358 fastResizeToZero(out_t);
2359 // Weights could be quantized on the fly
2360 auto packed_weight_tmp =
2361 p_node->Input(1).toCustomClass<LinearPackedParamsBase>();
2362 apply_dynamic_out_functor<has_relu>(
2363 packed_weight_tmp, input, out_t, false);
2364 };
2365 }
2366 }
2367
2368 } // namespace
2369
2370 REGISTER_OPERATOR_FUNCTOR(
2371 quantized::linear_dynamic_fp16,
2372 quantized_linear_dynamic_fp16,
__anon11f46a8b7502(Node* n) 2373 [](Node* n) -> SROperator {
2374 if (!n->matches(torch::schema(
2375 "quantized::linear_dynamic_fp16(Tensor X, __torch__.torch.classes."
2376 "quantized.LinearPackedParamsBase W_prepack) -> Tensor Y"))) {
2377 LogAndDumpSchema(n);
2378 return nullptr;
2379 }
2380 return quantized_linear_dynamic_fp16_impl<false>(n);
2381 });
2382
2383 REGISTER_OPERATOR_FUNCTOR(
2384 quantized::linear_relu_dynamic_fp16,
2385 quantized_linear_relu_dynamic_fp16,
__anon11f46a8b7602(Node* n) 2386 [](Node* n) -> SROperator {
2387 if (!n->matches(torch::schema(
2388 "quantized::linear_relu_dynamic_fp16(Tensor X, __torch__.torch.classes."
2389 "quantized.LinearPackedParamsBase W_prepack) -> Tensor Y"))) {
2390 LogAndDumpSchema(n);
2391 return nullptr;
2392 }
2393 return quantized_linear_dynamic_fp16_impl<true>(n);
2394 });
2395
2396 // device & pin_memory matter only when CUDA is enabled.
hasTensorWithOptions(const IValue & ivalue,std::optional<c10::ScalarType> dtype,std::optional<c10::Layout> layout)2397 static bool hasTensorWithOptions(
2398 const IValue& ivalue,
2399 std::optional<c10::ScalarType> dtype,
2400 std::optional<c10::Layout> layout) {
2401 if (!ivalue.isTensor()) {
2402 return false;
2403 }
2404 const auto& tensor = ivalue.toTensor();
2405 if (dtype == tensor.dtype().toScalarType() &&
2406 layout == tensor.options().layout_opt()) {
2407 return true;
2408 }
2409 VLOG(1) << "tensor exists, but tensor options were different";
2410 return false;
2411 }
2412
hasTensorWithOptions(const IValue & ivalue,std::optional<c10::ScalarType> dtype,std::optional<c10::Layout> layout,std::optional<c10::MemoryFormat> memory_format)2413 static bool hasTensorWithOptions(
2414 const IValue& ivalue,
2415 std::optional<c10::ScalarType> dtype,
2416 std::optional<c10::Layout> layout,
2417 std::optional<c10::MemoryFormat> memory_format) {
2418 return hasTensorWithOptions(ivalue, dtype, layout) &&
2419 (memory_format == ivalue.toTensor().options().memory_format_opt());
2420 }
2421
__anon11f46a8b7702(Node* n) 2422 REGISTER_OPERATOR_FUNCTOR(aten::full, aten_full, [](Node* n) -> SROperator {
2423 if (!n->matches(torch::schema(
2424 "aten::full(int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"))) {
2425 LogAndDumpSchema(n);
2426 return nullptr;
2427 }
2428 return [](ProcessedNode* p_node) {
2429 const auto& size = p_node->Input(0).toDimVector();
2430 const auto fill_value = p_node->Input(1).toScalar();
2431 const auto dtype = p_node->Input(2).toOptional<c10::ScalarType>();
2432 const auto layout = p_node->Input(3).toOptional<c10::Layout>();
2433 if (!hasTensorWithOptions(p_node->Output(0), dtype, layout)) {
2434 const auto device = p_node->Input(4).toOptional<c10::Device>();
2435 const auto pin_memory = p_node->Input(5).toOptional<bool>();
2436 p_node->Output(0) =
2437 at::native::full(size, fill_value, dtype, layout, device, pin_memory);
2438 return;
2439 }
2440 p_node->Output(0) =
2441 at::native::full_out(size, fill_value, p_node->Output(0).toTensor());
2442 };
2443 });
2444
__anon11f46a8b7902(Node* n) 2445 REGISTER_OPERATOR_FUNCTOR(aten::full_like, aten_full_like, [](Node* n) -> SROperator {
2446 if (!n->matches(torch::schema(
2447 "aten::full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"))) {
2448 LogAndDumpSchema(n);
2449 return nullptr;
2450 }
2451 return [](ProcessedNode* p_node) {
2452 const auto in1_s = p_node->Input(1).toScalar();
2453 const auto& in0_t = p_node->Input(0).toTensor();
2454 const auto dtype = p_node->Input(2).toOptional<c10::ScalarType>();
2455 const auto layout = p_node->Input(3).toOptional<c10::Layout>();
2456 if (!hasTensorWithOptions(p_node->Output(0), dtype, layout)) {
2457 const auto device = p_node->Input(4).toOptional<c10::Device>();
2458 const auto pin_memory = p_node->Input(5).toOptional<bool>();
2459 const auto memory_format =
2460 p_node->Input(6).toOptional<c10::MemoryFormat>();
2461
2462 p_node->Output(0) = at::native::empty_like(
2463 in0_t, dtype, layout, device, pin_memory, memory_format);
2464 }
2465 auto& out_t = p_node->Output(0).toTensor();
2466 at::native::resize_(out_t, in0_t.sizes(), std::nullopt);
2467 at::native::fill_out(out_t, in1_s);
2468 };
2469 });
2470
__anon11f46a8b7b02(Node* n) 2471 REGISTER_OPERATOR_FUNCTOR(aten::ones, aten_ones, [](Node* n) -> SROperator {
2472 if (!n->matches(torch::schema(
2473 "aten::ones(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"))) {
2474 LogAndDumpSchema(n);
2475 return nullptr;
2476 }
2477 return [](ProcessedNode* p_node) {
2478 const auto size = p_node->Input(0).toDimVector();
2479 if (p_node->Output(0).isNone()) {
2480 const auto dtype = p_node->Input(1).toOptional<c10::ScalarType>();
2481 const auto layout = p_node->Input(2).toOptional<c10::Layout>();
2482 const auto device = p_node->Input(3).toOptional<c10::Device>();
2483 const auto pin_memory = p_node->Input(4).toOptional<bool>();
2484 p_node->Output(0) =
2485 at::native::ones(size, dtype, layout, device, pin_memory);
2486 return;
2487 }
2488 auto& out_t = p_node->Output(0).toTensor();
2489 fastResizeToZero(out_t);
2490 at::native::ones_out(size, out_t);
2491 };
2492 });
2493
__anon11f46a8b7d02(Node* n) 2494 REGISTER_OPERATOR_FUNCTOR(aten::ones_like, aten_ones_like, [](Node* n) -> SROperator {
2495 if (!n->matches(torch::schema(
2496 "aten::ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"))) {
2497 LogAndDumpSchema(n);
2498 return nullptr;
2499 }
2500 return [](ProcessedNode* p_node) {
2501 const auto& self = p_node->Input(0).toTensor();
2502 const auto dtype = p_node->Input(1).toOptional<c10::ScalarType>();
2503 const auto layout = p_node->Input(2).toOptional<c10::Layout>();
2504 const auto device = p_node->Input(3).toOptional<c10::Device>();
2505 const auto pin_memory = p_node->Input(4).toOptional<bool>();
2506 const auto memory_format = p_node->Input(5).toOptional<c10::MemoryFormat>();
2507 if (!hasTensorWithOptions(
2508 p_node->Output(0), dtype, layout, memory_format)) {
2509 p_node->Output(0) = at::native::ones_like(
2510 self, dtype, layout, device, pin_memory, memory_format);
2511 return;
2512 }
2513 auto& out_t = p_node->Output(0).toTensor();
2514 fastResizeToZero(out_t);
2515 at::native::ones_out(self.sizes(), out_t);
2516 };
2517 });
2518
__anon11f46a8b7f02(Node* n) 2519 REGISTER_OPERATOR_FUNCTOR(aten::zeros, aten_zeros, [](Node* n) -> SROperator {
2520 if (!n->matches(torch::schema(
2521 "aten::zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"))) {
2522 LogAndDumpSchema(n);
2523 return nullptr;
2524 }
2525 return [](ProcessedNode* p_node) {
2526 const auto size = p_node->Input(0).toDimVector();
2527 const auto dtype = p_node->Input(1).toOptional<c10::ScalarType>();
2528 const auto layout = p_node->Input(2).toOptional<c10::Layout>();
2529 if (!hasTensorWithOptions(p_node->Output(0), dtype, layout)) {
2530 p_node->Output(0) = at::compositeexplicitautograd::zeros(
2531 size, dtype, layout, std::nullopt, std::nullopt);
2532 return;
2533 }
2534 auto& out_t = p_node->Output(0).toTensor();
2535 fastResizeToZero(out_t);
2536 at::compositeexplicitautograd::zeros_out(out_t, size);
2537 };
2538 });
2539
__anon11f46a8b8102(Node* n) 2540 REGISTER_OPERATOR_FUNCTOR(aten::linear, aten_linear, [](Node* n) -> SROperator {
2541 if (!n->matches(torch::schema(
2542 "aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor"))) {
2543 LogAndDumpSchema(n);
2544 return nullptr;
2545 }
2546
2547 return [](ProcessedNode* p_node) {
2548 const auto& in0_t = p_node->Input(0).toTensor();
2549 const auto& in1_t = p_node->Input(1).toTensor();
2550 auto in2_t = p_node->Input(2).toOptional<at::Tensor>();
2551
2552 if (p_node->Output(0).isNone()) {
2553 p_node->Output(0) = at::native::linear(in0_t, in1_t, in2_t);
2554 return;
2555 }
2556 auto& out_t = p_node->Output(0).toTensor();
2557 fastResizeToZero(out_t);
2558 at::native::linear_out(out_t, in0_t, in1_t, in2_t);
2559 };
2560 });
2561
__anon11f46a8b8302(Node* n) 2562 REGISTER_OPERATOR_FUNCTOR(aten::linalg_norm, aten_linalg_norm, [](Node* n) -> SROperator {
2563 if (n->matches(torch::schema(
2564 "aten::linalg_norm(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"))) {
2565 return [](ProcessedNode* p_node) {
2566 const auto& input = p_node->Input(0).toTensor();
2567 const auto dim = p_node->Input(2).toDimVector();
2568 const auto keepdim = p_node->Input(3).toBool();
2569 const auto dtype = p_node->Input(4).toOptional<c10::ScalarType>();
2570 if (p_node->Output(0).isNone()) {
2571 p_node->Output(0) = at::native::linalg_norm(
2572 input,
2573 p_node->Input(1).toOptional<at::Scalar>(),
2574 dim,
2575 keepdim,
2576 dtype);
2577 return;
2578 }
2579 auto& output = p_node->Output(0).toTensor();
2580 fastResizeToZero(output);
2581 at::native::linalg_norm_out(
2582 input,
2583 p_node->Input(1).toOptional<at::Scalar>(),
2584 dim,
2585 keepdim,
2586 dtype,
2587 output);
2588 };
2589 }
2590 if (n->matches(torch::schema(
2591 "aten::linalg_norm.ord_str(Tensor self, str ord, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"))) {
2592 return [](ProcessedNode* p_node) {
2593 const auto& input = p_node->Input(0).toTensor();
2594 const auto dim = p_node->Input(2).toDimVector();
2595 const auto keepdim = p_node->Input(3).toBool();
2596 const auto dtype = p_node->Input(4).toOptional<c10::ScalarType>();
2597 if (p_node->Output(0).isNone()) {
2598 p_node->Output(0) = at::native::linalg_norm(
2599 input, p_node->Input(1).toStringView(), dim, keepdim, dtype);
2600 return;
2601 }
2602 auto& output = p_node->Output(0).toTensor();
2603 fastResizeToZero(output);
2604 at::native::linalg_norm_out(
2605 input, p_node->Input(1).toStringRef(), dim, keepdim, dtype, output);
2606 };
2607 }
2608 LogAndDumpSchema(n);
2609 return nullptr;
2610 });
2611
__anon11f46a8b8602(Node* n) 2612 REGISTER_OPERATOR_FUNCTOR(aten::cat, aten_cat, [](Node* n) -> SROperator {
2613 if (!n->matches(
2614 torch::schema("aten::cat(Tensor[] tensors, int dim=0) -> Tensor"))) {
2615 LogAndDumpSchema(n);
2616 return nullptr;
2617 }
2618 return [](ProcessedNode* p_node) {
2619 const auto inputs = p_node->Input(0).toTensorVector();
2620 TORCH_CHECK(!inputs.empty(), "concat expects non-empty tensor list");
2621 const auto dim = p_node->Input(1).toInt();
2622 if (p_node->Output(0).isNone()) {
2623 p_node->Output(0) = at::cpu::cat(inputs, dim);
2624 return;
2625 }
2626 auto& output = p_node->Output(0).toTensor();
2627 fastResizeToZero(output);
2628 at::cpu::cat_outf(inputs, dim, output);
2629 };
2630 });
2631
__anon11f46a8b8802(Node* n) 2632 REGISTER_OPERATOR_FUNCTOR(aten::cumsum, aten_cumsum, [](Node* n) -> SROperator {
2633 if (!n->matches(torch::schema(
2634 "aten::cumsum(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"))) {
2635 LogAndDumpSchema(n);
2636 return nullptr;
2637 }
2638 return [](ProcessedNode* p_node) {
2639 const auto& input = p_node->Input(0).toTensor();
2640 const auto dim = p_node->Input(1).toInt();
2641 const auto dtype = p_node->Input(2).toOptional<c10::ScalarType>();
2642 if (p_node->Output(0).isNone()) {
2643 p_node->Output(0) = at::cpu::cumsum(input, dim, dtype);
2644 return;
2645 }
2646 auto& output = p_node->Output(0).toTensor();
2647 fastResizeToZero(output);
2648 at::cpu::cumsum_out(output, input, dim, dtype);
2649 };
2650 });
2651
2652 REGISTER_OPERATOR_FUNCTOR(
2653 aten::nonzero,
2654 aten_nonzero,
__anon11f46a8b8a02(Node* n) 2655 [](Node* n) -> SROperator {
2656 if (!n->matches(torch::schema("aten::nonzero(Tensor self) -> Tensor"))) {
2657 LogAndDumpSchema(n);
2658 return nullptr;
2659 }
2660 return [](ProcessedNode* p_node) {
2661 const auto& input = p_node->Input(0).toTensor();
2662 if (p_node->Output(0).isNone()) {
2663 p_node->Output(0) = at::native::nonzero_cpu(input);
2664 return;
2665 }
2666 auto& output = p_node->Output(0).toTensor();
2667 fastResizeToZero(output);
2668 at::native::nonzero_out_cpu(input, output);
2669 };
2670 });
2671
2672 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
2673 REGISTER_OPERATOR_FUNCTOR(
2674 prim::VarConcat,
2675 prim_VarConcat,
__anon11f46a8b8c02(Node* n) 2676 [](Node* n) -> SROperator {
2677 if (!sr_schema_check_kind(n, prim::VarConcat)) {
2678 return nullptr;
2679 }
2680 return [](ProcessedNode* p_node) {
2681 const size_t num_inputs = p_node->num_inputs();
2682 std::vector<at::Tensor> inputs(num_inputs - 1);
2683 for (const auto i : c10::irange(num_inputs - 1)) {
2684 inputs[i] = p_node->Input(i).toTensor();
2685 }
2686 auto dim = p_node->Input(num_inputs - 1).toInt();
2687 if (p_node->Output(0).isNone()) {
2688 p_node->Output(0) = at::cpu::cat(inputs, dim);
2689 return;
2690 }
2691 auto& out_t = p_node->Output(0).toTensor();
2692 fastResizeToZero(out_t);
2693 at::cpu::cat_outf(inputs, dim, out_t);
2694 };
2695 });
2696
2697 namespace {
2698 // This template and its specialization help us avoid compiler warnings
2699 // about taking the absolute value of an unsigned type in signed_log1p
2700 template <class T>
abs_if_signed(T val)2701 T abs_if_signed(T val) {
2702 return std::abs(val);
2703 }
2704
2705 template <>
abs_if_signed(unsigned char val)2706 unsigned char abs_if_signed<unsigned char>(unsigned char val) {
2707 return val;
2708 }
2709
2710 // Computes f(x) = sign(x) * ln(|1 + x|) for each x in the input tensor
signed_log1p_out(at::Tensor & out,const at::Tensor & input)2711 void signed_log1p_out(at::Tensor& out, const at::Tensor& input) {
2712 at::native::resize_(out, input.sizes(), std::nullopt);
2713
2714 const auto input_contig = input.expect_contiguous();
2715 auto output_contig = out.expect_contiguous();
2716
2717 AT_DISPATCH_ALL_TYPES(input.scalar_type(), "signed_log1p_kernel", [&]() {
2718 const auto input_data = input_contig->const_data_ptr<scalar_t>();
2719 auto output_data = output_contig->mutable_data_ptr<float>();
2720 const auto N = input.numel();
2721
2722 for (const auto i : c10::irange(N)) {
2723 const int sign = input_data[i] < 0 ? -1 : 1;
2724 output_data[i] = std::log1p(abs_if_signed(input_data[i])) * sign;
2725 }
2726 });
2727 }
2728
2729 } // namespace
2730
2731 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
2732 REGISTER_OPERATOR_FUNCTOR(
2733 static_runtime::signed_log1p,
2734 static_runtime_signed_log1p,
__anon11f46a8b9002(Node* n) 2735 [](Node* n) -> SROperator {
2736 if (!n->matches(torch::schema(
2737 "static_runtime::signed_log1p(Tensor x) -> Tensor"))) {
2738 LogAndDumpSchema(n);
2739 return nullptr;
2740 }
2741 auto te = createSignedLog1p();
2742 return [te](ProcessedNode* p_node) {
2743 const auto& input = p_node->Input(0).toTensor();
2744 if (p_node->Output(0).isNone()) {
2745 p_node->Output(0) = create_empty_from(input);
2746 }
2747 auto& out = p_node->Output(0).toTensor();
2748 if (!te || !te->checkInput<float>(input)) {
2749 fastResizeToZero(out);
2750 signed_log1p_out(out, input);
2751 return;
2752 }
2753 at::native::resize_(out, input.sizes(), std::nullopt);
2754 int64_t nn = input.numel();
2755 te->call({out.data_ptr(), input.data_ptr(), &nn});
2756 };
2757 });
2758
2759 REGISTER_OPERATOR_FUNCTOR(
2760 aten::remainder,
2761 aten_remainder,
__anon11f46a8b9202(Node* n) 2762 [](Node* n) -> SROperator {
2763 if (n->matches(torch::schema(
2764 "aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor"))) {
2765 return [](ProcessedNode* p_node) {
2766 const auto& self = p_node->Input(0).toTensor();
2767 if (p_node->Output(0).isNone()) {
2768 p_node->Output(0) =
2769 at::cpu::remainder(self, p_node->Input(1).toTensor());
2770 return;
2771 }
2772 auto& out = p_node->Output(0).toTensor();
2773 fastResizeToZero(out);
2774 at::cpu::remainder_out(out, self, p_node->Input(1).toTensor());
2775 };
2776 }
2777 if (n->matches(torch::schema(
2778 "aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor"))) {
2779 return [](ProcessedNode* p_node) {
2780 const auto& self = p_node->Input(0).toTensor();
2781 if (p_node->Output(0).isNone()) {
2782 p_node->Output(0) =
2783 at::native::remainder(self, p_node->Input(1).toScalar());
2784 return;
2785 }
2786 auto& out = p_node->Output(0).toTensor();
2787 fastResizeToZero(out);
2788 at::native::remainder_out(self, p_node->Input(1).toScalar(), out);
2789 };
2790 }
2791
2792 // Unrecognized overload
2793 LogAndDumpSchema(n);
2794 return nullptr;
2795 });
2796
__anon11f46a8b9502(Node* n) 2797 REGISTER_OPERATOR_FUNCTOR(aten::where, aten_where, [](Node* n) -> SROperator {
2798 if (n->matches(torch::schema(
2799 "aten::where.self(Tensor condition, Tensor self, Tensor other) -> Tensor"))) {
2800 return [](ProcessedNode* p_node) {
2801 const auto& cond = p_node->Input(0).toTensor();
2802 const auto& self = p_node->Input(1).toTensor();
2803 const auto& other = p_node->Input(2).toTensor();
2804
2805 if (p_node->Output(0).isNone()) {
2806 p_node->Output(0) = create_empty_from(self);
2807 }
2808 auto& out = p_node->Output(0).toTensor();
2809 fastResizeToZero(out);
2810 at::native::where_self_out(cond, self, other, out);
2811 };
2812 }
2813
2814 LogAndDumpSchema(n);
2815 return nullptr;
2816 });
2817
2818 REGISTER_OPERATOR_FUNCTOR(
2819 prim::NumToTensor,
2820 prim_NumToTensor,
__anon11f46a8b9702(Node* n) 2821 [](Node* n) -> SROperator {
2822 if (n->matches(
2823 torch::schema("prim::NumToTensor.Scalar(Scalar s) -> Tensor")) ||
2824 n->matches(
2825 torch::schema("prim::NumToTensor.bool(bool a) -> Tensor"))) {
2826 return [](ProcessedNode* pnode) {
2827 const auto scalar = pnode->Input(0).toScalar();
2828 if (pnode->Output(0).isNone()) {
2829 pnode->Output(0) = at::scalar_to_tensor(scalar);
2830 return;
2831 }
2832 auto& out = pnode->Output(0).toTensor();
2833 at::detail::scalar_fill(out, scalar);
2834 };
2835 }
2836 LogAndDumpSchema(n);
2837 return nullptr;
2838 });
2839
2840 REGISTER_OPERATOR_FUNCTOR(
2841 quantized::embedding_bag_byte_unpack,
2842 quantized_embedding_bag_byte_unpack,
__anon11f46a8b9902(Node* n) 2843 [](Node* n) -> SROperator {
2844 if (!sr_schema_check(
2845 n,
2846 "quantized::embedding_bag_byte_unpack(Tensor weight) -> Tensor")) {
2847 return nullptr;
2848 }
2849 return [](ProcessedNode* pnode) {
2850 auto& weight = pnode->Input(0).toTensor();
2851 if (pnode->Output(0).isNone()) {
2852 pnode->Output(0) = at::empty(
2853 {},
2854 weight.options().dtype(at::kFloat),
2855 weight.suggest_memory_format());
2856 }
2857 auto& out = pnode->Output(0).toTensor();
2858 at::native::qembeddingbag_byte_unpack_out(out, weight);
2859 };
2860 });
2861
2862 } // namespace torch::jit
2863