xref: /aosp_15_r20/external/pytorch/benchmarks/static_runtime/test_static_runtime.cc (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/dispatch/OperatorOptions.h>
2 #include <c10/core/ScalarType.h>
3 #include <gtest/gtest.h>
4 #include <torch/csrc/jit/ir/alias_analysis.h>
5 #include <torch/csrc/jit/ir/irparser.h>
6 #include <torch/csrc/jit/runtime/static/ProcessedNodeInputs.h>
7 #include <torch/csrc/jit/runtime/static/impl.h>
8 #include <torch/csrc/jit/runtime/static/passes.h>
9 #include <torch/csrc/jit/testing/file_check.h>
10 #include <stdexcept>
11 
12 #include "deep_wide_pt.h"
13 #include "test_utils.h"
14 
15 using namespace caffe2;
16 using namespace torch;
17 using namespace torch::jit;
18 using namespace torch::jit::test;
19 using c10::IValue;
20 
21 /*
22  When adding a test for an operator implemented in static runtime, there are
23  several things that you need to pay attention to:
24 
25  1) if the op is an out variant, in the test script of the op,
26  instead of:
27     def forward(self, input):
28       return myop(input)
29 
30   do:
31     def forward(self, input):
32       return myop(input).clone()
33 
34  This makes sure that the output of myop is managed by the memory planner and
35  exercise the code path in the op impl that otherwise doesn't get exercised. The
36  output of the model is not managed by the memory planner, because it needs to
37  be returned to the client.
38 
39  2) The memory planner rounds up the size of each Tensor's storage to multiples
40  of 64 bytes (alignment requirement on AVX512). Make sure the sizes of the input
41  tensors in args2 are big enough to trigger resizing.
42 
43  3) for view ops such as aten::reshape or aten::to, if you want it to be
44  replaced by the copy version with the ReplaceWithCopy pass in passes.h, you
45  also want to make sure its output is not returned as the model output. The
46  reason is that ReplaceWithCopy only replaces the op whose output is not an
47  alias of the model output.
48 */
49 
50 C10_DECLARE_bool(static_runtime_enable_fast_math);
51 
TEST(StaticRuntime,UnaryOps)52 TEST(StaticRuntime, UnaryOps) {
53   const auto aten_sum = R"JIT(
54     def forward(self, input):
55         return torch.sum(input).clone()
56   )JIT";
57 
58   const auto aten_sum_0 = R"JIT(
59     def forward(self, input):
60         return torch.sum(input, 0).clone()
61   )JIT";
62 
63   const auto aten_sum_1 = R"JIT(
64     def forward(self, input):
65         return torch.sum(input, 1).clone()
66   )JIT";
67 
68   const auto aten_sum_0_true = R"JIT(
69     def forward(self, input):
70         return torch.sum(input, 0, True).clone()
71   )JIT";
72 
73   const auto aten_sum_1_true = R"JIT(
74     def forward(self, input):
75         return torch.sum(input, 1, True).clone()
76   )JIT";
77 
78   auto a = at::randn({2, 3});
79   auto b = at::randn({3, 3, 6});
80 
81   std::vector<IValue> args{a}, args2{b};
82 
83   // sum
84   testStaticRuntime(aten_sum, args);
85   testStaticRuntime(aten_sum_0, args);
86   testStaticRuntime(aten_sum_1, args);
87   testStaticRuntime(aten_sum_0_true, args);
88   testStaticRuntime(aten_sum_1_true, args);
89 
90   testStaticRuntime(aten_sum, args, args2, false, false, false);
91   testStaticRuntime(aten_sum_0, args, args2);
92   testStaticRuntime(aten_sum_1, args, args2);
93   testStaticRuntime(aten_sum_0_true, args, args2);
94   testStaticRuntime(aten_sum_1_true, args, args2);
95 }
96 
TEST(StaticRuntime,Max)97 TEST(StaticRuntime, Max) {
98   auto src_max_reduce = R"JIT(
99     def forward(self, input):
100         return torch.max(input).clone()
101   )JIT";
102 
103   auto src_max_dim = R"JIT(
104     def forward(self, input, dim: int):
105         values, indices = torch.max(input, dim)
106         return values.clone(), indices.clone()
107   )JIT";
108 
109   auto src_max_dim_keepdim = R"JIT(
110     def forward(self, input, dim: int):
111         values, indices = torch.max(input, dim, keepdim=True)
112         return values.clone(), indices.clone()
113   )JIT";
114 
115   auto src_max_pointwise = R"JIT(
116     def forward(self, input, other):
117         return torch.max(input, other).clone()
118   )JIT";
119 
120   auto input = at::randn({2, 3, 2});
121   auto input_other = at::randn({2, 3, 2});
122   auto large_input = at::randn({8, 9, 10});
123   auto large_input_other = at::randn({8, 9, 10});
124 
125   testStaticRuntime(src_max_reduce, {input});
126   testStaticRuntime(src_max_dim, {input, 1});
127   testStaticRuntime(src_max_dim, {input, 1}, {large_input, 0});
128   testStaticRuntime(src_max_dim_keepdim, {input, 0});
129   testStaticRuntime(src_max_dim_keepdim, {input, 0}, {large_input, 2});
130   testStaticRuntime(src_max_pointwise, {input, input_other});
131   testStaticRuntime(src_max_pointwise, {input, input_other}, {large_input, large_input_other});
132 }
133 
TEST(StaticRuntime,Mean)134 TEST(StaticRuntime, Mean) {
135   const auto src_default = R"JIT(
136     def forward(self, input):
137         return torch.mean(input).clone()
138   )JIT";
139   const auto src_dtype = R"JIT(
140     def forward(self, input, dtype: int):
141         return torch.mean(input, dtype=dtype).clone()
142   )JIT";
143   const auto src_dim = R"JIT(
144     def forward(self, input, dim: List[int]):
145         return torch.mean(input, dim).clone()
146   )JIT";
147   const auto src_dim_keepdim = R"JIT(
148     def forward(self, input, dim: List[int]):
149         return torch.mean(input, dim, keepdim=True).clone()
150   )JIT";
151   const auto src_dim_dtype = R"JIT(
152     def forward(self, input, dim: List[int], dtype: int):
153         return torch.mean(input, dim, dtype=dtype).clone()
154   )JIT";
155 
156   auto input = at::randn({2, 3, 2});
157   auto large_input = at::randn({8, 7, 6, 8});
158 
159   std::vector<IValue> args_default = {input};
160   std::vector<IValue> args_dtype = {input, torch::kFloat};
161   std::vector<IValue> args_dim = {input, c10::List<int64_t>{0, 2}};
162   std::vector<IValue> args_dim_keepdim = {input, c10::List<int64_t>{1, 2}};
163   std::vector<IValue> args_dim_dtype = {input, c10::List<int64_t>{0, 1}, torch::kBFloat16};
164 
165   testStaticRuntime(src_default, args_default);
166   testStaticRuntime(src_dtype, args_dtype);
167   testStaticRuntime(src_dim, args_dim);
168   testStaticRuntime(src_dim_keepdim, args_dim_keepdim);
169   testStaticRuntime(src_dim_dtype, args_dim_dtype);
170 
171   std::vector<IValue> large_args_dim = {large_input, c10::List<int64_t>{0, 3}};
172   std::vector<IValue> large_args_dim_keepdim = {large_input, c10::List<int64_t>{1, 2}};
173   std::vector<IValue> large_args_dim_dtype = {large_input, c10::List<int64_t>{1, 3}, torch::kBFloat16};
174 
175   testStaticRuntime(src_dim, args_dim, large_args_dim);
176   testStaticRuntime(src_dim_keepdim, args_dim_keepdim, large_args_dim_keepdim);
177   testStaticRuntime(src_dim_dtype, args_dim_dtype, large_args_dim_dtype);
178 }
179 
TEST(StaticRuntime,Sigmoid)180 TEST(StaticRuntime, Sigmoid) {
181   const auto sigmoid_script = R"JIT(
182     def forward(self, inp: Tensor):
183         b = torch.sigmoid(inp).clone()
184         return (b)
185   )JIT";
186   auto a = at::randn({2, 3});
187   auto b = at::randn({4, 3, 2});
188 
189   std::vector<IValue> args{a}, args2{b};
190 
191   testStaticRuntime(sigmoid_script, args, /*args2=*/{}, /*use_allclose=*/true);
192   testStaticRuntime(sigmoid_script, args, {args2}, /*use_allclose=*/true);
193 
194   FLAGS_static_runtime_enable_fast_math = false;
195   testStaticRuntime(sigmoid_script, args, /*args2=*/{}, /*use_allclose=*/true);
196   testStaticRuntime(sigmoid_script, args, {args2}, /*use_allclose=*/true);
197   FLAGS_static_runtime_enable_fast_math = true;
198 }
199 
TEST(StaticRuntime,Clone)200 TEST(StaticRuntime, Clone) {
201   /*
202   Clone called two times to trigger memory planner for output of first clone.
203   The output of last op(second clone) is not managed by memory planner since it
204   needs to be returned to the client and cannot be reused by planner.
205   */
206   const auto clone_script_0 = R"JIT(
207     def forward(self, input):
208         a = torch.clone(input).clone()
209         return (a * a)
210   )JIT";
211 
212   // Case: clone with different set of memory_formats
213   const auto clone_script_1 = R"JIT(
214     def forward(self, input: Tensor, memory_format: int):
215         a = torch.clone(input, memory_format=memory_format).clone()
216         return (a * a)
217   )JIT";
218 
219   /*
220   Case: input stride set to 0 (due to expand op)
221   calls native clone instead of out variant
222   */
223   const auto clone_script_2 = R"JIT(
224     def forward(self, input: Tensor, other:Tensor):
225         a = input.expand_as(other)
226         return a.clone().clone()
227   )JIT";
228 
229   /*
230   Case: testing the case of sliced tensor for
231   testing non-contiguous tensor storage
232   */
233   const auto clone_script_3 = R"JIT(
234     def forward(self, input: Tensor):
235         a = input[:, 0:10:2]
236         return a.clone().clone()
237   )JIT";
238 
239   auto a = at::randn({2, 3});
240   auto b = at::randn({3, 2}).as_strided({3, 2}, {1, 3});
241   auto b_larger = at::randn({30, 20}).as_strided({30, 20}, {1, 3});
242   auto c = at::randn({1, 20, 13, 8});
243   auto d = at::randn({1, 0, 3, 4});
244   auto e = at::randn({2, 1});
245   auto f = at::randn({2, 10});
246   auto g = at::randn({3, 20});
247   std::vector<IValue> args_0{b, c10::MemoryFormat::Contiguous};
248   std::vector<IValue> args_1{b_larger, c10::MemoryFormat::Preserve};
249   std::vector<IValue> args_2{c, c10::MemoryFormat::ChannelsLast};
250   std::vector<IValue> args_3{d, c10::MemoryFormat::ChannelsLast};
251   std::vector<IValue> args_4{e,a};
252   std::vector<IValue> args_5{e,f};
253 
254   testStaticRuntime(clone_script_0, {a});
255   testStaticRuntime(clone_script_0, {a}, {b_larger});
256 
257   testStaticRuntime(clone_script_1, args_0);
258   testStaticRuntime(clone_script_1, args_1);
259   testStaticRuntime(clone_script_1, args_2);
260   testStaticRuntime(clone_script_1, args_3);
261   testStaticRuntime(clone_script_1, args_0, args_1);
262   testStaticRuntime(clone_script_1, args_3, args_2);
263 
264   testStaticRuntime(clone_script_2, args_4);
265   testStaticRuntime(clone_script_2, args_4, args_5);
266 
267   testStaticRuntime(clone_script_3, {f});
268   testStaticRuntime(clone_script_3, {f}, {g});
269 }
270 
TEST(StaticRuntime,Clamp)271 TEST(StaticRuntime, Clamp) {
272   const auto clamp_script_1 = R"JIT(
273     def forward(self, inp: Tensor, min: int, max: int):
274         a = torch.clamp(inp, min, max).clone()
275         return (a)
276   )JIT";
277 
278   const auto clamp_script_2 = R"JIT(
279     def forward(self, inp: Tensor, min: Tensor, max: Tensor):
280         a = torch.clamp(inp, min, max).clone()
281         return (a)
282   )JIT";
283   auto a = at::randn({2, 3});
284   auto max_t = at::full_like(a, 1);
285   auto min_t = at::full_like(a, -1);
286 
287   auto b = at::randn({4, 3, 2});
288   auto max_t1 = at::full_like(b, 1);
289   auto min_t1 = at::full_like(b, -1);
290 
291   testStaticRuntime(clamp_script_1, {a, -1, 1});
292   testStaticRuntime(clamp_script_2, {a, min_t, max_t});
293 
294   testStaticRuntime(clamp_script_1, {a, -1, 1}, {b, -1, 1});
295   testStaticRuntime(clamp_script_2, {a, min_t, max_t}, {b, max_t1, min_t1});
296 }
297 
TEST(StaticRuntime,ClampMinOnly)298 TEST(StaticRuntime, ClampMinOnly) {
299   const auto src = R"JIT(
300     def forward(self, inp: Tensor, min: float):
301         a = torch.clamp(inp, min, None).clone()
302         return (a)
303   )JIT";
304   auto a = at::randn({2, 3});
305   auto b = at::randn({4, 3, 2});
306   testStaticRuntime(src, {a, 0.5});
307   testStaticRuntime(src, {a, 0.5}, {b, 0.25});
308 }
309 
TEST(StaticRuntime,ClampMaxOnly)310 TEST(StaticRuntime, ClampMaxOnly) {
311   const auto src = R"JIT(
312     def forward(self, inp: Tensor, max: float):
313         a = torch.clamp(inp, None, max).clone()
314         return (a)
315   )JIT";
316   auto a = at::randn({2, 3});
317   auto b = at::randn({4, 3, 2});
318   testStaticRuntime(src, {a, 0.5});
319   testStaticRuntime(src, {a, 0.5}, {b, 0.25});
320 }
321 
TEST(StaticRuntime,ClampIntTensor)322 TEST(StaticRuntime, ClampIntTensor) {
323   const auto src = R"JIT(
324     def forward(self, inp: Tensor, min: float, max: float):
325         a = torch.clamp(inp, min, max).clone()
326         return (a)
327   )JIT";
328   auto a = at::randint(0, 20, {2, 3}, at::kFloat);
329   auto b = at::randint(0, 20, {4, 3, 2}, at::kFloat);
330   auto min = 5.0f;
331   auto max = 5.0f;
332   testStaticRuntime(src, {a, min, max});
333   testStaticRuntime(src, {a, min, max}, {b, min, max});
334 }
335 
TEST(StaticRuntime,LenWithTuple)336 TEST(StaticRuntime, LenWithTuple) {
337   const auto src = R"IR(
338     graph(%input : int[]):
339         %res : int = aten::len(%input)
340         return (%res)
341   )IR";
342 
343   testStaticRuntime(src, {c10::List<int64_t>(4)});
344 }
345 
TEST(StaticRuntime,LenWithTensor)346 TEST(StaticRuntime, LenWithTensor) {
347   const auto src = R"IR(
348     graph(%input : Tensor):
349         %res : int = aten::len(%input)
350         return (%res)
351   )IR";
352 
353   testStaticRuntime(src, {at::randn({2, 2, 2})});
354 }
355 
TEST(StaticRuntime,LenWithStr)356 TEST(StaticRuntime, LenWithStr) {
357   const auto src = R"IR(
358     graph(%input : str):
359         %res : int = aten::len(%input)
360         return (%res)
361   )IR";
362 
363   testStaticRuntime(src, {"static_runtime"});
364 }
365 
TEST(StaticRuntime,LenWithDict_str)366 TEST(StaticRuntime, LenWithDict_str) {
367   const auto script = R"JIT(
368     def forward(self, input: Dict[str, str]):
369         return len(input)
370   )JIT";
371 
372   c10::Dict<std::string, std::string> dict;
373   dict.insert("abc", "123");
374   dict.insert("def", "456");
375   testStaticRuntime(script, {dict});
376 }
377 
TEST(StaticRuntime,LenWithDict_int)378 TEST(StaticRuntime, LenWithDict_int) {
379   const auto script = R"JIT(
380     def forward(self, input: Dict[int, int]):
381         return len(input)
382   )JIT";
383 
384   c10::Dict<int64_t, int64_t> dict;
385   dict.insert(0, 1);
386   dict.insert(2, 3);
387   testStaticRuntime(script, {dict});
388 }
389 
TEST(StaticRuntime,LenWithDict_bool)390 TEST(StaticRuntime, LenWithDict_bool) {
391   const auto script = R"JIT(
392     def forward(self, input: Dict[bool, bool]):
393         return len(input)
394   )JIT";
395 
396   c10::Dict<bool, bool> dict;
397   dict.insert(true, false);
398   dict.insert(false, true);
399   testStaticRuntime(script, {dict});
400 }
401 
TEST(StaticRuntime,LenWithDict_float)402 TEST(StaticRuntime, LenWithDict_float) {
403   const auto script = R"JIT(
404     def forward(self, input: Dict[float, float]):
405         return len(input)
406   )JIT";
407 
408   c10::Dict<double, double> dict;
409   dict.insert(0.1, 0.9);
410   dict.insert(0.8, 0.18);
411   testStaticRuntime(script, {dict});
412 }
413 
TEST(StaticRuntime,LenWithDict_complex)414 TEST(StaticRuntime, LenWithDict_complex) {
415   const auto script = R"JIT(
416     def forward(self, input: Dict[complex, complex]):
417         return len(input)
418   )JIT";
419 
420   c10::Dict<c10::complex<double>, c10::complex<double>> dict;
421   dict.insert(0.1, 0.4);
422   dict.insert(0.9, 0.45);
423   testStaticRuntime(script, {dict});
424 }
425 
TEST(StaticRuntime,LenWithDict_Tensor)426 TEST(StaticRuntime, LenWithDict_Tensor) {
427   const auto script = R"JIT(
428     def forward(self, input: Dict[Tensor, Tensor]):
429         return len(input)
430   )JIT";
431 
432   c10::Dict<at::Tensor, at::Tensor> dict;
433   dict.insert(at::randn({1, 2}), at::randn({1, 2}));
434   dict.insert(at::randn({1, 2}), at::randn({1, 2}));
435   testStaticRuntime(script, {dict});
436 }
437 
TEST(StaticRuntime,Logit)438 TEST(StaticRuntime, Logit) {
439   // no nnc
440   const auto logit_script_1 = R"JIT(
441     def forward(self, inp: Tensor):
442         a = torch.logit(inp).clone()
443         return (a)
444   )JIT";
445 
446   // with nnc
447   const auto logit_script_2 = R"JIT(
448     def forward(self, inp: Tensor):
449         a = torch.logit(inp, 1e-6).clone()
450         return (a)
451   )JIT";
452 
453   // no nnc
454   const auto logit_script_3 = R"JIT(
455     def forward(self, inp: Tensor, eps: float):
456         a = torch.logit(inp, eps).clone()
457         return (a)
458   )JIT";
459   auto a = at::ones({2, 3});
460   double b = 1e-6;
461   std::vector<IValue> args_1{a};
462   std::vector<IValue> args_2({a, b});
463 
464   auto c = at::ones({4, 3, 2});
465 
466   // logit
467   testStaticRuntime(logit_script_1, args_1);
468   testStaticRuntime(logit_script_2, args_1);
469   testStaticRuntime(logit_script_3, args_2);
470 
471   testStaticRuntime(logit_script_1, args_1, {c});
472   testStaticRuntime(logit_script_2, args_1, {c});
473   testStaticRuntime(logit_script_3, args_2, {c, b});
474 }
475 
TEST(StaticRuntime,EmbeddingBag)476 TEST(StaticRuntime, EmbeddingBag) {
477   const std::string embedding_bag_default = R"JIT(
478     def forward(self, a: Tensor, b: Tensor, c: Tensor):
479         x, y, z, _ = torch.embedding_bag(a, b, c)
480         return (x.clone(), y.clone(), z.clone(), _.clone())
481   )JIT";
482 
483   const std::string embedding_bag_mean = R"JIT(
484     def forward(self, a: Tensor, b: Tensor, c: Tensor):
485         x, y, z, _ = torch.embedding_bag(a, b, c, False, 1)
486         return (x.clone(), y.clone(), z.clone(), _.clone())
487   )JIT";
488 
489   const std::string embedding_bag_max = R"JIT(
490     def forward(self, a: Tensor, b: Tensor, c: Tensor):
491         x, y, z, _ = torch.embedding_bag(a, b, c, False, 2)
492         return (x.clone(), y.clone(), z.clone(), _.clone())
493   )JIT";
494 
495   const std::string embedding_bag_sum_last_offset = R"JIT(
496     def forward(self, a: Tensor, b: Tensor, c: Tensor):
497         x, y, z, _ = torch.embedding_bag(a, b, c, False, 0, False, None, True)
498         return (x.clone(), y.clone(), z.clone(), _.clone())
499   )JIT";
500 
501   const std::string embedding_bag_mean_last_offset = R"JIT(
502     def forward(self, a: Tensor, b: Tensor, c: Tensor):
503         x, y, z, _ = torch.embedding_bag(a, b, c, False, 1, False, None, True)
504         return (x.clone(), y.clone(), z.clone(), _.clone())
505   )JIT";
506 
507   const std::string embedding_bag_max_last_offset = R"JIT(
508     def forward(self, a: Tensor, b: Tensor, c: Tensor):
509         x, y, z, _ = torch.embedding_bag(a, b, c, False, 2, False, None, True)
510         return (x.clone(), y.clone(), z.clone(), _.clone())
511   )JIT";
512 
513   at::Tensor weight = torch::randn({3, 11}, at::ScalarType::Float);
514   at::Tensor input = torch::tensor({0, 1, 0, 2});
515   at::Tensor offset = torch::tensor({0, 2, 4});
516   std::vector<IValue> args{weight, input, offset};
517   testStaticRuntime(embedding_bag_default, args);
518   testStaticRuntime(embedding_bag_mean, args);
519   testStaticRuntime(embedding_bag_max, args);
520   testStaticRuntime(embedding_bag_sum_last_offset, args);
521   testStaticRuntime(embedding_bag_mean_last_offset, args);
522   testStaticRuntime(embedding_bag_max_last_offset, args);
523 
524   at::Tensor weight2 = torch::randn({10, 11}, at::ScalarType::Float);
525   at::Tensor input2 = torch::tensor({0, 1, 0, 2, 1});
526   at::Tensor offset2 = torch::tensor({0, 1, 2, 3, 4, 5});
527   std::vector<IValue> args2{weight2, input2, offset2};
528   testStaticRuntime(embedding_bag_default, args, args2);
529   testStaticRuntime(embedding_bag_mean, args, args2);
530   testStaticRuntime(embedding_bag_max, args, args2);
531   testStaticRuntime(embedding_bag_sum_last_offset, args, args2);
532   testStaticRuntime(embedding_bag_mean_last_offset, args, args2);
533   testStaticRuntime(embedding_bag_max_last_offset, args, args2);
534 }
535 
TEST(StaticRuntime,EmbeddingBagWithManagedOutput)536 TEST(StaticRuntime, EmbeddingBagWithManagedOutput) {
537   const std::string embedding_bag_managed_output = R"JIT(
538     def forward(self, a: Tensor, b: Tensor, c: Tensor):
539         # The outputs of embedding_bag become an intermediate tensors
540         # since they are not directly returned from the graph.
541         x, y, z, _ = torch.embedding_bag(a, b, c)
542         return x + x
543   )JIT";
544 
545   at::Tensor weight = torch::randn({3, 8}, at::ScalarType::Float);
546   at::Tensor input = torch::tensor({0, 1, 0, 2});
547   at::Tensor offset = torch::tensor({0, 2});
548   std::vector<IValue> args{weight, input, offset};
549 
550   at::Tensor weight2 = torch::randn({6, 8}, at::ScalarType::Float);
551   at::Tensor input2 = torch::tensor({0, 1, 0, 2, 3, 4});
552   at::Tensor offset2 = torch::tensor({0, 2, 4, 5});
553   std::vector<IValue> args2{weight2, input2, offset2};
554 
555   testStaticRuntime(embedding_bag_managed_output, args);
556   testStaticRuntime(embedding_bag_managed_output, args, args2);
557 }
558 
TEST(StaticRuntime,EmbeddingBagWithExtraneousOutput)559 TEST(StaticRuntime, EmbeddingBagWithExtraneousOutput) {
560   const std::string embedding_bag_default_ir = R"IR(
561     graph(%weight, %indices, %offsets):
562         %scale_grad_by_freq : bool = prim::Constant[value=0]()
563         %mode : int = prim::Constant[value=0]()
564         %sparse : bool = prim::Constant[value=0]()
565         %per_sample_weights : NoneType = prim::Constant()
566         %include_last_offset : bool = prim::Constant[value=0]()
567         %y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
568         %none : NoneType = prim::Constant()
569         %res : Tensor = aten::clone(%y0, %none)
570         return (%res)
571   )IR";
572   auto graph = getGraphFromIR(embedding_bag_default_ir);
573   RemoveUnnecessaryOutputs(graph);
574   torch::jit::testing::FileCheck()
575       .check("static_runtime::embedding_bag")
576       ->run(*graph);
577 
578   const std::string embedding_bag_mean_ir = R"IR(
579     graph(%weight, %indices, %offsets):
580         %scale_grad_by_freq : bool = prim::Constant[value=0]()
581         %mode : int = prim::Constant[value=1]()
582         %sparse : bool = prim::Constant[value=0]()
583         %per_sample_weights : NoneType = prim::Constant()
584         %include_last_offset : bool = prim::Constant[value=0]()
585         %y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
586         %none : NoneType = prim::Constant()
587         %res : Tensor = aten::clone(%y0, %none)
588         return (%res)
589   )IR";
590   graph = getGraphFromIR(embedding_bag_mean_ir);
591   RemoveUnnecessaryOutputs(graph);
592   torch::jit::testing::FileCheck()
593       .check("static_runtime::embedding_bag")
594       ->run(*graph);
595 
596   const std::string embedding_bag_max_last_offset_ir = R"IR(
597     graph(%weight, %indices, %offsets):
598         %scale_grad_by_freq : bool = prim::Constant[value=0]()
599         %mode : int = prim::Constant[value=2]()
600         %sparse : bool = prim::Constant[value=0]()
601         %per_sample_weights : NoneType = prim::Constant()
602         %include_last_offset : bool = prim::Constant[value=1]()
603         %y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
604         %none : NoneType = prim::Constant()
605         %res : Tensor = aten::clone(%y0, %none)
606         return (%res)
607   )IR";
608   graph = getGraphFromIR(embedding_bag_max_last_offset_ir);
609   RemoveUnnecessaryOutputs(graph);
610   torch::jit::testing::FileCheck()
611       .check("static_runtime::embedding_bag")
612       ->run(*graph);
613 
614   const std::string embedding_bag_normal_ir = R"IR(
615     graph(%weight, %indices, %offsets):
616         %scale_grad_by_freq : bool = prim::Constant[value=0]()
617         %mode : int = prim::Constant[value=0]()
618         %sparse : bool = prim::Constant[value=0]()
619         %per_sample_weights : NoneType = prim::Constant()
620         %include_last_offset : bool = prim::Constant[value=0]()
621         %y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
622         %none : NoneType = prim::Constant()
623         %res0 : Tensor = aten::clone(%y0, %none)
624         %res1 : Tensor = aten::clone(%y1, %none)
625         %res2 : Tensor = aten::clone(%y2, %none)
626         %res3 : Tensor = aten::clone(%y3, %none)
627         return (%res0, %res1, %res2, %res3)
628   )IR";
629   graph = getGraphFromIR(embedding_bag_normal_ir);
630   RemoveUnnecessaryOutputs(graph);
631   torch::jit::testing::FileCheck()
632       .check_not("static_runtime::embedding_bag")
633       ->run(*graph);
634 
635   at::Tensor weight = torch::randn({3, 11}, at::ScalarType::Float);
636   at::Tensor input = torch::tensor({0, 1, 0, 2});
637   at::Tensor offset = torch::tensor({0, 2, 4});
638   std::vector<IValue> args{weight, input, offset};
639   testStaticRuntime(embedding_bag_default_ir, args);
640   testStaticRuntime(embedding_bag_mean_ir, args);
641   testStaticRuntime(embedding_bag_max_last_offset_ir, args);
642 
643   at::Tensor weight2 = torch::randn({10, 11}, at::ScalarType::Float);
644   at::Tensor input2 = torch::tensor({0, 1, 0, 2, 1});
645   at::Tensor offset2 = torch::tensor({0, 1, 2, 3, 4, 5});
646   std::vector<IValue> args2{weight2, input2, offset2};
647   testStaticRuntime(embedding_bag_default_ir, args, args2);
648   testStaticRuntime(embedding_bag_mean_ir, args, args2);
649   testStaticRuntime(embedding_bag_max_last_offset_ir, args, args2);
650 }
651 
TEST(StaticRuntime,EmbeddingBagWithMixedInt32Int64Input)652 TEST(StaticRuntime, EmbeddingBagWithMixedInt32Int64Input) {
653   const std::string embedding_bag_default = R"JIT(
654     def forward(self, a: Tensor, b: Tensor, c: Tensor):
655         x, y, z, _ = torch.embedding_bag(a, b, c)
656         return (x.clone(), y.clone(), z.clone(), _.clone())
657   )JIT";
658   auto weight = torch::randn({3, 11}, at::ScalarType::Float);
659   auto input = torch::tensor({0, 1, 0, 2}, at::ScalarType::Long);
660   auto offset = torch::tensor({0, 2, 4}, at::ScalarType::Int);
661   std::vector<IValue> args{weight, input, offset};
662   testStaticRuntime(embedding_bag_default, args);
663 }
664 
TEST(StaticRuntime,LayerNorm)665 TEST(StaticRuntime, LayerNorm) {
666   const std::string layer_norm_with_weights = R"JIT(
667     def forward(self, input: Tensor, normalized_shape: List[int], weight: Tensor, bias: Tensor):
668         return torch.layer_norm(input, normalized_shape, weight, bias, 1e-05, False).clone()
669   )JIT";
670 
671   const std::string layer_norm_without_weights = R"JIT(
672     def forward(self, input: Tensor, normalized_shape: List[int]):
673         return torch.layer_norm(input, normalized_shape, None, None, 1e-05, False).clone()
674   )JIT";
675 
676   const std::string layer_norm_with_noncontiguous_input = R"JIT(
677     def forward(self, input: Tensor, normalized_shape: List[int], weight: Tensor, bias: Tensor):
678         input = torch.transpose(input, 1, 2)
679         return torch.layer_norm(input, normalized_shape, weight, bias, 1e-05, False).clone()
680   )JIT";
681 
682   const auto a = torch::rand({1, 2, 2, 2});
683   const auto b = torch::rand({3, 2, 2, 2});
684   for (int normalized_size : {2, 3}) {
685     std::vector<int64_t> normalized_shape(normalized_size, 2);
686     const auto weight = torch::rand(normalized_shape);
687     const auto bias = torch::rand(normalized_shape);
688 
689     std::vector<IValue> args{a, normalized_shape, weight, bias};
690     std::vector<IValue> args1{b, normalized_shape, weight, bias};
691     testStaticRuntime(layer_norm_with_weights, args);
692     testStaticRuntime(layer_norm_with_weights, args, args1);
693     testStaticRuntime(layer_norm_with_noncontiguous_input, args);
694 
695     args = {a, normalized_shape};
696     testStaticRuntime(layer_norm_without_weights, args);
697     testStaticRuntime(layer_norm_without_weights, args, {b, normalized_shape});
698   }
699 }
700 
TEST(StaticRuntime,Bmm)701 TEST(StaticRuntime, Bmm) {
702   const auto bmm_script = R"JIT(
703     def forward(self, inp: Tensor, mat2: Tensor):
704       return torch.bmm(inp, mat2).clone()
705   )JIT";
706 
707   auto a = at::randn({10, 4, 5});
708   auto b = at::randn({10, 5, 6});
709 
710   auto c = at::randn({12, 5, 6});
711   auto d = at::randn({12, 6, 7});
712 
713   std::vector<IValue> args{a, b};
714   std::vector<IValue> args1{c, d};
715   testStaticRuntime(bmm_script, args);
716   testStaticRuntime(bmm_script, args1);
717   testStaticRuntime(bmm_script, args, args1);
718 }
719 
TEST(StaticRuntime,Addmm)720 TEST(StaticRuntime, Addmm) {
721   const auto addmm_script = R"JIT(
722     def forward(self, inp: Tensor, mat1: Tensor, mat2: Tensor, beta: float, alpha: float):
723       return torch.addmm(inp, mat1, mat2, alpha=alpha, beta=beta).clone()
724   )JIT";
725   auto inp1 = at::randn({5});
726   auto mat1 = at::randn({3, 4});
727   auto mat2 = at::randn({4, 5});
728 
729   auto inp2 = at::randn({3, 7});
730   auto mat3 = at::randn({3, 6});
731   auto mat4 = at::randn({6, 7});
732 
733   std::vector<IValue> args{inp1, mat1, mat2, 1.0, 2.0};
734   std::vector<IValue> args1{inp2, mat3, mat4, 2.0, 1.0};
735   testStaticRuntime(addmm_script, args);
736   testStaticRuntime(addmm_script, args1);
737   testStaticRuntime(addmm_script, args, args1);
738 }
739 
TEST(StaticRuntime,Abs)740 TEST(StaticRuntime, Abs) {
741   const auto abs_script = R"JIT(
742     def forward(self, a):
743       return a.abs().clone()
744   )JIT";
745   auto a = at::randn({2, 3});
746   auto b = at::randn({4, 2, 3});
747   std::vector<IValue> args{a};
748   std::vector<IValue> args2{b};
749   testStaticRuntime(abs_script, args);
750   testStaticRuntime(abs_script, args, args2);
751 }
752 
TEST(StaticRuntime,Binary)753 TEST(StaticRuntime, Binary) {
754   const auto add_script = R"JIT(
755     def forward(self, a, b):
756         c = a + b
757         return (c.clone())
758   )JIT";
759 
760   const auto add_script_ints = R"JIT(
761     def forward(self, a: int, b: int):
762         c = a + b
763         d = c + 1
764         return d
765   )JIT";
766 
767   const auto add_list_script = R"JIT(
768     def forward(self, a: List[int], b: List[int]):
769         c = a + b
770         return c[::]
771   )JIT";
772 
773   const auto list_construct_script = R"JIT(
774     def forward(self, a, b):
775       return [a, b]
776   )JIT";
777 
778   const auto list_construct_script_2 = R"JIT(
779     def forward(self, a, b):
780       c = a + a
781       return [c, c]
782   )JIT";
783 
784   const auto list_construct_script_3 = R"JIT(
785     def forward(self, a, b):
786       c = a + a
787       return [c, c.flatten()]
788   )JIT";
789 
790   const auto list_unpack_script = R"JIT(
791     def forward(self, a, b):
792       c = [a, b]
793       x, y = c
794       z = x + y
795       return z.clone()
796   )JIT";
797 
798   const auto list_unpack_script_2 = R"JIT(
799     def forward(self, a, b):
800       c = [a, b]
801       x, y = c
802       z = (x, y)
803       return z
804   )JIT";
805 
806   const auto tuple_construct_script = R"JIT(
807     def forward(self, a, b):
808       return (a, b)
809   )JIT";
810 
811   const auto tuple_construct_script_2 = R"JIT(
812     def forward(self, a, b):
813       return (a.flatten(), b)
814   )JIT";
815 
816   auto a = at::randn({2, 3});
817   auto b = at::ones({2, 3});
818 
819   auto c = at::randn({4, 2, 3});
820   auto d = at::ones({4, 2, 3});
821 
822   std::vector<IValue> args{a, b};
823 
824   testStaticRuntime(add_script, args);
825   testStaticRuntime(add_script_ints, {1, 2});
826   testStaticRuntime(add_script, args, {c, d});
827   testStaticRuntime(list_construct_script, args);
828   testStaticRuntime(list_construct_script_2, args);
829   testStaticRuntime(list_construct_script_3, args);
830   testStaticRuntime(list_unpack_script, args);
831   testStaticRuntime(list_unpack_script_2, args);
832   testStaticRuntime(tuple_construct_script, args);
833   testStaticRuntime(tuple_construct_script_2, args);
834 
835   std::vector<IValue> list_args{
836       c10::List<int64_t>{1, 2, 3}, c10::List<int64_t>{4, 5, 6}};
837   testStaticRuntime(add_list_script, list_args);
838 }
839 
TEST(StaticRuntime,MatMul)840 TEST(StaticRuntime, MatMul) {
841   const auto aten_matmul = R"JIT(
842     def forward(self, a: Tensor, b: Tensor):
843         return torch.matmul(a, b).clone()
844   )JIT";
845 
846   // 1-D, 1-D
847   std::vector<IValue> args{at::randn({3}), at::randn({3})};
848   testStaticRuntime(aten_matmul, args);
849   // 2-D, 2-D
850   std::vector<IValue> args1 = {at::randn({3, 2}), at::randn({2, 3})};
851   testStaticRuntime(aten_matmul, args1);
852   // 1-D, 2-D
853   std::vector<IValue> args2 = {at::randn({3}), at::randn({3, 5})};
854   testStaticRuntime(aten_matmul, args2);
855   // 2-D, 1-D
856   std::vector<IValue> args3 = {at::randn({3, 5}), at::randn({5})};
857   testStaticRuntime(aten_matmul, args3);
858   // > 2-D , > 2-D
859   std::vector<IValue> args4 = {at::randn({3, 1, 4, 5}), at::randn({2, 5, 6})};
860   testStaticRuntime(aten_matmul, args4);
861 
862   testStaticRuntime(aten_matmul, args3, args4);
863 }
864 
TEST(StaticRuntime,Sign)865 TEST(StaticRuntime, Sign) {
866   const auto sign_tensor = R"JIT(
867     def forward(self, input: Tensor):
868         return torch.sign(input).clone()
869   )JIT";
870 
871   auto a = at::randn({2, 3});
872   auto b = at::randn({4, 3, 2});
873 
874   std::vector<IValue> args{a};
875   testStaticRuntime(sign_tensor, args);
876   testStaticRuntime(sign_tensor, args, {b});
877 }
878 
TEST(StaticRuntime,Div)879 TEST(StaticRuntime, Div) {
880   const auto div_tensor = R"JIT(
881     def forward(self, a: Tensor, b: Tensor):
882         return torch.div(a, b).clone()
883   )JIT";
884 
885   const auto div_scalar = R"JIT(
886     def forward(self, a: Tensor, b: int):
887         return torch.div(a, b).clone()
888   )JIT";
889 
890   const auto div_tensor_mode = R"JIT(
891     def forward(self, a: Tensor, b: Tensor, c: str):
892         return torch.div(a, b, rounding_mode=c).clone()
893   )JIT";
894 
895   const auto div_scalar_mode = R"JIT(
896     def forward(self, a: Tensor, b: float, c: str):
897         return torch.div(a, b, rounding_mode=c).clone()
898   )JIT";
899 
900   const auto div_strided = R"JIT(
901     def forward(self, a: Tensor, b: Tensor):
902         a_strided = torch.transpose(a, 0, 1)
903         b_strided = torch.transpose(b, 0, 1)
904         return torch.div(a_strided, b_strided).clone()
905   )JIT";
906 
907   auto a = at::randn({2, 3});
908   auto b = at::randn({2, 3});
909   auto bs = at::randn({3, 2}).transpose(0, 1);
910   auto c = at::randn({4, 3, 2});
911   auto d = at::randn({4, 3, 2});
912   auto ds = at::randn({3, 4, 2}).transpose(0, 1);
913 
914   std::vector<IValue> args0{a, b};
915   testStaticRuntime(div_tensor, args0);
916   testStaticRuntime(div_tensor, args0, {c, d});
917 
918   testStaticRuntime(div_strided, args0);
919   testStaticRuntime(div_strided, args0, {c, d});
920 
921   testStaticRuntime(div_tensor, {a, bs});
922   testStaticRuntime(div_tensor, {a, bs}, {c, ds});
923 
924   std::vector<IValue> args1{a, 3};
925   testStaticRuntime(div_scalar, args1);
926   testStaticRuntime(div_scalar, args1, {c, 4});
927 
928   std::vector<IValue> args2{a, b, "floor"};
929   testStaticRuntime(div_tensor_mode, args2);
930   testStaticRuntime(div_tensor_mode, args2, {c, d, "floor"});
931 
932   std::vector<IValue> args3{a, 2.3, "trunc"};
933   testStaticRuntime(div_scalar_mode, args3);
934   testStaticRuntime(div_scalar_mode, args3, {c, 1.5, "trunc"});
935 }
936 
TEST(StaticRuntime,Mul)937 TEST(StaticRuntime, Mul) {
938   const auto mul_tensor = R"JIT(
939     def forward(self, a: Tensor, b: Tensor):
940         return torch.mul(a, b).clone()
941   )JIT";
942 
943   const auto mul_scalar = R"JIT(
944     def forward(self, a: Tensor, b: int):
945         return torch.mul(a, b).clone()
946   )JIT";
947 
948   const auto mul_list = R"JIT(
949     def forward(self, a: List[int], n: int):
950         b = a * n
951         return b[::]
952   )JIT";
953 
954   auto a = at::randn({3, 3});
955   auto b = at::randn({3, 3});
956   auto c = at::randn({3, 3, 3});
957   auto d = at::randn({3, 3, 3});
958 
959   std::vector<IValue> tensor_args1{a, b};
960   std::vector<IValue> tensor_args2{c, d};
961 
962   testStaticRuntime(mul_tensor, tensor_args1);
963   testStaticRuntime(mul_tensor, tensor_args1, tensor_args2);
964 
965   std::vector<IValue> scalar_args1{a, 42};
966   std::vector<IValue> scalar_args2{c, 42};
967 
968   testStaticRuntime(mul_scalar, scalar_args1);
969   testStaticRuntime(mul_scalar, scalar_args1, scalar_args2);
970 
971   std::vector<IValue> list_args{c10::List<int64_t>{1, 2}, 3};
972   testStaticRuntime(mul_list, list_args);
973 }
974 
TEST(StaticRuntime,Log)975 TEST(StaticRuntime, Log) {
976   const auto log_tensor = R"JIT(
977     def forward(self, inp: Tensor):
978         a = torch.log(inp).clone()
979         return (a)
980   )JIT";
981 
982   // Ensure that the input values are valid.
983   auto a = at::abs(at::randn({2, 3}));
984   auto b = at::abs(at::randn({4, 3, 2}));
985 
986   std::vector<IValue> args{a};
987   testStaticRuntime(log_tensor, args);
988   testStaticRuntime(log_tensor, args, {b});
989 }
990 
TEST(StaticRuntime,Sub)991 TEST(StaticRuntime, Sub) {
992   const auto sub_tensor = R"JIT(
993     def forward(self, a: Tensor, b: Tensor):
994         return torch.sub(a, b).clone()
995   )JIT";
996 
997   const auto sub_scalar = R"JIT(
998     def forward(self, a: Tensor, b: int):
999         return torch.sub(a, b).clone()
1000   )JIT";
1001 
1002   const auto sub_tensor_alpha = R"JIT(
1003     def forward(self, a: Tensor, b: Tensor, c: float):
1004         return torch.sub(a, b, alpha=c).clone()
1005   )JIT";
1006 
1007   const auto sub_scalar_alpha = R"JIT(
1008     def forward(self, a: Tensor, b: float, c: int):
1009         return torch.sub(a, b, alpha=c).clone()
1010   )JIT";
1011 
1012   const auto sub_two_scalars = R"JIT(
1013     def forward(self, a: int, b: int):
1014         return (a - b - b)
1015   )JIT";
1016 
1017   auto a = at::randn({2, 3});
1018   auto b = at::randn({2, 3});
1019   auto c = at::randn({4, 3, 2});
1020   auto d = at::randn({4, 3, 2});
1021 
1022   std::vector<IValue> args0{a, b};
1023   testStaticRuntime(sub_tensor, args0);
1024   testStaticRuntime(sub_tensor, args0, {c, d});
1025 
1026   std::vector<IValue> args1{a, 3};
1027   testStaticRuntime(sub_scalar, args1);
1028   testStaticRuntime(sub_scalar, args1, {c, 4});
1029 
1030   std::vector<IValue> args2{a, b, 2.3};
1031   testStaticRuntime(sub_tensor_alpha, args2);
1032   testStaticRuntime(sub_tensor_alpha, {c, d, 3.1});
1033 
1034   std::vector<IValue> args3{a, 2.3, 4};
1035   testStaticRuntime(sub_scalar_alpha, args3);
1036   testStaticRuntime(sub_scalar_alpha, {c, 1.3, 2});
1037 
1038   std::vector<IValue> args4{1, 2};
1039   testStaticRuntime(sub_two_scalars, args4);
1040 }
1041 
TEST(StaticRuntime,NanToNum)1042 TEST(StaticRuntime, NanToNum) {
1043   const auto nan_to_num_script = R"JIT(
1044     def forward(self, a: Tensor, nan: float, posinf: float, neginf: float):
1045         return torch.nan_to_num(a, nan, posinf, neginf).clone()
1046   )JIT";
1047 
1048   const auto inf = std::numeric_limits<double>::infinity();
1049   const auto nan = std::numeric_limits<double>::quiet_NaN();
1050 
1051   auto a = torch::tensor({{1.0, nan}, {-inf, inf}});
1052   auto b = at::randn({3, 6});
1053   float* b_data = b.data_ptr<float>();
1054   b_data[0] = nan;
1055   b_data[4] = -inf;
1056   b_data[11] = inf;
1057   b_data[13] = nan;
1058 
1059   std::vector<IValue> args1{a, 1.0, 2.0, -2.0};
1060   std::vector<IValue> args2{b, 1.0, 2.0, -2.0};
1061 
1062   testStaticRuntime(
1063       nan_to_num_script,
1064       args1,
1065       /*args2*/ {},
1066       /*use_allclose*/ true,
1067       /*use_equalnan*/ true);
1068   testStaticRuntime(
1069       nan_to_num_script,
1070       args1,
1071       args2,
1072       /*use_allclose*/ true,
1073       /*use_equalnan*/ true);
1074 }
1075 
TEST(StaticRuntime,Stack)1076 TEST(StaticRuntime, Stack) {
1077   const auto stack_dim = R"JIT(
1078     def forward(self, a: Tensor, b: Tensor, dim: int):
1079         inputs = [a]
1080         inputs.append(b) # mutation to avoid using VarStack
1081         return torch.stack(inputs, dim = dim).clone()
1082   )JIT";
1083 
1084   const auto stack_three = R"JIT(
1085     def forward(self, a: Tensor, b: Tensor, c: Tensor):
1086         inputs = [a, b]
1087         inputs.append(c) # mutation to avoid using VarStack
1088         return torch.stack(inputs).clone()
1089   )JIT";
1090 
1091   auto a = at::randn({2, 2});
1092   auto b = at::randn({2, 2});
1093   auto c = at::randn({2, 2});
1094 
1095   auto d = at::randn({3, 3, 3});
1096   auto e = at::randn({3, 3, 3});
1097   auto f = at::randn({3, 3, 3});
1098 
1099   std::vector<IValue> args1_dim{a, b, 0};
1100   std::vector<IValue> args2_dim{d, e, 1};
1101   std::vector<IValue> args_dim_negative{d, e, -1};
1102 
1103   std::vector<IValue> args1_three_tensors{a, b, c};
1104   std::vector<IValue> args2_three_tensors{d, e, f};
1105 
1106   testStaticRuntime(stack_dim, args1_dim);
1107   testStaticRuntime(stack_dim, args1_dim, args2_dim);
1108 
1109   testStaticRuntime(stack_dim, args_dim_negative);
1110 
1111   testStaticRuntime(stack_three, args1_three_tensors);
1112   testStaticRuntime(stack_three, args1_three_tensors, args2_three_tensors);
1113 }
1114 
TEST(StaticRuntime,ReLU)1115 TEST(StaticRuntime, ReLU) {
1116   const auto relu_script = R"JIT(
1117     def forward(self, a: Tensor):
1118         return torch.relu(a).clone()
1119   )JIT";
1120   auto a = at::randint(-10, 10, {2, 4});
1121   auto b = at::randint(-10, 10, {3, 6});
1122 
1123   std::vector<IValue> args1{a};
1124   std::vector<IValue> args2{b};
1125 
1126   testStaticRuntime(relu_script, args1);
1127   testStaticRuntime(relu_script, args1, args2);
1128 }
1129 
TEST(StaticRuntime,Tanh)1130 TEST(StaticRuntime, Tanh) {
1131   const auto tanh_script = R"JIT(
1132     def forward(self, a):
1133         return torch.tanh(a).clone()
1134   )JIT";
1135   auto a = at::randn({2, 2});
1136   auto b = at::randn({3, 3, 3});
1137 
1138   std::vector<IValue> args1{a};
1139   std::vector<IValue> args2{b};
1140 
1141   testStaticRuntime(tanh_script, args1, /*args2*/ {}, /*use_allclose*/ true);
1142   testStaticRuntime(tanh_script, args1, args2, /*use_allclose*/ true);
1143 }
1144 
TEST(StaticRuntime,Norm)1145 TEST(StaticRuntime, Norm) {
1146   const auto norm_2arg = R"JIT(
1147     def forward(self, a: Tensor, p: int):
1148         return torch.norm(a, p).clone()
1149   )JIT";
1150 
1151   const auto norm_3arg = R"JIT(
1152     def forward(self, a: Tensor, p: int, dtype: int):
1153         return torch.norm(a, p, dtype=dtype).clone()
1154   )JIT";
1155 
1156   const auto norm_4arg = R"JIT(
1157     def forward(self, a: Tensor, p: int, dim: List[int], keepdim: bool):
1158         return torch.norm(a, p, dim, keepdim).clone()
1159   )JIT";
1160 
1161   const auto norm_5arg = R"JIT(
1162     def forward(self, a: Tensor, p: int, dim: List[int], keepdim: bool, dtype: int):
1163         return torch.norm(a, p, dim, keepdim, dtype=dtype).clone()
1164   )JIT";
1165 
1166   auto a = at::randn({2, 3});
1167   auto b = at::randn({4, 3, 5});
1168   auto dim = std::vector<int64_t>({1});
1169   auto dtype = at::ScalarType::Float;
1170 
1171   std::vector<IValue> args2{a, 2};
1172   testStaticRuntime(norm_2arg, args2);
1173   testStaticRuntime(norm_2arg, args2, {b, 2}, false, false, false);
1174 
1175   std::vector<IValue> args3{a, 2, dtype};
1176   testStaticRuntime(norm_3arg, args3);
1177   testStaticRuntime(norm_3arg, args3, {b, 2, dtype}, false, false, false);
1178 
1179   std::vector<IValue> args4{a, 3, dim, false};
1180   testStaticRuntime(norm_4arg, args4);
1181   testStaticRuntime(norm_4arg, args4, {b, 3, dim, false});
1182 
1183   std::vector<IValue> args5{a, 4, dim, true, dtype};
1184   testStaticRuntime(norm_5arg, args5);
1185   testStaticRuntime(norm_5arg, args5, {b, 4, dim, true, dtype});
1186 }
1187 
TEST(StaticRuntime,Reshape)1188 TEST(StaticRuntime, Reshape) {
1189   const auto reshape_script_1 = R"JIT(
1190     def forward(self, a: Tensor, shape: List[int]):
1191         b = a.reshape(shape)
1192         return b + b
1193   )JIT";
1194 
1195   const auto reshape_script_2 = R"JIT(
1196     def forward(self, a: Tensor, shape: List[int]):
1197         b = a.transpose(0, 1)
1198         return b.reshape(shape)
1199   )JIT";
1200 
1201   const auto reshape_script_3 = R"JIT(
1202     def forward(self, inp: Tensor, shape: List[int]):
1203         a = inp + inp
1204         b = a.reshape(shape)
1205         c = a.reshape(shape)
1206         d = c + c
1207         e = d + d
1208         f = e * e
1209         g = f * f
1210         return b.reshape(shape), g
1211   )JIT";
1212 
1213   // exercise reshape_copy and flatten_copy
1214   const auto reshape_script_4 = R"JIT(
1215     def forward(self, inp: Tensor, shape: List[int]):
1216         k = inp + inp
1217         a = k + k
1218         b = a.reshape(shape)
1219         c = a.flatten().reshape(shape)
1220         return b + c
1221   )JIT";
1222 
1223   // exercise reshape_copy
1224   const auto reshape_script_5 = R"JIT(
1225     def forward(self, inp: Tensor, shape: List[int]):
1226         a = inp + inp
1227         b = a.reshape(shape)
1228         c = a.reshape(shape).relu()
1229         d = c + c
1230         e = d + d
1231         f = e * e
1232         g = f * f
1233         return g
1234   )JIT";
1235 
1236   const auto reshape_inplace_script = R"JIT(
1237     def forward(self, inp: Tensor, shape: List[int]):
1238         a = inp + inp
1239         b = a.reshape(shape)
1240         c = b.sigmoid_()
1241         d = c + c
1242         e = a + a
1243         f = b + b
1244         return (d, e, f)
1245   )JIT";
1246 
1247   // b is in_contiguous
1248   const auto reshape_incontiguous_script = R"JIT(
1249     def forward(self, a: Tensor, shape: List[int]):
1250         b = a.transpose(0, 1)
1251         c = b.reshape(shape)
1252         c = c.relu()
1253         return (c)
1254   )JIT";
1255 
1256   auto a = at::randn({2, 3});
1257   auto b = std::vector<int64_t>({3, 2});
1258   std::vector<IValue> args{a, b};
1259 
1260   auto c = at::randn({4, 5});
1261   auto d = std::vector<int64_t>({5, 1, 2, 2});
1262   std::vector<IValue> args1{c, d};
1263 
1264   testStaticRuntime(reshape_script_1, args);
1265   testStaticRuntime(reshape_script_2, args);
1266   testStaticRuntime(reshape_script_3, args);
1267   testStaticRuntime(reshape_script_4, args);
1268   testStaticRuntime(reshape_script_5, args);
1269   testStaticRuntime(reshape_inplace_script, args);
1270   testStaticRuntime(reshape_incontiguous_script, args);
1271 
1272   testStaticRuntime(reshape_script_1, args, args1);
1273   testStaticRuntime(reshape_script_2, args, args1);
1274   testStaticRuntime(reshape_script_3, args, args1);
1275   testStaticRuntime(reshape_script_4, args, args1);
1276   testStaticRuntime(reshape_script_5, args, args1);
1277   testStaticRuntime(reshape_inplace_script, args, args1);
1278   testStaticRuntime(reshape_incontiguous_script, args, args1);
1279 }
1280 
TEST(StaticRuntime,Repeat)1281 TEST(StaticRuntime, Repeat) {
1282   const std::string repeat = R"JIT(
1283     def forward(self, a: Tensor, repeats: List[int]):
1284         return torch.repeat(a, repeats).clone()
1285   )JIT";
1286 
1287   auto a = at::randn({2, 3});
1288   auto b = at::randn({4, 3});
1289   auto c = std::vector<int64_t>({1, 2});
1290   auto d = std::vector<int64_t>({2, 3});
1291   std::vector<IValue> args1{a, c};
1292   std::vector<IValue> args2{b, d};
1293 
1294   testStaticRuntime(repeat, args1);
1295   testStaticRuntime(repeat, args2);
1296   testStaticRuntime(repeat, args1, args2);
1297 }
1298 
TEST(StaticRuntime,Flatten)1299 TEST(StaticRuntime, Flatten) {
1300   // exercise flatten_copy
1301   const auto flatten_script_1 = R"JIT(
1302     def forward(self, a: Tensor, start_dim: int, end_dim: int):
1303         b = a * a
1304         c = torch.flatten(b, start_dim, end_dim)
1305         d = torch.relu(c)
1306         return d
1307   )JIT";
1308 
1309   const auto flatten_script_2 = R"JIT(
1310     def forward(self, a: Tensor, start_dim: int, end_dim: int):
1311         b = a.transpose(0, 1)
1312         return torch.flatten(b, start_dim, end_dim).clone()
1313   )JIT";
1314 
1315   auto test_flatten =
1316       [&](std::vector<int64_t> shape, int64_t start_dim, int64_t end_dim) {
1317         std::vector<int64_t> shape1(shape);
1318         if (shape1.size() > 0) {
1319           shape1[0] *= 6;
1320         }
1321         auto a = at::randn(shape);
1322         auto b = at::randn(shape1);
1323         std::vector<IValue> args{a, start_dim, end_dim};
1324         bool check_resize = shape1.size() > 0;
1325         testStaticRuntime(flatten_script_1, args);
1326         testStaticRuntime(
1327             flatten_script_1,
1328             args,
1329             {b, start_dim, end_dim},
1330             false, /* use_allclose */
1331             false, /* use_equalnan */
1332             check_resize);
1333         if (shape.size() > 2) {
1334           testStaticRuntime(flatten_script_2, args);
1335           testStaticRuntime(flatten_script_2, args, {b, start_dim, end_dim});
1336         }
1337       };
1338 
1339   test_flatten({2, 3}, 0, 1);
1340   test_flatten({2, 1, 3}, 1, 2);
1341   test_flatten({0, 1, 3, 0}, 1, 2);
1342   test_flatten({2, 3}, 1, 1);
1343   test_flatten({}, 0, 0);
1344 }
1345 
TEST(StaticRuntime,pow)1346 TEST(StaticRuntime, pow) {
1347   const auto pow_script_ten_sca = R"JIT(
1348     def forward(self, input : Tensor, exponent : int):
1349         return torch.pow(input, exponent).clone()
1350   )JIT";
1351 
1352   const auto pow_script_ten_ten = R"JIT(
1353     def forward(self, input : Tensor, exponent : Tensor):
1354         return torch.pow(input, exponent).clone()
1355   )JIT";
1356 
1357   const auto pow_script_sca_ten = R"JIT(
1358     def forward(self, input : int, exponent : Tensor):
1359         return torch.pow(input, exponent).clone()
1360   )JIT";
1361 
1362   auto a = at::randn({2, 3});
1363   auto b = at::randn({2, 3});
1364   auto c = at::randn({4, 3, 2});
1365   auto d = at::randn({4, 3, 2});
1366 
1367   std::vector<IValue> args0{a, 4};
1368   testStaticRuntime(pow_script_ten_sca, args0);
1369   testStaticRuntime(pow_script_ten_sca, args0, {c, 4});
1370 
1371   std::vector<IValue> args1{at::abs(a), b};
1372   testStaticRuntime(pow_script_ten_ten, args1);
1373   testStaticRuntime(pow_script_ten_ten, args1, {at::abs(c), d});
1374 
1375   std::vector<IValue> args2{5, b};
1376   testStaticRuntime(pow_script_sca_ten, args2);
1377   testStaticRuntime(pow_script_sca_ten, args2, {3, d});
1378 }
1379 
TEST(StaticRuntime,to)1380 TEST(StaticRuntime, to) {
1381   const auto to_script_dtype = R"JIT(
1382     def forward(self, input: Tensor, dtype: int, non_blocking: bool, copy: bool, memory_format: int):
1383         a = input + input
1384         return torch.to(a, dtype, non_blocking, copy, memory_format).clone()
1385   )JIT";
1386 
1387   const auto to_script_dtype_strided = R"JIT(
1388     def forward(self, input: Tensor, dtype: int, non_blocking: bool, copy: bool, memory_format: int):
1389         b = input.permute(0, 2, 3, 1)
1390         return torch.to(b, dtype, non_blocking, copy, memory_format).clone()
1391   )JIT";
1392 
1393   const auto to_script_prim_dtype = R"JIT(
1394     def forward(self, input:Tensor, dtype: Optional[int], non_blocking: bool, copy: bool):
1395         a = input + input
1396         return torch.to(a, dtype, non_blocking, copy).clone()
1397   )JIT";
1398 
1399   const auto to_script_other = R"JIT(
1400     def forward(self, input:Tensor, other: Tensor, non_blocking: bool, copy: bool, memory_format: int):
1401         a = input + input
1402         return torch.to(a, other, non_blocking, copy, memory_format).clone()
1403   )JIT";
1404 
1405   // if input is float tensor, b could be alias of a
1406   const auto to_script_alias = R"JIT(
1407     def forward(self, input:Tensor):
1408         a = input + input
1409         b = a.float()
1410         c = b * b
1411         return (c)
1412   )JIT";
1413 
1414   const auto to_script_fails_managed_output_check = R"JIT(
1415     def forward(self, a, b):
1416         d = a.half() * b.half()
1417         e = d.float()
1418         return e
1419   )JIT";
1420 
1421   const auto to_script_select_tensor_output_into_tuple = R"JIT(
1422     def forward(self, a, b):
1423         d = a.half() * b.half()
1424         e = d.float()
1425         return (d, e)
1426   )JIT";
1427 
1428   const auto to_script_memory_planning_fail = R"JIT(
1429     def forward(self, a, b):
1430         d = a.half() * b.half()
1431         e = d.float().relu()
1432         return e
1433   )JIT";
1434 
1435   auto test_to = [&](at::ScalarType b, bool c, bool d, c10::MemoryFormat e) {
1436     auto a = at::randn({4, 3, 1, 2});
1437     auto other = at::randn({4, 3, 1, 2}).to(b);
1438     auto a2 = at::randn({3, 2, 2, 4});
1439     auto a2_other = at::randn({3, 2, 2, 4}).to(b);
1440 
1441     std::vector<IValue> args0{a, b, c, d, e};
1442     std::vector<IValue> args1{a, b, c, d};
1443     std::vector<IValue> args2{a, other, c, d, e};
1444     std::vector<IValue> args2WithDifferentOtherType{
1445         a, at::randn({4, 3, 1, 2}, ScalarType::Double), c, d, e};
1446     std::vector<IValue> args3{a, std::nullopt, c, d};
1447 
1448     std::vector<IValue> args0WithInt{a, ScalarType::Int, c, d, e};
1449     testStaticRuntime(
1450         to_script_dtype,
1451         args0,
1452         args0WithInt,
1453         /* default for use_allclose */ false,
1454         /* default for use_equalnan */ false,
1455         /* check_resize */ false);
1456     testStaticRuntime(to_script_dtype_strided, args0);
1457     testStaticRuntime(to_script_prim_dtype, args1);
1458     if (!d) {
1459       testStaticRuntime(to_script_prim_dtype, args3);
1460     }
1461     // Second set of args tests case where the `other` tensor's dtype
1462     // changes between iterations.
1463     testStaticRuntime(
1464         to_script_other,
1465         args2,
1466         args2WithDifferentOtherType,
1467         /* default for use_allclose */ false,
1468         /* default for use_equalnan */ false,
1469         /* check_resize */ false);
1470     testStaticRuntime(to_script_alias, {a});
1471 
1472     testStaticRuntime(to_script_memory_planning_fail, {a, a});
1473     testStaticRuntime(to_script_fails_managed_output_check, {a, a});
1474     testStaticRuntime(to_script_select_tensor_output_into_tuple, {a, a});
1475 
1476     // dynamic shapes
1477     testStaticRuntime(to_script_dtype, args0, {a2, b, c, d, e});
1478     testStaticRuntime(to_script_dtype_strided, args0, {a2, b, c, d, e});
1479     testStaticRuntime(to_script_prim_dtype, args1, {a2, b, c, d});
1480     if (!d) {
1481       testStaticRuntime(to_script_prim_dtype, args3, {a2, std::nullopt, c, d});
1482     }
1483     testStaticRuntime(to_script_other, args2, {a2, a2_other, c, d, e});
1484     testStaticRuntime(to_script_alias, {a}, {a2});
1485   };
1486   for (const bool non_blocking : {false, true}) {
1487     for (const bool copy : {false, true}) {
1488       // float->float, NCHW->NHWC
1489       test_to(
1490           at::ScalarType::Float,
1491           non_blocking,
1492           copy,
1493           c10::MemoryFormat::ChannelsLast);
1494       // float->half
1495       test_to(
1496           at::ScalarType::Half,
1497           non_blocking,
1498           copy,
1499           c10::MemoryFormat::Preserve);
1500       // float->float
1501       test_to(
1502           at::ScalarType::Float,
1503           non_blocking,
1504           copy,
1505           c10::MemoryFormat::Contiguous);
1506       test_to(
1507           at::ScalarType::Bool,
1508           non_blocking,
1509           copy,
1510           c10::MemoryFormat::Contiguous);
1511       // TODO: check if fbgemm is enabled properly in this case
1512       // half->float, NCHW->NHWC
1513       test_to(
1514           at::ScalarType::Half,
1515           non_blocking,
1516           copy,
1517           c10::MemoryFormat::ChannelsLast);
1518     }
1519   }
1520 }
1521 
TEST(StaticRuntime,ExpandAs)1522 TEST(StaticRuntime, ExpandAs) {
1523   const auto expand_as_script = R"JIT(
1524     def forward(self, input: Tensor, other:Tensor):
1525         a = input.expand_as(other)
1526         return a.clone()
1527   )JIT";
1528 
1529   auto a = at::randn({3, 1});
1530   auto b = at::randn({3, 2});
1531   auto c = at::randn({4, 1});
1532   auto d = at::randn({4, 2});
1533   std::vector<IValue> args{a, b};
1534   std::vector<IValue> args2{c, d};
1535   testStaticRuntime(expand_as_script, args);
1536   testStaticRuntime(expand_as_script, args, args2);
1537 }
1538 
TEST(StaticRuntime,Full)1539 TEST(StaticRuntime, Full) {
1540   const auto full_script = R"JIT(
1541     def forward(self,
1542                 size: List[int],
1543                 fill_value: int,
1544                 dtype: Optional[int],
1545                 layout: Optional[int],
1546                 device: Optional[Device],
1547                 pin_memory: Optional[bool]):
1548         a = torch.full(size,
1549                       fill_value,
1550                       dtype=dtype,
1551                       layout=layout,
1552                       device=device,
1553                       pin_memory=pin_memory)
1554         return (a.clone())
1555   )JIT";
1556 
1557   auto cpu = at::Device(DeviceType::CPU);
1558   c10::List<int64_t> size0{2, 5};
1559   std::vector<IValue> args{
1560       size0, 4, at::ScalarType::Int, at::kStrided, cpu, false};
1561   std::vector<IValue> args1{
1562       size0, 4, at::ScalarType::Float, at::kStrided, cpu, false};
1563   c10::List<int64_t> size1{5, 6};
1564   std::vector<IValue> args2{
1565       size1, 5, at::ScalarType::Float, at::kStrided, cpu, false};
1566   testStaticRuntime(full_script, args);
1567   testStaticRuntime(
1568       full_script,
1569       args,
1570       args1,
1571       /*use_allclose=*/false,
1572       /*use_equalnan=*/false,
1573       /*check_resize=*/false);
1574   testStaticRuntime(full_script, args, args2);
1575 }
1576 
TEST(StaticRuntime,FullLike)1577 TEST(StaticRuntime, FullLike) {
1578   const auto full_like_script = R"JIT(
1579     def forward(self,
1580                 a: Tensor,
1581                 fill_value: int,
1582                 dtype: Optional[int],
1583                 layout: Optional[int],
1584                 device: Optional[Device],
1585                 pin_memory: Optional[bool],
1586                 memory_format: Optional[int]):
1587         b = torch.full_like(a,
1588                             fill_value,
1589                             dtype=dtype,
1590                             layout=layout,
1591                             device=device,
1592                             pin_memory=pin_memory,
1593                             memory_format=memory_format)
1594         return (b.clone())
1595   )JIT";
1596 
1597   auto a = at::randn({2, 3});
1598   auto b = at::randn({3, 4, 2});
1599   auto cpu = at::Device(DeviceType::CPU);
1600   std::vector<IValue> args{
1601       a,
1602       4,
1603       at::ScalarType::Int,
1604       at::kStrided,
1605       cpu,
1606       false,
1607       c10::MemoryFormat::Contiguous};
1608   std::vector<IValue> args1{
1609       a,
1610       4,
1611       at::ScalarType::Float,
1612       at::kStrided,
1613       cpu,
1614       false,
1615       c10::MemoryFormat::Contiguous};
1616   std::vector<IValue> args2{
1617       b,
1618       4,
1619       at::ScalarType::Float,
1620       at::kStrided,
1621       cpu,
1622       false,
1623       c10::MemoryFormat::Contiguous};
1624   testStaticRuntime(full_like_script, args);
1625   testStaticRuntime(
1626       full_like_script,
1627       args,
1628       args1,
1629       /*use_allclose=*/false,
1630       /*use_equalnan=*/false,
1631       /*check_resize=*/false);
1632   testStaticRuntime(full_like_script, args, args2);
1633 }
1634 
TEST(StaticRuntime,Ones)1635 TEST(StaticRuntime, Ones) {
1636   const auto script = R"JIT(
1637     def forward(self,
1638                 size: List[int],
1639                 dtype: Optional[int],
1640                 layout: Optional[int],
1641                 device: Optional[Device],
1642                 pin_memory: Optional[bool]):
1643         a = torch.ones(size,
1644                        dtype=dtype,
1645                        layout=layout,
1646                        device=device,
1647                        pin_memory=pin_memory)
1648         return (a.clone())
1649   )JIT";
1650 
1651   auto dtype = at::ScalarType::Int;
1652   auto cpu = at::Device(DeviceType::CPU);
1653   c10::List<int64_t> size0{2, 5};
1654   std::vector<IValue> args{size0, dtype, at::kStrided, cpu, false};
1655   c10::List<int64_t> size1{5, 6};
1656   std::vector<IValue> args2{size1, dtype, at::kStrided, cpu, false};
1657   testStaticRuntime(script, args);
1658   testStaticRuntime(script, args, args2);
1659 }
1660 
TEST(StaticRuntime,OnesLike)1661 TEST(StaticRuntime, OnesLike) {
1662   const auto script = R"JIT(
1663     def forward(self,
1664                 input: Tensor,
1665                 dtype: Optional[int],
1666                 layout: Optional[int],
1667                 device: Optional[Device],
1668                 pin_memory: Optional[bool],
1669                 memory_format: Optional[int]):
1670         a = torch.ones_like(input,
1671                             dtype=dtype,
1672                             layout=layout,
1673                             device=device,
1674                             pin_memory=pin_memory,
1675                             memory_format=memory_format)
1676         return (a.clone())
1677   )JIT";
1678 
1679   auto cpu = at::Device(DeviceType::CPU);
1680   auto input0 = at::randn({2, 5});
1681   std::vector<IValue> args{
1682       input0,
1683       at::ScalarType::Int,
1684       at::kStrided,
1685       cpu,
1686       false,
1687       c10::MemoryFormat::Contiguous};
1688   std::vector<IValue> args1{
1689       input0,
1690       at::ScalarType::Float,
1691       at::kStrided,
1692       cpu,
1693       false,
1694       c10::MemoryFormat::Contiguous};
1695   auto input1 = at::randn({5, 6});
1696   std::vector<IValue> args2{
1697       input1,
1698       at::ScalarType::Float,
1699       at::kStrided,
1700       cpu,
1701       false,
1702       c10::MemoryFormat::Contiguous};
1703   testStaticRuntime(script, args);
1704   testStaticRuntime(
1705       script,
1706       args,
1707       args1,
1708       /*use_allclose=*/false,
1709       /*use_equalnan=*/false,
1710       /*check_resize=*/false);
1711   testStaticRuntime(script, args, args2);
1712 }
1713 
TEST(StaticRuntime,Zeros)1714 TEST(StaticRuntime, Zeros) {
1715   const auto script = R"JIT(
1716     def forward(self,
1717                 size: List[int],
1718                 dtype: Optional[int],
1719                 layout: Optional[int],
1720                 device: Optional[Device],
1721                 pin_memory: Optional[bool]):
1722         a = torch.zeros(size,
1723                        dtype=dtype,
1724                        layout=layout,
1725                        device=device,
1726                        pin_memory=pin_memory)
1727         return (a.clone())
1728   )JIT";
1729 
1730   auto cpu = at::Device(DeviceType::CPU);
1731   c10::List<int64_t> size0{2, 5};
1732   std::vector<IValue> args{
1733       size0, at::ScalarType::Int, at::kStrided, cpu, false};
1734   std::vector<IValue> args1{
1735       size0, at::ScalarType::Float, at::kStrided, cpu, false};
1736   c10::List<int64_t> size1{5, 6};
1737   std::vector<IValue> args2{
1738       size1, at::ScalarType::Float, at::kStrided, cpu, false};
1739   testStaticRuntime(script, args);
1740   testStaticRuntime(
1741       script,
1742       args,
1743       args1,
1744       /*use_allclose=*/false,
1745       /*use_equalnan=*/false,
1746       /*check_resize=*/false);
1747   testStaticRuntime(script, args, args2);
1748 }
1749 
TEST(StaticRuntime,Linear)1750 TEST(StaticRuntime, Linear) {
1751   const auto linear_script = R"JIT(
1752     def forward(self, inp: Tensor, weights: Tensor, bias: Optional[Tensor]) -> Tensor:
1753         return torch.linear(inp, weights, bias).clone()
1754   )JIT";
1755 
1756   auto input = at::randn({1, 2});
1757   auto weights = at::randn({1, 2});
1758   auto bias = at::randn({1, 1});
1759 
1760   std::vector<IValue> args{input, weights, bias};
1761   std::vector<IValue> args_no_bias{input, weights, std::nullopt};
1762 
1763   auto input2 = at::randn({6, 3});
1764   auto weights2 = at::randn({6, 3});
1765   auto bias2 = at::randn({6, 6});
1766 
1767   std::vector<IValue> args2{input2, weights2, bias2};
1768   std::vector<IValue> args2_no_bias{input2, weights2, std::nullopt};
1769 
1770   testStaticRuntime(linear_script, args);
1771   testStaticRuntime(linear_script, args_no_bias);
1772 
1773   testStaticRuntime(linear_script, args, args2);
1774   testStaticRuntime(linear_script, args, args2_no_bias);
1775 }
1776 
TEST(StaticRuntime,VarCat)1777 TEST(StaticRuntime, VarCat) {
1778   const auto var_cat_script = R"JIT(
1779     def forward(self, inp1: Tensor, inp2: Tensor, dim: int):
1780       return torch.cat([inp1, inp2], dim).clone()
1781   )JIT";
1782 
1783   // 2D tensors - cat dim = 0
1784   std::vector<IValue> args1 = {at::randn({4, 6}), at::randn({5, 6}), 0};
1785   testStaticRuntime(var_cat_script, args1);
1786 
1787   // 3D tensors - cat dim = 1
1788   std::vector<IValue> args2 = {at::randn({4, 5, 6}), at::randn({4, 8, 6}), 1};
1789   testStaticRuntime(var_cat_script, args2);
1790 
1791   // 3D tensors - cat dim = 2
1792   std::vector<IValue> args3 = {at::randn({4, 5, 6}), at::randn({4, 5, 7}), 2};
1793   testStaticRuntime(var_cat_script, args3);
1794 
1795   // Negative dim
1796   std::vector<IValue> args4 = {at::randn({4, 5, 6}), at::randn({4, 5, 7}), -1};
1797   testStaticRuntime(var_cat_script, args4);
1798 
1799   testStaticRuntime(var_cat_script, args1, args2);
1800 }
1801 
TEST(StaticRuntime,LeakyReLU)1802 TEST(StaticRuntime, LeakyReLU) {
1803   torch::jit::Module mod = getLeakyReLUConstScriptModel();
1804   auto inputs = torch::randn({2, 2});
1805 
1806   // run jit graph executor
1807   std::vector<at::IValue> input_ivalues({inputs});
1808   at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
1809 
1810   // run static runtime
1811   std::vector<c10::IValue> input_tensors({inputs});
1812   torch::jit::StaticModule smod(mod);
1813   at::Tensor output_2 = smod(input_tensors, {}).toTensor();
1814   smod.runtime().check_for_memory_leak();
1815   EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
1816 }
1817 
createProcessedNodeInputs(c10::ArrayRef<uint16_t> inputs)1818 static ProcessedNodeInputs createProcessedNodeInputs(
1819     c10::ArrayRef<uint16_t> inputs) {
1820   ProcessedNodeInputs result(inputs.size());
1821   for (const auto idx : c10::irange(inputs.size())) {
1822     result[idx] = inputs[idx];
1823   }
1824   return result;
1825 }
1826 
checkProcessedNodeInputs(const ProcessedNodeInputs & io,c10::ArrayRef<uint16_t> inputs)1827 static void checkProcessedNodeInputs(
1828     const ProcessedNodeInputs& io,
1829     c10::ArrayRef<uint16_t> inputs) {
1830   ASSERT_EQ(inputs.size(), io.size());
1831   for (const auto idx : c10::irange(inputs.size())) {
1832     EXPECT_EQ(inputs[idx], io[idx]);
1833   }
1834 }
1835 
testProcessedNodeInputsRoundTrip(c10::ArrayRef<uint16_t> inputs)1836 static void testProcessedNodeInputsRoundTrip(c10::ArrayRef<uint16_t> inputs) {
1837   auto io = createProcessedNodeInputs(inputs);
1838   checkProcessedNodeInputs(io, inputs);
1839 
1840   ProcessedNodeInputs copied(io);
1841   checkProcessedNodeInputs(copied, inputs);
1842   ProcessedNodeInputs moved(std::move(io));
1843   checkProcessedNodeInputs(moved, inputs);
1844 }
1845 
TEST(ProcessedNodeInputs,Basic)1846 TEST(ProcessedNodeInputs, Basic) {
1847   std::vector<std::vector<uint16_t>> testCases = {
1848       {}, // empty
1849       {0xABCD, 0x5a5a}, // inline
1850       {0x11, 0x22, 0x33, 0x44, 0x55}, // max inline size
1851       {0x11, 0x22, 0x33, 0x44, 0x55, 0x66}, // minimum outline size
1852       std::vector<uint16_t>(100, 0x5a), // large outline size
1853   };
1854 
1855   for (const auto& values : testCases) {
1856     testProcessedNodeInputsRoundTrip(values);
1857     for (const auto& values2 : testCases) {
1858       auto from = createProcessedNodeInputs(values);
1859       auto to = createProcessedNodeInputs(values2);
1860 
1861       to = from;
1862       checkProcessedNodeInputs(to, values);
1863 
1864       auto toMoveInto = createProcessedNodeInputs(values2);
1865       toMoveInto = std::move(from);
1866       checkProcessedNodeInputs(toMoveInto, values);
1867     }
1868   }
1869 }
1870 
TEST(StaticRuntime,isinstance)1871 TEST(StaticRuntime, isinstance) {
1872   const auto isinstance_int_script = R"JIT(
1873     def forward(self, a: Any):
1874         return isinstance(a, int)
1875   )JIT";
1876 
1877   const auto isinstance_tensor_script = R"JIT(
1878     def forward(self, a: Any):
1879         return isinstance(a, torch.Tensor)
1880   )JIT";
1881 
1882   const auto isinstance_many_types_script = R"JIT(
1883     def forward(self, a: Any):
1884         return isinstance(a, (bool, int))
1885   )JIT";
1886 
1887   auto a = at::randn({2, 2});
1888   auto b = at::randn({2, 2, 2});
1889 
1890   std::vector<at::IValue> args{a};
1891   std::vector<at::IValue> args2{b};
1892 
1893   testStaticRuntime(isinstance_int_script, args);
1894   testStaticRuntime(isinstance_int_script, args, args2);
1895 
1896   testStaticRuntime(isinstance_tensor_script, args);
1897   testStaticRuntime(isinstance_tensor_script, args, args2);
1898 
1899   testStaticRuntime(isinstance_many_types_script, args);
1900   testStaticRuntime(isinstance_many_types_script, args, args2);
1901 }
1902 
TEST(StaticRuntime,TypeCheck)1903 TEST(StaticRuntime, TypeCheck) {
1904   const auto typecheck_ir = R"IR(
1905   graph(%a.1 : Tensor,
1906         %b.1 : Tensor):
1907     %t0 : Float(2, 2, strides=[2, 1], device=cpu), %t1 : Float(3, 3, strides=[3, 1]), %type_matched : bool = prim::TypeCheck[types=[Float(2, 2, strides=[2, 1], device=cpu), Float(3, 3, strides=[3, 1])]](%a.1, %b.1)
1908     return (%t0, %t1, %type_matched)
1909   )IR";
1910 
1911   auto a = at::zeros({2, 2}, at::kFloat);
1912   a.to(at::kCPU);
1913   auto b = at::ones({3, 3}, at::kFloat);
1914   auto c = at::ones({2, 2, 2}, at::kFloat);
1915 
1916   std::vector<IValue> args_correct = {a, b};
1917   std::vector<IValue> args_incorrect = {a, c};
1918 
1919   testStaticRuntime(typecheck_ir, args_correct);
1920   testStaticRuntime(typecheck_ir, args_correct, args_incorrect);
1921 }
1922 
TEST(StaticRuntime,Index)1923 TEST(StaticRuntime, Index) {
1924   const auto index_without_none_script = R"JIT(
1925     def forward(self, a: Tensor, idx: Tensor):
1926         return a[idx].clone()
1927   )JIT";
1928 
1929   // Index with boolean mask
1930   auto a = at::arange(4, at::kFloat).view({2, 2});
1931   auto idx_a = torch::tensor({{0, 1}, {0, 0}}, at::kBool);
1932   std::vector<IValue> args_a{a, idx_a};
1933 
1934   // Index with tensor
1935   auto b = at::arange(27, at::kFloat).view({3, 3, 3});
1936   auto idx_b = torch::tensor({0, 1, 2}, at::kLong);
1937   std::vector<IValue> args_b{b, idx_b};
1938 
1939   testStaticRuntime(index_without_none_script, args_a);
1940   testStaticRuntime(index_without_none_script, args_a, args_b);
1941 
1942   const auto index_with_none_script = R"JIT(
1943     def forward(self, a: Tensor, idx: Tensor, none: Optional[Tensor]):
1944         return a[idx, none].clone()
1945   )JIT";
1946 
1947   // Index with None
1948   // When indexing with none, the shape of `f` becomes [2, 1, 2],
1949   // so the mask must be reshaped appropriately.
1950   auto f = at::arange(4, at::kFloat).view({2, 1, 2});
1951   auto idx_f_reshape = torch::tensor({{{0, 1}}, {{0, 0}}}, at::kBool);
1952   std::vector<IValue> args_f_with_none{f, idx_f_reshape};
1953   args_f_with_none.emplace_back();
1954 
1955   testStaticRuntime(index_with_none_script, args_f_with_none);
1956   testStaticRuntime(
1957       index_with_none_script,
1958       args_f_with_none,
1959       {IValue(b), IValue(idx_b), IValue()});
1960 
1961   const auto index_with_two_tensors_script = R"JIT(
1962     def forward(self, a: Tensor, idx_a: Tensor, idx_b: Tensor):
1963         return a[idx_a, idx_b].clone()
1964   )JIT";
1965 
1966   // Index with multiple tensors
1967   const auto& c = a; // 2x2 tensor
1968   auto idx_c1 = torch::tensor({0, 0}, at::kLong);
1969   auto idx_c2 = torch::tensor({0}, at::kLong);
1970   std::vector<IValue> args_c{c, idx_c1, idx_c2};
1971 
1972   const auto& d = b; // 3x3x3 tensor
1973   auto idx_d1 = torch::tensor({{0, 0, 2}, {0, 1, 1}}, at::kLong);
1974   auto idx_d2 = torch::tensor({{1, 1, 0}, {1, 0, 2}}, at::kLong);
1975   std::vector<IValue> args_d{d, idx_d1, idx_d2};
1976 
1977   testStaticRuntime(index_with_two_tensors_script, args_c, args_d);
1978 }
1979 
TEST(StaticRuntime,IndexSelect)1980 TEST(StaticRuntime, IndexSelect) {
1981   const std::string script = R"IR(
1982     graph(%self: Tensor, %dim: int, %index: Tensor):
1983         %bias: None = prim::Constant()
1984         %ret = aten::index_select(%self, %dim, %index)
1985         %cloned = aten::clone(%ret, %bias)
1986         return (%cloned)
1987   )IR";
1988 
1989   auto self0 = at::rand({6});
1990   auto dim0 = 0;
1991   auto index0 = at::randint(0, 5, {6}, torch::kInt32);
1992   std::vector<IValue> args{self0, dim0, index0};
1993   testStaticRuntime(script, args);
1994 
1995   auto self1 = at::rand({128});
1996   auto dim1 = 0;
1997   auto index1 = at::randint(0, 127, {127}, torch::kInt32);
1998   std::vector<IValue> args2{self1, dim1, index1};
1999   testStaticRuntime(script, args, args2);
2000 }
2001 
TEST(StaticRuntime,ClampMin)2002 TEST(StaticRuntime, ClampMin) {
2003   const auto clamp_min_int_script = R"JIT(
2004     def forward(self, a: Tensor, b: int):
2005         return torch.clamp_min(a, b).clone()
2006   )JIT";
2007 
2008   const auto clamp_min_float_script = R"JIT(
2009     def forward(self, a: Tensor, b: float):
2010         return torch.clamp_min(a, b).clone()
2011   )JIT";
2012 
2013   auto a = at::randn({2, 2});
2014   auto b = at::randn({3, 3, 3});
2015   int scalar_int = 1;
2016   float scalar_float = 3.14;
2017 
2018   std::vector<IValue> args_a_int{a, scalar_int};
2019   std::vector<IValue> args_b_int{b, scalar_int};
2020 
2021   testStaticRuntime(clamp_min_int_script, args_a_int);
2022   testStaticRuntime(clamp_min_int_script, args_a_int, args_b_int);
2023 
2024   std::vector<IValue> args_a_float{a, scalar_float};
2025   std::vector<IValue> args_b_float{b, scalar_float};
2026 
2027   testStaticRuntime(clamp_min_float_script, args_a_float);
2028   testStaticRuntime(clamp_min_float_script, args_a_float, args_b_float);
2029 }
2030 
TEST(StaticRuntime,Argmin)2031 TEST(StaticRuntime, Argmin) {
2032   const auto argmin_script = R"JIT(
2033     def forward(self, a: Tensor):
2034         return torch.argmin(a).clone()
2035   )JIT";
2036 
2037   const auto argmin_with_dim_script = R"JIT(
2038     def forward(self, a: Tensor, dim: int):
2039         return torch.argmin(a, dim).clone()
2040   )JIT";
2041 
2042   const auto argmin_with_keep_dim_script = R"JIT(
2043     def forward(self, a: Tensor, dim: int):
2044         return torch.argmin(a, dim, True).clone()
2045   )JIT";
2046 
2047   auto a = at::randn({2, 2});
2048   auto b = at::randn({17, 2, 1});
2049 
2050   testStaticRuntime(argmin_script, {a});
2051   testStaticRuntime(
2052       argmin_script,
2053       {a},
2054       {b},
2055       /* use_allclose */ false,
2056       /* use_equalnan */ false,
2057       /* check_resize */ false);
2058 
2059   int dim_a = 0;
2060   int dim_b = 1;
2061 
2062   std::vector<IValue> args_a{a, dim_a};
2063   std::vector<IValue> args_b{b, dim_b};
2064 
2065   testStaticRuntime(argmin_with_dim_script, args_a);
2066   testStaticRuntime(argmin_with_dim_script, args_a, args_b);
2067 
2068   testStaticRuntime(argmin_with_keep_dim_script, args_a);
2069   testStaticRuntime(argmin_with_keep_dim_script, args_a, args_b);
2070 }
2071 
TEST(StaticRuntime,Softmax)2072 TEST(StaticRuntime, Softmax) {
2073   const auto softmax_script = R"JIT(
2074     def forward(self, a: Tensor, dim: int):
2075         return torch.softmax(a, dim).clone()
2076   )JIT";
2077 
2078   const auto softmax_script_with_dtype = R"JIT(
2079     def forward(self, a: Tensor, dim: int, dtype: int):
2080         return torch.softmax(a, dim, dtype=dtype).clone()
2081   )JIT";
2082 
2083   auto a = at::randn({2, 3});
2084   auto b = at::randn({3, 3, 3});
2085 
2086   testStaticRuntime(softmax_script, {a, 0});
2087   testStaticRuntime(softmax_script, {a, 1});
2088 
2089   testStaticRuntime(softmax_script, {b, 0});
2090   testStaticRuntime(softmax_script, {b, 1});
2091   testStaticRuntime(softmax_script, {b, 2});
2092 
2093   testStaticRuntime(softmax_script_with_dtype, {a, 1, at::ScalarType::Float});
2094   testStaticRuntime(softmax_script_with_dtype, {b, 1, at::ScalarType::Float});
2095 }
2096 
TEST(StaticRuntime,GetItem_Dict)2097 TEST(StaticRuntime, GetItem_Dict) {
2098   const auto getitem_dict_tensor_script = R"JIT(
2099     def forward(self, key: Tensor):
2100         d = {key: 1}
2101         return d[key]
2102   )JIT";
2103 
2104   const auto getitem_dict_int_script = R"JIT(
2105     def forward(self, key: int):
2106         d = {key: 1}
2107         return d[key]
2108   )JIT";
2109 
2110   const auto getitem_dict_str_script = R"JIT(
2111     def forward(self, key: str):
2112         d = {key: 1}
2113         return d[key]
2114   )JIT";
2115 
2116   int int_key = 0;
2117   std::string str_key = "str";
2118 
2119   // No need to test these multiple times, args are not tensors
2120   testStaticRuntime(getitem_dict_int_script, {int_key});
2121   testStaticRuntime(getitem_dict_str_script, {str_key});
2122 
2123   auto a = torch::tensor({1});
2124   auto b = torch::tensor({1, 1});
2125 
2126   testStaticRuntime(getitem_dict_tensor_script, {a});
2127   testStaticRuntime(getitem_dict_tensor_script, {a}, {b});
2128 }
2129 
TEST(StaticRuntime,GetItem_List)2130 TEST(StaticRuntime, GetItem_List) {
2131   const auto getitem_list_int_script = R"JIT(
2132     def forward(self, idx: int):
2133         lst = [1, 2, 3]
2134         return lst[idx]
2135   )JIT";
2136 
2137   const auto getitem_list_tensor_script = R"JIT(
2138     def forward(self, tensor: Tensor, idx: int):
2139         lst = [tensor, tensor]
2140         return lst[idx]
2141   )JIT";
2142 
2143   testStaticRuntime(getitem_list_int_script, {1});
2144   testStaticRuntime(getitem_list_int_script, {-1});
2145 
2146   auto a = torch::tensor({1});
2147   auto b = torch::tensor({1, 1});
2148 
2149   testStaticRuntime(getitem_list_tensor_script, {a, 1});
2150   testStaticRuntime(getitem_list_tensor_script, {a, 1}, {b, -1});
2151 }
2152 
TEST(StaticRuntime,Transpose)2153 TEST(StaticRuntime, Transpose) {
2154   const auto transpose_script = R"JIT(
2155     def forward(self, a: Tensor, dim1: int, dim2: int):
2156         return torch.transpose(a, dim1, dim2).clone()
2157   )JIT";
2158 
2159   auto a = at::randn({2, 2});
2160   int dim1_a = 0;
2161   int dim2_a = 1;
2162   std::vector<IValue> args_a{a, dim1_a, dim2_a};
2163 
2164   auto b = at::randn({3, 3, 3});
2165   int dim1_b = 0;
2166   int dim2_b = 2;
2167   std::vector<IValue> args_b{b, dim1_b, dim2_b};
2168 
2169   testStaticRuntime(transpose_script, args_a);
2170   testStaticRuntime(transpose_script, args_a, args_b);
2171 }
2172 
TEST(StaticRuntime,Permute)2173 TEST(StaticRuntime, Permute) {
2174   auto permute_script = R"JIT(
2175     def forward(self, a: Tensor, dims: List[int]):
2176         return torch.permute(a, dims).clone()
2177   )JIT";
2178 
2179   auto a = at::randn({2, 2});
2180   c10::List<int64_t> dims_a{1, 0};
2181   std::vector<IValue> args_a{a, dims_a};
2182 
2183   auto b = at::randn({3, 3, 3});
2184   c10::List<int64_t> dims_b{0, 2, 1};
2185   std::vector<IValue> args_b{b, dims_b};
2186 
2187   testStaticRuntime(permute_script, args_a);
2188   testStaticRuntime(permute_script, args_a, args_b);
2189 
2190   permute_script = R"JIT(
2191     def forward(self, a: Tensor, dims: List[int], shape: List[int]):
2192         return torch.permute(a, dims).reshape(shape).clone()
2193   )JIT";
2194 
2195   a = at::randn({8, 16, 4});
2196   dims_a = {0, 2, 1};
2197   dims_b = {-1, 16};
2198   testStaticRuntime(permute_script, {a, dims_a, dims_b});
2199 }
2200 
TEST(StaticRuntime,Slice)2201 TEST(StaticRuntime, Slice) {
2202   const auto slice_script = R"JIT(
2203     def forward(self, a: Tensor, dim: int, start: int, end: int, step: int):
2204       return a.slice(dim, start, end, step).clone()
2205   )JIT";
2206 
2207   auto a = at::randn({2, 2});
2208   int dim_a = 1;
2209   int start_a = 0;
2210   int end_a = 1;
2211   int step_a = 1;
2212   std::vector<IValue> args_a{a, dim_a, start_a, end_a, step_a};
2213 
2214   auto b = at::randn({3, 3, 3});
2215   int dim_b = 2;
2216   int start_b = 0;
2217   int end_b = 1;
2218   int step_b = 2;
2219   std::vector<IValue> args_b{b, dim_b, start_b, end_b, step_b};
2220 
2221   testStaticRuntime(slice_script, args_a);
2222   testStaticRuntime(slice_script, args_a, args_b);
2223 
2224   const auto slice_script2 = R"JIT(
2225     def forward(self, a: Tensor, dim: int, step: int):
2226       return a.slice(dim, None, None, step).clone()
2227   )JIT";
2228   std::vector<IValue> args_c{b, dim_b, step_b};
2229   testStaticRuntime(slice_script2, args_c);
2230 }
2231 
TEST(StaticRuntime,Narrow)2232 TEST(StaticRuntime, Narrow) {
2233   const auto narrow_with_int_script = R"JIT(
2234     def forward(self, a: Tensor, dim: int, start: int, length: int):
2235         return a.narrow(dim, start, length).clone()
2236   )JIT";
2237 
2238   auto a = at::randn({5, 5});
2239   int dim_a = 0;
2240   int start_a_int = 3;
2241   int len_a = 2;
2242   std::vector<IValue> args_a{a, dim_a, start_a_int, len_a};
2243 
2244   auto b = at::randn({5, 5, 5});
2245   int dim_b = 1;
2246   int start_b_int = 2;
2247   int len_b = 3;
2248   std::vector<IValue> args_b{b, dim_b, start_b_int, len_b};
2249 
2250   testStaticRuntime(narrow_with_int_script, args_a);
2251   testStaticRuntime(narrow_with_int_script, args_a, args_b);
2252 }
2253 
TEST(StaticRuntime,TupleUnpack)2254 TEST(StaticRuntime, TupleUnpack) {
2255   const auto two_tuple_unpack_script = R"JIT(
2256     def forward(self, tup: Tuple[Tensor, Tensor]):
2257         a, b = tup
2258         return (a, b)
2259   )JIT";
2260 
2261   const auto three_tuple_unpack_script = R"JIT(
2262     def forward(self, tup: Tuple[Tensor, Tensor, Tensor]):
2263         a, b, c = tup
2264         return (a, b, c)
2265   )JIT";
2266 
2267   auto two_tup = c10::ivalue::Tuple::create({at::randn({1}), at::randn({1})});
2268   auto two_tup_large =
2269       c10::ivalue::Tuple::create({at::randn({2, 2}), at::randn({2, 2})});
2270 
2271   auto three_tup = c10::ivalue::Tuple::create(
2272       {at::randn({1}), at::randn({1}), at::randn({1})});
2273   auto three_tup_large = c10::ivalue::Tuple::create(
2274       {at::randn({2, 2}), at::randn({2, 2}), at::randn({2, 2})});
2275 
2276   testStaticRuntime(two_tuple_unpack_script, {two_tup});
2277   testStaticRuntime(two_tuple_unpack_script, {two_tup}, {two_tup_large});
2278 
2279   testStaticRuntime(three_tuple_unpack_script, {three_tup});
2280   testStaticRuntime(three_tuple_unpack_script, {three_tup}, {three_tup_large});
2281 }
2282 
TEST(StaticRuntime,Append)2283 TEST(StaticRuntime, Append) {
2284   const auto append_int_script = R"JIT(
2285     def forward(self, a: int):
2286         lst = [1, 2, 3]
2287         lst.append(a)
2288         return lst
2289   )JIT";
2290 
2291   const auto append_tensor_script = R"JIT(
2292     def forward(self, a: Tensor):
2293         lst = []
2294         lst.append(a)
2295         return lst
2296   )JIT";
2297 
2298   std::vector<IValue> args_int{1};
2299 
2300   testStaticRuntime(append_int_script, args_int);
2301 
2302   std::vector<IValue> args_tensor{at::randn({1})};
2303   std::vector<IValue> args_tensor_large{at::randn({2, 2})};
2304 
2305   testStaticRuntime(append_tensor_script, args_tensor);
2306   testStaticRuntime(append_tensor_script, args_tensor, args_tensor_large);
2307 }
2308 
TEST(StaticRuntime,QuantizedLinear)2309 TEST(StaticRuntime, QuantizedLinear) {
2310   const std::string quantize_script = R"IR(
2311     graph(%input: Tensor, %weights: Tensor):
2312         %scale: float = prim::Constant[value=1.]()
2313         %zero_point: int = prim::Constant[value=1]()
2314         %bias: None = prim::Constant()
2315         %packed_params = quantized::linear_prepack(%weights, %bias)
2316         %1254 = quantized::linear(%input, %packed_params, %scale, %zero_point)
2317         %1249: Tensor = aten::dequantize(%1254)
2318         return (%1249)
2319   )IR";
2320   at::Tensor weight =
2321       at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQInt8);
2322   at::Tensor input =
2323       at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQUInt8);
2324 
2325   at::Tensor weight_2 =
2326       at::quantize_per_tensor(torch::randn({8, 3}), 2, 3, torch::kQInt8);
2327   at::Tensor input_2 =
2328       at::quantize_per_tensor(torch::randn({9, 3}), 2, 3, torch::kQUInt8);
2329 
2330   testStaticRuntime(quantize_script, {input, weight}, {input_2, weight_2});
2331 }
2332 
TEST(StaticRuntime,QuantizedLinearDynamicFp16)2333 TEST(StaticRuntime, QuantizedLinearDynamicFp16) {
2334   const std::string quantized_linear_dynamic_fp16_script = R"IR(
2335     graph(%input: Tensor, %weights: Tensor):
2336         %bias: None = prim::Constant()
2337         %packed_params = quantized::linear_prepack_fp16(%weights, %bias)
2338         %output = quantized::linear_dynamic_fp16(%input, %packed_params)
2339         %ret = aten::clone(%output, %bias)
2340         return (%ret)
2341   )IR";
2342   at::Tensor weight = torch::randn({3, 2}, torch::kFloat);
2343   at::Tensor input = torch::randn({3, 2}, torch::kFloat);
2344 
2345   at::Tensor weight_2 = torch::randn({4, 3}, torch::kFloat);
2346   at::Tensor input_2 = torch::randn({5, 3}, torch::kFloat);
2347 
2348   testStaticRuntime(
2349       quantized_linear_dynamic_fp16_script,
2350       {input, weight},
2351       {input_2, weight_2});
2352 }
2353 
TEST(StaticRuntime,QuantizedLinearReluDynamicFp16)2354 TEST(StaticRuntime, QuantizedLinearReluDynamicFp16) {
2355   const std::string quantized_linear_relu_dynamic_fp16_script = R"IR(
2356     graph(%input: Tensor, %weights: Tensor):
2357         %bias: None = prim::Constant()
2358         %packed_params = quantized::linear_prepack_fp16(%weights, %bias)
2359         %output = quantized::linear_relu_dynamic_fp16(%input, %packed_params)
2360         %ret = aten::clone(%output, %bias)
2361         return (%ret)
2362   )IR";
2363   at::Tensor weight = torch::randn({3, 2}, torch::kFloat);
2364   at::Tensor input = torch::randn({3, 2}, torch::kFloat);
2365 
2366   at::Tensor weight_2 = torch::randn({4, 3}, torch::kFloat);
2367   at::Tensor input_2 = torch::randn({5, 3}, torch::kFloat);
2368 
2369   testStaticRuntime(
2370       quantized_linear_relu_dynamic_fp16_script,
2371       {input, weight},
2372       {input_2, weight_2});
2373 }
2374 
TEST(StaticRuntime,VarStack)2375 TEST(StaticRuntime, VarStack) {
2376   const auto var_stack_script = R"JIT(
2377     def forward(self, inp1: Tensor, inp2: Tensor, dim: int):
2378         return torch.stack([inp1, inp2], dim).clone()
2379   )JIT";
2380 
2381   // 2D tensors - stack dim = 0
2382   std::vector<IValue> args1 = {at::randn({6, 6}), at::randn({6, 6}), 0};
2383   testStaticRuntime(var_stack_script, args1);
2384 
2385   // 3D tensors - stack dim = 1
2386   std::vector<IValue> args2 = {at::randn({4, 5, 6}), at::randn({4, 5, 6}), 1};
2387   testStaticRuntime(var_stack_script, args2);
2388 
2389   // 3D tensors - stack dim = 2
2390   std::vector<IValue> args3 = {at::randn({4, 5, 6}), at::randn({4, 5, 6}), 2};
2391   testStaticRuntime(var_stack_script, args3);
2392 
2393   // Negative dim
2394   std::vector<IValue> args4 = {at::randn({4, 5, 6}), at::randn({4, 5, 6}), -1};
2395   testStaticRuntime(var_stack_script, args4);
2396 
2397   // Non-serial path
2398   std::vector<IValue> args5 = {at::randn({1, 2, 3}), at::randn({1, 2, 3}), 3};
2399   testStaticRuntime(var_stack_script, args5);
2400 
2401   // Fast path
2402   std::vector<IValue> args6 = {at::randn({1}), at::randn({1}), 0};
2403   testStaticRuntime(var_stack_script, args6);
2404 
2405   testStaticRuntime(var_stack_script, args1, args2);
2406 }
2407 
TEST(StaticRuntime,FmodTensor)2408 TEST(StaticRuntime, FmodTensor) {
2409   const auto fmod_tensor = R"JIT(
2410     def forward(self, a: Tensor, b: Tensor):
2411         return torch.fmod(a, b).clone()
2412   )JIT";
2413 
2414   // fmod tensor version
2415   auto a = at::randn({2, 3});
2416   auto b = at::randn({2, 3});
2417   std::vector<IValue> args0{a, b};
2418   testStaticRuntime(fmod_tensor, args0);
2419 
2420   // check for dynamic shapes
2421   auto c = at::randn({4, 3, 2});
2422   auto d = at::randn({4, 3, 2});
2423   std::vector<IValue> args1{c, d};
2424   testStaticRuntime(fmod_tensor, args0, args1);
2425 }
2426 
TEST(StaticRuntime,FmodScalar)2427 TEST(StaticRuntime, FmodScalar) {
2428   const auto fmod_scalar = R"JIT(
2429     def forward(self, a: Tensor, b: int):
2430         return torch.fmod(a, b).clone()
2431   )JIT";
2432 
2433   auto a = at::randn({2, 3});
2434 
2435   // fmod scalar version
2436   std::vector<IValue> args2{a, 3};
2437   testStaticRuntime(fmod_scalar, args2);
2438 
2439   // check for dynamic shapes
2440   auto c = at::randn({4, 3, 2});
2441   std::vector<IValue> args3{c, 4};
2442   testStaticRuntime(fmod_scalar, args2, args3);
2443 
2444   // test int32 version
2445   a = at::randint(-100, 100, {2, 3}, at::kInt);
2446   c = at::randint(-100, 100, {4, 3, 2}, at::kInt);
2447   testStaticRuntime(fmod_scalar, {a, 3});
2448   testStaticRuntime(fmod_scalar, {a, 3}, {c, 4});
2449 }
2450 
TEST(StaticRuntime,QEmbeddingBagBytePrepack)2451 TEST(StaticRuntime, QEmbeddingBagBytePrepack) {
2452   const std::string embedding_bag_byte_prepack_script = R"IR(
2453     graph(%input: Tensor):
2454         %none : None = prim::Constant()
2455         %output: Tensor = quantized::embedding_bag_byte_prepack(%input)
2456         %res: Tensor = aten::clone(%output, %none)
2457         return (%res)
2458   )IR";
2459 
2460   auto a = torch::randn({8, 16}, at::ScalarType::Float);
2461   auto b = torch::randn({8 * 2, 16 * 2}, at::ScalarType::Float);
2462 
2463   testStaticRuntime(embedding_bag_byte_prepack_script, {a});
2464   testStaticRuntime(embedding_bag_byte_prepack_script, {a}, {b});
2465 }
2466 
TEST(StaticRuntime,QEmbeddingBagByteUnpack)2467 TEST(StaticRuntime, QEmbeddingBagByteUnpack) {
2468   const auto src = R"IR(
2469     graph(%input: Tensor):
2470         %none : None = prim::Constant()
2471         %weight: Tensor = quantized::embedding_bag_byte_prepack(%input)
2472         %output: Tensor = quantized::embedding_bag_byte_unpack(%weight)
2473         %res: Tensor = aten::clone(%output, %none)
2474         return (%res)
2475   )IR";
2476 
2477   auto a = torch::randn({8, 16}, at::ScalarType::Float);
2478   auto b = torch::randn({8 * 2, 16 * 2}, at::ScalarType::Float);
2479 
2480   testStaticRuntime(src, {a});
2481   testStaticRuntime(src, {a}, {b});
2482 }
2483 
TEST(StaticRuntime,LinalgNorm_ScalarOrd)2484 TEST(StaticRuntime, LinalgNorm_ScalarOrd) {
2485   const auto linalg_norm_ord_scalar = R"JIT(
2486     def forward(self, a: Tensor, ord: int, dim: List[int], keepdim: bool, dtype: int):
2487         return torch.linalg_norm(a, ord, dim, keepdim, dtype=dtype).clone()
2488   )JIT";
2489 
2490   auto a = at::randn({2, 3});
2491   auto dim = std::vector<int64_t>({1});
2492   auto dtype = at::ScalarType::Float;
2493 
2494   std::vector<IValue> args0{a, 4, dim, true, dtype};
2495   testStaticRuntime(linalg_norm_ord_scalar, args0);
2496 
2497   auto b = at::randn({3, 2, 6});
2498   std::vector<IValue> args1{b, 4, dim, true, dtype};
2499   testStaticRuntime(linalg_norm_ord_scalar, args0, args1);
2500 }
2501 
TEST(StaticRuntime,LinalgNorm_StringOrd)2502 TEST(StaticRuntime, LinalgNorm_StringOrd) {
2503   const auto linalg_norm_ord_str = R"JIT(
2504     def forward(self, a: Tensor, ord: str, dim: List[int], keepdim: bool, dtype: int):
2505         return torch.linalg_norm(a, ord, dim, keepdim, dtype=dtype).clone()
2506   )JIT";
2507 
2508   auto a = at::randn({2, 3});
2509   auto dim = std::vector<int64_t>({0, 1});
2510   auto dtype = at::ScalarType::Float;
2511 
2512   std::vector<IValue> args0{a, "fro", dim, true, dtype};
2513   testStaticRuntime(linalg_norm_ord_str, args0);
2514 
2515   auto b = at::randn({3, 2, 17});
2516   std::vector<IValue> args1{b, "fro", dim, true, dtype};
2517   testStaticRuntime(linalg_norm_ord_str, args0, args1);
2518 }
2519 
TEST(StaticRuntime,Index_Put)2520 TEST(StaticRuntime, Index_Put) {
2521   const auto index_put_str = R"JIT(
2522     def forward(self, a: Tensor, indices: Tuple[Optional[Tensor]], values: Tensor, accumulate: bool):
2523         return torch.index_put(a, indices, values, accumulate).clone()
2524   )JIT";
2525 
2526   auto a = at::randn({2});
2527   auto indices_a = std::make_tuple(torch::tensor({0}, at::kLong));
2528   auto values_a = at::randn({1});
2529 
2530   std::vector<IValue> args0{a, indices_a, values_a, false};
2531   testStaticRuntime(index_put_str, args0);
2532 
2533   const auto index_put_non_optional_str = R"JIT(
2534     def forward(self, a: Tensor, indices: List[Tensor], values: Tensor, accumulate: bool):
2535         return torch.index_put(a, indices, values, accumulate).clone()
2536   )JIT";
2537 
2538   auto indices_b = c10::List<at::Tensor>{torch::tensor({0}, at::kLong)};
2539   std::vector<IValue> args1{a, indices_b, values_a, false};
2540   testStaticRuntime(index_put_non_optional_str, args1);
2541 
2542   const auto index_put_list_construct = R"JIT(
2543     def forward(self, a: Tensor, indices: Tensor, values: Tensor, accumulate: bool):
2544         indices: List[Optional[Tensor]] = [indices]
2545         return torch.index_put(a, indices, values, accumulate).clone()
2546   )JIT";
2547 
2548   std::vector<IValue> args2{a, torch::tensor({0}, at::kLong), values_a, false};
2549   testStaticRuntime(index_put_list_construct, args2);
2550 }
2551 
TEST(StaticRuntime,Item)2552 TEST(StaticRuntime, Item) {
2553   const auto item_str = R"JIT(
2554     def forward(self, a: Tensor):
2555         return torch.item(a)
2556   )JIT";
2557 
2558   auto a = at::randn({1});
2559 
2560   std::vector<IValue> args0{a};
2561   testStaticRuntime(item_str, args0);
2562 }
2563 
TEST(StaticRuntime,Tensor_Split)2564 TEST(StaticRuntime, Tensor_Split) {
2565   const auto tensor_split_str1 = R"JIT(
2566     def forward(self, a: Tensor, sections: int, dim: int):
2567         return torch.tensor_split(a, sections, dim)
2568   )JIT";
2569   std::vector<IValue> args1{at::randn({8}), 3, 0};
2570 
2571   const auto tensor_split_str2 = R"JIT(
2572     def forward(self, a: Tensor, sections: Tensor, dim: int):
2573         return torch.tensor_split(a, sections, dim)
2574   )JIT";
2575   std::vector<IValue> args2{at::randn({8}), torch::tensor(3), 0};
2576 
2577   const auto tensor_split_str3 = R"JIT(
2578     def forward(self, a: Tensor, indices: List[int], dim: int):
2579         return torch.tensor_split(a, indices, dim)
2580   )JIT";
2581   std::vector<IValue> args3{at::randn({8}), c10::List<int64_t>({1, 6}), 0};
2582 
2583   testStaticRuntime(tensor_split_str1, args1);
2584   testStaticRuntime(tensor_split_str2, args2);
2585   testStaticRuntime(tensor_split_str3, args3);
2586 }
2587 
TEST(StaticRuntime,JIT_Aten_Cpu)2588 TEST(StaticRuntime, JIT_Aten_Cpu) {
2589   const std::string script = R"IR(
2590     graph(%a: Tensor):
2591         %1 : int = prim::Constant[value=0]()
2592         %aa: Tensor = aten::add(%a, %a, %1)
2593         %ret: Tensor = aten::cpu(%aa)
2594         return (%ret)
2595   )IR";
2596 
2597   auto graph = std::make_shared<Graph>();
2598   std::unordered_map<std::string, Value*> vmap;
2599   vmap.reserve(0);
2600   parseIR(script, graph.get(), vmap);
2601   torch::jit::StaticModule smodule(graph);
2602 
2603   auto a = at::randn({2, 4});
2604   std::vector<IValue> args0{a};
2605 
2606   testStaticRuntime(script, args0);
2607 }
2608 
TEST(StaticRuntime,JIT_Aten_Numel)2609 TEST(StaticRuntime, JIT_Aten_Numel) {
2610   const std::string script = R"IR(
2611     graph(%a: Tensor):
2612         %1 : int = prim::Constant[value=0]()
2613         %aa: Tensor = aten::add(%a, %a, %1)
2614         %ret: int = aten::numel(%aa)
2615         return (%ret)
2616   )IR";
2617 
2618   auto graph = std::make_shared<Graph>();
2619   std::unordered_map<std::string, Value*> vmap;
2620   vmap.reserve(0);
2621   parseIR(script, graph.get(), vmap);
2622   torch::jit::StaticModule smodule(graph);
2623 
2624   auto a = at::randn({2, 4});
2625   std::vector<IValue> args0{a};
2626 
2627   testStaticRuntime(script, args0);
2628 }
2629 
TEST(StaticRuntime,JIT_Aten_List)2630 TEST(StaticRuntime, JIT_Aten_List) {
2631   const auto script_str = R"IR(
2632     graph(%a: str):
2633         %ret: str[] = aten::list(%a)
2634         return (%ret)
2635   )IR";
2636   std::string a = "abcd";
2637   std::vector<IValue> args0{a};
2638   testStaticRuntime(script_str, args0);
2639 
2640   // Update the result of aten::list to ensure that a deep copy
2641   // took place
2642   const auto script_list = R"IR(
2643     graph(%a : int[]):
2644         %idx : int = prim::Constant[value=0]()
2645         %value : int = prim::Constant[value=42]()
2646         %res : int[] = aten::list(%a)
2647         %updated : int[] = aten::_set_item(%res, %idx, %value)
2648         return (%res, %a)
2649   )IR";
2650 
2651   std::vector<IValue> args1{c10::List<int64_t>{1, 2, 3}};
2652   testStaticRuntime(script_list, args1);
2653 }
2654 
TEST(StaticRuntime,JIT_Aten_Range_Length)2655 TEST(StaticRuntime, JIT_Aten_Range_Length) {
2656   const std::string script = R"IR(
2657     graph(%lo: int, %hi: int, %step: int):
2658         %1 : int = prim::Constant[value=0]()
2659         %ret: int = aten::__range_length(%lo, %hi, %step)
2660         return (%ret)
2661   )IR";
2662 
2663   auto graph = std::make_shared<Graph>();
2664   std::unordered_map<std::string, Value*> vmap;
2665   vmap.reserve(0);
2666   parseIR(script, graph.get(), vmap);
2667   torch::jit::StaticModule smodule(graph);
2668 
2669   std::vector<IValue> args0{0, 10, 2};
2670 
2671   testStaticRuntime(script, args0);
2672 }
2673 
TEST(StaticRuntime,Cat)2674 TEST(StaticRuntime, Cat) {
2675   const std::string cat_script = R"IR(
2676     graph(%a: Tensor, %b: Tensor, %dim: int):
2677         %ten_list: Tensor[] = prim::ListConstruct(%a, %b)
2678         %1 : int = prim::Constant[value=0]()
2679         %2 : int = prim::Constant[value=1]()
2680         %3 : int = prim::Constant[value=1]()
2681         %ten_list2 : Tensor[] = aten::slice(%ten_list, %1, %2, %3)
2682         %ret: Tensor = aten::cat(%ten_list2, %dim)
2683         return (%ret)
2684   )IR";
2685 
2686   auto graph = std::make_shared<Graph>();
2687   std::unordered_map<std::string, Value*> vmap;
2688   parseIR(cat_script, graph.get(), vmap);
2689   torch::jit::StaticModule smodule(graph);
2690   ASSERT_TRUE(getNodeWithKind(smodule, "aten::cat"));
2691 
2692   auto a = at::randn({2, 4});
2693   auto b = at::randn({3, 4});
2694   std::vector<IValue> args0{a, b, 0};
2695 
2696   testStaticRuntime(cat_script, args0);
2697 
2698   auto c = at::randn({3, 4});
2699   auto d = at::randn({3, 5});
2700   std::vector<IValue> args1{c, d, 1};
2701   testStaticRuntime(cat_script, args0, args1);
2702 
2703   std::vector<IValue> args_dim_negative{c, d, -1};
2704   testStaticRuntime(cat_script, args_dim_negative);
2705 }
2706 
TEST(StaticRuntime,Cumsum)2707 TEST(StaticRuntime, Cumsum) {
2708   const auto cumsum_script = R"JIT(
2709     def forward(self, a: Tensor, dim: int):
2710         return torch.cumsum(a, dim).clone()
2711   )JIT";
2712 
2713   auto a = at::randn({2, 3});
2714   std::vector<IValue> args0{a, 0};
2715   testStaticRuntime(cumsum_script, args0);
2716 
2717   auto b = at::randn({3, 6});
2718   std::vector<IValue> args1{b, 1};
2719   testStaticRuntime(cumsum_script, args0, args1);
2720 }
2721 
TEST(StaticRuntime,CumsumDtype)2722 TEST(StaticRuntime, CumsumDtype) {
2723   const auto cumsum_script_dtype = R"JIT(
2724     def forward(self, a: Tensor, dim: int, dtype: int):
2725         return torch.cumsum(a, dim, dtype=dtype).clone()
2726   )JIT";
2727 
2728   auto a = at::randn({1, 2});
2729   auto dtype = at::ScalarType::Float;
2730   std::vector<IValue> args0{a, 0, dtype};
2731   testStaticRuntime(cumsum_script_dtype, args0);
2732 
2733   auto b = at::randn({3, 6});
2734   std::vector<IValue> args1{b, 1, dtype};
2735   testStaticRuntime(cumsum_script_dtype, args0, args1);
2736 }
2737 
TEST(StaticRuntime,Nonzero)2738 TEST(StaticRuntime, Nonzero) {
2739   const auto nonzero_tensor = R"JIT(
2740     def forward(self, input: Tensor):
2741         a = torch.nonzero(input).clone()
2742         return (a)
2743   )JIT";
2744 
2745   auto a = at::randint(0, 2, {2, 3});
2746   testStaticRuntime(nonzero_tensor, {a});
2747 
2748   auto b = at::randint(0, 2, {4, 3, 2});
2749   testStaticRuntime(nonzero_tensor, {a}, {b});
2750 }
2751 
TEST(StaticRuntime,SignedLog1p)2752 TEST(StaticRuntime, SignedLog1p) {
2753   const std::string signed_log1p_script = R"IR(
2754     graph(%input):
2755         %0 : Tensor = aten::sign(%input)
2756         %1 : Tensor = aten::abs(%input)
2757         %2 : Tensor = aten::log1p(%1)
2758         %3 : Tensor = aten::mul(%0, %2)
2759         %none : NoneType = prim::Constant()
2760         %res : Tensor = aten::clone(%3, %none)
2761         return (%res)
2762   )IR";
2763 
2764   std::vector<IValue> args1 = {at::randn({2, 2})};
2765   testStaticRuntime(signed_log1p_script, args1, {}, true);
2766 
2767   std::vector<IValue> args2 = {at::randn({3, 3, 3})};
2768   testStaticRuntime(signed_log1p_script, args1, args2, true);
2769 }
2770 
TEST(StaticRuntime,RemoveImmutableInputDictLookupsWithImmutableInputDict)2771 TEST(StaticRuntime, RemoveImmutableInputDictLookupsWithImmutableInputDict) {
2772   const auto getitem_immutable_input_dict_script = R"JIT(
2773     def forward(self, input: Dict[int, Tensor]):
2774         a = input[0]
2775         b = input[1]
2776         c = a + b
2777         return c.clone()
2778   )JIT";
2779 
2780   script::Module module("module");
2781   module.define(getitem_immutable_input_dict_script);
2782   torch::jit::StaticModule smodule(module);
2783   EXPECT_FALSE(hasNodeWithKind(smodule, "aten::__getitem__"));
2784   EXPECT_TRUE(hasNodeWithKind(smodule, "static_runtime::dict_unpack"));
2785 
2786   auto a = at::randn({2, 4});
2787   auto b = at::randn({2, 4});
2788   c10::Dict<c10::IValue, c10::IValue> dict(
2789       c10::IntType::get(), c10::TensorType::get());
2790   dict.insert(0, a);
2791   dict.insert(1, b);
2792   testStaticRuntime(getitem_immutable_input_dict_script, {dict});
2793 
2794   c10::Dict<c10::IValue, c10::IValue> dict0(
2795       c10::IntType::get(), c10::TensorType::get());
2796   auto a0 = at::randn({3, 4});
2797   auto b0 = at::randn({3, 4});
2798   dict0.insert(0, a0);
2799   dict0.insert(1, b0);
2800   testStaticRuntime(getitem_immutable_input_dict_script, {dict0});
2801 }
2802 
TEST(StaticRuntime,RemoveImmutableInputDictLookupsWithMutableInputDict)2803 TEST(StaticRuntime, RemoveImmutableInputDictLookupsWithMutableInputDict) {
2804   const auto getitem_mutable_input_dict_script = R"JIT(
2805     def forward(self, input: Dict[int, Tensor]):
2806         a = input[0]
2807         input[1] = a
2808         b = input[1]
2809         c = a + b
2810         return c.clone()
2811   )JIT";
2812 
2813   script::Module module("module");
2814   module.define(getitem_mutable_input_dict_script);
2815   torch::jit::StaticModule smodule(module);
2816   EXPECT_TRUE(hasNodeWithKind(smodule, "aten::__getitem__"));
2817   EXPECT_FALSE(hasNodeWithKind(smodule, "static_runtime::dict_unpack"));
2818 }
2819 
TEST(StaticRuntime,VarTupleUnpack)2820 TEST(StaticRuntime, VarTupleUnpack) {
2821   const auto var_tuple_unpack_script = R"JIT(
2822     def forward(self, input_0: Tuple[Tensor, Tensor], input_1: Tuple[int, int]):
2823         a, b = input_0
2824         c, d = input_1
2825         res = a * c + b * d
2826         return res.clone()
2827   )JIT";
2828 
2829   script::Module module("module");
2830   module.define(var_tuple_unpack_script);
2831   torch::jit::StaticModule smodule(module);
2832   EXPECT_FALSE(hasNodeWithKind(smodule, "prim::TupleUnpack"));
2833   EXPECT_TRUE(hasNodeWithKind(smodule, "static_runtime::VarTupleUnpack"));
2834 
2835   auto a = at::randn({2, 2});
2836   auto b = at::randn({3, 3, 3});
2837   std::vector<IValue> args1{
2838       c10::ivalue::Tuple::create(a, a), c10::ivalue::Tuple::create(1, 2)};
2839   std::vector<IValue> args2{
2840       c10::ivalue::Tuple::create(b, b), c10::ivalue::Tuple::create(1, 2)};
2841 
2842   testStaticRuntime(var_tuple_unpack_script, args1);
2843   testStaticRuntime(var_tuple_unpack_script, args1, args2);
2844 }
2845 
TEST(StaticRuntime,VarTupleUnpack_NotApplied)2846 TEST(StaticRuntime, VarTupleUnpack_NotApplied) {
2847   const auto var_tuple_unpack_not_applied_script = R"JIT(
2848     def forward(self, input_0: Tuple[Tensor, Tensor], input_1: Tuple[int, int]):
2849         a, b = input_0
2850         x = a + b
2851         c, d = input_1
2852         res = a * c + b * d + x
2853         return res.clone()
2854   )JIT";
2855 
2856   script::Module module("module");
2857   // In this script, the optimization is not applied since there is a
2858   // computation between the TupleUnpack nodes.
2859   module.define(var_tuple_unpack_not_applied_script);
2860   torch::jit::StaticModule smodule(module);
2861   EXPECT_FALSE(hasNodeWithKind(smodule, "static_runtime::VarTupleUnpack"));
2862   EXPECT_TRUE(hasNodeWithKind(smodule, "prim::TupleUnpack"));
2863 }
2864 
TEST(StaticRuntime,RemainderTensor)2865 TEST(StaticRuntime, RemainderTensor) {
2866   const auto remainder_tensor = R"JIT(
2867     def forward(self, x, y):
2868         return torch.remainder(x, y).clone()
2869   )JIT";
2870 
2871   std::vector<IValue> args1 = {
2872       at::randint(0, 10, {2, 2}), at::randint(1, 10, {2, 2})};
2873   std::vector<IValue> args2 = {
2874       at::randint(0, 10, {3, 6}), at::randint(1, 10, {3, 6})};
2875 
2876   // Use allclose and equalnan since outputs may be NaN.
2877   testStaticRuntime(
2878       remainder_tensor,
2879       args1,
2880       /*args2*/ {},
2881       /*use_alloclose*/ true,
2882       /*use_equalnan*/ true);
2883   testStaticRuntime(
2884       remainder_tensor,
2885       args1,
2886       args2,
2887       /*use_allclose*/ true,
2888       /*use_equalnan*/ true);
2889 }
2890 
TEST(StaticRuntime,RemainderScalar)2891 TEST(StaticRuntime, RemainderScalar) {
2892   const auto remainder_scalar = R"JIT(
2893     def forward(self, x, y: int):
2894         return torch.remainder(x, y).clone()
2895   )JIT";
2896 
2897   std::vector<IValue> args1 = {at::randint(0, 10, {2, 2}), 4};
2898   std::vector<IValue> args2 = {at::randint(0, 10, {3, 6}), 4};
2899 
2900   // Use allclose and equalnan since outputs may be NaN.
2901   testStaticRuntime(
2902       remainder_scalar,
2903       args1,
2904       /*args2*/ {},
2905       /*use_alloclose*/ true,
2906       /*use_equalnan*/ true);
2907   testStaticRuntime(
2908       remainder_scalar,
2909       args1,
2910       args2,
2911       /*use_allclose*/ true,
2912       /*use_equalnan*/ true);
2913 }
2914 
TEST(StaticRuntime,Where)2915 TEST(StaticRuntime, Where) {
2916   const auto where_script = R"JIT(
2917     def forward(self, x, y):
2918         return torch.where(x > 0, x, y).clone()
2919   )JIT";
2920 
2921   std::vector<IValue> args1 = {at::randn({2, 2}), at::randn({2, 2})};
2922   std::vector<IValue> args2 = {at::randn({8, 10}), at::randn({8, 10})};
2923 
2924   testStaticRuntime(where_script, args1);
2925   testStaticRuntime(where_script, args1, args2);
2926 }
2927 
TEST(StaticRuntime,WhereBroadcast)2928 TEST(StaticRuntime, WhereBroadcast) {
2929   const auto where_script = R"JIT(
2930     def forward(self, cond_1d, x, y):
2931         shape = [-1] + [1] * (x.dim() - 1)
2932         cond = cond_1d.view(shape)
2933         return torch.where(cond, x, y).clone()
2934   )JIT";
2935 
2936   std::vector<IValue> args1 = {
2937       at::tensor({0, 1}).to(at::kBool), at::randn({2, 2}), at::randn({2, 2})};
2938   std::vector<IValue> args2 = {
2939       at::tensor({1, 0, 0}).to(at::kBool),
2940       at::randn({3, 6}),
2941       at::randn({3, 6})};
2942 
2943   testStaticRuntime(where_script, args1);
2944   testStaticRuntime(where_script, args1, args2);
2945 }
2946 
TEST(StaticRuntime,View)2947 TEST(StaticRuntime, View) {
2948   // Note that clone is not technically necessary here since this is not
2949   // an out variant, but it suppresses warnings about only have one op
2950   // in testStaticRuntime
2951   const auto src = R"IR(
2952     graph(%input : Tensor, %shape : int[]):
2953         %none : NoneType = prim::Constant()
2954         %view : Tensor = aten::view(%input, %shape)
2955         %res : Tensor = aten::clone(%view, %none)
2956         return (%res)
2957   )IR";
2958 
2959   std::vector<IValue> args1{at::randn({2, 2}), c10::List<int64_t>(4)};
2960   std::vector<IValue> args2{at::randn({2, 2, 2}), c10::List<int64_t>({4, 2})};
2961 
2962   testStaticRuntime(src, args1);
2963   testStaticRuntime(src, args1, args2);
2964 }
2965 
TEST(StaticRuntime,Size)2966 TEST(StaticRuntime, Size) {
2967   const auto src_with_dim = R"JIT(
2968       def forward(self, x, dim: int):
2969           return x.size(dim)
2970   )JIT";
2971 
2972   const auto src_no_dim = R"JIT(
2973       def forward(self, x):
2974           return x.size()
2975   )JIT";
2976 
2977   std::vector<IValue> args1{at::randn({1}), 0};
2978   std::vector<IValue> args2{at::randn({1}), -1};
2979   std::vector<IValue> args3{at::randn({2, 4}), 1};
2980   std::vector<IValue> args_no_dim{at::randn({2, 4})};
2981 
2982   testStaticRuntime(src_with_dim, args1);
2983   testStaticRuntime(src_with_dim, args2);
2984   testStaticRuntime(src_with_dim, args1, args3);
2985   testStaticRuntime(src_no_dim, args_no_dim);
2986 }
2987 
TEST(StaticRuntime,Squeeze)2988 TEST(StaticRuntime, Squeeze) {
2989   // Note: this is a native op, not an out variant, but clone anyways
2990   // to silence warnings in testStaticRuntime
2991   const auto src = R"JIT(
2992     def forward(self, inp, dim: int):
2993         return inp.squeeze(dim).clone()
2994   )JIT";
2995 
2996   const auto a = at::randn({2, 2});
2997   const auto b = at::randn({3, 2, 3});
2998 
2999   testStaticRuntime(src, {a, 0});
3000   testStaticRuntime(src, {a, 1});
3001   testStaticRuntime(src, {a, -1}, {b, 2});
3002 }
3003 
TEST(StaticRuntime,NumToTensorScalar)3004 TEST(StaticRuntime, NumToTensorScalar) {
3005   const auto num_to_tensor_ir = R"IR(
3006     graph(%1 : int):
3007       %2 : NoneType = prim::Constant()
3008       %3 : Tensor = prim::NumToTensor(%1)
3009       %4 : Tensor = aten::clone(%3, %2)
3010       return (%4)
3011   )IR";
3012 
3013   IValue arg{5};
3014   std::vector<IValue> args = {arg};
3015   testStaticRuntime(num_to_tensor_ir, args);
3016 }
3017 
TEST(StaticRuntime,NumToTensorFalse)3018 TEST(StaticRuntime, NumToTensorFalse) {
3019   const auto num_to_tensor_ir = R"IR(
3020     graph(%1 : bool):
3021       %2 : NoneType = prim::Constant()
3022       %3 : Tensor = prim::NumToTensor(%1)
3023       %4 : Tensor = aten::clone(%3, %2)
3024       return (%4)
3025   )IR";
3026 
3027   IValue arg{false};
3028   std::vector<IValue> args = {arg};
3029   testStaticRuntime(num_to_tensor_ir, args);
3030 }
3031 
TEST(StaticRuntime,NumToTensorTrue)3032 TEST(StaticRuntime, NumToTensorTrue) {
3033   const auto num_to_tensor_ir = R"IR(
3034     graph(%1 : bool):
3035       %2 : NoneType = prim::Constant()
3036       %3 : Tensor = prim::NumToTensor(%1)
3037       %4 : Tensor = aten::clone(%3, %2)
3038       return (%4)
3039   )IR";
3040 
3041   IValue arg{true};
3042   std::vector<IValue> args = {arg};
3043   testStaticRuntime(num_to_tensor_ir, args);
3044 }
3045 
TEST(StaticRuntime,Split)3046 TEST(StaticRuntime, Split) {
3047   const auto src = R"JIT(
3048     def forward(self, inp, split_size: int, dim: int):
3049         return inp.split(split_size, dim)
3050   )JIT";
3051 
3052   const auto a = at::randn({2, 2});
3053   const auto b = at::randn({2, 2, 2});
3054 
3055   testStaticRuntime(src, {a, 1, 0});
3056   testStaticRuntime(src, {a, 1, 1});
3057   testStaticRuntime(src, {a, 2, -1}, {b, 2, 2});
3058 }
3059 
TEST(StaticRuntime,SplitWithSizes)3060 TEST(StaticRuntime, SplitWithSizes) {
3061   const auto src = R"JIT(
3062     def forward(self, inp, split_sizes: List[int], dim: int):
3063         return inp.split(split_sizes, dim)
3064   )JIT";
3065 
3066   const auto a = at::randn({2, 2});
3067   const auto b = at::randn({2, 2, 2});
3068   const auto split_sizes = c10::List<int64_t>{1, 1};
3069 
3070   testStaticRuntime(src, {a, split_sizes, 0});
3071   testStaticRuntime(src, {a, split_sizes, 1});
3072   testStaticRuntime(src, {a, split_sizes, -1}, {b, split_sizes, 2});
3073 }
3074 
3075 namespace {
3076 
maybe_throw(bool should_throw)3077 void maybe_throw(bool should_throw) {
3078   if (should_throw) {
3079     throw std::runtime_error("test exception");
3080   }
3081 }
3082 
TORCH_LIBRARY(static_runtime_tests,m)3083 TORCH_LIBRARY(static_runtime_tests, m) {
3084   // Conservative so this op doesn't get deleted by dead
3085   // code elimination
3086   m.def(torch::schema(
3087       "static_runtime_tests::maybe_throw(bool throw) -> ()",
3088       at::AliasAnalysisKind::CONSERVATIVE));
3089   m.impl("maybe_throw", maybe_throw);
3090 }
3091 
3092 } // namespace
3093 
TEST(StaticRuntime,ModelCrashOnFirstRun)3094 TEST(StaticRuntime, ModelCrashOnFirstRun) {
3095   const auto src = R"JIT(
3096     graph(%0: Tensor, %throw: bool):
3097         %1: Tensor = aten::mul(%0, %0)
3098         static_runtime_tests::maybe_throw(%throw)
3099         %2: Tensor = aten::mul(%1, %1)
3100         %3: Tensor = aten::mul(%2, %2)
3101         return (%3)
3102   )JIT";
3103 
3104   auto graph = getGraphFromIR(src);
3105   auto static_module = StaticModule(graph);
3106   auto& runtime = static_module.runtime();
3107 
3108   std::vector<IValue> args_crash{at::randn({1}), true};
3109   std::vector<IValue> args_no_crash{at::randn({1}), false};
3110   EXPECT_THROW(runtime(args_crash, {}), std::runtime_error);
3111 
3112   // The run didn't finish, we didn't allocate the memory planner
3113   EXPECT_EQ(runtime.get_memory_planner(), nullptr);
3114   runtime.check_for_memory_leak();
3115 
3116   // We guarantee that the runtime is still usable after the crash.
3117   // Run again to verify this.
3118   compareResultsWithJIT(runtime, graph, args_no_crash);
3119   EXPECT_NE(runtime.get_memory_planner(), nullptr);
3120 }
3121 
TEST(StaticRuntime,ModelCrashOnSecondRun)3122 TEST(StaticRuntime, ModelCrashOnSecondRun) {
3123   const auto src = R"JIT(
3124     graph(%0: Tensor, %throw: bool):
3125         %1: Tensor = aten::mul(%0, %0)
3126         static_runtime_tests::maybe_throw(%throw)
3127         %2: Tensor = aten::mul(%1, %1)
3128         %3: Tensor = aten::mul(%2, %2)
3129         return (%3)
3130   )JIT";
3131 
3132   auto graph = getGraphFromIR(src);
3133   auto static_module = StaticModule(graph);
3134   auto& runtime = static_module.runtime();
3135 
3136   std::vector<IValue> args_crash{at::randn({1}), true};
3137   std::vector<IValue> args_no_crash{at::randn({1}), false};
3138   runtime(args_no_crash, {});
3139   EXPECT_NE(runtime.get_memory_planner(), nullptr);
3140   runtime.check_for_memory_leak();
3141 
3142   EXPECT_THROW(runtime(args_crash, {}), std::runtime_error);
3143   runtime.check_for_memory_leak();
3144 
3145   // We guarantee that the runtime is still usable after the crash.
3146   // Run again to verify this.
3147   compareResultsWithJIT(runtime, graph, args_no_crash);
3148 }
3149 
TEST(StaticRuntime,ModelCrashOnFirstRunWithBorrows)3150 TEST(StaticRuntime, ModelCrashOnFirstRunWithBorrows) {
3151   const auto src = R"JIT(
3152     graph(%0: Tensor):
3153         %1: Tensor = aten::mul(%0, %0)
3154         %2: Tensor = aten::mul(%1, %1)
3155         %3: bool = prim::Constant[value=1]()
3156         %4: Tensor = static_runtime::select_tensor(%1, %2, %3)
3157         static_runtime_tests::maybe_throw(%3)
3158         return (%4)
3159   )JIT";
3160   auto graph = getGraphFromIR(src);
3161   auto static_module = StaticModule(graph);
3162   auto& runtime = static_module.runtime();
3163 
3164   std::vector<IValue> args{at::randn({1})};
3165   EXPECT_THROW(runtime(args), std::runtime_error);
3166 }
3167 
TEST(StaticRuntime,ModelCrashOnFirstRunWithBorrowedInputs)3168 TEST(StaticRuntime, ModelCrashOnFirstRunWithBorrowedInputs) {
3169   const auto src = R"JIT(
3170     graph(%0: Tensor, %1: Tensor):
3171         %2: bool = prim::Constant[value=1]()
3172         %3: Tensor = static_runtime::select_tensor(%0, %1, %2)
3173         static_runtime_tests::maybe_throw(%2)
3174         return (%3)
3175   )JIT";
3176   auto graph = getGraphFromIR(src);
3177   auto static_module = StaticModule(graph);
3178   auto& runtime = static_module.runtime();
3179 
3180   std::vector<IValue> args{at::randn({1}), at::randn({1})};
3181   EXPECT_THROW(runtime(std::move(args)), std::runtime_error);
3182 }
3183 
TEST(StaticRuntime,ReplaceWithMaybeCopy)3184 TEST(StaticRuntime, ReplaceWithMaybeCopy) {
3185   const std::string to = R"IR(
3186     graph(%0 : Tensor):
3187       %1: int = prim::Constant[value=4]()
3188       %2: bool = prim::Constant[value=0]()
3189       %3: None = prim::Constant()
3190       %res : Tensor = aten::to(%0, %1, %2, %2, %3)
3191       return (%res)
3192   )IR";
3193 
3194   at::Tensor a = at::tensor({1.1, 2.2, 3.3, 4.0}, at::ScalarType::Float);
3195   std::vector<IValue> args{a};
3196   auto g = std::make_shared<torch::jit::Graph>();
3197   torch::jit::parseIR(to, g.get());
3198 
3199   // Jit Interpreter.
3200   Stack stack(args);
3201   torch::jit::GraphExecutor graph_exec(g, "");
3202   graph_exec.run(stack);
3203   ASSERT_EQ(stack.size(), 1);
3204   auto expected = stack[0].toTensor();
3205 
3206   // Static Runtime.
3207   torch::jit::StaticModule smodule(g);
3208   auto actual = smodule(args, {}).toTensor();
3209   smodule.runtime().check_for_memory_leak();
3210 
3211   EXPECT_TRUE(expected.equal(actual));
3212 
3213   // Make a fresh graph to ensure the pass works in isolation
3214   auto new_graph = std::make_shared<torch::jit::Graph>();
3215   torch::jit::parseIR(to, new_graph.get());
3216   ReplaceWithMaybeCopy(new_graph);
3217   EXPECT_FALSE(hasNodeWithKind(new_graph, "aten::to"));
3218   EXPECT_TRUE(
3219       hasNodeWithKind(new_graph, "static_runtime::to_maybe_copy_out"));
3220 }
3221 
TEST(StaticRuntime,Int)3222 TEST(StaticRuntime, Int) {
3223   const auto src = R"JIT(
3224     def forward(self, x):
3225         return int(x) + int(x)
3226   )JIT";
3227   std::vector<IValue> args{at::tensor({3.14})};
3228   testStaticRuntime(src, args);
3229 }
3230 
TEST(StaticRuntime,ReturnConstant)3231 TEST(StaticRuntime, ReturnConstant) {
3232   const auto src = R"JIT(
3233     def forward(self):
3234         return 1
3235   )JIT";
3236 
3237   testStaticRuntime(src, {});
3238 }
3239 
TEST(StaticRuntime,SimpleIf)3240 TEST(StaticRuntime, SimpleIf) {
3241   const auto src = R"JIT(
3242     def forward(self, cond: bool, x):
3243         if cond:
3244             return torch.mul(x, 42).clone()
3245         else:
3246             return x.clone()
3247   )JIT";
3248 
3249   std::vector<IValue> args_false{false, at::randn({1})};
3250   std::vector<IValue> args_true{true, at::randn({1})};
3251   std::vector<IValue> args_big_tensor{true, at::randn({3, 3, 3})};
3252 
3253   testStaticRuntime(src, args_false);
3254   testStaticRuntime(src, args_true);
3255   testStaticRuntime(src, args_true, args_big_tensor);
3256 }
3257 
TEST(StaticRuntime,NestedIf)3258 TEST(StaticRuntime, NestedIf) {
3259   const auto src = R"JIT(
3260     def forward(self, cond1: bool, cond2: bool, x):
3261         y = x * 42
3262         if cond1:
3263             y = y * y
3264             if cond2:
3265                 y += x
3266         else:
3267             if cond2:
3268                 return x.clone()
3269 
3270         return y.clone()
3271   )JIT";
3272 
3273   for (auto cond1 : {true, false}) {
3274     for (auto cond2 : {true, false}) {
3275       std::vector<IValue> args1{cond1, cond2, at::randn({1})};
3276       std::vector<IValue> args2{cond1, cond2, at::randn({3, 3, 3})};
3277       testStaticRuntime(src, args1, args2);
3278     }
3279   }
3280 }
3281 
TEST(StaticRuntime,DeeplyNestedIf)3282 TEST(StaticRuntime, DeeplyNestedIf) {
3283   const auto src = R"JIT(
3284     def forward(self, cond1: bool, cond2: bool, cond3: bool, x):
3285         y = x * 42
3286         if cond1:
3287             y = y * y
3288             if cond2:
3289                 y += x
3290 
3291             if cond2 and cond3:
3292                 y += 1
3293 
3294             if cond2:
3295                 if cond3:
3296                     y += 2
3297                 else:
3298                     y = y * y
3299                     y += 4
3300         else:
3301             if cond2:
3302                 return x.clone()
3303             if cond3 or cond2:
3304                 y += 42
3305 
3306         return y.clone()
3307   )JIT";
3308 
3309   for (auto cond1 : {true, false}) {
3310     for (auto cond2 : {true, false}) {
3311       for (auto cond3 : {true, false}) {
3312         std::vector<IValue> args1{cond1, cond2, cond3, at::randn({1})};
3313         std::vector<IValue> args2{cond1, cond2, cond3, at::randn({3, 3, 3})};
3314         testStaticRuntime(src, args1, args2);
3315       }
3316     }
3317   }
3318 }
3319 
TEST(StaticRuntime,BasicForLoop)3320 TEST(StaticRuntime, BasicForLoop) {
3321   const auto src = R"JIT(
3322     def forward(self, x, loop_max: int):
3323         y = x.clone()
3324         for i in range(loop_max):
3325             y += 1
3326         return y
3327   )JIT";
3328 
3329   std::vector<IValue> args1{at::randn({1}), 10};
3330   std::vector<IValue> args2{at::randn({3, 3, 3}), 10};
3331 
3332   testStaticRuntime(src, args1, args2);
3333 }
3334 
TEST(StaticRuntime,BasicWhileLoop)3335 TEST(StaticRuntime, BasicWhileLoop) {
3336   const auto src = R"JIT(
3337     def forward(self, x, loop_max: int):
3338         y = x.clone()
3339         loop_count = 0
3340         while loop_count < loop_max:
3341             y += 1
3342             loop_count += 1
3343         return y
3344   )JIT";
3345 
3346   std::vector<IValue> args1{at::randn({1}), 10};
3347   std::vector<IValue> args2{at::randn({3, 3, 3}), 10};
3348 
3349   testStaticRuntime(src, args1, args2);
3350 }
3351 
TEST(StaticRuntime,NestedLoops)3352 TEST(StaticRuntime, NestedLoops) {
3353   const auto src = R"JIT(
3354     def forward(self, x, loop_max: int):
3355         y = x.clone()
3356         even: List[int] = []
3357         odd: List[int] = []
3358 
3359         for i in range(loop_max):
3360             if i % 2:
3361                 odd.append(i)
3362             else:
3363                 even.append(i)
3364 
3365             for j in range(i):
3366                 y += 1
3367 
3368         return y, even, odd
3369   )JIT";
3370 
3371   std::vector<IValue> args1{at::randn({1}), 10};
3372   std::vector<IValue> args2{at::randn({3, 3, 3}), 10};
3373 
3374   testStaticRuntime(src, args1, args2);
3375 }
3376 
TEST(StaticRuntime,TupleIndex)3377 TEST(StaticRuntime, TupleIndex) {
3378   const auto src = R"JIT(
3379     def forward(self, idx: int, tup: Tuple[int, int]):
3380         a = tup[idx]
3381         return a * a
3382   )JIT";
3383   const auto tuple = c10::ivalue::Tuple::create({1, 2});
3384   testStaticRuntime(src, {1, tuple}, {-1, tuple});
3385 
3386   torch::jit::Module mod("module");
3387   mod.define(src);
3388   StaticModule smod(mod);
3389   EXPECT_THROW(smod({100, tuple}), std::out_of_range);
3390 }
3391 
TEST(StaticRuntime,RaiseException)3392 TEST(StaticRuntime, RaiseException) {
3393   const auto src = R"IR(
3394     graph(%str: str):
3395         %none: NoneType = prim::Constant()
3396         prim::RaiseException(%str, %none)
3397         return (%none)
3398   )IR";
3399   auto graph = getGraphFromIR(src);
3400   StaticModule smod(graph);
3401   const auto msg = "exception message";
3402   EXPECT_THROW(
3403       {
3404         try {
3405           smod({msg});
3406         } catch (const std::runtime_error& e) {
3407           EXPECT_STREQ(msg, e.what());
3408           throw;
3409         }
3410       },
3411       std::runtime_error);
3412 }
3413 
TEST(StaticRuntime,Uninitialized)3414 TEST(StaticRuntime, Uninitialized) {
3415   const auto src = R"IR(
3416     graph():
3417       %0: int = prim::Uninitialized()
3418       return (%0)
3419   )IR";
3420   auto graph = getGraphFromIR(src);
3421   StaticModule smod(graph);
3422   const auto ret = smod({});
3423   // If a and b are both uninitialized, then a != b. So just check that the type
3424   // is Any
3425   EXPECT_EQ(ret.type()->kind(), c10::TypeKind::AnyType);
3426 }
3427 
TEST(StaticRuntime,Format)3428 TEST(StaticRuntime, Format) {
3429   const auto src = R"JIT(
3430     def forward(self, arg1: int, arg2: Tensor, arg3: str):
3431         a = "arg1: {}, arg2: {}, arg3: {}".format(arg1, arg2, arg3)
3432         return a[::]
3433   )JIT";
3434   testStaticRuntime(src, {1, at::randn({3}), "str"});
3435 }
3436 
TEST(StaticRuntime,Device)3437 TEST(StaticRuntime, Device) {
3438   const auto src = R"JIT(
3439     def forward(self, x):
3440         return x.device, x.device
3441   )JIT";
3442   testStaticRuntime(src, {at::tensor({1})});
3443 }
3444 
TEST(StaticRuntime,Dtype)3445 TEST(StaticRuntime, Dtype) {
3446   const auto src = R"JIT(
3447     def forward(self, x, y):
3448         return x.dtype, y.dtype
3449   )JIT";
3450   testStaticRuntime(
3451       src, {at::tensor({1}, at::kLong), at::tensor({1}, at::kFloat)});
3452 }
3453 
TEST(StaticRuntime,Dim)3454 TEST(StaticRuntime, Dim) {
3455   const auto src = R"JIT(
3456     def forward(self, x, y):
3457         return x.dim(), y.dim()
3458   )JIT";
3459   testStaticRuntime(src, {at::randn({2, 2}), at::randn({1})});
3460 }
3461 
TEST(StaticRuntime,Not)3462 TEST(StaticRuntime, Not) {
3463   const auto src = R"JIT(
3464     def forward(self, x: bool, y: bool):
3465         return not x, not y
3466   )JIT";
3467   testStaticRuntime(src, {true, false});
3468 }
3469 
TEST(StaticRuntime,Bool)3470 TEST(StaticRuntime, Bool) {
3471   const auto src = R"JIT(
3472       def forward(self, x: Tensor, y: int, z: float):
3473           return bool(x), bool(y), bool(z)
3474   )JIT";
3475   testStaticRuntime(src, {at::randn({1}), 0, 1.151}, {at::zeros({1}), 1, 0.0});
3476 }
3477 
TEST(StaticRuntime,IsCuda)3478 TEST(StaticRuntime, IsCuda) {
3479   const auto src = R"JIT(
3480       def forward(self, x: Tensor, y: Tensor):
3481           return x.is_cuda, y.is_cuda
3482   )JIT";
3483   testStaticRuntime(src, {at::randn({1}), at::randn({1})});
3484 }
3485 
TEST(StaticRuntime,ToList)3486 TEST(StaticRuntime, ToList) {
3487   const auto src = R"JIT(
3488       graph(%x: Tensor):
3489           %type: int = prim::Constant[value=1]()
3490           %dim: int = aten::dim(%x)
3491           %ret: float[] = prim::tolist(%x, %dim, %type)
3492           return (%ret)
3493   )JIT";
3494   testStaticRuntime(src, {at::randn({2, 2})});
3495 }
3496 
TEST(StaticRuntime,IfThenElse)3497 TEST(StaticRuntime, IfThenElse) {
3498   const auto src = R"IR(
3499     graph(%cond: bool, %a: Tensor, %b: Tensor):
3500         %none: NoneType = prim::Constant()
3501         %c: Tensor = prim::IfThenElse(%cond, %a, %b)
3502         %d: Tensor = aten::clone(%c, %none)
3503         return (%d)
3504   )IR";
3505 
3506   std::vector<IValue> args1{true, at::randn({1}), at::randn({1})};
3507   std::vector<IValue> args2{false, at::randn({1}), at::randn({1})};
3508 
3509   testStaticRuntime(src, args1);
3510   testStaticRuntime(src, args2);
3511 }
3512 
TEST(StaticRuntime,EmptyIfBlock)3513 TEST(StaticRuntime, EmptyIfBlock) {
3514   const auto src =
3515       R"JIT(
3516       def forward(self, cond: bool, a: Tensor, b: Tensor):
3517           l = []
3518           if cond:
3519               l.append((a + b).clone())
3520           return l
3521   )JIT";
3522 
3523   testStaticRuntime(src, {true, at::rand(1), at::rand({1, 2})});
3524   testStaticRuntime(src, {false, at::rand(1), at::rand({1, 2})});
3525 }
3526 
TEST(StaticRuntime,EmptyNestedIfBlock)3527 TEST(StaticRuntime, EmptyNestedIfBlock) {
3528   const auto src =
3529       R"JIT(
3530       def forward(self, cond: bool, a: Tensor, b: Tensor):
3531           l = []
3532           if cond:
3533               if cond:
3534                   l.append((a + b).clone())
3535           return l
3536   )JIT";
3537 
3538   testStaticRuntime(src, {true, at::rand(1), at::rand({1, 2})});
3539   testStaticRuntime(src, {false, at::rand(1), at::rand({1, 2})});
3540 }
3541 
TEST(StaticRuntime,StackEmpty)3542 TEST(StaticRuntime, StackEmpty) {
3543   const auto src = R"JIT(
3544     def forward(self):
3545         x = torch.stack([])
3546         return x
3547   )JIT";
3548 
3549   torch::jit::Module mod("mod");
3550   mod.define(src);
3551 
3552   torch::jit::StaticModule smod(mod);
3553   EXPECT_THROW(smod({}), c10::Error);
3554 }
3555 
TEST(StaticRuntime,ConcatEmpty)3556 TEST(StaticRuntime, ConcatEmpty) {
3557   const auto src = R"JIT(
3558     def forward(self):
3559         x = torch.concat([])
3560         return x
3561   )JIT";
3562 
3563   torch::jit::Module mod("mod");
3564   mod.define(src);
3565 
3566   torch::jit::StaticModule smod(mod);
3567   EXPECT_THROW(smod({}), c10::Error);
3568 }
3569 
TEST(StaticRuntime,IntImplicit)3570 TEST(StaticRuntime, IntImplicit) {
3571   const auto src = R"IR(
3572     graph(%a: Tensor):
3573         %y: int = aten::IntImplicit(%a)
3574         return (%y)
3575   )IR";
3576   testStaticRuntime(src, {at::tensor({1}, at::kInt).squeeze()});
3577 }
3578 
TEST(StaticRuntime,IntImplicit_ThrowOnBadInputs)3579 TEST(StaticRuntime, IntImplicit_ThrowOnBadInputs) {
3580   const auto src = R"IR(
3581     graph(%a: Tensor):
3582         %y: int = aten::IntImplicit(%a)
3583         return (%y)
3584   )IR";
3585   auto graph = getGraphFromIR(src);
3586   torch::jit::StaticModule smod(graph);
3587   // Not 0D tensor
3588   EXPECT_THROW(smod({at::tensor({1, 2}, at::kInt)}), std::runtime_error);
3589   // Wrong dtype
3590   EXPECT_THROW(
3591       smod({at::tensor({1}, at::kFloat).squeeze()}), std::runtime_error);
3592 }
3593 
TEST(StaticRuntime,Select)3594 TEST(StaticRuntime, Select) {
3595   const auto src = R"IR(
3596     graph(%a: Tensor, %dim: int, %index: int):
3597         %none: NoneType = prim::Constant()
3598         %b: Tensor = aten::select(%a, %dim, %index)
3599         %c: Tensor = aten::clone(%b, %none)
3600         return (%c)
3601   )IR";
3602   testStaticRuntime(src, {at::randn({2, 2}), 0, 1});
3603 }
3604 
TEST(StaticRuntime,ReshapeAs)3605 TEST(StaticRuntime, ReshapeAs) {
3606   const auto src = R"JIT(
3607     def forward(self, a, b):
3608         return a.reshape_as(b).clone()
3609   )JIT";
3610   testStaticRuntime(src, {at::randn({2, 2}), at::randn({4})});
3611 }
3612 
TEST(StaticRuntime,MoveCtor)3613 TEST(StaticRuntime, MoveCtor) {
3614   auto mod = getDeepAndWideSciptModel();
3615   std::vector<IValue> args{
3616       at::randn({1, 1, 32}), at::randn({1, 1, 32}), at::randn({1, 50})};
3617 
3618   torch::jit::StaticModule smod(mod);
3619 
3620   torch::jit::StaticRuntime runtime(smod);
3621   auto expected = runtime(args);
3622 
3623   torch::jit::StaticRuntime new_runtime(std::move(runtime));
3624   auto actual = new_runtime(args);
3625   compareResults(expected, actual);
3626 }
3627 
TEST(StaticRuntime,SingleBlockIfReturnList)3628 TEST(StaticRuntime, SingleBlockIfReturnList) {
3629   const auto src = R"JIT(
3630     def forward(self, a, b, cond: bool):
3631         lst = []
3632         if cond:
3633             lst.append(a + b)
3634         return lst
3635   )JIT";
3636   std::vector<IValue> args1{at::randn({1}), at::randn({1}), true};
3637   std::vector<IValue> args2{at::randn({42, 42}), at::randn({42, 42}), false};
3638   testStaticRuntime(src, args1, args2);
3639 }
3640 
TEST(StaticRuntime,NestedBlockIfReturnList)3641 TEST(StaticRuntime, NestedBlockIfReturnList) {
3642   const auto src = R"JIT(
3643     def forward(self, a, b, cond1: bool, cond2: bool):
3644         if cond1:
3645             lst = []
3646             if cond2:
3647                 lst.append(a + b)
3648             lst.append(a * b)
3649             return lst
3650         return []
3651   )JIT";
3652   std::vector<IValue> args1{at::randn({1}), at::randn({1}), true, true};
3653   std::vector<IValue> args2{
3654       at::randn({42, 42}), at::randn({42, 42}), true, false};
3655   testStaticRuntime(src, args1, args2);
3656 }
3657 
TEST(StaticRuntime,ClampNaNToNum)3658 TEST(StaticRuntime, ClampNaNToNum) {
3659   const auto src1 = R"JIT(
3660     def forward(self, a):
3661         return torch.clamp(a, min=1.0, max=2.0).nan_to_num().clone()
3662   )JIT";
3663 
3664   const auto src2 = R"JIT(
3665     def forward(self, a, nan: float):
3666         return torch.clamp(a, min=-1.0, max=2.0).nan_to_num(nan=nan).clone()
3667   )JIT";
3668 
3669   const auto src3 = R"JIT(
3670     def forward(self, a):
3671         return torch.clamp(a, min=1.0, max=-1.0).nan_to_num().clone()
3672   )JIT";
3673 
3674   auto a = at::tensor({
3675       std::numeric_limits<float>::quiet_NaN(),
3676       std::numeric_limits<float>::infinity(),
3677       -std::numeric_limits<float>::infinity(),
3678       0.0f,
3679       3.0f
3680     });
3681   auto b = a.repeat({10, 5});
3682 
3683   // Have to use_allclose even though all NaNs will be replaced - testStaticRuntime
3684   // also checks inputs at the end to make sure they're not changed
3685   testStaticRuntime(src1, {a}, {}, /*use_allclose=*/true, /*use_equalnan=*/true);
3686   testStaticRuntime(src1, {a}, {b}, /*use_allclose=*/true, /*use_equalnan=*/true);
3687 
3688   testStaticRuntime(src2, {a, 42.0}, {}, /*use_allclose=*/true, /*use_equalnan=*/true);
3689   testStaticRuntime(src2, {a, 2.0}, {b, 1.0}, /*use_allclose=*/true, /*use_equalnan=*/true);
3690 
3691   testStaticRuntime(src3, {a}, {}, /*use_allclose=*/true, /*use_equalnan=*/true);
3692   testStaticRuntime(src3, {a}, {b}, /*use_allclose=*/true, /*use_equalnan=*/true);
3693 
3694   // Non-NNC path
3695   testStaticRuntime(src1, {a.to(at::kDouble)}, {}, /*use_allclose=*/true, /*use_equalnan=*/true);
3696   testStaticRuntime(src1, {a.to(at::kDouble)}, {b.to(at::kDouble)}, /*use_allclose=*/true, /*use_equalnan=*/true);
3697 }
3698 
TEST(StaticRuntime,IfReturningTuple)3699 TEST(StaticRuntime, IfReturningTuple) {
3700   const auto src = R"JIT(
3701     def forward(self, x, y, cond: bool, idx: int):
3702         if cond:
3703             tup = (x, y)
3704         else:
3705             tup = (x, x)
3706         return tup[idx]
3707   )JIT";
3708 
3709   std::vector<IValue> args{at::randn({3}), at::randn({3}), true, 0};
3710   testStaticRuntime(src, args);
3711 }
3712