xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/codegen/fuser/interface.h>
2 #include <torch/csrc/jit/runtime/register_ops_utils.h>
3 
4 #include <ATen/core/ivalue.h>
5 #include <c10/util/ApproximateClock.h>
6 #include <c10/util/irange.h>
7 #include <torch/csrc/autograd/profiler.h>
8 #include <torch/csrc/jit/frontend/tracer.h>
9 
10 #include <algorithm>
11 #include <bitset>
12 #include <cctype>
13 #include <cmath>
14 #include <exception>
15 #include <fstream>
16 #include <iostream>
17 #include <limits>
18 #include <memory>
19 #include <mutex>
20 #include <ostream>
21 #include <stdexcept>
22 #include <string>
23 #include <typeinfo>
24 #include <unordered_map>
25 #include <unordered_set>
26 #include <utility>
27 #include <vector>
28 
29 namespace torch::jit {
30 
31 namespace {
32 
33 RegisterOperators reg({
34     Operator(
35         prim::profile,
__anon1bf66a480202(const Node* node) 36         [](const Node* node) -> Operation {
37           return [](Stack& stack) {
38             AT_ERROR(
39                 "Must be lowered to Interpreter's PROFILE instruction"); // NOLINT
40           };
41         },
42         aliasAnalysisSpecialCase()),
43     Operator(
44         prim::profile_ivalue,
__anon1bf66a480402(const Node* node) 45         [](const Node* node) -> Operation {
46           return [](Stack& stack) {
47             AT_ERROR(
48                 "Must be lowered to Interpreter's PROFILE instruction"); // NOLINT
49           };
50         },
51         aliasAnalysisSpecialCase()),
52     Operator(
53         prim::FusionGroup,
__anon1bf66a480602(const Node* node) 54         [](const Node* node) -> Operation {
55           const auto key = registerFusion(node);
56           return [key](Stack& stack) {
57             RECORD_FUNCTION("FusionGroup", std::vector<c10::IValue>());
58             runFusion(key, stack);
59           };
60         },
61         aliasAnalysisSpecialCase()),
62     Operator(
63         prim::RequiresGradCheck /* (...)  -> (..., bool) */,
__anon1bf66a480802(const Node* node) 64         [](const Node* node) -> Operation {
65           std::vector<bool> rg_props =
66               fmap(node->tys(attr::types), [](const TypePtr& t) {
67                 // if an rg property changes we assume a tensor does require
68                 // gradients which is set in `guardDifferentiableGraph`
69                 TORCH_INTERNAL_ASSERT(
70                     t->castRaw<TensorType>()->requiresGrad().has_value());
71                 return *t->castRaw<TensorType>()->requiresGrad();
72               });
73           return [rg_props](Stack& stack) {
74             auto num_inputs = rg_props.size();
75             // Check every input's shape against profiled (expected) shape.
76             for (const auto i : c10::irange(num_inputs)) {
77               auto& input = peek(stack, i, num_inputs);
78               const auto& t = input.toTensor();
79               if (rg_props[i] != t.requires_grad()) {
80                 push(stack, false);
81                 return;
82               }
83             }
84 
85             push(stack, true);
86           };
87         },
88         aliasAnalysisSpecialCase()),
89     Operator(
90         prim::ConstantChunk,
__anon1bf66a480b02(const Node* node) 91         [](const Node* node) -> Operation {
92           int64_t chunks = node->i(attr::chunks);
93           int64_t dim = node->i(attr::dim);
94           auto outputs_used = fmap(node->outputs(), [](const Value* v) {
95             return !v->uses().empty();
96           });
97           return [=](Stack& stack) {
98             RECORD_FUNCTION("chunk", last(stack, 1));
99 
100             at::Tensor t;
101             pop(stack, t);
102             auto result = at::chunk(t, chunks, dim);
103             stack.insert(
104                 stack.end(),
105                 std::make_move_iterator(result.begin()),
106                 std::make_move_iterator(result.end()));
107             // NB: Chunk can sometimes return a smaller number of outputs.
108             int64_t num_results = result.size();
109             if (num_results != chunks) {
110               if (num_results > chunks) {
111                 TORCH_CHECK(
112                     num_results == chunks,
113                     "Expected chunk to return ",
114                     chunks,
115                     " outputs, but got ",
116                     num_results);
117               }
118               for (const auto i : c10::irange(num_results, chunks)) {
119                 TORCH_CHECK(
120                     !outputs_used[i],
121                     "Expected chunk to return at least ",
122                     chunks,
123                     " outputs, but got only ",
124                     num_results);
125                 // We know that the output is unused, so it's ok to push
126                 // anything on the stack.
127                 stack.emplace_back();
128               }
129             }
130           };
131         },
132         aliasAnalysisSpecialCase()),
133     Operator(
134         prim::ChunkSizes,
__anon1bf66a480e02(const Node* node) 135         [](const Node* node) -> Operation {
136           int64_t raw_dim = node->i(attr::dim);
137           int64_t chunks = node->i(attr::chunks);
138           return [raw_dim, chunks](Stack& stack) {
139             c10::List<int64_t> shape = pop(stack).toIntList();
140             c10::List<int64_t> regular_shape = shape.copy();
141             c10::List<int64_t> last_shape = shape.copy();
142             int64_t dim = at::maybe_wrap_dim(raw_dim, shape.size());
143             TORCH_CHECK(
144                 dim < (int64_t)regular_shape.size(),
145                 "Dimension out of range for chunk");
146             int64_t split_size = (regular_shape[dim] + chunks - 1) / chunks;
147             regular_shape[dim] = split_size;
148             if (shape[dim] % chunks == 0) {
149               last_shape[dim] = split_size;
150             } else {
151               int64_t num_splits = std::max<int64_t>(
152                   (shape[dim] + split_size - 1) / split_size, 1);
153               last_shape[dim] =
154                   split_size - (split_size * num_splits - shape[dim]);
155               AT_ASSERT(last_shape[dim] >= 0);
156             }
157             push(stack, std::move(regular_shape));
158             push(stack, std::move(last_shape));
159           };
160         },
161         aliasAnalysisSpecialCase()),
162     Operator(
163         "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)",
__anon1bf66a481002(Stack& stack) 164         [](Stack& stack) {
165           RECORD_FUNCTION("_grad_sum_to_size", std::vector<c10::IValue>());
166           IValue self, size;
167           pop(stack, self, size);
168           if (size.isNone()) {
169             push(stack, std::move(self));
170           } else {
171             push(stack, at::sum_to(self.toTensor(), size.toDimVector()));
172           }
173         },
174         aliasAnalysisFromSchema()),
175     // This operator is generated inside the compiler for indexing into
176     // ModuleDict without a statically determinable key. Accordingly,
177     // self must be a ModuleType and the output must be an InterfaceType.
178     OperatorGenerator(
179         TORCH_SELECTIVE_SCHEMA(
180             "prim::ModuleContainerIndex.dict(Any self, str ind) -> Any"),
__anon1bf66a481102(Stack& stack) 181         [](Stack& stack) {
182           IValue ind = pop(stack);
183           IValue module_dict = pop(stack);
184           push(stack, module_dict.toModule().attr(ind.toStringRef()));
185         },
186         aliasAnalysisFromSchema()),
187     Operator(
188         prim::TypeCheck /* (...)  -> (..., bool) */,
__anon1bf66a481202(const Node* ) 189         [](const Node* /* node */) -> Operation {
190           return [](Stack& /* stack */) {
191             AT_ERROR("prim::TypeCheck not yet implemented"); // NOLINT
192           };
193         },
194         aliasAnalysisSpecialCase()),
195     Operator(
196         prim::FallbackGraph,
__anon1bf66a481402(const Node* node) 197         [](const Node* node) -> Operation {
198           return [](Stack& stack) {
199             AT_ERROR(
200                 "Must be converted to prim::FunctionCall by replaceFallbackGraphWithFallbackFunction"); // NOLINT
201           };
202         },
203         aliasAnalysisSpecialCase()),
204     Operator(
205         "prim::Guard(Tensor(a) t) -> Tensor(a)",
__anon1bf66a481602(Stack& stack) 206         [](Stack& stack) { AT_ERROR("Should be replaced by prim::BailOut"); },
207         aliasAnalysisFromSchema()),
208     Operator(
209         "prim::BailOut(...) -> Tensor(a)",
__anon1bf66a481702(Stack& ) 210         [](Stack& /* stack */) {
211           AT_ERROR("prim::BailOut not yet implemented"); // NOLINT
212         },
213         aliasAnalysisFromSchema()),
214     Operator(
215         "prim::BailoutTemplate() -> int",
__anon1bf66a481802(Stack& stack) 216         [](Stack& stack) {
217           // TODO: today, we put a single bailout template at the front to
218           // carry the un-optimized graph for bailout nodes to use. Ideally
219           // this should never run, but we haven't written the code to remove
220           // it yet.
221           // TORCH_INTERNAL_ASSERT(false);
222 
223           // Returns an int so that we have an easy way to do graph traversal
224           push(stack, 1);
225         },
226         aliasAnalysisFromSchema()),
227     Operator(
228         "aten::grad(Tensor[] outputs, Tensor[] inputs, Tensor?[]? grad_outputs=None, bool? retain_graph=None, bool create_graph=False, bool allow_unused=False) -> Tensor?[]",
__anon1bf66a481902(Stack& stack) 229         [](Stack& stack) {
230           bool allow_unused = pop(stack).toBool();
231           bool create_graph = pop(stack).toBool();
232           auto retain_graph = pop(stack).toOptional<bool>();
233           auto grad_outputs = pop(stack);
234           auto inputs = pop(stack).toTensorList();
235           auto outputs = pop(stack).toTensorList();
236           std::vector<torch::autograd::Variable> input_vars(
237               inputs.begin(), inputs.end());
238           std::vector<torch::autograd::Variable> output_vars(
239               outputs.begin(), outputs.end());
240           std::vector<torch::autograd::Variable> gradients;
241 
242           if (!grad_outputs.isNone()) {
243             for (const IValue& v : grad_outputs.toListRef()) {
244               gradients.emplace_back(v.isNone() ? at::Tensor() : v.toTensor());
245             }
246           }
247 
248           auto res = torch::autograd::grad(
249               output_vars,
250               input_vars,
251               gradients,
252               retain_graph,
253               create_graph,
254               allow_unused);
255 
256           c10::impl::GenericList res_list{OptionalType::ofTensor()};
257           for (const at::Tensor& t : res) {
258             res_list.emplace_back(t.defined() ? t : IValue());
259           }
260           push(stack, res_list);
261         },
262         aliasAnalysisFromSchema()),
263     // NB: backward op might write to every input tensors in the graph and it's
264     // much more expensive to analyze the leaves and sometimes it might retain
265     // the whole gradients in every tensor of the Autograd graph with
266     // create_graph=True so we use aliasAnalysisConservative for these two OPs
267     Operator(
268         "aten::backward.TensorList(Tensor[] tensors, Tensor?[]? grad_tensors=None, bool? retain_graph=None, bool create_graph=False) -> ()",
__anon1bf66a481a02(Stack& stack) 269         [](Stack& stack) {
270           bool create_graph = pop(stack).toBool();
271           auto retain_graph = pop(stack).toOptional<bool>();
272           auto grad_tensors = pop(stack);
273           auto outputs = pop(stack).toTensorList();
274           std::vector<torch::autograd::Variable> output_vars(
275               outputs.begin(), outputs.end());
276           std::vector<torch::autograd::Variable> gradients;
277 
278           if (!grad_tensors.isNone()) {
279             for (const IValue& v : grad_tensors.toListRef()) {
280               gradients.emplace_back(v.isNone() ? at::Tensor() : v.toTensor());
281             }
282           }
283 
284           torch::autograd::backward(
285               output_vars, gradients, retain_graph, create_graph);
286         },
287         aliasAnalysisConservative()),
288     Operator(
289         "aten::save(t item, str filename) -> ()",
__anon1bf66a481b02(Stack& stack) 290         [](Stack& stack) {
291           auto filename = pop(stack).toStringRef();
292           auto ivalue = pop(stack);
293 
294           // Pickle the tensor
295           auto data = jit::pickle_save(ivalue);
296 
297           // Write file
298           std::fstream output(filename, std::ios::out | std::ios::binary);
299           output.write(data.data(), data.size());
300         },
301         aliasAnalysisFromSchema()),
302     Operator(
303         "prim::IgnoredPythonOp(...) -> None",
__anon1bf66a481c02(Stack& stack) 304         [](Stack& stack) {
305           throw JITException(
306               "This Python function is annotated to be ignored"
307               " and cannot be and has not been included in the exported"
308               " binary, meaning that it cannot be executed now."
309               " Make sure that ignored operations are never executed after"
310               " import");
311         },
312         aliasAnalysisFromSchema()),
313     Operator(
314         "aten::wait(Future(t) self) -> t",
__anon1bf66a481d02(Stack& stack) 315         [](Stack& stack) {
316           TORCH_CHECK(false, "wait is implemented directly in the interpreter");
317         },
318         aliasAnalysisSpecialCase()),
319     Operator(
320         "prim::awaitable_wait(Await(t) self) -> t",
__anon1bf66a481e02(Stack& stack) 321         [](Stack& stack) {
322           auto aw = stack.back().toAwait();
323           aw->wait();
324           stack.pop_back();
325           stack.emplace_back(aw->value());
326         },
327         aliasAnalysisSpecialCase()),
328     Operator(
329         "prim::awaitable_nowait(t self) -> Await(t)",
__anon1bf66a481f02(Stack& stack) 330         [](Stack& stack) {
331           auto aw =
332               c10::make_intrusive<c10::ivalue::Await>(stack.back().type());
333           aw->markCompleted(pop(stack));
334           push(stack, std::move(aw));
335         },
336         aliasAnalysisSpecialCase()),
337 });
338 
339 RegisterOperators logging_operators(
340     {Operator(
341          "prim::AddStatValue(str key, int val) -> ()",
__anon1bf66a482002(Stack& stack) 342          [](Stack& stack) {
343            auto val = pop(stack).toInt();
344            auto key = pop(stack).toString();
345 
346            auto schema =
347                parseSchema("prim::AddStatValue(str key, int val) -> ()");
348            // TODO: remove this custom tracing code once the custom op bugfix
349            // lands
350            if (jit::tracer::isTracing()) {
351              const auto& graph = tracer::getTracingState()->graph;
352              Node* node = graph->create(prim::AddStatValue, /*num_outputs=*/0);
353              tracer::recordSourceLocation(node);
354              node->addInput(insertConstant(*graph, key));
355              tracer::addInputs(node, "val", val);
356              graph->insertNode(node);
357            }
358            torch::jit::logging::getLogger()->addStatValue(*key, val);
359          },
360          aliasAnalysisFromSchema()),
361      Operator(
362          "prim::TimePoint() -> int",
__anon1bf66a482102(Stack& stack) 363          [](Stack& stack) {
364            auto schema = parseSchema("prim::TimePoint() -> int");
365            Node* node = nullptr;
366            // TODO: remove this custom tracing code once the custom op bugfix
367            // lands
368            if (jit::tracer::isTracing()) {
369              const auto& graph = tracer::getTracingState()->graph;
370              Node* node = graph->create(prim::TimePoint, /*num_outputs=*/0);
371              tracer::recordSourceLocation(node);
372              graph->insertNode(node);
373            }
374            auto output = c10::getTime(/*allow_monotonic=*/true);
375            push(stack, output);
376            if (jit::tracer::isTracing()) {
377              jit::tracer::addOutput(node, output);
378            }
379          },
380          aliasAnalysisFromSchema())});
381 
hashValue(Stack & stack)382 C10_UNUSED void hashValue(Stack& stack) {
383   auto value = pop(stack);
384   push(stack, value.hash());
385 }
386 
387 // reference: _output_size in torch/nn/functional.py
388 // size can be none, int or intlist
389 // scale_factors can be none, float, or floatlist
_output_size(const at::Tensor & input,size_t dim,const IValue & size,const IValue & scale_factors)390 std::vector<int64_t> _output_size(
391     const at::Tensor& input,
392     size_t dim,
393     const IValue& size,
394     const IValue& scale_factors) {
395   if (!size.isNone()) {
396     if (size.isInt()) {
397       std::vector<int64_t> repeated(dim, size.toInt());
398       return repeated;
399     } else {
400       return size.toIntVector();
401     }
402   }
403   std::vector<double> scale_repeated;
404   if (scale_factors.isDouble()) {
405     scale_repeated = std::vector<double>(dim, scale_factors.toDouble());
406   } else {
407     scale_repeated = scale_factors.toDoubleVector();
408   }
409   std::vector<int64_t> ret;
410   for (const auto i : c10::irange(dim)) {
411     ret.push_back(std::floor(input.size(i + 2) * scale_repeated[i]));
412   }
413   return ret;
414 }
415 
416 // return true if v is a real float
417 // and false if it is an integer
_is_floating_value(double v)418 bool _is_floating_value(double v) {
419   return std::floor(v) != v;
420 }
421 
422 // reference: interpolate in torch/nn/functional.py
423 // size can be none, int or intlist
424 // scale_factors can be none, float, or floatlist
interpolate(const at::Tensor & input,const IValue & size,const IValue & scale_factors,const std::string & mode,std::optional<bool> align_corners,std::optional<bool> recompute_scale_factor)425 at::Tensor interpolate(
426     const at::Tensor& input,
427     const IValue& size,
428     const IValue& scale_factors,
429     const std::string& mode,
430     std::optional<bool> align_corners,
431     std::optional<bool> recompute_scale_factor) {
432   if ((mode == "nearest" || mode == "area")) {
433     if (align_corners != std::nullopt) {
434       throw std::runtime_error(
435           "align_corners option can only be set with the "
436           "interpolating modes: linear | bilinear | bicubic | trilinear");
437     }
438   } else {
439     if (align_corners == std::nullopt) {
440       TORCH_WARN(
441           "Default upsampling behavior when mode=",
442           mode,
443           " is changed "
444           "to align_corners=False since 0.4.0. Please specify align_corners=True "
445           "if the old behavior is desired. See the documentation of nn.Upsample for details");
446       align_corners = false;
447     }
448   }
449 
450   double scale_factors_1 = -1.0;
451   double scale_factors_2 = -1.0;
452   double scale_factors_3 = -1.0;
453 
454   if (!scale_factors.isNone() && recompute_scale_factor == std::nullopt) {
455     recompute_scale_factor = true;
456     bool warn_recompute_scale_factor = false;
457 
458     if (scale_factors.isDouble()) {
459       // only warn when the scales have floating values since
460       // the result for ints is the same with/without recompute_scale_factor
461       if (_is_floating_value(scale_factors.toDouble())) {
462         warn_recompute_scale_factor = true;
463       }
464     } else if (scale_factors.isDoubleList()) {
465       auto scale_factors_list = scale_factors.toDoubleList();
466 
467       for (const auto& scales : scale_factors_list) {
468         // only warn when the scales have floating values since
469         // the result for ints is the same with/without recompute_scale_factor
470         if (_is_floating_value(scales)) {
471           warn_recompute_scale_factor = true;
472           break;
473         }
474       }
475     }
476 
477     if (warn_recompute_scale_factor) {
478       TORCH_WARN(
479           "The default behavior for interpolate/upsample with float scale_factor will change "
480           "in 1.5.0 to align with other frameworks/libraries, and use scale_factor directly, "
481           "instead of relying on the computed output size. "
482           "If you wish to keep the old behavior, please set recompute_scale_factor=True. "
483           "See the documentation of nn.Upsample for details.");
484     }
485   }
486 
487   if (recompute_scale_factor == false) {
488     if (scale_factors.isDouble()) {
489       scale_factors_1 = scale_factors.toDouble();
490       scale_factors_2 = scale_factors.toDouble();
491       scale_factors_3 = scale_factors.toDouble();
492     } else if (scale_factors.isDoubleList()) {
493       auto scale_factors_list = scale_factors.toDoubleList();
494       scale_factors_1 = scale_factors_list[0];
495       if (scale_factors_list.size() >= 2) {
496         scale_factors_2 = scale_factors_list[1];
497         if (scale_factors_list.size() >= 3) {
498           scale_factors_3 = scale_factors_list[2];
499         }
500       }
501     }
502   }
503 
504   const auto dim1d = 3;
505   const auto dim2d = 4;
506   const auto dim3d = 5;
507 
508   auto input_dim = input.dim();
509   if (input_dim == dim1d && mode == "nearest")
510     return at::upsample_nearest1d(
511         input,
512         _output_size(input, 1, size, scale_factors),
513         std::make_optional(scale_factors_1));
514   if (input_dim == dim2d && mode == "nearest")
515     return at::upsample_nearest2d(
516         input,
517         _output_size(input, 2, size, scale_factors),
518         scale_factors_1,
519         scale_factors_2);
520   if (input_dim == dim3d && mode == "nearest")
521     return at::upsample_nearest3d(
522         input,
523         _output_size(input, 3, size, scale_factors),
524         scale_factors_1,
525         scale_factors_2,
526         scale_factors_3);
527   if (input_dim == dim1d && mode == "area")
528     return at::adaptive_avg_pool1d(
529         input, _output_size(input, 1, size, scale_factors));
530   if (input_dim == dim2d && mode == "area")
531     return at::adaptive_avg_pool2d(
532         input, _output_size(input, 2, size, scale_factors));
533   if (input_dim == dim3d && mode == "area")
534     return at::adaptive_avg_pool3d(
535         input, _output_size(input, 3, size, scale_factors));
536   if (input_dim == dim1d && mode == "linear")
537     return at::upsample_linear1d(
538         input,
539         _output_size(input, 1, size, scale_factors),
540         *align_corners,
541         std::make_optional(scale_factors_1));
542   if (input_dim == dim1d && mode == "bilinear")
543     throw std::runtime_error("Got 3D input, but bilinear mode needs 4D input");
544   if (input_dim == dim1d && mode == "bicubic")
545     throw std::runtime_error("Got 3D input, but bicubic mode needs 4D input");
546   if (input_dim == dim1d && mode == "trilinear")
547     throw std::runtime_error("Got 3D input, but trilinear mode needs 5D input");
548   if (input_dim == dim2d && mode == "linear")
549     throw std::runtime_error("Got 4D input, but linear mode needs 3D input");
550   if (input_dim == dim2d && mode == "bilinear")
551     return at::upsample_bilinear2d(
552         input,
553         _output_size(input, 2, size, scale_factors),
554         *align_corners,
555         scale_factors_1,
556         scale_factors_2);
557   if (input_dim == dim2d && mode == "bicubic")
558     return at::upsample_bicubic2d(
559         input,
560         _output_size(input, 2, size, scale_factors),
561         *align_corners,
562         scale_factors_1,
563         scale_factors_2);
564   if (input_dim == dim2d && mode == "trilinear")
565     throw std::runtime_error("Got 4D input, but trilinear mode needs 5D input");
566   if (input_dim == dim3d && mode == "linear")
567     throw std::runtime_error("Got 5D input, but linear mode needs 3D input");
568   if (input_dim == dim3d && mode == "bilinear")
569     throw std::runtime_error("Got 5D input, but bilinear mode needs 4D input");
570   if (input_dim == dim3d && mode == "bicubic")
571     throw std::runtime_error("Got 5D input, but bicubic mode needs 4D input");
572   if (input_dim == dim3d && mode == "trilinear")
573     return at::upsample_trilinear3d(
574         input,
575         _output_size(input, 3, size, scale_factors),
576         *align_corners,
577         scale_factors_1,
578         scale_factors_2,
579         scale_factors_3);
580 
581   AT_ERROR(
582       "Input Error: Only 3D, 4D and 5D input Tensors supported",
583       " (got ",
584       input_dim,
585       "D) for the modes: nearest | linear | bilinear | trilinear",
586       " (got ",
587       mode,
588       ") ");
589 }
590 
interpolate_op(Stack & stack)591 void interpolate_op(Stack& stack) {
592   at::Tensor input;
593   IValue size;
594   IValue scale_factors;
595   std::string mode;
596   IValue align_corners;
597   IValue recompute_scale_factor;
598   bool antialias = false;
599   pop(stack,
600       input,
601       size,
602       scale_factors,
603       mode,
604       align_corners,
605       recompute_scale_factor,
606       antialias);
607   if (antialias) {
608     throw std::runtime_error("Antialias is not yet supported");
609   }
610   at::Tensor res = interpolate(
611       input,
612       size,
613       scale_factors,
614       mode,
615       align_corners.toOptional<bool>(),
616       recompute_scale_factor.toOptional<bool>());
617   push(stack, std::move(res));
618 }
619 
620 // interpolate takes in float & float[] for scale factor
621 // upsample takes in int & int[], so convert the ints to floats before
622 // passing on to the interpolate op
convert_scale_factor_to_double(const IValue & int_ivalue)623 IValue convert_scale_factor_to_double(const IValue& int_ivalue) {
624   IValue scale_factor_double;
625   if (int_ivalue.isInt()) {
626     scale_factor_double = static_cast<double>(int_ivalue.toInt());
627   } else if (int_ivalue.isIntList()) {
628     auto int_list = int_ivalue.toDimVector();
629     std::vector<double> double_vec(int_list.begin(), int_list.end());
630     scale_factor_double = double_vec;
631   } else if (int_ivalue.isNone()) {
632     return IValue();
633   } else {
634     std::stringstream ss;
635     ss << "Expecting optional int or int list arg for scale factor, got"
636        << int_ivalue;
637     throw std::runtime_error(ss.str());
638   }
639   return scale_factor_double;
640 }
641 
upsample_nearest_op(Stack & stack)642 void upsample_nearest_op(Stack& stack) {
643   at::Tensor input;
644   IValue size;
645   IValue scale_factor_int;
646   pop(stack, input, size, scale_factor_int);
647   IValue scale_factor_double = convert_scale_factor_to_double(scale_factor_int);
648   at::Tensor res = interpolate(
649       input, size, scale_factor_double, "nearest", std::nullopt, std::nullopt);
650   push(stack, std::move(res));
651 }
652 
upsample_op(Stack & stack)653 void upsample_op(Stack& stack) {
654   at::Tensor input;
655   IValue size;
656   IValue scale_factor_int;
657   std::string mode;
658   IValue align_corners;
659   pop(stack, input, size, scale_factor_int, mode, align_corners);
660   IValue scale_factor_double = convert_scale_factor_to_double(scale_factor_int);
661   at::Tensor res = interpolate(
662       input,
663       size,
664       scale_factor_double,
665       mode,
666       align_corners.toOptional<bool>(),
667       std::nullopt);
668   push(stack, std::move(res));
669 }
670 
upsample_bilinear_op(Stack & stack)671 void upsample_bilinear_op(Stack& stack) {
672   at::Tensor input;
673   IValue size;
674   IValue scale_factor_int;
675   pop(stack, input, size, scale_factor_int);
676   IValue scale_factor_double = convert_scale_factor_to_double(scale_factor_int);
677   at::Tensor res = interpolate(
678       input, size, scale_factor_double, "bilinear", true, std::nullopt);
679   push(stack, std::move(res));
680 }
681 
682 // These ops are no longer generated, but remain here for BC
683 RegisterOperators reg3({
684     Operator(
685         "aten::__interpolate.scale_list(Tensor input, int? size = None, float[]? scale_factor = None, str mode = 'nearest', bool? align_corners = None, bool? recompute_scale_factor = None, bool antialias = False) -> Tensor",
686         interpolate_op,
687         aliasAnalysisFromSchema()),
688     Operator(
689         "aten::__interpolate.size_list_scale_list(Tensor input, int[]? size = None, float[]? scale_factor = None, str mode = 'nearest', bool? align_corners = None, bool? recompute_scale_factor = None, bool antialias = False) -> Tensor",
690         interpolate_op,
691         aliasAnalysisFromSchema()),
692     Operator(
693         "aten::__interpolate(Tensor input, int? size = None, float? scale_factor = None, str mode = 'nearest', bool? align_corners = None, bool? recompute_scale_factor = None, bool antialias = False) -> Tensor",
694         interpolate_op,
695         aliasAnalysisFromSchema()),
696     Operator(
697         "aten::__interpolate.size_list(Tensor input, int[]? size = None, float? scale_factor = None, str mode = 'nearest', bool? align_corners = None, bool? recompute_scale_factor = None, bool antialias = False) -> Tensor",
698         interpolate_op,
699         aliasAnalysisFromSchema()),
700 
701     Operator(
702         "aten::__upsample_nearest(Tensor input, int? size = None, int? scale_factor = None) -> Tensor",
703         upsample_nearest_op,
704         aliasAnalysisFromSchema()),
705     Operator(
706         "aten::__upsample_nearest.size_list(Tensor input, int[]? size = None, int? scale_factor = None) -> Tensor",
707         upsample_nearest_op,
708         aliasAnalysisFromSchema()),
709 
710     Operator(
711         "aten::__upsample(Tensor input, int? size = None, int? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
712         upsample_op,
713         aliasAnalysisFromSchema()),
714     Operator(
715         "aten::__upsample.size_list(Tensor input, int[]? size = None, int? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
716         upsample_op,
717         aliasAnalysisFromSchema()),
718 
719     Operator(
720         "aten::__upsample_bilinear(Tensor input, int? size = None, int? scale_factor = None) -> Tensor",
721         upsample_bilinear_op,
722         aliasAnalysisFromSchema()),
723     Operator(
724         "aten::__upsample_bilinear.size_list(Tensor input, int[]? size = None, int? scale_factor = None) -> Tensor",
725         upsample_bilinear_op,
726         aliasAnalysisFromSchema()),
727     Operator(
728         "aten::__upsample_bilinear.scale_list(Tensor input, int? size = None, int[]? scale_factor = None) -> Tensor",
729         upsample_bilinear_op,
730         aliasAnalysisFromSchema()),
731     Operator(
732         "aten::__upsample_bilinear.size_list_scale_list(Tensor input, int[]? size = None, int[]? scale_factor = None) -> Tensor",
733         upsample_bilinear_op,
734         aliasAnalysisFromSchema()),
735 
736 });
737 
leaky_relu(const at::Tensor & tensor,double scalar)738 at::Tensor leaky_relu(const at::Tensor& tensor, double scalar) {
739   return at::leaky_relu(tensor, scalar);
740 }
cat(const c10::List<at::Tensor> & tensors)741 at::Tensor cat(const c10::List<at::Tensor>& tensors) {
742   return at::cat(tensors.vec());
743 }
744 
get_first(const c10::List<c10::List<std::string>> & strings)745 std::string get_first(const c10::List<c10::List<std::string>>& strings) {
746   return strings.get(0).get(0);
747 }
748 
749 static auto reg4 =
750     torch::RegisterOperators()
751         .op("_test::leaky_relu(Tensor self, float v=0.01) -> Tensor",
752             &leaky_relu)
753         .op("_test::cat(Tensor[] inputs) -> Tensor", &cat)
754         .op("_test::get_first", &get_first);
755 
756 } // namespace
757 } // namespace torch::jit
758