1 #include <ATen/core/dispatch/OperatorOptions.h>
2 #include <c10/core/ScalarType.h>
3 #include <gtest/gtest.h>
4 #include <torch/csrc/jit/ir/alias_analysis.h>
5 #include <torch/csrc/jit/ir/irparser.h>
6 #include <torch/csrc/jit/runtime/static/ProcessedNodeInputs.h>
7 #include <torch/csrc/jit/runtime/static/impl.h>
8 #include <torch/csrc/jit/runtime/static/passes.h>
9 #include <torch/csrc/jit/testing/file_check.h>
10 #include <stdexcept>
11
12 #include "deep_wide_pt.h"
13 #include "test_utils.h"
14
15 using namespace caffe2;
16 using namespace torch;
17 using namespace torch::jit;
18 using namespace torch::jit::test;
19 using c10::IValue;
20
21 /*
22 When adding a test for an operator implemented in static runtime, there are
23 several things that you need to pay attention to:
24
25 1) if the op is an out variant, in the test script of the op,
26 instead of:
27 def forward(self, input):
28 return myop(input)
29
30 do:
31 def forward(self, input):
32 return myop(input).clone()
33
34 This makes sure that the output of myop is managed by the memory planner and
35 exercise the code path in the op impl that otherwise doesn't get exercised. The
36 output of the model is not managed by the memory planner, because it needs to
37 be returned to the client.
38
39 2) The memory planner rounds up the size of each Tensor's storage to multiples
40 of 64 bytes (alignment requirement on AVX512). Make sure the sizes of the input
41 tensors in args2 are big enough to trigger resizing.
42
43 3) for view ops such as aten::reshape or aten::to, if you want it to be
44 replaced by the copy version with the ReplaceWithCopy pass in passes.h, you
45 also want to make sure its output is not returned as the model output. The
46 reason is that ReplaceWithCopy only replaces the op whose output is not an
47 alias of the model output.
48 */
49
50 C10_DECLARE_bool(static_runtime_enable_fast_math);
51
TEST(StaticRuntime,UnaryOps)52 TEST(StaticRuntime, UnaryOps) {
53 const auto aten_sum = R"JIT(
54 def forward(self, input):
55 return torch.sum(input).clone()
56 )JIT";
57
58 const auto aten_sum_0 = R"JIT(
59 def forward(self, input):
60 return torch.sum(input, 0).clone()
61 )JIT";
62
63 const auto aten_sum_1 = R"JIT(
64 def forward(self, input):
65 return torch.sum(input, 1).clone()
66 )JIT";
67
68 const auto aten_sum_0_true = R"JIT(
69 def forward(self, input):
70 return torch.sum(input, 0, True).clone()
71 )JIT";
72
73 const auto aten_sum_1_true = R"JIT(
74 def forward(self, input):
75 return torch.sum(input, 1, True).clone()
76 )JIT";
77
78 auto a = at::randn({2, 3});
79 auto b = at::randn({3, 3, 6});
80
81 std::vector<IValue> args{a}, args2{b};
82
83 // sum
84 testStaticRuntime(aten_sum, args);
85 testStaticRuntime(aten_sum_0, args);
86 testStaticRuntime(aten_sum_1, args);
87 testStaticRuntime(aten_sum_0_true, args);
88 testStaticRuntime(aten_sum_1_true, args);
89
90 testStaticRuntime(aten_sum, args, args2, false, false, false);
91 testStaticRuntime(aten_sum_0, args, args2);
92 testStaticRuntime(aten_sum_1, args, args2);
93 testStaticRuntime(aten_sum_0_true, args, args2);
94 testStaticRuntime(aten_sum_1_true, args, args2);
95 }
96
TEST(StaticRuntime,Max)97 TEST(StaticRuntime, Max) {
98 auto src_max_reduce = R"JIT(
99 def forward(self, input):
100 return torch.max(input).clone()
101 )JIT";
102
103 auto src_max_dim = R"JIT(
104 def forward(self, input, dim: int):
105 values, indices = torch.max(input, dim)
106 return values.clone(), indices.clone()
107 )JIT";
108
109 auto src_max_dim_keepdim = R"JIT(
110 def forward(self, input, dim: int):
111 values, indices = torch.max(input, dim, keepdim=True)
112 return values.clone(), indices.clone()
113 )JIT";
114
115 auto src_max_pointwise = R"JIT(
116 def forward(self, input, other):
117 return torch.max(input, other).clone()
118 )JIT";
119
120 auto input = at::randn({2, 3, 2});
121 auto input_other = at::randn({2, 3, 2});
122 auto large_input = at::randn({8, 9, 10});
123 auto large_input_other = at::randn({8, 9, 10});
124
125 testStaticRuntime(src_max_reduce, {input});
126 testStaticRuntime(src_max_dim, {input, 1});
127 testStaticRuntime(src_max_dim, {input, 1}, {large_input, 0});
128 testStaticRuntime(src_max_dim_keepdim, {input, 0});
129 testStaticRuntime(src_max_dim_keepdim, {input, 0}, {large_input, 2});
130 testStaticRuntime(src_max_pointwise, {input, input_other});
131 testStaticRuntime(src_max_pointwise, {input, input_other}, {large_input, large_input_other});
132 }
133
TEST(StaticRuntime,Mean)134 TEST(StaticRuntime, Mean) {
135 const auto src_default = R"JIT(
136 def forward(self, input):
137 return torch.mean(input).clone()
138 )JIT";
139 const auto src_dtype = R"JIT(
140 def forward(self, input, dtype: int):
141 return torch.mean(input, dtype=dtype).clone()
142 )JIT";
143 const auto src_dim = R"JIT(
144 def forward(self, input, dim: List[int]):
145 return torch.mean(input, dim).clone()
146 )JIT";
147 const auto src_dim_keepdim = R"JIT(
148 def forward(self, input, dim: List[int]):
149 return torch.mean(input, dim, keepdim=True).clone()
150 )JIT";
151 const auto src_dim_dtype = R"JIT(
152 def forward(self, input, dim: List[int], dtype: int):
153 return torch.mean(input, dim, dtype=dtype).clone()
154 )JIT";
155
156 auto input = at::randn({2, 3, 2});
157 auto large_input = at::randn({8, 7, 6, 8});
158
159 std::vector<IValue> args_default = {input};
160 std::vector<IValue> args_dtype = {input, torch::kFloat};
161 std::vector<IValue> args_dim = {input, c10::List<int64_t>{0, 2}};
162 std::vector<IValue> args_dim_keepdim = {input, c10::List<int64_t>{1, 2}};
163 std::vector<IValue> args_dim_dtype = {input, c10::List<int64_t>{0, 1}, torch::kBFloat16};
164
165 testStaticRuntime(src_default, args_default);
166 testStaticRuntime(src_dtype, args_dtype);
167 testStaticRuntime(src_dim, args_dim);
168 testStaticRuntime(src_dim_keepdim, args_dim_keepdim);
169 testStaticRuntime(src_dim_dtype, args_dim_dtype);
170
171 std::vector<IValue> large_args_dim = {large_input, c10::List<int64_t>{0, 3}};
172 std::vector<IValue> large_args_dim_keepdim = {large_input, c10::List<int64_t>{1, 2}};
173 std::vector<IValue> large_args_dim_dtype = {large_input, c10::List<int64_t>{1, 3}, torch::kBFloat16};
174
175 testStaticRuntime(src_dim, args_dim, large_args_dim);
176 testStaticRuntime(src_dim_keepdim, args_dim_keepdim, large_args_dim_keepdim);
177 testStaticRuntime(src_dim_dtype, args_dim_dtype, large_args_dim_dtype);
178 }
179
TEST(StaticRuntime,Sigmoid)180 TEST(StaticRuntime, Sigmoid) {
181 const auto sigmoid_script = R"JIT(
182 def forward(self, inp: Tensor):
183 b = torch.sigmoid(inp).clone()
184 return (b)
185 )JIT";
186 auto a = at::randn({2, 3});
187 auto b = at::randn({4, 3, 2});
188
189 std::vector<IValue> args{a}, args2{b};
190
191 testStaticRuntime(sigmoid_script, args, /*args2=*/{}, /*use_allclose=*/true);
192 testStaticRuntime(sigmoid_script, args, {args2}, /*use_allclose=*/true);
193
194 FLAGS_static_runtime_enable_fast_math = false;
195 testStaticRuntime(sigmoid_script, args, /*args2=*/{}, /*use_allclose=*/true);
196 testStaticRuntime(sigmoid_script, args, {args2}, /*use_allclose=*/true);
197 FLAGS_static_runtime_enable_fast_math = true;
198 }
199
TEST(StaticRuntime,Clone)200 TEST(StaticRuntime, Clone) {
201 /*
202 Clone called two times to trigger memory planner for output of first clone.
203 The output of last op(second clone) is not managed by memory planner since it
204 needs to be returned to the client and cannot be reused by planner.
205 */
206 const auto clone_script_0 = R"JIT(
207 def forward(self, input):
208 a = torch.clone(input).clone()
209 return (a * a)
210 )JIT";
211
212 // Case: clone with different set of memory_formats
213 const auto clone_script_1 = R"JIT(
214 def forward(self, input: Tensor, memory_format: int):
215 a = torch.clone(input, memory_format=memory_format).clone()
216 return (a * a)
217 )JIT";
218
219 /*
220 Case: input stride set to 0 (due to expand op)
221 calls native clone instead of out variant
222 */
223 const auto clone_script_2 = R"JIT(
224 def forward(self, input: Tensor, other:Tensor):
225 a = input.expand_as(other)
226 return a.clone().clone()
227 )JIT";
228
229 /*
230 Case: testing the case of sliced tensor for
231 testing non-contiguous tensor storage
232 */
233 const auto clone_script_3 = R"JIT(
234 def forward(self, input: Tensor):
235 a = input[:, 0:10:2]
236 return a.clone().clone()
237 )JIT";
238
239 auto a = at::randn({2, 3});
240 auto b = at::randn({3, 2}).as_strided({3, 2}, {1, 3});
241 auto b_larger = at::randn({30, 20}).as_strided({30, 20}, {1, 3});
242 auto c = at::randn({1, 20, 13, 8});
243 auto d = at::randn({1, 0, 3, 4});
244 auto e = at::randn({2, 1});
245 auto f = at::randn({2, 10});
246 auto g = at::randn({3, 20});
247 std::vector<IValue> args_0{b, c10::MemoryFormat::Contiguous};
248 std::vector<IValue> args_1{b_larger, c10::MemoryFormat::Preserve};
249 std::vector<IValue> args_2{c, c10::MemoryFormat::ChannelsLast};
250 std::vector<IValue> args_3{d, c10::MemoryFormat::ChannelsLast};
251 std::vector<IValue> args_4{e,a};
252 std::vector<IValue> args_5{e,f};
253
254 testStaticRuntime(clone_script_0, {a});
255 testStaticRuntime(clone_script_0, {a}, {b_larger});
256
257 testStaticRuntime(clone_script_1, args_0);
258 testStaticRuntime(clone_script_1, args_1);
259 testStaticRuntime(clone_script_1, args_2);
260 testStaticRuntime(clone_script_1, args_3);
261 testStaticRuntime(clone_script_1, args_0, args_1);
262 testStaticRuntime(clone_script_1, args_3, args_2);
263
264 testStaticRuntime(clone_script_2, args_4);
265 testStaticRuntime(clone_script_2, args_4, args_5);
266
267 testStaticRuntime(clone_script_3, {f});
268 testStaticRuntime(clone_script_3, {f}, {g});
269 }
270
TEST(StaticRuntime,Clamp)271 TEST(StaticRuntime, Clamp) {
272 const auto clamp_script_1 = R"JIT(
273 def forward(self, inp: Tensor, min: int, max: int):
274 a = torch.clamp(inp, min, max).clone()
275 return (a)
276 )JIT";
277
278 const auto clamp_script_2 = R"JIT(
279 def forward(self, inp: Tensor, min: Tensor, max: Tensor):
280 a = torch.clamp(inp, min, max).clone()
281 return (a)
282 )JIT";
283 auto a = at::randn({2, 3});
284 auto max_t = at::full_like(a, 1);
285 auto min_t = at::full_like(a, -1);
286
287 auto b = at::randn({4, 3, 2});
288 auto max_t1 = at::full_like(b, 1);
289 auto min_t1 = at::full_like(b, -1);
290
291 testStaticRuntime(clamp_script_1, {a, -1, 1});
292 testStaticRuntime(clamp_script_2, {a, min_t, max_t});
293
294 testStaticRuntime(clamp_script_1, {a, -1, 1}, {b, -1, 1});
295 testStaticRuntime(clamp_script_2, {a, min_t, max_t}, {b, max_t1, min_t1});
296 }
297
TEST(StaticRuntime,ClampMinOnly)298 TEST(StaticRuntime, ClampMinOnly) {
299 const auto src = R"JIT(
300 def forward(self, inp: Tensor, min: float):
301 a = torch.clamp(inp, min, None).clone()
302 return (a)
303 )JIT";
304 auto a = at::randn({2, 3});
305 auto b = at::randn({4, 3, 2});
306 testStaticRuntime(src, {a, 0.5});
307 testStaticRuntime(src, {a, 0.5}, {b, 0.25});
308 }
309
TEST(StaticRuntime,ClampMaxOnly)310 TEST(StaticRuntime, ClampMaxOnly) {
311 const auto src = R"JIT(
312 def forward(self, inp: Tensor, max: float):
313 a = torch.clamp(inp, None, max).clone()
314 return (a)
315 )JIT";
316 auto a = at::randn({2, 3});
317 auto b = at::randn({4, 3, 2});
318 testStaticRuntime(src, {a, 0.5});
319 testStaticRuntime(src, {a, 0.5}, {b, 0.25});
320 }
321
TEST(StaticRuntime,ClampIntTensor)322 TEST(StaticRuntime, ClampIntTensor) {
323 const auto src = R"JIT(
324 def forward(self, inp: Tensor, min: float, max: float):
325 a = torch.clamp(inp, min, max).clone()
326 return (a)
327 )JIT";
328 auto a = at::randint(0, 20, {2, 3}, at::kFloat);
329 auto b = at::randint(0, 20, {4, 3, 2}, at::kFloat);
330 auto min = 5.0f;
331 auto max = 5.0f;
332 testStaticRuntime(src, {a, min, max});
333 testStaticRuntime(src, {a, min, max}, {b, min, max});
334 }
335
TEST(StaticRuntime,LenWithTuple)336 TEST(StaticRuntime, LenWithTuple) {
337 const auto src = R"IR(
338 graph(%input : int[]):
339 %res : int = aten::len(%input)
340 return (%res)
341 )IR";
342
343 testStaticRuntime(src, {c10::List<int64_t>(4)});
344 }
345
TEST(StaticRuntime,LenWithTensor)346 TEST(StaticRuntime, LenWithTensor) {
347 const auto src = R"IR(
348 graph(%input : Tensor):
349 %res : int = aten::len(%input)
350 return (%res)
351 )IR";
352
353 testStaticRuntime(src, {at::randn({2, 2, 2})});
354 }
355
TEST(StaticRuntime,LenWithStr)356 TEST(StaticRuntime, LenWithStr) {
357 const auto src = R"IR(
358 graph(%input : str):
359 %res : int = aten::len(%input)
360 return (%res)
361 )IR";
362
363 testStaticRuntime(src, {"static_runtime"});
364 }
365
TEST(StaticRuntime,LenWithDict_str)366 TEST(StaticRuntime, LenWithDict_str) {
367 const auto script = R"JIT(
368 def forward(self, input: Dict[str, str]):
369 return len(input)
370 )JIT";
371
372 c10::Dict<std::string, std::string> dict;
373 dict.insert("abc", "123");
374 dict.insert("def", "456");
375 testStaticRuntime(script, {dict});
376 }
377
TEST(StaticRuntime,LenWithDict_int)378 TEST(StaticRuntime, LenWithDict_int) {
379 const auto script = R"JIT(
380 def forward(self, input: Dict[int, int]):
381 return len(input)
382 )JIT";
383
384 c10::Dict<int64_t, int64_t> dict;
385 dict.insert(0, 1);
386 dict.insert(2, 3);
387 testStaticRuntime(script, {dict});
388 }
389
TEST(StaticRuntime,LenWithDict_bool)390 TEST(StaticRuntime, LenWithDict_bool) {
391 const auto script = R"JIT(
392 def forward(self, input: Dict[bool, bool]):
393 return len(input)
394 )JIT";
395
396 c10::Dict<bool, bool> dict;
397 dict.insert(true, false);
398 dict.insert(false, true);
399 testStaticRuntime(script, {dict});
400 }
401
TEST(StaticRuntime,LenWithDict_float)402 TEST(StaticRuntime, LenWithDict_float) {
403 const auto script = R"JIT(
404 def forward(self, input: Dict[float, float]):
405 return len(input)
406 )JIT";
407
408 c10::Dict<double, double> dict;
409 dict.insert(0.1, 0.9);
410 dict.insert(0.8, 0.18);
411 testStaticRuntime(script, {dict});
412 }
413
TEST(StaticRuntime,LenWithDict_complex)414 TEST(StaticRuntime, LenWithDict_complex) {
415 const auto script = R"JIT(
416 def forward(self, input: Dict[complex, complex]):
417 return len(input)
418 )JIT";
419
420 c10::Dict<c10::complex<double>, c10::complex<double>> dict;
421 dict.insert(0.1, 0.4);
422 dict.insert(0.9, 0.45);
423 testStaticRuntime(script, {dict});
424 }
425
TEST(StaticRuntime,LenWithDict_Tensor)426 TEST(StaticRuntime, LenWithDict_Tensor) {
427 const auto script = R"JIT(
428 def forward(self, input: Dict[Tensor, Tensor]):
429 return len(input)
430 )JIT";
431
432 c10::Dict<at::Tensor, at::Tensor> dict;
433 dict.insert(at::randn({1, 2}), at::randn({1, 2}));
434 dict.insert(at::randn({1, 2}), at::randn({1, 2}));
435 testStaticRuntime(script, {dict});
436 }
437
TEST(StaticRuntime,Logit)438 TEST(StaticRuntime, Logit) {
439 // no nnc
440 const auto logit_script_1 = R"JIT(
441 def forward(self, inp: Tensor):
442 a = torch.logit(inp).clone()
443 return (a)
444 )JIT";
445
446 // with nnc
447 const auto logit_script_2 = R"JIT(
448 def forward(self, inp: Tensor):
449 a = torch.logit(inp, 1e-6).clone()
450 return (a)
451 )JIT";
452
453 // no nnc
454 const auto logit_script_3 = R"JIT(
455 def forward(self, inp: Tensor, eps: float):
456 a = torch.logit(inp, eps).clone()
457 return (a)
458 )JIT";
459 auto a = at::ones({2, 3});
460 double b = 1e-6;
461 std::vector<IValue> args_1{a};
462 std::vector<IValue> args_2({a, b});
463
464 auto c = at::ones({4, 3, 2});
465
466 // logit
467 testStaticRuntime(logit_script_1, args_1);
468 testStaticRuntime(logit_script_2, args_1);
469 testStaticRuntime(logit_script_3, args_2);
470
471 testStaticRuntime(logit_script_1, args_1, {c});
472 testStaticRuntime(logit_script_2, args_1, {c});
473 testStaticRuntime(logit_script_3, args_2, {c, b});
474 }
475
TEST(StaticRuntime,EmbeddingBag)476 TEST(StaticRuntime, EmbeddingBag) {
477 const std::string embedding_bag_default = R"JIT(
478 def forward(self, a: Tensor, b: Tensor, c: Tensor):
479 x, y, z, _ = torch.embedding_bag(a, b, c)
480 return (x.clone(), y.clone(), z.clone(), _.clone())
481 )JIT";
482
483 const std::string embedding_bag_mean = R"JIT(
484 def forward(self, a: Tensor, b: Tensor, c: Tensor):
485 x, y, z, _ = torch.embedding_bag(a, b, c, False, 1)
486 return (x.clone(), y.clone(), z.clone(), _.clone())
487 )JIT";
488
489 const std::string embedding_bag_max = R"JIT(
490 def forward(self, a: Tensor, b: Tensor, c: Tensor):
491 x, y, z, _ = torch.embedding_bag(a, b, c, False, 2)
492 return (x.clone(), y.clone(), z.clone(), _.clone())
493 )JIT";
494
495 const std::string embedding_bag_sum_last_offset = R"JIT(
496 def forward(self, a: Tensor, b: Tensor, c: Tensor):
497 x, y, z, _ = torch.embedding_bag(a, b, c, False, 0, False, None, True)
498 return (x.clone(), y.clone(), z.clone(), _.clone())
499 )JIT";
500
501 const std::string embedding_bag_mean_last_offset = R"JIT(
502 def forward(self, a: Tensor, b: Tensor, c: Tensor):
503 x, y, z, _ = torch.embedding_bag(a, b, c, False, 1, False, None, True)
504 return (x.clone(), y.clone(), z.clone(), _.clone())
505 )JIT";
506
507 const std::string embedding_bag_max_last_offset = R"JIT(
508 def forward(self, a: Tensor, b: Tensor, c: Tensor):
509 x, y, z, _ = torch.embedding_bag(a, b, c, False, 2, False, None, True)
510 return (x.clone(), y.clone(), z.clone(), _.clone())
511 )JIT";
512
513 at::Tensor weight = torch::randn({3, 11}, at::ScalarType::Float);
514 at::Tensor input = torch::tensor({0, 1, 0, 2});
515 at::Tensor offset = torch::tensor({0, 2, 4});
516 std::vector<IValue> args{weight, input, offset};
517 testStaticRuntime(embedding_bag_default, args);
518 testStaticRuntime(embedding_bag_mean, args);
519 testStaticRuntime(embedding_bag_max, args);
520 testStaticRuntime(embedding_bag_sum_last_offset, args);
521 testStaticRuntime(embedding_bag_mean_last_offset, args);
522 testStaticRuntime(embedding_bag_max_last_offset, args);
523
524 at::Tensor weight2 = torch::randn({10, 11}, at::ScalarType::Float);
525 at::Tensor input2 = torch::tensor({0, 1, 0, 2, 1});
526 at::Tensor offset2 = torch::tensor({0, 1, 2, 3, 4, 5});
527 std::vector<IValue> args2{weight2, input2, offset2};
528 testStaticRuntime(embedding_bag_default, args, args2);
529 testStaticRuntime(embedding_bag_mean, args, args2);
530 testStaticRuntime(embedding_bag_max, args, args2);
531 testStaticRuntime(embedding_bag_sum_last_offset, args, args2);
532 testStaticRuntime(embedding_bag_mean_last_offset, args, args2);
533 testStaticRuntime(embedding_bag_max_last_offset, args, args2);
534 }
535
TEST(StaticRuntime,EmbeddingBagWithManagedOutput)536 TEST(StaticRuntime, EmbeddingBagWithManagedOutput) {
537 const std::string embedding_bag_managed_output = R"JIT(
538 def forward(self, a: Tensor, b: Tensor, c: Tensor):
539 # The outputs of embedding_bag become an intermediate tensors
540 # since they are not directly returned from the graph.
541 x, y, z, _ = torch.embedding_bag(a, b, c)
542 return x + x
543 )JIT";
544
545 at::Tensor weight = torch::randn({3, 8}, at::ScalarType::Float);
546 at::Tensor input = torch::tensor({0, 1, 0, 2});
547 at::Tensor offset = torch::tensor({0, 2});
548 std::vector<IValue> args{weight, input, offset};
549
550 at::Tensor weight2 = torch::randn({6, 8}, at::ScalarType::Float);
551 at::Tensor input2 = torch::tensor({0, 1, 0, 2, 3, 4});
552 at::Tensor offset2 = torch::tensor({0, 2, 4, 5});
553 std::vector<IValue> args2{weight2, input2, offset2};
554
555 testStaticRuntime(embedding_bag_managed_output, args);
556 testStaticRuntime(embedding_bag_managed_output, args, args2);
557 }
558
TEST(StaticRuntime,EmbeddingBagWithExtraneousOutput)559 TEST(StaticRuntime, EmbeddingBagWithExtraneousOutput) {
560 const std::string embedding_bag_default_ir = R"IR(
561 graph(%weight, %indices, %offsets):
562 %scale_grad_by_freq : bool = prim::Constant[value=0]()
563 %mode : int = prim::Constant[value=0]()
564 %sparse : bool = prim::Constant[value=0]()
565 %per_sample_weights : NoneType = prim::Constant()
566 %include_last_offset : bool = prim::Constant[value=0]()
567 %y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
568 %none : NoneType = prim::Constant()
569 %res : Tensor = aten::clone(%y0, %none)
570 return (%res)
571 )IR";
572 auto graph = getGraphFromIR(embedding_bag_default_ir);
573 RemoveUnnecessaryOutputs(graph);
574 torch::jit::testing::FileCheck()
575 .check("static_runtime::embedding_bag")
576 ->run(*graph);
577
578 const std::string embedding_bag_mean_ir = R"IR(
579 graph(%weight, %indices, %offsets):
580 %scale_grad_by_freq : bool = prim::Constant[value=0]()
581 %mode : int = prim::Constant[value=1]()
582 %sparse : bool = prim::Constant[value=0]()
583 %per_sample_weights : NoneType = prim::Constant()
584 %include_last_offset : bool = prim::Constant[value=0]()
585 %y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
586 %none : NoneType = prim::Constant()
587 %res : Tensor = aten::clone(%y0, %none)
588 return (%res)
589 )IR";
590 graph = getGraphFromIR(embedding_bag_mean_ir);
591 RemoveUnnecessaryOutputs(graph);
592 torch::jit::testing::FileCheck()
593 .check("static_runtime::embedding_bag")
594 ->run(*graph);
595
596 const std::string embedding_bag_max_last_offset_ir = R"IR(
597 graph(%weight, %indices, %offsets):
598 %scale_grad_by_freq : bool = prim::Constant[value=0]()
599 %mode : int = prim::Constant[value=2]()
600 %sparse : bool = prim::Constant[value=0]()
601 %per_sample_weights : NoneType = prim::Constant()
602 %include_last_offset : bool = prim::Constant[value=1]()
603 %y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
604 %none : NoneType = prim::Constant()
605 %res : Tensor = aten::clone(%y0, %none)
606 return (%res)
607 )IR";
608 graph = getGraphFromIR(embedding_bag_max_last_offset_ir);
609 RemoveUnnecessaryOutputs(graph);
610 torch::jit::testing::FileCheck()
611 .check("static_runtime::embedding_bag")
612 ->run(*graph);
613
614 const std::string embedding_bag_normal_ir = R"IR(
615 graph(%weight, %indices, %offsets):
616 %scale_grad_by_freq : bool = prim::Constant[value=0]()
617 %mode : int = prim::Constant[value=0]()
618 %sparse : bool = prim::Constant[value=0]()
619 %per_sample_weights : NoneType = prim::Constant()
620 %include_last_offset : bool = prim::Constant[value=0]()
621 %y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
622 %none : NoneType = prim::Constant()
623 %res0 : Tensor = aten::clone(%y0, %none)
624 %res1 : Tensor = aten::clone(%y1, %none)
625 %res2 : Tensor = aten::clone(%y2, %none)
626 %res3 : Tensor = aten::clone(%y3, %none)
627 return (%res0, %res1, %res2, %res3)
628 )IR";
629 graph = getGraphFromIR(embedding_bag_normal_ir);
630 RemoveUnnecessaryOutputs(graph);
631 torch::jit::testing::FileCheck()
632 .check_not("static_runtime::embedding_bag")
633 ->run(*graph);
634
635 at::Tensor weight = torch::randn({3, 11}, at::ScalarType::Float);
636 at::Tensor input = torch::tensor({0, 1, 0, 2});
637 at::Tensor offset = torch::tensor({0, 2, 4});
638 std::vector<IValue> args{weight, input, offset};
639 testStaticRuntime(embedding_bag_default_ir, args);
640 testStaticRuntime(embedding_bag_mean_ir, args);
641 testStaticRuntime(embedding_bag_max_last_offset_ir, args);
642
643 at::Tensor weight2 = torch::randn({10, 11}, at::ScalarType::Float);
644 at::Tensor input2 = torch::tensor({0, 1, 0, 2, 1});
645 at::Tensor offset2 = torch::tensor({0, 1, 2, 3, 4, 5});
646 std::vector<IValue> args2{weight2, input2, offset2};
647 testStaticRuntime(embedding_bag_default_ir, args, args2);
648 testStaticRuntime(embedding_bag_mean_ir, args, args2);
649 testStaticRuntime(embedding_bag_max_last_offset_ir, args, args2);
650 }
651
TEST(StaticRuntime,EmbeddingBagWithMixedInt32Int64Input)652 TEST(StaticRuntime, EmbeddingBagWithMixedInt32Int64Input) {
653 const std::string embedding_bag_default = R"JIT(
654 def forward(self, a: Tensor, b: Tensor, c: Tensor):
655 x, y, z, _ = torch.embedding_bag(a, b, c)
656 return (x.clone(), y.clone(), z.clone(), _.clone())
657 )JIT";
658 auto weight = torch::randn({3, 11}, at::ScalarType::Float);
659 auto input = torch::tensor({0, 1, 0, 2}, at::ScalarType::Long);
660 auto offset = torch::tensor({0, 2, 4}, at::ScalarType::Int);
661 std::vector<IValue> args{weight, input, offset};
662 testStaticRuntime(embedding_bag_default, args);
663 }
664
TEST(StaticRuntime,LayerNorm)665 TEST(StaticRuntime, LayerNorm) {
666 const std::string layer_norm_with_weights = R"JIT(
667 def forward(self, input: Tensor, normalized_shape: List[int], weight: Tensor, bias: Tensor):
668 return torch.layer_norm(input, normalized_shape, weight, bias, 1e-05, False).clone()
669 )JIT";
670
671 const std::string layer_norm_without_weights = R"JIT(
672 def forward(self, input: Tensor, normalized_shape: List[int]):
673 return torch.layer_norm(input, normalized_shape, None, None, 1e-05, False).clone()
674 )JIT";
675
676 const std::string layer_norm_with_noncontiguous_input = R"JIT(
677 def forward(self, input: Tensor, normalized_shape: List[int], weight: Tensor, bias: Tensor):
678 input = torch.transpose(input, 1, 2)
679 return torch.layer_norm(input, normalized_shape, weight, bias, 1e-05, False).clone()
680 )JIT";
681
682 const auto a = torch::rand({1, 2, 2, 2});
683 const auto b = torch::rand({3, 2, 2, 2});
684 for (int normalized_size : {2, 3}) {
685 std::vector<int64_t> normalized_shape(normalized_size, 2);
686 const auto weight = torch::rand(normalized_shape);
687 const auto bias = torch::rand(normalized_shape);
688
689 std::vector<IValue> args{a, normalized_shape, weight, bias};
690 std::vector<IValue> args1{b, normalized_shape, weight, bias};
691 testStaticRuntime(layer_norm_with_weights, args);
692 testStaticRuntime(layer_norm_with_weights, args, args1);
693 testStaticRuntime(layer_norm_with_noncontiguous_input, args);
694
695 args = {a, normalized_shape};
696 testStaticRuntime(layer_norm_without_weights, args);
697 testStaticRuntime(layer_norm_without_weights, args, {b, normalized_shape});
698 }
699 }
700
TEST(StaticRuntime,Bmm)701 TEST(StaticRuntime, Bmm) {
702 const auto bmm_script = R"JIT(
703 def forward(self, inp: Tensor, mat2: Tensor):
704 return torch.bmm(inp, mat2).clone()
705 )JIT";
706
707 auto a = at::randn({10, 4, 5});
708 auto b = at::randn({10, 5, 6});
709
710 auto c = at::randn({12, 5, 6});
711 auto d = at::randn({12, 6, 7});
712
713 std::vector<IValue> args{a, b};
714 std::vector<IValue> args1{c, d};
715 testStaticRuntime(bmm_script, args);
716 testStaticRuntime(bmm_script, args1);
717 testStaticRuntime(bmm_script, args, args1);
718 }
719
TEST(StaticRuntime,Addmm)720 TEST(StaticRuntime, Addmm) {
721 const auto addmm_script = R"JIT(
722 def forward(self, inp: Tensor, mat1: Tensor, mat2: Tensor, beta: float, alpha: float):
723 return torch.addmm(inp, mat1, mat2, alpha=alpha, beta=beta).clone()
724 )JIT";
725 auto inp1 = at::randn({5});
726 auto mat1 = at::randn({3, 4});
727 auto mat2 = at::randn({4, 5});
728
729 auto inp2 = at::randn({3, 7});
730 auto mat3 = at::randn({3, 6});
731 auto mat4 = at::randn({6, 7});
732
733 std::vector<IValue> args{inp1, mat1, mat2, 1.0, 2.0};
734 std::vector<IValue> args1{inp2, mat3, mat4, 2.0, 1.0};
735 testStaticRuntime(addmm_script, args);
736 testStaticRuntime(addmm_script, args1);
737 testStaticRuntime(addmm_script, args, args1);
738 }
739
TEST(StaticRuntime,Abs)740 TEST(StaticRuntime, Abs) {
741 const auto abs_script = R"JIT(
742 def forward(self, a):
743 return a.abs().clone()
744 )JIT";
745 auto a = at::randn({2, 3});
746 auto b = at::randn({4, 2, 3});
747 std::vector<IValue> args{a};
748 std::vector<IValue> args2{b};
749 testStaticRuntime(abs_script, args);
750 testStaticRuntime(abs_script, args, args2);
751 }
752
TEST(StaticRuntime,Binary)753 TEST(StaticRuntime, Binary) {
754 const auto add_script = R"JIT(
755 def forward(self, a, b):
756 c = a + b
757 return (c.clone())
758 )JIT";
759
760 const auto add_script_ints = R"JIT(
761 def forward(self, a: int, b: int):
762 c = a + b
763 d = c + 1
764 return d
765 )JIT";
766
767 const auto add_list_script = R"JIT(
768 def forward(self, a: List[int], b: List[int]):
769 c = a + b
770 return c[::]
771 )JIT";
772
773 const auto list_construct_script = R"JIT(
774 def forward(self, a, b):
775 return [a, b]
776 )JIT";
777
778 const auto list_construct_script_2 = R"JIT(
779 def forward(self, a, b):
780 c = a + a
781 return [c, c]
782 )JIT";
783
784 const auto list_construct_script_3 = R"JIT(
785 def forward(self, a, b):
786 c = a + a
787 return [c, c.flatten()]
788 )JIT";
789
790 const auto list_unpack_script = R"JIT(
791 def forward(self, a, b):
792 c = [a, b]
793 x, y = c
794 z = x + y
795 return z.clone()
796 )JIT";
797
798 const auto list_unpack_script_2 = R"JIT(
799 def forward(self, a, b):
800 c = [a, b]
801 x, y = c
802 z = (x, y)
803 return z
804 )JIT";
805
806 const auto tuple_construct_script = R"JIT(
807 def forward(self, a, b):
808 return (a, b)
809 )JIT";
810
811 const auto tuple_construct_script_2 = R"JIT(
812 def forward(self, a, b):
813 return (a.flatten(), b)
814 )JIT";
815
816 auto a = at::randn({2, 3});
817 auto b = at::ones({2, 3});
818
819 auto c = at::randn({4, 2, 3});
820 auto d = at::ones({4, 2, 3});
821
822 std::vector<IValue> args{a, b};
823
824 testStaticRuntime(add_script, args);
825 testStaticRuntime(add_script_ints, {1, 2});
826 testStaticRuntime(add_script, args, {c, d});
827 testStaticRuntime(list_construct_script, args);
828 testStaticRuntime(list_construct_script_2, args);
829 testStaticRuntime(list_construct_script_3, args);
830 testStaticRuntime(list_unpack_script, args);
831 testStaticRuntime(list_unpack_script_2, args);
832 testStaticRuntime(tuple_construct_script, args);
833 testStaticRuntime(tuple_construct_script_2, args);
834
835 std::vector<IValue> list_args{
836 c10::List<int64_t>{1, 2, 3}, c10::List<int64_t>{4, 5, 6}};
837 testStaticRuntime(add_list_script, list_args);
838 }
839
TEST(StaticRuntime,MatMul)840 TEST(StaticRuntime, MatMul) {
841 const auto aten_matmul = R"JIT(
842 def forward(self, a: Tensor, b: Tensor):
843 return torch.matmul(a, b).clone()
844 )JIT";
845
846 // 1-D, 1-D
847 std::vector<IValue> args{at::randn({3}), at::randn({3})};
848 testStaticRuntime(aten_matmul, args);
849 // 2-D, 2-D
850 std::vector<IValue> args1 = {at::randn({3, 2}), at::randn({2, 3})};
851 testStaticRuntime(aten_matmul, args1);
852 // 1-D, 2-D
853 std::vector<IValue> args2 = {at::randn({3}), at::randn({3, 5})};
854 testStaticRuntime(aten_matmul, args2);
855 // 2-D, 1-D
856 std::vector<IValue> args3 = {at::randn({3, 5}), at::randn({5})};
857 testStaticRuntime(aten_matmul, args3);
858 // > 2-D , > 2-D
859 std::vector<IValue> args4 = {at::randn({3, 1, 4, 5}), at::randn({2, 5, 6})};
860 testStaticRuntime(aten_matmul, args4);
861
862 testStaticRuntime(aten_matmul, args3, args4);
863 }
864
TEST(StaticRuntime,Sign)865 TEST(StaticRuntime, Sign) {
866 const auto sign_tensor = R"JIT(
867 def forward(self, input: Tensor):
868 return torch.sign(input).clone()
869 )JIT";
870
871 auto a = at::randn({2, 3});
872 auto b = at::randn({4, 3, 2});
873
874 std::vector<IValue> args{a};
875 testStaticRuntime(sign_tensor, args);
876 testStaticRuntime(sign_tensor, args, {b});
877 }
878
TEST(StaticRuntime,Div)879 TEST(StaticRuntime, Div) {
880 const auto div_tensor = R"JIT(
881 def forward(self, a: Tensor, b: Tensor):
882 return torch.div(a, b).clone()
883 )JIT";
884
885 const auto div_scalar = R"JIT(
886 def forward(self, a: Tensor, b: int):
887 return torch.div(a, b).clone()
888 )JIT";
889
890 const auto div_tensor_mode = R"JIT(
891 def forward(self, a: Tensor, b: Tensor, c: str):
892 return torch.div(a, b, rounding_mode=c).clone()
893 )JIT";
894
895 const auto div_scalar_mode = R"JIT(
896 def forward(self, a: Tensor, b: float, c: str):
897 return torch.div(a, b, rounding_mode=c).clone()
898 )JIT";
899
900 const auto div_strided = R"JIT(
901 def forward(self, a: Tensor, b: Tensor):
902 a_strided = torch.transpose(a, 0, 1)
903 b_strided = torch.transpose(b, 0, 1)
904 return torch.div(a_strided, b_strided).clone()
905 )JIT";
906
907 auto a = at::randn({2, 3});
908 auto b = at::randn({2, 3});
909 auto bs = at::randn({3, 2}).transpose(0, 1);
910 auto c = at::randn({4, 3, 2});
911 auto d = at::randn({4, 3, 2});
912 auto ds = at::randn({3, 4, 2}).transpose(0, 1);
913
914 std::vector<IValue> args0{a, b};
915 testStaticRuntime(div_tensor, args0);
916 testStaticRuntime(div_tensor, args0, {c, d});
917
918 testStaticRuntime(div_strided, args0);
919 testStaticRuntime(div_strided, args0, {c, d});
920
921 testStaticRuntime(div_tensor, {a, bs});
922 testStaticRuntime(div_tensor, {a, bs}, {c, ds});
923
924 std::vector<IValue> args1{a, 3};
925 testStaticRuntime(div_scalar, args1);
926 testStaticRuntime(div_scalar, args1, {c, 4});
927
928 std::vector<IValue> args2{a, b, "floor"};
929 testStaticRuntime(div_tensor_mode, args2);
930 testStaticRuntime(div_tensor_mode, args2, {c, d, "floor"});
931
932 std::vector<IValue> args3{a, 2.3, "trunc"};
933 testStaticRuntime(div_scalar_mode, args3);
934 testStaticRuntime(div_scalar_mode, args3, {c, 1.5, "trunc"});
935 }
936
TEST(StaticRuntime,Mul)937 TEST(StaticRuntime, Mul) {
938 const auto mul_tensor = R"JIT(
939 def forward(self, a: Tensor, b: Tensor):
940 return torch.mul(a, b).clone()
941 )JIT";
942
943 const auto mul_scalar = R"JIT(
944 def forward(self, a: Tensor, b: int):
945 return torch.mul(a, b).clone()
946 )JIT";
947
948 const auto mul_list = R"JIT(
949 def forward(self, a: List[int], n: int):
950 b = a * n
951 return b[::]
952 )JIT";
953
954 auto a = at::randn({3, 3});
955 auto b = at::randn({3, 3});
956 auto c = at::randn({3, 3, 3});
957 auto d = at::randn({3, 3, 3});
958
959 std::vector<IValue> tensor_args1{a, b};
960 std::vector<IValue> tensor_args2{c, d};
961
962 testStaticRuntime(mul_tensor, tensor_args1);
963 testStaticRuntime(mul_tensor, tensor_args1, tensor_args2);
964
965 std::vector<IValue> scalar_args1{a, 42};
966 std::vector<IValue> scalar_args2{c, 42};
967
968 testStaticRuntime(mul_scalar, scalar_args1);
969 testStaticRuntime(mul_scalar, scalar_args1, scalar_args2);
970
971 std::vector<IValue> list_args{c10::List<int64_t>{1, 2}, 3};
972 testStaticRuntime(mul_list, list_args);
973 }
974
TEST(StaticRuntime,Log)975 TEST(StaticRuntime, Log) {
976 const auto log_tensor = R"JIT(
977 def forward(self, inp: Tensor):
978 a = torch.log(inp).clone()
979 return (a)
980 )JIT";
981
982 // Ensure that the input values are valid.
983 auto a = at::abs(at::randn({2, 3}));
984 auto b = at::abs(at::randn({4, 3, 2}));
985
986 std::vector<IValue> args{a};
987 testStaticRuntime(log_tensor, args);
988 testStaticRuntime(log_tensor, args, {b});
989 }
990
TEST(StaticRuntime,Sub)991 TEST(StaticRuntime, Sub) {
992 const auto sub_tensor = R"JIT(
993 def forward(self, a: Tensor, b: Tensor):
994 return torch.sub(a, b).clone()
995 )JIT";
996
997 const auto sub_scalar = R"JIT(
998 def forward(self, a: Tensor, b: int):
999 return torch.sub(a, b).clone()
1000 )JIT";
1001
1002 const auto sub_tensor_alpha = R"JIT(
1003 def forward(self, a: Tensor, b: Tensor, c: float):
1004 return torch.sub(a, b, alpha=c).clone()
1005 )JIT";
1006
1007 const auto sub_scalar_alpha = R"JIT(
1008 def forward(self, a: Tensor, b: float, c: int):
1009 return torch.sub(a, b, alpha=c).clone()
1010 )JIT";
1011
1012 const auto sub_two_scalars = R"JIT(
1013 def forward(self, a: int, b: int):
1014 return (a - b - b)
1015 )JIT";
1016
1017 auto a = at::randn({2, 3});
1018 auto b = at::randn({2, 3});
1019 auto c = at::randn({4, 3, 2});
1020 auto d = at::randn({4, 3, 2});
1021
1022 std::vector<IValue> args0{a, b};
1023 testStaticRuntime(sub_tensor, args0);
1024 testStaticRuntime(sub_tensor, args0, {c, d});
1025
1026 std::vector<IValue> args1{a, 3};
1027 testStaticRuntime(sub_scalar, args1);
1028 testStaticRuntime(sub_scalar, args1, {c, 4});
1029
1030 std::vector<IValue> args2{a, b, 2.3};
1031 testStaticRuntime(sub_tensor_alpha, args2);
1032 testStaticRuntime(sub_tensor_alpha, {c, d, 3.1});
1033
1034 std::vector<IValue> args3{a, 2.3, 4};
1035 testStaticRuntime(sub_scalar_alpha, args3);
1036 testStaticRuntime(sub_scalar_alpha, {c, 1.3, 2});
1037
1038 std::vector<IValue> args4{1, 2};
1039 testStaticRuntime(sub_two_scalars, args4);
1040 }
1041
TEST(StaticRuntime,NanToNum)1042 TEST(StaticRuntime, NanToNum) {
1043 const auto nan_to_num_script = R"JIT(
1044 def forward(self, a: Tensor, nan: float, posinf: float, neginf: float):
1045 return torch.nan_to_num(a, nan, posinf, neginf).clone()
1046 )JIT";
1047
1048 const auto inf = std::numeric_limits<double>::infinity();
1049 const auto nan = std::numeric_limits<double>::quiet_NaN();
1050
1051 auto a = torch::tensor({{1.0, nan}, {-inf, inf}});
1052 auto b = at::randn({3, 6});
1053 float* b_data = b.data_ptr<float>();
1054 b_data[0] = nan;
1055 b_data[4] = -inf;
1056 b_data[11] = inf;
1057 b_data[13] = nan;
1058
1059 std::vector<IValue> args1{a, 1.0, 2.0, -2.0};
1060 std::vector<IValue> args2{b, 1.0, 2.0, -2.0};
1061
1062 testStaticRuntime(
1063 nan_to_num_script,
1064 args1,
1065 /*args2*/ {},
1066 /*use_allclose*/ true,
1067 /*use_equalnan*/ true);
1068 testStaticRuntime(
1069 nan_to_num_script,
1070 args1,
1071 args2,
1072 /*use_allclose*/ true,
1073 /*use_equalnan*/ true);
1074 }
1075
TEST(StaticRuntime,Stack)1076 TEST(StaticRuntime, Stack) {
1077 const auto stack_dim = R"JIT(
1078 def forward(self, a: Tensor, b: Tensor, dim: int):
1079 inputs = [a]
1080 inputs.append(b) # mutation to avoid using VarStack
1081 return torch.stack(inputs, dim = dim).clone()
1082 )JIT";
1083
1084 const auto stack_three = R"JIT(
1085 def forward(self, a: Tensor, b: Tensor, c: Tensor):
1086 inputs = [a, b]
1087 inputs.append(c) # mutation to avoid using VarStack
1088 return torch.stack(inputs).clone()
1089 )JIT";
1090
1091 auto a = at::randn({2, 2});
1092 auto b = at::randn({2, 2});
1093 auto c = at::randn({2, 2});
1094
1095 auto d = at::randn({3, 3, 3});
1096 auto e = at::randn({3, 3, 3});
1097 auto f = at::randn({3, 3, 3});
1098
1099 std::vector<IValue> args1_dim{a, b, 0};
1100 std::vector<IValue> args2_dim{d, e, 1};
1101 std::vector<IValue> args_dim_negative{d, e, -1};
1102
1103 std::vector<IValue> args1_three_tensors{a, b, c};
1104 std::vector<IValue> args2_three_tensors{d, e, f};
1105
1106 testStaticRuntime(stack_dim, args1_dim);
1107 testStaticRuntime(stack_dim, args1_dim, args2_dim);
1108
1109 testStaticRuntime(stack_dim, args_dim_negative);
1110
1111 testStaticRuntime(stack_three, args1_three_tensors);
1112 testStaticRuntime(stack_three, args1_three_tensors, args2_three_tensors);
1113 }
1114
TEST(StaticRuntime,ReLU)1115 TEST(StaticRuntime, ReLU) {
1116 const auto relu_script = R"JIT(
1117 def forward(self, a: Tensor):
1118 return torch.relu(a).clone()
1119 )JIT";
1120 auto a = at::randint(-10, 10, {2, 4});
1121 auto b = at::randint(-10, 10, {3, 6});
1122
1123 std::vector<IValue> args1{a};
1124 std::vector<IValue> args2{b};
1125
1126 testStaticRuntime(relu_script, args1);
1127 testStaticRuntime(relu_script, args1, args2);
1128 }
1129
TEST(StaticRuntime,Tanh)1130 TEST(StaticRuntime, Tanh) {
1131 const auto tanh_script = R"JIT(
1132 def forward(self, a):
1133 return torch.tanh(a).clone()
1134 )JIT";
1135 auto a = at::randn({2, 2});
1136 auto b = at::randn({3, 3, 3});
1137
1138 std::vector<IValue> args1{a};
1139 std::vector<IValue> args2{b};
1140
1141 testStaticRuntime(tanh_script, args1, /*args2*/ {}, /*use_allclose*/ true);
1142 testStaticRuntime(tanh_script, args1, args2, /*use_allclose*/ true);
1143 }
1144
TEST(StaticRuntime,Norm)1145 TEST(StaticRuntime, Norm) {
1146 const auto norm_2arg = R"JIT(
1147 def forward(self, a: Tensor, p: int):
1148 return torch.norm(a, p).clone()
1149 )JIT";
1150
1151 const auto norm_3arg = R"JIT(
1152 def forward(self, a: Tensor, p: int, dtype: int):
1153 return torch.norm(a, p, dtype=dtype).clone()
1154 )JIT";
1155
1156 const auto norm_4arg = R"JIT(
1157 def forward(self, a: Tensor, p: int, dim: List[int], keepdim: bool):
1158 return torch.norm(a, p, dim, keepdim).clone()
1159 )JIT";
1160
1161 const auto norm_5arg = R"JIT(
1162 def forward(self, a: Tensor, p: int, dim: List[int], keepdim: bool, dtype: int):
1163 return torch.norm(a, p, dim, keepdim, dtype=dtype).clone()
1164 )JIT";
1165
1166 auto a = at::randn({2, 3});
1167 auto b = at::randn({4, 3, 5});
1168 auto dim = std::vector<int64_t>({1});
1169 auto dtype = at::ScalarType::Float;
1170
1171 std::vector<IValue> args2{a, 2};
1172 testStaticRuntime(norm_2arg, args2);
1173 testStaticRuntime(norm_2arg, args2, {b, 2}, false, false, false);
1174
1175 std::vector<IValue> args3{a, 2, dtype};
1176 testStaticRuntime(norm_3arg, args3);
1177 testStaticRuntime(norm_3arg, args3, {b, 2, dtype}, false, false, false);
1178
1179 std::vector<IValue> args4{a, 3, dim, false};
1180 testStaticRuntime(norm_4arg, args4);
1181 testStaticRuntime(norm_4arg, args4, {b, 3, dim, false});
1182
1183 std::vector<IValue> args5{a, 4, dim, true, dtype};
1184 testStaticRuntime(norm_5arg, args5);
1185 testStaticRuntime(norm_5arg, args5, {b, 4, dim, true, dtype});
1186 }
1187
TEST(StaticRuntime,Reshape)1188 TEST(StaticRuntime, Reshape) {
1189 const auto reshape_script_1 = R"JIT(
1190 def forward(self, a: Tensor, shape: List[int]):
1191 b = a.reshape(shape)
1192 return b + b
1193 )JIT";
1194
1195 const auto reshape_script_2 = R"JIT(
1196 def forward(self, a: Tensor, shape: List[int]):
1197 b = a.transpose(0, 1)
1198 return b.reshape(shape)
1199 )JIT";
1200
1201 const auto reshape_script_3 = R"JIT(
1202 def forward(self, inp: Tensor, shape: List[int]):
1203 a = inp + inp
1204 b = a.reshape(shape)
1205 c = a.reshape(shape)
1206 d = c + c
1207 e = d + d
1208 f = e * e
1209 g = f * f
1210 return b.reshape(shape), g
1211 )JIT";
1212
1213 // exercise reshape_copy and flatten_copy
1214 const auto reshape_script_4 = R"JIT(
1215 def forward(self, inp: Tensor, shape: List[int]):
1216 k = inp + inp
1217 a = k + k
1218 b = a.reshape(shape)
1219 c = a.flatten().reshape(shape)
1220 return b + c
1221 )JIT";
1222
1223 // exercise reshape_copy
1224 const auto reshape_script_5 = R"JIT(
1225 def forward(self, inp: Tensor, shape: List[int]):
1226 a = inp + inp
1227 b = a.reshape(shape)
1228 c = a.reshape(shape).relu()
1229 d = c + c
1230 e = d + d
1231 f = e * e
1232 g = f * f
1233 return g
1234 )JIT";
1235
1236 const auto reshape_inplace_script = R"JIT(
1237 def forward(self, inp: Tensor, shape: List[int]):
1238 a = inp + inp
1239 b = a.reshape(shape)
1240 c = b.sigmoid_()
1241 d = c + c
1242 e = a + a
1243 f = b + b
1244 return (d, e, f)
1245 )JIT";
1246
1247 // b is in_contiguous
1248 const auto reshape_incontiguous_script = R"JIT(
1249 def forward(self, a: Tensor, shape: List[int]):
1250 b = a.transpose(0, 1)
1251 c = b.reshape(shape)
1252 c = c.relu()
1253 return (c)
1254 )JIT";
1255
1256 auto a = at::randn({2, 3});
1257 auto b = std::vector<int64_t>({3, 2});
1258 std::vector<IValue> args{a, b};
1259
1260 auto c = at::randn({4, 5});
1261 auto d = std::vector<int64_t>({5, 1, 2, 2});
1262 std::vector<IValue> args1{c, d};
1263
1264 testStaticRuntime(reshape_script_1, args);
1265 testStaticRuntime(reshape_script_2, args);
1266 testStaticRuntime(reshape_script_3, args);
1267 testStaticRuntime(reshape_script_4, args);
1268 testStaticRuntime(reshape_script_5, args);
1269 testStaticRuntime(reshape_inplace_script, args);
1270 testStaticRuntime(reshape_incontiguous_script, args);
1271
1272 testStaticRuntime(reshape_script_1, args, args1);
1273 testStaticRuntime(reshape_script_2, args, args1);
1274 testStaticRuntime(reshape_script_3, args, args1);
1275 testStaticRuntime(reshape_script_4, args, args1);
1276 testStaticRuntime(reshape_script_5, args, args1);
1277 testStaticRuntime(reshape_inplace_script, args, args1);
1278 testStaticRuntime(reshape_incontiguous_script, args, args1);
1279 }
1280
TEST(StaticRuntime,Repeat)1281 TEST(StaticRuntime, Repeat) {
1282 const std::string repeat = R"JIT(
1283 def forward(self, a: Tensor, repeats: List[int]):
1284 return torch.repeat(a, repeats).clone()
1285 )JIT";
1286
1287 auto a = at::randn({2, 3});
1288 auto b = at::randn({4, 3});
1289 auto c = std::vector<int64_t>({1, 2});
1290 auto d = std::vector<int64_t>({2, 3});
1291 std::vector<IValue> args1{a, c};
1292 std::vector<IValue> args2{b, d};
1293
1294 testStaticRuntime(repeat, args1);
1295 testStaticRuntime(repeat, args2);
1296 testStaticRuntime(repeat, args1, args2);
1297 }
1298
TEST(StaticRuntime,Flatten)1299 TEST(StaticRuntime, Flatten) {
1300 // exercise flatten_copy
1301 const auto flatten_script_1 = R"JIT(
1302 def forward(self, a: Tensor, start_dim: int, end_dim: int):
1303 b = a * a
1304 c = torch.flatten(b, start_dim, end_dim)
1305 d = torch.relu(c)
1306 return d
1307 )JIT";
1308
1309 const auto flatten_script_2 = R"JIT(
1310 def forward(self, a: Tensor, start_dim: int, end_dim: int):
1311 b = a.transpose(0, 1)
1312 return torch.flatten(b, start_dim, end_dim).clone()
1313 )JIT";
1314
1315 auto test_flatten =
1316 [&](std::vector<int64_t> shape, int64_t start_dim, int64_t end_dim) {
1317 std::vector<int64_t> shape1(shape);
1318 if (shape1.size() > 0) {
1319 shape1[0] *= 6;
1320 }
1321 auto a = at::randn(shape);
1322 auto b = at::randn(shape1);
1323 std::vector<IValue> args{a, start_dim, end_dim};
1324 bool check_resize = shape1.size() > 0;
1325 testStaticRuntime(flatten_script_1, args);
1326 testStaticRuntime(
1327 flatten_script_1,
1328 args,
1329 {b, start_dim, end_dim},
1330 false, /* use_allclose */
1331 false, /* use_equalnan */
1332 check_resize);
1333 if (shape.size() > 2) {
1334 testStaticRuntime(flatten_script_2, args);
1335 testStaticRuntime(flatten_script_2, args, {b, start_dim, end_dim});
1336 }
1337 };
1338
1339 test_flatten({2, 3}, 0, 1);
1340 test_flatten({2, 1, 3}, 1, 2);
1341 test_flatten({0, 1, 3, 0}, 1, 2);
1342 test_flatten({2, 3}, 1, 1);
1343 test_flatten({}, 0, 0);
1344 }
1345
TEST(StaticRuntime,pow)1346 TEST(StaticRuntime, pow) {
1347 const auto pow_script_ten_sca = R"JIT(
1348 def forward(self, input : Tensor, exponent : int):
1349 return torch.pow(input, exponent).clone()
1350 )JIT";
1351
1352 const auto pow_script_ten_ten = R"JIT(
1353 def forward(self, input : Tensor, exponent : Tensor):
1354 return torch.pow(input, exponent).clone()
1355 )JIT";
1356
1357 const auto pow_script_sca_ten = R"JIT(
1358 def forward(self, input : int, exponent : Tensor):
1359 return torch.pow(input, exponent).clone()
1360 )JIT";
1361
1362 auto a = at::randn({2, 3});
1363 auto b = at::randn({2, 3});
1364 auto c = at::randn({4, 3, 2});
1365 auto d = at::randn({4, 3, 2});
1366
1367 std::vector<IValue> args0{a, 4};
1368 testStaticRuntime(pow_script_ten_sca, args0);
1369 testStaticRuntime(pow_script_ten_sca, args0, {c, 4});
1370
1371 std::vector<IValue> args1{at::abs(a), b};
1372 testStaticRuntime(pow_script_ten_ten, args1);
1373 testStaticRuntime(pow_script_ten_ten, args1, {at::abs(c), d});
1374
1375 std::vector<IValue> args2{5, b};
1376 testStaticRuntime(pow_script_sca_ten, args2);
1377 testStaticRuntime(pow_script_sca_ten, args2, {3, d});
1378 }
1379
TEST(StaticRuntime,to)1380 TEST(StaticRuntime, to) {
1381 const auto to_script_dtype = R"JIT(
1382 def forward(self, input: Tensor, dtype: int, non_blocking: bool, copy: bool, memory_format: int):
1383 a = input + input
1384 return torch.to(a, dtype, non_blocking, copy, memory_format).clone()
1385 )JIT";
1386
1387 const auto to_script_dtype_strided = R"JIT(
1388 def forward(self, input: Tensor, dtype: int, non_blocking: bool, copy: bool, memory_format: int):
1389 b = input.permute(0, 2, 3, 1)
1390 return torch.to(b, dtype, non_blocking, copy, memory_format).clone()
1391 )JIT";
1392
1393 const auto to_script_prim_dtype = R"JIT(
1394 def forward(self, input:Tensor, dtype: Optional[int], non_blocking: bool, copy: bool):
1395 a = input + input
1396 return torch.to(a, dtype, non_blocking, copy).clone()
1397 )JIT";
1398
1399 const auto to_script_other = R"JIT(
1400 def forward(self, input:Tensor, other: Tensor, non_blocking: bool, copy: bool, memory_format: int):
1401 a = input + input
1402 return torch.to(a, other, non_blocking, copy, memory_format).clone()
1403 )JIT";
1404
1405 // if input is float tensor, b could be alias of a
1406 const auto to_script_alias = R"JIT(
1407 def forward(self, input:Tensor):
1408 a = input + input
1409 b = a.float()
1410 c = b * b
1411 return (c)
1412 )JIT";
1413
1414 const auto to_script_fails_managed_output_check = R"JIT(
1415 def forward(self, a, b):
1416 d = a.half() * b.half()
1417 e = d.float()
1418 return e
1419 )JIT";
1420
1421 const auto to_script_select_tensor_output_into_tuple = R"JIT(
1422 def forward(self, a, b):
1423 d = a.half() * b.half()
1424 e = d.float()
1425 return (d, e)
1426 )JIT";
1427
1428 const auto to_script_memory_planning_fail = R"JIT(
1429 def forward(self, a, b):
1430 d = a.half() * b.half()
1431 e = d.float().relu()
1432 return e
1433 )JIT";
1434
1435 auto test_to = [&](at::ScalarType b, bool c, bool d, c10::MemoryFormat e) {
1436 auto a = at::randn({4, 3, 1, 2});
1437 auto other = at::randn({4, 3, 1, 2}).to(b);
1438 auto a2 = at::randn({3, 2, 2, 4});
1439 auto a2_other = at::randn({3, 2, 2, 4}).to(b);
1440
1441 std::vector<IValue> args0{a, b, c, d, e};
1442 std::vector<IValue> args1{a, b, c, d};
1443 std::vector<IValue> args2{a, other, c, d, e};
1444 std::vector<IValue> args2WithDifferentOtherType{
1445 a, at::randn({4, 3, 1, 2}, ScalarType::Double), c, d, e};
1446 std::vector<IValue> args3{a, std::nullopt, c, d};
1447
1448 std::vector<IValue> args0WithInt{a, ScalarType::Int, c, d, e};
1449 testStaticRuntime(
1450 to_script_dtype,
1451 args0,
1452 args0WithInt,
1453 /* default for use_allclose */ false,
1454 /* default for use_equalnan */ false,
1455 /* check_resize */ false);
1456 testStaticRuntime(to_script_dtype_strided, args0);
1457 testStaticRuntime(to_script_prim_dtype, args1);
1458 if (!d) {
1459 testStaticRuntime(to_script_prim_dtype, args3);
1460 }
1461 // Second set of args tests case where the `other` tensor's dtype
1462 // changes between iterations.
1463 testStaticRuntime(
1464 to_script_other,
1465 args2,
1466 args2WithDifferentOtherType,
1467 /* default for use_allclose */ false,
1468 /* default for use_equalnan */ false,
1469 /* check_resize */ false);
1470 testStaticRuntime(to_script_alias, {a});
1471
1472 testStaticRuntime(to_script_memory_planning_fail, {a, a});
1473 testStaticRuntime(to_script_fails_managed_output_check, {a, a});
1474 testStaticRuntime(to_script_select_tensor_output_into_tuple, {a, a});
1475
1476 // dynamic shapes
1477 testStaticRuntime(to_script_dtype, args0, {a2, b, c, d, e});
1478 testStaticRuntime(to_script_dtype_strided, args0, {a2, b, c, d, e});
1479 testStaticRuntime(to_script_prim_dtype, args1, {a2, b, c, d});
1480 if (!d) {
1481 testStaticRuntime(to_script_prim_dtype, args3, {a2, std::nullopt, c, d});
1482 }
1483 testStaticRuntime(to_script_other, args2, {a2, a2_other, c, d, e});
1484 testStaticRuntime(to_script_alias, {a}, {a2});
1485 };
1486 for (const bool non_blocking : {false, true}) {
1487 for (const bool copy : {false, true}) {
1488 // float->float, NCHW->NHWC
1489 test_to(
1490 at::ScalarType::Float,
1491 non_blocking,
1492 copy,
1493 c10::MemoryFormat::ChannelsLast);
1494 // float->half
1495 test_to(
1496 at::ScalarType::Half,
1497 non_blocking,
1498 copy,
1499 c10::MemoryFormat::Preserve);
1500 // float->float
1501 test_to(
1502 at::ScalarType::Float,
1503 non_blocking,
1504 copy,
1505 c10::MemoryFormat::Contiguous);
1506 test_to(
1507 at::ScalarType::Bool,
1508 non_blocking,
1509 copy,
1510 c10::MemoryFormat::Contiguous);
1511 // TODO: check if fbgemm is enabled properly in this case
1512 // half->float, NCHW->NHWC
1513 test_to(
1514 at::ScalarType::Half,
1515 non_blocking,
1516 copy,
1517 c10::MemoryFormat::ChannelsLast);
1518 }
1519 }
1520 }
1521
TEST(StaticRuntime,ExpandAs)1522 TEST(StaticRuntime, ExpandAs) {
1523 const auto expand_as_script = R"JIT(
1524 def forward(self, input: Tensor, other:Tensor):
1525 a = input.expand_as(other)
1526 return a.clone()
1527 )JIT";
1528
1529 auto a = at::randn({3, 1});
1530 auto b = at::randn({3, 2});
1531 auto c = at::randn({4, 1});
1532 auto d = at::randn({4, 2});
1533 std::vector<IValue> args{a, b};
1534 std::vector<IValue> args2{c, d};
1535 testStaticRuntime(expand_as_script, args);
1536 testStaticRuntime(expand_as_script, args, args2);
1537 }
1538
TEST(StaticRuntime,Full)1539 TEST(StaticRuntime, Full) {
1540 const auto full_script = R"JIT(
1541 def forward(self,
1542 size: List[int],
1543 fill_value: int,
1544 dtype: Optional[int],
1545 layout: Optional[int],
1546 device: Optional[Device],
1547 pin_memory: Optional[bool]):
1548 a = torch.full(size,
1549 fill_value,
1550 dtype=dtype,
1551 layout=layout,
1552 device=device,
1553 pin_memory=pin_memory)
1554 return (a.clone())
1555 )JIT";
1556
1557 auto cpu = at::Device(DeviceType::CPU);
1558 c10::List<int64_t> size0{2, 5};
1559 std::vector<IValue> args{
1560 size0, 4, at::ScalarType::Int, at::kStrided, cpu, false};
1561 std::vector<IValue> args1{
1562 size0, 4, at::ScalarType::Float, at::kStrided, cpu, false};
1563 c10::List<int64_t> size1{5, 6};
1564 std::vector<IValue> args2{
1565 size1, 5, at::ScalarType::Float, at::kStrided, cpu, false};
1566 testStaticRuntime(full_script, args);
1567 testStaticRuntime(
1568 full_script,
1569 args,
1570 args1,
1571 /*use_allclose=*/false,
1572 /*use_equalnan=*/false,
1573 /*check_resize=*/false);
1574 testStaticRuntime(full_script, args, args2);
1575 }
1576
TEST(StaticRuntime,FullLike)1577 TEST(StaticRuntime, FullLike) {
1578 const auto full_like_script = R"JIT(
1579 def forward(self,
1580 a: Tensor,
1581 fill_value: int,
1582 dtype: Optional[int],
1583 layout: Optional[int],
1584 device: Optional[Device],
1585 pin_memory: Optional[bool],
1586 memory_format: Optional[int]):
1587 b = torch.full_like(a,
1588 fill_value,
1589 dtype=dtype,
1590 layout=layout,
1591 device=device,
1592 pin_memory=pin_memory,
1593 memory_format=memory_format)
1594 return (b.clone())
1595 )JIT";
1596
1597 auto a = at::randn({2, 3});
1598 auto b = at::randn({3, 4, 2});
1599 auto cpu = at::Device(DeviceType::CPU);
1600 std::vector<IValue> args{
1601 a,
1602 4,
1603 at::ScalarType::Int,
1604 at::kStrided,
1605 cpu,
1606 false,
1607 c10::MemoryFormat::Contiguous};
1608 std::vector<IValue> args1{
1609 a,
1610 4,
1611 at::ScalarType::Float,
1612 at::kStrided,
1613 cpu,
1614 false,
1615 c10::MemoryFormat::Contiguous};
1616 std::vector<IValue> args2{
1617 b,
1618 4,
1619 at::ScalarType::Float,
1620 at::kStrided,
1621 cpu,
1622 false,
1623 c10::MemoryFormat::Contiguous};
1624 testStaticRuntime(full_like_script, args);
1625 testStaticRuntime(
1626 full_like_script,
1627 args,
1628 args1,
1629 /*use_allclose=*/false,
1630 /*use_equalnan=*/false,
1631 /*check_resize=*/false);
1632 testStaticRuntime(full_like_script, args, args2);
1633 }
1634
TEST(StaticRuntime,Ones)1635 TEST(StaticRuntime, Ones) {
1636 const auto script = R"JIT(
1637 def forward(self,
1638 size: List[int],
1639 dtype: Optional[int],
1640 layout: Optional[int],
1641 device: Optional[Device],
1642 pin_memory: Optional[bool]):
1643 a = torch.ones(size,
1644 dtype=dtype,
1645 layout=layout,
1646 device=device,
1647 pin_memory=pin_memory)
1648 return (a.clone())
1649 )JIT";
1650
1651 auto dtype = at::ScalarType::Int;
1652 auto cpu = at::Device(DeviceType::CPU);
1653 c10::List<int64_t> size0{2, 5};
1654 std::vector<IValue> args{size0, dtype, at::kStrided, cpu, false};
1655 c10::List<int64_t> size1{5, 6};
1656 std::vector<IValue> args2{size1, dtype, at::kStrided, cpu, false};
1657 testStaticRuntime(script, args);
1658 testStaticRuntime(script, args, args2);
1659 }
1660
TEST(StaticRuntime,OnesLike)1661 TEST(StaticRuntime, OnesLike) {
1662 const auto script = R"JIT(
1663 def forward(self,
1664 input: Tensor,
1665 dtype: Optional[int],
1666 layout: Optional[int],
1667 device: Optional[Device],
1668 pin_memory: Optional[bool],
1669 memory_format: Optional[int]):
1670 a = torch.ones_like(input,
1671 dtype=dtype,
1672 layout=layout,
1673 device=device,
1674 pin_memory=pin_memory,
1675 memory_format=memory_format)
1676 return (a.clone())
1677 )JIT";
1678
1679 auto cpu = at::Device(DeviceType::CPU);
1680 auto input0 = at::randn({2, 5});
1681 std::vector<IValue> args{
1682 input0,
1683 at::ScalarType::Int,
1684 at::kStrided,
1685 cpu,
1686 false,
1687 c10::MemoryFormat::Contiguous};
1688 std::vector<IValue> args1{
1689 input0,
1690 at::ScalarType::Float,
1691 at::kStrided,
1692 cpu,
1693 false,
1694 c10::MemoryFormat::Contiguous};
1695 auto input1 = at::randn({5, 6});
1696 std::vector<IValue> args2{
1697 input1,
1698 at::ScalarType::Float,
1699 at::kStrided,
1700 cpu,
1701 false,
1702 c10::MemoryFormat::Contiguous};
1703 testStaticRuntime(script, args);
1704 testStaticRuntime(
1705 script,
1706 args,
1707 args1,
1708 /*use_allclose=*/false,
1709 /*use_equalnan=*/false,
1710 /*check_resize=*/false);
1711 testStaticRuntime(script, args, args2);
1712 }
1713
TEST(StaticRuntime,Zeros)1714 TEST(StaticRuntime, Zeros) {
1715 const auto script = R"JIT(
1716 def forward(self,
1717 size: List[int],
1718 dtype: Optional[int],
1719 layout: Optional[int],
1720 device: Optional[Device],
1721 pin_memory: Optional[bool]):
1722 a = torch.zeros(size,
1723 dtype=dtype,
1724 layout=layout,
1725 device=device,
1726 pin_memory=pin_memory)
1727 return (a.clone())
1728 )JIT";
1729
1730 auto cpu = at::Device(DeviceType::CPU);
1731 c10::List<int64_t> size0{2, 5};
1732 std::vector<IValue> args{
1733 size0, at::ScalarType::Int, at::kStrided, cpu, false};
1734 std::vector<IValue> args1{
1735 size0, at::ScalarType::Float, at::kStrided, cpu, false};
1736 c10::List<int64_t> size1{5, 6};
1737 std::vector<IValue> args2{
1738 size1, at::ScalarType::Float, at::kStrided, cpu, false};
1739 testStaticRuntime(script, args);
1740 testStaticRuntime(
1741 script,
1742 args,
1743 args1,
1744 /*use_allclose=*/false,
1745 /*use_equalnan=*/false,
1746 /*check_resize=*/false);
1747 testStaticRuntime(script, args, args2);
1748 }
1749
TEST(StaticRuntime,Linear)1750 TEST(StaticRuntime, Linear) {
1751 const auto linear_script = R"JIT(
1752 def forward(self, inp: Tensor, weights: Tensor, bias: Optional[Tensor]) -> Tensor:
1753 return torch.linear(inp, weights, bias).clone()
1754 )JIT";
1755
1756 auto input = at::randn({1, 2});
1757 auto weights = at::randn({1, 2});
1758 auto bias = at::randn({1, 1});
1759
1760 std::vector<IValue> args{input, weights, bias};
1761 std::vector<IValue> args_no_bias{input, weights, std::nullopt};
1762
1763 auto input2 = at::randn({6, 3});
1764 auto weights2 = at::randn({6, 3});
1765 auto bias2 = at::randn({6, 6});
1766
1767 std::vector<IValue> args2{input2, weights2, bias2};
1768 std::vector<IValue> args2_no_bias{input2, weights2, std::nullopt};
1769
1770 testStaticRuntime(linear_script, args);
1771 testStaticRuntime(linear_script, args_no_bias);
1772
1773 testStaticRuntime(linear_script, args, args2);
1774 testStaticRuntime(linear_script, args, args2_no_bias);
1775 }
1776
TEST(StaticRuntime,VarCat)1777 TEST(StaticRuntime, VarCat) {
1778 const auto var_cat_script = R"JIT(
1779 def forward(self, inp1: Tensor, inp2: Tensor, dim: int):
1780 return torch.cat([inp1, inp2], dim).clone()
1781 )JIT";
1782
1783 // 2D tensors - cat dim = 0
1784 std::vector<IValue> args1 = {at::randn({4, 6}), at::randn({5, 6}), 0};
1785 testStaticRuntime(var_cat_script, args1);
1786
1787 // 3D tensors - cat dim = 1
1788 std::vector<IValue> args2 = {at::randn({4, 5, 6}), at::randn({4, 8, 6}), 1};
1789 testStaticRuntime(var_cat_script, args2);
1790
1791 // 3D tensors - cat dim = 2
1792 std::vector<IValue> args3 = {at::randn({4, 5, 6}), at::randn({4, 5, 7}), 2};
1793 testStaticRuntime(var_cat_script, args3);
1794
1795 // Negative dim
1796 std::vector<IValue> args4 = {at::randn({4, 5, 6}), at::randn({4, 5, 7}), -1};
1797 testStaticRuntime(var_cat_script, args4);
1798
1799 testStaticRuntime(var_cat_script, args1, args2);
1800 }
1801
TEST(StaticRuntime,LeakyReLU)1802 TEST(StaticRuntime, LeakyReLU) {
1803 torch::jit::Module mod = getLeakyReLUConstScriptModel();
1804 auto inputs = torch::randn({2, 2});
1805
1806 // run jit graph executor
1807 std::vector<at::IValue> input_ivalues({inputs});
1808 at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
1809
1810 // run static runtime
1811 std::vector<c10::IValue> input_tensors({inputs});
1812 torch::jit::StaticModule smod(mod);
1813 at::Tensor output_2 = smod(input_tensors, {}).toTensor();
1814 smod.runtime().check_for_memory_leak();
1815 EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
1816 }
1817
createProcessedNodeInputs(c10::ArrayRef<uint16_t> inputs)1818 static ProcessedNodeInputs createProcessedNodeInputs(
1819 c10::ArrayRef<uint16_t> inputs) {
1820 ProcessedNodeInputs result(inputs.size());
1821 for (const auto idx : c10::irange(inputs.size())) {
1822 result[idx] = inputs[idx];
1823 }
1824 return result;
1825 }
1826
checkProcessedNodeInputs(const ProcessedNodeInputs & io,c10::ArrayRef<uint16_t> inputs)1827 static void checkProcessedNodeInputs(
1828 const ProcessedNodeInputs& io,
1829 c10::ArrayRef<uint16_t> inputs) {
1830 ASSERT_EQ(inputs.size(), io.size());
1831 for (const auto idx : c10::irange(inputs.size())) {
1832 EXPECT_EQ(inputs[idx], io[idx]);
1833 }
1834 }
1835
testProcessedNodeInputsRoundTrip(c10::ArrayRef<uint16_t> inputs)1836 static void testProcessedNodeInputsRoundTrip(c10::ArrayRef<uint16_t> inputs) {
1837 auto io = createProcessedNodeInputs(inputs);
1838 checkProcessedNodeInputs(io, inputs);
1839
1840 ProcessedNodeInputs copied(io);
1841 checkProcessedNodeInputs(copied, inputs);
1842 ProcessedNodeInputs moved(std::move(io));
1843 checkProcessedNodeInputs(moved, inputs);
1844 }
1845
TEST(ProcessedNodeInputs,Basic)1846 TEST(ProcessedNodeInputs, Basic) {
1847 std::vector<std::vector<uint16_t>> testCases = {
1848 {}, // empty
1849 {0xABCD, 0x5a5a}, // inline
1850 {0x11, 0x22, 0x33, 0x44, 0x55}, // max inline size
1851 {0x11, 0x22, 0x33, 0x44, 0x55, 0x66}, // minimum outline size
1852 std::vector<uint16_t>(100, 0x5a), // large outline size
1853 };
1854
1855 for (const auto& values : testCases) {
1856 testProcessedNodeInputsRoundTrip(values);
1857 for (const auto& values2 : testCases) {
1858 auto from = createProcessedNodeInputs(values);
1859 auto to = createProcessedNodeInputs(values2);
1860
1861 to = from;
1862 checkProcessedNodeInputs(to, values);
1863
1864 auto toMoveInto = createProcessedNodeInputs(values2);
1865 toMoveInto = std::move(from);
1866 checkProcessedNodeInputs(toMoveInto, values);
1867 }
1868 }
1869 }
1870
TEST(StaticRuntime,isinstance)1871 TEST(StaticRuntime, isinstance) {
1872 const auto isinstance_int_script = R"JIT(
1873 def forward(self, a: Any):
1874 return isinstance(a, int)
1875 )JIT";
1876
1877 const auto isinstance_tensor_script = R"JIT(
1878 def forward(self, a: Any):
1879 return isinstance(a, torch.Tensor)
1880 )JIT";
1881
1882 const auto isinstance_many_types_script = R"JIT(
1883 def forward(self, a: Any):
1884 return isinstance(a, (bool, int))
1885 )JIT";
1886
1887 auto a = at::randn({2, 2});
1888 auto b = at::randn({2, 2, 2});
1889
1890 std::vector<at::IValue> args{a};
1891 std::vector<at::IValue> args2{b};
1892
1893 testStaticRuntime(isinstance_int_script, args);
1894 testStaticRuntime(isinstance_int_script, args, args2);
1895
1896 testStaticRuntime(isinstance_tensor_script, args);
1897 testStaticRuntime(isinstance_tensor_script, args, args2);
1898
1899 testStaticRuntime(isinstance_many_types_script, args);
1900 testStaticRuntime(isinstance_many_types_script, args, args2);
1901 }
1902
TEST(StaticRuntime,TypeCheck)1903 TEST(StaticRuntime, TypeCheck) {
1904 const auto typecheck_ir = R"IR(
1905 graph(%a.1 : Tensor,
1906 %b.1 : Tensor):
1907 %t0 : Float(2, 2, strides=[2, 1], device=cpu), %t1 : Float(3, 3, strides=[3, 1]), %type_matched : bool = prim::TypeCheck[types=[Float(2, 2, strides=[2, 1], device=cpu), Float(3, 3, strides=[3, 1])]](%a.1, %b.1)
1908 return (%t0, %t1, %type_matched)
1909 )IR";
1910
1911 auto a = at::zeros({2, 2}, at::kFloat);
1912 a.to(at::kCPU);
1913 auto b = at::ones({3, 3}, at::kFloat);
1914 auto c = at::ones({2, 2, 2}, at::kFloat);
1915
1916 std::vector<IValue> args_correct = {a, b};
1917 std::vector<IValue> args_incorrect = {a, c};
1918
1919 testStaticRuntime(typecheck_ir, args_correct);
1920 testStaticRuntime(typecheck_ir, args_correct, args_incorrect);
1921 }
1922
TEST(StaticRuntime,Index)1923 TEST(StaticRuntime, Index) {
1924 const auto index_without_none_script = R"JIT(
1925 def forward(self, a: Tensor, idx: Tensor):
1926 return a[idx].clone()
1927 )JIT";
1928
1929 // Index with boolean mask
1930 auto a = at::arange(4, at::kFloat).view({2, 2});
1931 auto idx_a = torch::tensor({{0, 1}, {0, 0}}, at::kBool);
1932 std::vector<IValue> args_a{a, idx_a};
1933
1934 // Index with tensor
1935 auto b = at::arange(27, at::kFloat).view({3, 3, 3});
1936 auto idx_b = torch::tensor({0, 1, 2}, at::kLong);
1937 std::vector<IValue> args_b{b, idx_b};
1938
1939 testStaticRuntime(index_without_none_script, args_a);
1940 testStaticRuntime(index_without_none_script, args_a, args_b);
1941
1942 const auto index_with_none_script = R"JIT(
1943 def forward(self, a: Tensor, idx: Tensor, none: Optional[Tensor]):
1944 return a[idx, none].clone()
1945 )JIT";
1946
1947 // Index with None
1948 // When indexing with none, the shape of `f` becomes [2, 1, 2],
1949 // so the mask must be reshaped appropriately.
1950 auto f = at::arange(4, at::kFloat).view({2, 1, 2});
1951 auto idx_f_reshape = torch::tensor({{{0, 1}}, {{0, 0}}}, at::kBool);
1952 std::vector<IValue> args_f_with_none{f, idx_f_reshape};
1953 args_f_with_none.emplace_back();
1954
1955 testStaticRuntime(index_with_none_script, args_f_with_none);
1956 testStaticRuntime(
1957 index_with_none_script,
1958 args_f_with_none,
1959 {IValue(b), IValue(idx_b), IValue()});
1960
1961 const auto index_with_two_tensors_script = R"JIT(
1962 def forward(self, a: Tensor, idx_a: Tensor, idx_b: Tensor):
1963 return a[idx_a, idx_b].clone()
1964 )JIT";
1965
1966 // Index with multiple tensors
1967 const auto& c = a; // 2x2 tensor
1968 auto idx_c1 = torch::tensor({0, 0}, at::kLong);
1969 auto idx_c2 = torch::tensor({0}, at::kLong);
1970 std::vector<IValue> args_c{c, idx_c1, idx_c2};
1971
1972 const auto& d = b; // 3x3x3 tensor
1973 auto idx_d1 = torch::tensor({{0, 0, 2}, {0, 1, 1}}, at::kLong);
1974 auto idx_d2 = torch::tensor({{1, 1, 0}, {1, 0, 2}}, at::kLong);
1975 std::vector<IValue> args_d{d, idx_d1, idx_d2};
1976
1977 testStaticRuntime(index_with_two_tensors_script, args_c, args_d);
1978 }
1979
TEST(StaticRuntime,IndexSelect)1980 TEST(StaticRuntime, IndexSelect) {
1981 const std::string script = R"IR(
1982 graph(%self: Tensor, %dim: int, %index: Tensor):
1983 %bias: None = prim::Constant()
1984 %ret = aten::index_select(%self, %dim, %index)
1985 %cloned = aten::clone(%ret, %bias)
1986 return (%cloned)
1987 )IR";
1988
1989 auto self0 = at::rand({6});
1990 auto dim0 = 0;
1991 auto index0 = at::randint(0, 5, {6}, torch::kInt32);
1992 std::vector<IValue> args{self0, dim0, index0};
1993 testStaticRuntime(script, args);
1994
1995 auto self1 = at::rand({128});
1996 auto dim1 = 0;
1997 auto index1 = at::randint(0, 127, {127}, torch::kInt32);
1998 std::vector<IValue> args2{self1, dim1, index1};
1999 testStaticRuntime(script, args, args2);
2000 }
2001
TEST(StaticRuntime,ClampMin)2002 TEST(StaticRuntime, ClampMin) {
2003 const auto clamp_min_int_script = R"JIT(
2004 def forward(self, a: Tensor, b: int):
2005 return torch.clamp_min(a, b).clone()
2006 )JIT";
2007
2008 const auto clamp_min_float_script = R"JIT(
2009 def forward(self, a: Tensor, b: float):
2010 return torch.clamp_min(a, b).clone()
2011 )JIT";
2012
2013 auto a = at::randn({2, 2});
2014 auto b = at::randn({3, 3, 3});
2015 int scalar_int = 1;
2016 float scalar_float = 3.14;
2017
2018 std::vector<IValue> args_a_int{a, scalar_int};
2019 std::vector<IValue> args_b_int{b, scalar_int};
2020
2021 testStaticRuntime(clamp_min_int_script, args_a_int);
2022 testStaticRuntime(clamp_min_int_script, args_a_int, args_b_int);
2023
2024 std::vector<IValue> args_a_float{a, scalar_float};
2025 std::vector<IValue> args_b_float{b, scalar_float};
2026
2027 testStaticRuntime(clamp_min_float_script, args_a_float);
2028 testStaticRuntime(clamp_min_float_script, args_a_float, args_b_float);
2029 }
2030
TEST(StaticRuntime,Argmin)2031 TEST(StaticRuntime, Argmin) {
2032 const auto argmin_script = R"JIT(
2033 def forward(self, a: Tensor):
2034 return torch.argmin(a).clone()
2035 )JIT";
2036
2037 const auto argmin_with_dim_script = R"JIT(
2038 def forward(self, a: Tensor, dim: int):
2039 return torch.argmin(a, dim).clone()
2040 )JIT";
2041
2042 const auto argmin_with_keep_dim_script = R"JIT(
2043 def forward(self, a: Tensor, dim: int):
2044 return torch.argmin(a, dim, True).clone()
2045 )JIT";
2046
2047 auto a = at::randn({2, 2});
2048 auto b = at::randn({17, 2, 1});
2049
2050 testStaticRuntime(argmin_script, {a});
2051 testStaticRuntime(
2052 argmin_script,
2053 {a},
2054 {b},
2055 /* use_allclose */ false,
2056 /* use_equalnan */ false,
2057 /* check_resize */ false);
2058
2059 int dim_a = 0;
2060 int dim_b = 1;
2061
2062 std::vector<IValue> args_a{a, dim_a};
2063 std::vector<IValue> args_b{b, dim_b};
2064
2065 testStaticRuntime(argmin_with_dim_script, args_a);
2066 testStaticRuntime(argmin_with_dim_script, args_a, args_b);
2067
2068 testStaticRuntime(argmin_with_keep_dim_script, args_a);
2069 testStaticRuntime(argmin_with_keep_dim_script, args_a, args_b);
2070 }
2071
TEST(StaticRuntime,Softmax)2072 TEST(StaticRuntime, Softmax) {
2073 const auto softmax_script = R"JIT(
2074 def forward(self, a: Tensor, dim: int):
2075 return torch.softmax(a, dim).clone()
2076 )JIT";
2077
2078 const auto softmax_script_with_dtype = R"JIT(
2079 def forward(self, a: Tensor, dim: int, dtype: int):
2080 return torch.softmax(a, dim, dtype=dtype).clone()
2081 )JIT";
2082
2083 auto a = at::randn({2, 3});
2084 auto b = at::randn({3, 3, 3});
2085
2086 testStaticRuntime(softmax_script, {a, 0});
2087 testStaticRuntime(softmax_script, {a, 1});
2088
2089 testStaticRuntime(softmax_script, {b, 0});
2090 testStaticRuntime(softmax_script, {b, 1});
2091 testStaticRuntime(softmax_script, {b, 2});
2092
2093 testStaticRuntime(softmax_script_with_dtype, {a, 1, at::ScalarType::Float});
2094 testStaticRuntime(softmax_script_with_dtype, {b, 1, at::ScalarType::Float});
2095 }
2096
TEST(StaticRuntime,GetItem_Dict)2097 TEST(StaticRuntime, GetItem_Dict) {
2098 const auto getitem_dict_tensor_script = R"JIT(
2099 def forward(self, key: Tensor):
2100 d = {key: 1}
2101 return d[key]
2102 )JIT";
2103
2104 const auto getitem_dict_int_script = R"JIT(
2105 def forward(self, key: int):
2106 d = {key: 1}
2107 return d[key]
2108 )JIT";
2109
2110 const auto getitem_dict_str_script = R"JIT(
2111 def forward(self, key: str):
2112 d = {key: 1}
2113 return d[key]
2114 )JIT";
2115
2116 int int_key = 0;
2117 std::string str_key = "str";
2118
2119 // No need to test these multiple times, args are not tensors
2120 testStaticRuntime(getitem_dict_int_script, {int_key});
2121 testStaticRuntime(getitem_dict_str_script, {str_key});
2122
2123 auto a = torch::tensor({1});
2124 auto b = torch::tensor({1, 1});
2125
2126 testStaticRuntime(getitem_dict_tensor_script, {a});
2127 testStaticRuntime(getitem_dict_tensor_script, {a}, {b});
2128 }
2129
TEST(StaticRuntime,GetItem_List)2130 TEST(StaticRuntime, GetItem_List) {
2131 const auto getitem_list_int_script = R"JIT(
2132 def forward(self, idx: int):
2133 lst = [1, 2, 3]
2134 return lst[idx]
2135 )JIT";
2136
2137 const auto getitem_list_tensor_script = R"JIT(
2138 def forward(self, tensor: Tensor, idx: int):
2139 lst = [tensor, tensor]
2140 return lst[idx]
2141 )JIT";
2142
2143 testStaticRuntime(getitem_list_int_script, {1});
2144 testStaticRuntime(getitem_list_int_script, {-1});
2145
2146 auto a = torch::tensor({1});
2147 auto b = torch::tensor({1, 1});
2148
2149 testStaticRuntime(getitem_list_tensor_script, {a, 1});
2150 testStaticRuntime(getitem_list_tensor_script, {a, 1}, {b, -1});
2151 }
2152
TEST(StaticRuntime,Transpose)2153 TEST(StaticRuntime, Transpose) {
2154 const auto transpose_script = R"JIT(
2155 def forward(self, a: Tensor, dim1: int, dim2: int):
2156 return torch.transpose(a, dim1, dim2).clone()
2157 )JIT";
2158
2159 auto a = at::randn({2, 2});
2160 int dim1_a = 0;
2161 int dim2_a = 1;
2162 std::vector<IValue> args_a{a, dim1_a, dim2_a};
2163
2164 auto b = at::randn({3, 3, 3});
2165 int dim1_b = 0;
2166 int dim2_b = 2;
2167 std::vector<IValue> args_b{b, dim1_b, dim2_b};
2168
2169 testStaticRuntime(transpose_script, args_a);
2170 testStaticRuntime(transpose_script, args_a, args_b);
2171 }
2172
TEST(StaticRuntime,Permute)2173 TEST(StaticRuntime, Permute) {
2174 auto permute_script = R"JIT(
2175 def forward(self, a: Tensor, dims: List[int]):
2176 return torch.permute(a, dims).clone()
2177 )JIT";
2178
2179 auto a = at::randn({2, 2});
2180 c10::List<int64_t> dims_a{1, 0};
2181 std::vector<IValue> args_a{a, dims_a};
2182
2183 auto b = at::randn({3, 3, 3});
2184 c10::List<int64_t> dims_b{0, 2, 1};
2185 std::vector<IValue> args_b{b, dims_b};
2186
2187 testStaticRuntime(permute_script, args_a);
2188 testStaticRuntime(permute_script, args_a, args_b);
2189
2190 permute_script = R"JIT(
2191 def forward(self, a: Tensor, dims: List[int], shape: List[int]):
2192 return torch.permute(a, dims).reshape(shape).clone()
2193 )JIT";
2194
2195 a = at::randn({8, 16, 4});
2196 dims_a = {0, 2, 1};
2197 dims_b = {-1, 16};
2198 testStaticRuntime(permute_script, {a, dims_a, dims_b});
2199 }
2200
TEST(StaticRuntime,Slice)2201 TEST(StaticRuntime, Slice) {
2202 const auto slice_script = R"JIT(
2203 def forward(self, a: Tensor, dim: int, start: int, end: int, step: int):
2204 return a.slice(dim, start, end, step).clone()
2205 )JIT";
2206
2207 auto a = at::randn({2, 2});
2208 int dim_a = 1;
2209 int start_a = 0;
2210 int end_a = 1;
2211 int step_a = 1;
2212 std::vector<IValue> args_a{a, dim_a, start_a, end_a, step_a};
2213
2214 auto b = at::randn({3, 3, 3});
2215 int dim_b = 2;
2216 int start_b = 0;
2217 int end_b = 1;
2218 int step_b = 2;
2219 std::vector<IValue> args_b{b, dim_b, start_b, end_b, step_b};
2220
2221 testStaticRuntime(slice_script, args_a);
2222 testStaticRuntime(slice_script, args_a, args_b);
2223
2224 const auto slice_script2 = R"JIT(
2225 def forward(self, a: Tensor, dim: int, step: int):
2226 return a.slice(dim, None, None, step).clone()
2227 )JIT";
2228 std::vector<IValue> args_c{b, dim_b, step_b};
2229 testStaticRuntime(slice_script2, args_c);
2230 }
2231
TEST(StaticRuntime,Narrow)2232 TEST(StaticRuntime, Narrow) {
2233 const auto narrow_with_int_script = R"JIT(
2234 def forward(self, a: Tensor, dim: int, start: int, length: int):
2235 return a.narrow(dim, start, length).clone()
2236 )JIT";
2237
2238 auto a = at::randn({5, 5});
2239 int dim_a = 0;
2240 int start_a_int = 3;
2241 int len_a = 2;
2242 std::vector<IValue> args_a{a, dim_a, start_a_int, len_a};
2243
2244 auto b = at::randn({5, 5, 5});
2245 int dim_b = 1;
2246 int start_b_int = 2;
2247 int len_b = 3;
2248 std::vector<IValue> args_b{b, dim_b, start_b_int, len_b};
2249
2250 testStaticRuntime(narrow_with_int_script, args_a);
2251 testStaticRuntime(narrow_with_int_script, args_a, args_b);
2252 }
2253
TEST(StaticRuntime,TupleUnpack)2254 TEST(StaticRuntime, TupleUnpack) {
2255 const auto two_tuple_unpack_script = R"JIT(
2256 def forward(self, tup: Tuple[Tensor, Tensor]):
2257 a, b = tup
2258 return (a, b)
2259 )JIT";
2260
2261 const auto three_tuple_unpack_script = R"JIT(
2262 def forward(self, tup: Tuple[Tensor, Tensor, Tensor]):
2263 a, b, c = tup
2264 return (a, b, c)
2265 )JIT";
2266
2267 auto two_tup = c10::ivalue::Tuple::create({at::randn({1}), at::randn({1})});
2268 auto two_tup_large =
2269 c10::ivalue::Tuple::create({at::randn({2, 2}), at::randn({2, 2})});
2270
2271 auto three_tup = c10::ivalue::Tuple::create(
2272 {at::randn({1}), at::randn({1}), at::randn({1})});
2273 auto three_tup_large = c10::ivalue::Tuple::create(
2274 {at::randn({2, 2}), at::randn({2, 2}), at::randn({2, 2})});
2275
2276 testStaticRuntime(two_tuple_unpack_script, {two_tup});
2277 testStaticRuntime(two_tuple_unpack_script, {two_tup}, {two_tup_large});
2278
2279 testStaticRuntime(three_tuple_unpack_script, {three_tup});
2280 testStaticRuntime(three_tuple_unpack_script, {three_tup}, {three_tup_large});
2281 }
2282
TEST(StaticRuntime,Append)2283 TEST(StaticRuntime, Append) {
2284 const auto append_int_script = R"JIT(
2285 def forward(self, a: int):
2286 lst = [1, 2, 3]
2287 lst.append(a)
2288 return lst
2289 )JIT";
2290
2291 const auto append_tensor_script = R"JIT(
2292 def forward(self, a: Tensor):
2293 lst = []
2294 lst.append(a)
2295 return lst
2296 )JIT";
2297
2298 std::vector<IValue> args_int{1};
2299
2300 testStaticRuntime(append_int_script, args_int);
2301
2302 std::vector<IValue> args_tensor{at::randn({1})};
2303 std::vector<IValue> args_tensor_large{at::randn({2, 2})};
2304
2305 testStaticRuntime(append_tensor_script, args_tensor);
2306 testStaticRuntime(append_tensor_script, args_tensor, args_tensor_large);
2307 }
2308
TEST(StaticRuntime,QuantizedLinear)2309 TEST(StaticRuntime, QuantizedLinear) {
2310 const std::string quantize_script = R"IR(
2311 graph(%input: Tensor, %weights: Tensor):
2312 %scale: float = prim::Constant[value=1.]()
2313 %zero_point: int = prim::Constant[value=1]()
2314 %bias: None = prim::Constant()
2315 %packed_params = quantized::linear_prepack(%weights, %bias)
2316 %1254 = quantized::linear(%input, %packed_params, %scale, %zero_point)
2317 %1249: Tensor = aten::dequantize(%1254)
2318 return (%1249)
2319 )IR";
2320 at::Tensor weight =
2321 at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQInt8);
2322 at::Tensor input =
2323 at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQUInt8);
2324
2325 at::Tensor weight_2 =
2326 at::quantize_per_tensor(torch::randn({8, 3}), 2, 3, torch::kQInt8);
2327 at::Tensor input_2 =
2328 at::quantize_per_tensor(torch::randn({9, 3}), 2, 3, torch::kQUInt8);
2329
2330 testStaticRuntime(quantize_script, {input, weight}, {input_2, weight_2});
2331 }
2332
TEST(StaticRuntime,QuantizedLinearDynamicFp16)2333 TEST(StaticRuntime, QuantizedLinearDynamicFp16) {
2334 const std::string quantized_linear_dynamic_fp16_script = R"IR(
2335 graph(%input: Tensor, %weights: Tensor):
2336 %bias: None = prim::Constant()
2337 %packed_params = quantized::linear_prepack_fp16(%weights, %bias)
2338 %output = quantized::linear_dynamic_fp16(%input, %packed_params)
2339 %ret = aten::clone(%output, %bias)
2340 return (%ret)
2341 )IR";
2342 at::Tensor weight = torch::randn({3, 2}, torch::kFloat);
2343 at::Tensor input = torch::randn({3, 2}, torch::kFloat);
2344
2345 at::Tensor weight_2 = torch::randn({4, 3}, torch::kFloat);
2346 at::Tensor input_2 = torch::randn({5, 3}, torch::kFloat);
2347
2348 testStaticRuntime(
2349 quantized_linear_dynamic_fp16_script,
2350 {input, weight},
2351 {input_2, weight_2});
2352 }
2353
TEST(StaticRuntime,QuantizedLinearReluDynamicFp16)2354 TEST(StaticRuntime, QuantizedLinearReluDynamicFp16) {
2355 const std::string quantized_linear_relu_dynamic_fp16_script = R"IR(
2356 graph(%input: Tensor, %weights: Tensor):
2357 %bias: None = prim::Constant()
2358 %packed_params = quantized::linear_prepack_fp16(%weights, %bias)
2359 %output = quantized::linear_relu_dynamic_fp16(%input, %packed_params)
2360 %ret = aten::clone(%output, %bias)
2361 return (%ret)
2362 )IR";
2363 at::Tensor weight = torch::randn({3, 2}, torch::kFloat);
2364 at::Tensor input = torch::randn({3, 2}, torch::kFloat);
2365
2366 at::Tensor weight_2 = torch::randn({4, 3}, torch::kFloat);
2367 at::Tensor input_2 = torch::randn({5, 3}, torch::kFloat);
2368
2369 testStaticRuntime(
2370 quantized_linear_relu_dynamic_fp16_script,
2371 {input, weight},
2372 {input_2, weight_2});
2373 }
2374
TEST(StaticRuntime,VarStack)2375 TEST(StaticRuntime, VarStack) {
2376 const auto var_stack_script = R"JIT(
2377 def forward(self, inp1: Tensor, inp2: Tensor, dim: int):
2378 return torch.stack([inp1, inp2], dim).clone()
2379 )JIT";
2380
2381 // 2D tensors - stack dim = 0
2382 std::vector<IValue> args1 = {at::randn({6, 6}), at::randn({6, 6}), 0};
2383 testStaticRuntime(var_stack_script, args1);
2384
2385 // 3D tensors - stack dim = 1
2386 std::vector<IValue> args2 = {at::randn({4, 5, 6}), at::randn({4, 5, 6}), 1};
2387 testStaticRuntime(var_stack_script, args2);
2388
2389 // 3D tensors - stack dim = 2
2390 std::vector<IValue> args3 = {at::randn({4, 5, 6}), at::randn({4, 5, 6}), 2};
2391 testStaticRuntime(var_stack_script, args3);
2392
2393 // Negative dim
2394 std::vector<IValue> args4 = {at::randn({4, 5, 6}), at::randn({4, 5, 6}), -1};
2395 testStaticRuntime(var_stack_script, args4);
2396
2397 // Non-serial path
2398 std::vector<IValue> args5 = {at::randn({1, 2, 3}), at::randn({1, 2, 3}), 3};
2399 testStaticRuntime(var_stack_script, args5);
2400
2401 // Fast path
2402 std::vector<IValue> args6 = {at::randn({1}), at::randn({1}), 0};
2403 testStaticRuntime(var_stack_script, args6);
2404
2405 testStaticRuntime(var_stack_script, args1, args2);
2406 }
2407
TEST(StaticRuntime,FmodTensor)2408 TEST(StaticRuntime, FmodTensor) {
2409 const auto fmod_tensor = R"JIT(
2410 def forward(self, a: Tensor, b: Tensor):
2411 return torch.fmod(a, b).clone()
2412 )JIT";
2413
2414 // fmod tensor version
2415 auto a = at::randn({2, 3});
2416 auto b = at::randn({2, 3});
2417 std::vector<IValue> args0{a, b};
2418 testStaticRuntime(fmod_tensor, args0);
2419
2420 // check for dynamic shapes
2421 auto c = at::randn({4, 3, 2});
2422 auto d = at::randn({4, 3, 2});
2423 std::vector<IValue> args1{c, d};
2424 testStaticRuntime(fmod_tensor, args0, args1);
2425 }
2426
TEST(StaticRuntime,FmodScalar)2427 TEST(StaticRuntime, FmodScalar) {
2428 const auto fmod_scalar = R"JIT(
2429 def forward(self, a: Tensor, b: int):
2430 return torch.fmod(a, b).clone()
2431 )JIT";
2432
2433 auto a = at::randn({2, 3});
2434
2435 // fmod scalar version
2436 std::vector<IValue> args2{a, 3};
2437 testStaticRuntime(fmod_scalar, args2);
2438
2439 // check for dynamic shapes
2440 auto c = at::randn({4, 3, 2});
2441 std::vector<IValue> args3{c, 4};
2442 testStaticRuntime(fmod_scalar, args2, args3);
2443
2444 // test int32 version
2445 a = at::randint(-100, 100, {2, 3}, at::kInt);
2446 c = at::randint(-100, 100, {4, 3, 2}, at::kInt);
2447 testStaticRuntime(fmod_scalar, {a, 3});
2448 testStaticRuntime(fmod_scalar, {a, 3}, {c, 4});
2449 }
2450
TEST(StaticRuntime,QEmbeddingBagBytePrepack)2451 TEST(StaticRuntime, QEmbeddingBagBytePrepack) {
2452 const std::string embedding_bag_byte_prepack_script = R"IR(
2453 graph(%input: Tensor):
2454 %none : None = prim::Constant()
2455 %output: Tensor = quantized::embedding_bag_byte_prepack(%input)
2456 %res: Tensor = aten::clone(%output, %none)
2457 return (%res)
2458 )IR";
2459
2460 auto a = torch::randn({8, 16}, at::ScalarType::Float);
2461 auto b = torch::randn({8 * 2, 16 * 2}, at::ScalarType::Float);
2462
2463 testStaticRuntime(embedding_bag_byte_prepack_script, {a});
2464 testStaticRuntime(embedding_bag_byte_prepack_script, {a}, {b});
2465 }
2466
TEST(StaticRuntime,QEmbeddingBagByteUnpack)2467 TEST(StaticRuntime, QEmbeddingBagByteUnpack) {
2468 const auto src = R"IR(
2469 graph(%input: Tensor):
2470 %none : None = prim::Constant()
2471 %weight: Tensor = quantized::embedding_bag_byte_prepack(%input)
2472 %output: Tensor = quantized::embedding_bag_byte_unpack(%weight)
2473 %res: Tensor = aten::clone(%output, %none)
2474 return (%res)
2475 )IR";
2476
2477 auto a = torch::randn({8, 16}, at::ScalarType::Float);
2478 auto b = torch::randn({8 * 2, 16 * 2}, at::ScalarType::Float);
2479
2480 testStaticRuntime(src, {a});
2481 testStaticRuntime(src, {a}, {b});
2482 }
2483
TEST(StaticRuntime,LinalgNorm_ScalarOrd)2484 TEST(StaticRuntime, LinalgNorm_ScalarOrd) {
2485 const auto linalg_norm_ord_scalar = R"JIT(
2486 def forward(self, a: Tensor, ord: int, dim: List[int], keepdim: bool, dtype: int):
2487 return torch.linalg_norm(a, ord, dim, keepdim, dtype=dtype).clone()
2488 )JIT";
2489
2490 auto a = at::randn({2, 3});
2491 auto dim = std::vector<int64_t>({1});
2492 auto dtype = at::ScalarType::Float;
2493
2494 std::vector<IValue> args0{a, 4, dim, true, dtype};
2495 testStaticRuntime(linalg_norm_ord_scalar, args0);
2496
2497 auto b = at::randn({3, 2, 6});
2498 std::vector<IValue> args1{b, 4, dim, true, dtype};
2499 testStaticRuntime(linalg_norm_ord_scalar, args0, args1);
2500 }
2501
TEST(StaticRuntime,LinalgNorm_StringOrd)2502 TEST(StaticRuntime, LinalgNorm_StringOrd) {
2503 const auto linalg_norm_ord_str = R"JIT(
2504 def forward(self, a: Tensor, ord: str, dim: List[int], keepdim: bool, dtype: int):
2505 return torch.linalg_norm(a, ord, dim, keepdim, dtype=dtype).clone()
2506 )JIT";
2507
2508 auto a = at::randn({2, 3});
2509 auto dim = std::vector<int64_t>({0, 1});
2510 auto dtype = at::ScalarType::Float;
2511
2512 std::vector<IValue> args0{a, "fro", dim, true, dtype};
2513 testStaticRuntime(linalg_norm_ord_str, args0);
2514
2515 auto b = at::randn({3, 2, 17});
2516 std::vector<IValue> args1{b, "fro", dim, true, dtype};
2517 testStaticRuntime(linalg_norm_ord_str, args0, args1);
2518 }
2519
TEST(StaticRuntime,Index_Put)2520 TEST(StaticRuntime, Index_Put) {
2521 const auto index_put_str = R"JIT(
2522 def forward(self, a: Tensor, indices: Tuple[Optional[Tensor]], values: Tensor, accumulate: bool):
2523 return torch.index_put(a, indices, values, accumulate).clone()
2524 )JIT";
2525
2526 auto a = at::randn({2});
2527 auto indices_a = std::make_tuple(torch::tensor({0}, at::kLong));
2528 auto values_a = at::randn({1});
2529
2530 std::vector<IValue> args0{a, indices_a, values_a, false};
2531 testStaticRuntime(index_put_str, args0);
2532
2533 const auto index_put_non_optional_str = R"JIT(
2534 def forward(self, a: Tensor, indices: List[Tensor], values: Tensor, accumulate: bool):
2535 return torch.index_put(a, indices, values, accumulate).clone()
2536 )JIT";
2537
2538 auto indices_b = c10::List<at::Tensor>{torch::tensor({0}, at::kLong)};
2539 std::vector<IValue> args1{a, indices_b, values_a, false};
2540 testStaticRuntime(index_put_non_optional_str, args1);
2541
2542 const auto index_put_list_construct = R"JIT(
2543 def forward(self, a: Tensor, indices: Tensor, values: Tensor, accumulate: bool):
2544 indices: List[Optional[Tensor]] = [indices]
2545 return torch.index_put(a, indices, values, accumulate).clone()
2546 )JIT";
2547
2548 std::vector<IValue> args2{a, torch::tensor({0}, at::kLong), values_a, false};
2549 testStaticRuntime(index_put_list_construct, args2);
2550 }
2551
TEST(StaticRuntime,Item)2552 TEST(StaticRuntime, Item) {
2553 const auto item_str = R"JIT(
2554 def forward(self, a: Tensor):
2555 return torch.item(a)
2556 )JIT";
2557
2558 auto a = at::randn({1});
2559
2560 std::vector<IValue> args0{a};
2561 testStaticRuntime(item_str, args0);
2562 }
2563
TEST(StaticRuntime,Tensor_Split)2564 TEST(StaticRuntime, Tensor_Split) {
2565 const auto tensor_split_str1 = R"JIT(
2566 def forward(self, a: Tensor, sections: int, dim: int):
2567 return torch.tensor_split(a, sections, dim)
2568 )JIT";
2569 std::vector<IValue> args1{at::randn({8}), 3, 0};
2570
2571 const auto tensor_split_str2 = R"JIT(
2572 def forward(self, a: Tensor, sections: Tensor, dim: int):
2573 return torch.tensor_split(a, sections, dim)
2574 )JIT";
2575 std::vector<IValue> args2{at::randn({8}), torch::tensor(3), 0};
2576
2577 const auto tensor_split_str3 = R"JIT(
2578 def forward(self, a: Tensor, indices: List[int], dim: int):
2579 return torch.tensor_split(a, indices, dim)
2580 )JIT";
2581 std::vector<IValue> args3{at::randn({8}), c10::List<int64_t>({1, 6}), 0};
2582
2583 testStaticRuntime(tensor_split_str1, args1);
2584 testStaticRuntime(tensor_split_str2, args2);
2585 testStaticRuntime(tensor_split_str3, args3);
2586 }
2587
TEST(StaticRuntime,JIT_Aten_Cpu)2588 TEST(StaticRuntime, JIT_Aten_Cpu) {
2589 const std::string script = R"IR(
2590 graph(%a: Tensor):
2591 %1 : int = prim::Constant[value=0]()
2592 %aa: Tensor = aten::add(%a, %a, %1)
2593 %ret: Tensor = aten::cpu(%aa)
2594 return (%ret)
2595 )IR";
2596
2597 auto graph = std::make_shared<Graph>();
2598 std::unordered_map<std::string, Value*> vmap;
2599 vmap.reserve(0);
2600 parseIR(script, graph.get(), vmap);
2601 torch::jit::StaticModule smodule(graph);
2602
2603 auto a = at::randn({2, 4});
2604 std::vector<IValue> args0{a};
2605
2606 testStaticRuntime(script, args0);
2607 }
2608
TEST(StaticRuntime,JIT_Aten_Numel)2609 TEST(StaticRuntime, JIT_Aten_Numel) {
2610 const std::string script = R"IR(
2611 graph(%a: Tensor):
2612 %1 : int = prim::Constant[value=0]()
2613 %aa: Tensor = aten::add(%a, %a, %1)
2614 %ret: int = aten::numel(%aa)
2615 return (%ret)
2616 )IR";
2617
2618 auto graph = std::make_shared<Graph>();
2619 std::unordered_map<std::string, Value*> vmap;
2620 vmap.reserve(0);
2621 parseIR(script, graph.get(), vmap);
2622 torch::jit::StaticModule smodule(graph);
2623
2624 auto a = at::randn({2, 4});
2625 std::vector<IValue> args0{a};
2626
2627 testStaticRuntime(script, args0);
2628 }
2629
TEST(StaticRuntime,JIT_Aten_List)2630 TEST(StaticRuntime, JIT_Aten_List) {
2631 const auto script_str = R"IR(
2632 graph(%a: str):
2633 %ret: str[] = aten::list(%a)
2634 return (%ret)
2635 )IR";
2636 std::string a = "abcd";
2637 std::vector<IValue> args0{a};
2638 testStaticRuntime(script_str, args0);
2639
2640 // Update the result of aten::list to ensure that a deep copy
2641 // took place
2642 const auto script_list = R"IR(
2643 graph(%a : int[]):
2644 %idx : int = prim::Constant[value=0]()
2645 %value : int = prim::Constant[value=42]()
2646 %res : int[] = aten::list(%a)
2647 %updated : int[] = aten::_set_item(%res, %idx, %value)
2648 return (%res, %a)
2649 )IR";
2650
2651 std::vector<IValue> args1{c10::List<int64_t>{1, 2, 3}};
2652 testStaticRuntime(script_list, args1);
2653 }
2654
TEST(StaticRuntime,JIT_Aten_Range_Length)2655 TEST(StaticRuntime, JIT_Aten_Range_Length) {
2656 const std::string script = R"IR(
2657 graph(%lo: int, %hi: int, %step: int):
2658 %1 : int = prim::Constant[value=0]()
2659 %ret: int = aten::__range_length(%lo, %hi, %step)
2660 return (%ret)
2661 )IR";
2662
2663 auto graph = std::make_shared<Graph>();
2664 std::unordered_map<std::string, Value*> vmap;
2665 vmap.reserve(0);
2666 parseIR(script, graph.get(), vmap);
2667 torch::jit::StaticModule smodule(graph);
2668
2669 std::vector<IValue> args0{0, 10, 2};
2670
2671 testStaticRuntime(script, args0);
2672 }
2673
TEST(StaticRuntime,Cat)2674 TEST(StaticRuntime, Cat) {
2675 const std::string cat_script = R"IR(
2676 graph(%a: Tensor, %b: Tensor, %dim: int):
2677 %ten_list: Tensor[] = prim::ListConstruct(%a, %b)
2678 %1 : int = prim::Constant[value=0]()
2679 %2 : int = prim::Constant[value=1]()
2680 %3 : int = prim::Constant[value=1]()
2681 %ten_list2 : Tensor[] = aten::slice(%ten_list, %1, %2, %3)
2682 %ret: Tensor = aten::cat(%ten_list2, %dim)
2683 return (%ret)
2684 )IR";
2685
2686 auto graph = std::make_shared<Graph>();
2687 std::unordered_map<std::string, Value*> vmap;
2688 parseIR(cat_script, graph.get(), vmap);
2689 torch::jit::StaticModule smodule(graph);
2690 ASSERT_TRUE(getNodeWithKind(smodule, "aten::cat"));
2691
2692 auto a = at::randn({2, 4});
2693 auto b = at::randn({3, 4});
2694 std::vector<IValue> args0{a, b, 0};
2695
2696 testStaticRuntime(cat_script, args0);
2697
2698 auto c = at::randn({3, 4});
2699 auto d = at::randn({3, 5});
2700 std::vector<IValue> args1{c, d, 1};
2701 testStaticRuntime(cat_script, args0, args1);
2702
2703 std::vector<IValue> args_dim_negative{c, d, -1};
2704 testStaticRuntime(cat_script, args_dim_negative);
2705 }
2706
TEST(StaticRuntime,Cumsum)2707 TEST(StaticRuntime, Cumsum) {
2708 const auto cumsum_script = R"JIT(
2709 def forward(self, a: Tensor, dim: int):
2710 return torch.cumsum(a, dim).clone()
2711 )JIT";
2712
2713 auto a = at::randn({2, 3});
2714 std::vector<IValue> args0{a, 0};
2715 testStaticRuntime(cumsum_script, args0);
2716
2717 auto b = at::randn({3, 6});
2718 std::vector<IValue> args1{b, 1};
2719 testStaticRuntime(cumsum_script, args0, args1);
2720 }
2721
TEST(StaticRuntime,CumsumDtype)2722 TEST(StaticRuntime, CumsumDtype) {
2723 const auto cumsum_script_dtype = R"JIT(
2724 def forward(self, a: Tensor, dim: int, dtype: int):
2725 return torch.cumsum(a, dim, dtype=dtype).clone()
2726 )JIT";
2727
2728 auto a = at::randn({1, 2});
2729 auto dtype = at::ScalarType::Float;
2730 std::vector<IValue> args0{a, 0, dtype};
2731 testStaticRuntime(cumsum_script_dtype, args0);
2732
2733 auto b = at::randn({3, 6});
2734 std::vector<IValue> args1{b, 1, dtype};
2735 testStaticRuntime(cumsum_script_dtype, args0, args1);
2736 }
2737
TEST(StaticRuntime,Nonzero)2738 TEST(StaticRuntime, Nonzero) {
2739 const auto nonzero_tensor = R"JIT(
2740 def forward(self, input: Tensor):
2741 a = torch.nonzero(input).clone()
2742 return (a)
2743 )JIT";
2744
2745 auto a = at::randint(0, 2, {2, 3});
2746 testStaticRuntime(nonzero_tensor, {a});
2747
2748 auto b = at::randint(0, 2, {4, 3, 2});
2749 testStaticRuntime(nonzero_tensor, {a}, {b});
2750 }
2751
TEST(StaticRuntime,SignedLog1p)2752 TEST(StaticRuntime, SignedLog1p) {
2753 const std::string signed_log1p_script = R"IR(
2754 graph(%input):
2755 %0 : Tensor = aten::sign(%input)
2756 %1 : Tensor = aten::abs(%input)
2757 %2 : Tensor = aten::log1p(%1)
2758 %3 : Tensor = aten::mul(%0, %2)
2759 %none : NoneType = prim::Constant()
2760 %res : Tensor = aten::clone(%3, %none)
2761 return (%res)
2762 )IR";
2763
2764 std::vector<IValue> args1 = {at::randn({2, 2})};
2765 testStaticRuntime(signed_log1p_script, args1, {}, true);
2766
2767 std::vector<IValue> args2 = {at::randn({3, 3, 3})};
2768 testStaticRuntime(signed_log1p_script, args1, args2, true);
2769 }
2770
TEST(StaticRuntime,RemoveImmutableInputDictLookupsWithImmutableInputDict)2771 TEST(StaticRuntime, RemoveImmutableInputDictLookupsWithImmutableInputDict) {
2772 const auto getitem_immutable_input_dict_script = R"JIT(
2773 def forward(self, input: Dict[int, Tensor]):
2774 a = input[0]
2775 b = input[1]
2776 c = a + b
2777 return c.clone()
2778 )JIT";
2779
2780 script::Module module("module");
2781 module.define(getitem_immutable_input_dict_script);
2782 torch::jit::StaticModule smodule(module);
2783 EXPECT_FALSE(hasNodeWithKind(smodule, "aten::__getitem__"));
2784 EXPECT_TRUE(hasNodeWithKind(smodule, "static_runtime::dict_unpack"));
2785
2786 auto a = at::randn({2, 4});
2787 auto b = at::randn({2, 4});
2788 c10::Dict<c10::IValue, c10::IValue> dict(
2789 c10::IntType::get(), c10::TensorType::get());
2790 dict.insert(0, a);
2791 dict.insert(1, b);
2792 testStaticRuntime(getitem_immutable_input_dict_script, {dict});
2793
2794 c10::Dict<c10::IValue, c10::IValue> dict0(
2795 c10::IntType::get(), c10::TensorType::get());
2796 auto a0 = at::randn({3, 4});
2797 auto b0 = at::randn({3, 4});
2798 dict0.insert(0, a0);
2799 dict0.insert(1, b0);
2800 testStaticRuntime(getitem_immutable_input_dict_script, {dict0});
2801 }
2802
TEST(StaticRuntime,RemoveImmutableInputDictLookupsWithMutableInputDict)2803 TEST(StaticRuntime, RemoveImmutableInputDictLookupsWithMutableInputDict) {
2804 const auto getitem_mutable_input_dict_script = R"JIT(
2805 def forward(self, input: Dict[int, Tensor]):
2806 a = input[0]
2807 input[1] = a
2808 b = input[1]
2809 c = a + b
2810 return c.clone()
2811 )JIT";
2812
2813 script::Module module("module");
2814 module.define(getitem_mutable_input_dict_script);
2815 torch::jit::StaticModule smodule(module);
2816 EXPECT_TRUE(hasNodeWithKind(smodule, "aten::__getitem__"));
2817 EXPECT_FALSE(hasNodeWithKind(smodule, "static_runtime::dict_unpack"));
2818 }
2819
TEST(StaticRuntime,VarTupleUnpack)2820 TEST(StaticRuntime, VarTupleUnpack) {
2821 const auto var_tuple_unpack_script = R"JIT(
2822 def forward(self, input_0: Tuple[Tensor, Tensor], input_1: Tuple[int, int]):
2823 a, b = input_0
2824 c, d = input_1
2825 res = a * c + b * d
2826 return res.clone()
2827 )JIT";
2828
2829 script::Module module("module");
2830 module.define(var_tuple_unpack_script);
2831 torch::jit::StaticModule smodule(module);
2832 EXPECT_FALSE(hasNodeWithKind(smodule, "prim::TupleUnpack"));
2833 EXPECT_TRUE(hasNodeWithKind(smodule, "static_runtime::VarTupleUnpack"));
2834
2835 auto a = at::randn({2, 2});
2836 auto b = at::randn({3, 3, 3});
2837 std::vector<IValue> args1{
2838 c10::ivalue::Tuple::create(a, a), c10::ivalue::Tuple::create(1, 2)};
2839 std::vector<IValue> args2{
2840 c10::ivalue::Tuple::create(b, b), c10::ivalue::Tuple::create(1, 2)};
2841
2842 testStaticRuntime(var_tuple_unpack_script, args1);
2843 testStaticRuntime(var_tuple_unpack_script, args1, args2);
2844 }
2845
TEST(StaticRuntime,VarTupleUnpack_NotApplied)2846 TEST(StaticRuntime, VarTupleUnpack_NotApplied) {
2847 const auto var_tuple_unpack_not_applied_script = R"JIT(
2848 def forward(self, input_0: Tuple[Tensor, Tensor], input_1: Tuple[int, int]):
2849 a, b = input_0
2850 x = a + b
2851 c, d = input_1
2852 res = a * c + b * d + x
2853 return res.clone()
2854 )JIT";
2855
2856 script::Module module("module");
2857 // In this script, the optimization is not applied since there is a
2858 // computation between the TupleUnpack nodes.
2859 module.define(var_tuple_unpack_not_applied_script);
2860 torch::jit::StaticModule smodule(module);
2861 EXPECT_FALSE(hasNodeWithKind(smodule, "static_runtime::VarTupleUnpack"));
2862 EXPECT_TRUE(hasNodeWithKind(smodule, "prim::TupleUnpack"));
2863 }
2864
TEST(StaticRuntime,RemainderTensor)2865 TEST(StaticRuntime, RemainderTensor) {
2866 const auto remainder_tensor = R"JIT(
2867 def forward(self, x, y):
2868 return torch.remainder(x, y).clone()
2869 )JIT";
2870
2871 std::vector<IValue> args1 = {
2872 at::randint(0, 10, {2, 2}), at::randint(1, 10, {2, 2})};
2873 std::vector<IValue> args2 = {
2874 at::randint(0, 10, {3, 6}), at::randint(1, 10, {3, 6})};
2875
2876 // Use allclose and equalnan since outputs may be NaN.
2877 testStaticRuntime(
2878 remainder_tensor,
2879 args1,
2880 /*args2*/ {},
2881 /*use_alloclose*/ true,
2882 /*use_equalnan*/ true);
2883 testStaticRuntime(
2884 remainder_tensor,
2885 args1,
2886 args2,
2887 /*use_allclose*/ true,
2888 /*use_equalnan*/ true);
2889 }
2890
TEST(StaticRuntime,RemainderScalar)2891 TEST(StaticRuntime, RemainderScalar) {
2892 const auto remainder_scalar = R"JIT(
2893 def forward(self, x, y: int):
2894 return torch.remainder(x, y).clone()
2895 )JIT";
2896
2897 std::vector<IValue> args1 = {at::randint(0, 10, {2, 2}), 4};
2898 std::vector<IValue> args2 = {at::randint(0, 10, {3, 6}), 4};
2899
2900 // Use allclose and equalnan since outputs may be NaN.
2901 testStaticRuntime(
2902 remainder_scalar,
2903 args1,
2904 /*args2*/ {},
2905 /*use_alloclose*/ true,
2906 /*use_equalnan*/ true);
2907 testStaticRuntime(
2908 remainder_scalar,
2909 args1,
2910 args2,
2911 /*use_allclose*/ true,
2912 /*use_equalnan*/ true);
2913 }
2914
TEST(StaticRuntime,Where)2915 TEST(StaticRuntime, Where) {
2916 const auto where_script = R"JIT(
2917 def forward(self, x, y):
2918 return torch.where(x > 0, x, y).clone()
2919 )JIT";
2920
2921 std::vector<IValue> args1 = {at::randn({2, 2}), at::randn({2, 2})};
2922 std::vector<IValue> args2 = {at::randn({8, 10}), at::randn({8, 10})};
2923
2924 testStaticRuntime(where_script, args1);
2925 testStaticRuntime(where_script, args1, args2);
2926 }
2927
TEST(StaticRuntime,WhereBroadcast)2928 TEST(StaticRuntime, WhereBroadcast) {
2929 const auto where_script = R"JIT(
2930 def forward(self, cond_1d, x, y):
2931 shape = [-1] + [1] * (x.dim() - 1)
2932 cond = cond_1d.view(shape)
2933 return torch.where(cond, x, y).clone()
2934 )JIT";
2935
2936 std::vector<IValue> args1 = {
2937 at::tensor({0, 1}).to(at::kBool), at::randn({2, 2}), at::randn({2, 2})};
2938 std::vector<IValue> args2 = {
2939 at::tensor({1, 0, 0}).to(at::kBool),
2940 at::randn({3, 6}),
2941 at::randn({3, 6})};
2942
2943 testStaticRuntime(where_script, args1);
2944 testStaticRuntime(where_script, args1, args2);
2945 }
2946
TEST(StaticRuntime,View)2947 TEST(StaticRuntime, View) {
2948 // Note that clone is not technically necessary here since this is not
2949 // an out variant, but it suppresses warnings about only have one op
2950 // in testStaticRuntime
2951 const auto src = R"IR(
2952 graph(%input : Tensor, %shape : int[]):
2953 %none : NoneType = prim::Constant()
2954 %view : Tensor = aten::view(%input, %shape)
2955 %res : Tensor = aten::clone(%view, %none)
2956 return (%res)
2957 )IR";
2958
2959 std::vector<IValue> args1{at::randn({2, 2}), c10::List<int64_t>(4)};
2960 std::vector<IValue> args2{at::randn({2, 2, 2}), c10::List<int64_t>({4, 2})};
2961
2962 testStaticRuntime(src, args1);
2963 testStaticRuntime(src, args1, args2);
2964 }
2965
TEST(StaticRuntime,Size)2966 TEST(StaticRuntime, Size) {
2967 const auto src_with_dim = R"JIT(
2968 def forward(self, x, dim: int):
2969 return x.size(dim)
2970 )JIT";
2971
2972 const auto src_no_dim = R"JIT(
2973 def forward(self, x):
2974 return x.size()
2975 )JIT";
2976
2977 std::vector<IValue> args1{at::randn({1}), 0};
2978 std::vector<IValue> args2{at::randn({1}), -1};
2979 std::vector<IValue> args3{at::randn({2, 4}), 1};
2980 std::vector<IValue> args_no_dim{at::randn({2, 4})};
2981
2982 testStaticRuntime(src_with_dim, args1);
2983 testStaticRuntime(src_with_dim, args2);
2984 testStaticRuntime(src_with_dim, args1, args3);
2985 testStaticRuntime(src_no_dim, args_no_dim);
2986 }
2987
TEST(StaticRuntime,Squeeze)2988 TEST(StaticRuntime, Squeeze) {
2989 // Note: this is a native op, not an out variant, but clone anyways
2990 // to silence warnings in testStaticRuntime
2991 const auto src = R"JIT(
2992 def forward(self, inp, dim: int):
2993 return inp.squeeze(dim).clone()
2994 )JIT";
2995
2996 const auto a = at::randn({2, 2});
2997 const auto b = at::randn({3, 2, 3});
2998
2999 testStaticRuntime(src, {a, 0});
3000 testStaticRuntime(src, {a, 1});
3001 testStaticRuntime(src, {a, -1}, {b, 2});
3002 }
3003
TEST(StaticRuntime,NumToTensorScalar)3004 TEST(StaticRuntime, NumToTensorScalar) {
3005 const auto num_to_tensor_ir = R"IR(
3006 graph(%1 : int):
3007 %2 : NoneType = prim::Constant()
3008 %3 : Tensor = prim::NumToTensor(%1)
3009 %4 : Tensor = aten::clone(%3, %2)
3010 return (%4)
3011 )IR";
3012
3013 IValue arg{5};
3014 std::vector<IValue> args = {arg};
3015 testStaticRuntime(num_to_tensor_ir, args);
3016 }
3017
TEST(StaticRuntime,NumToTensorFalse)3018 TEST(StaticRuntime, NumToTensorFalse) {
3019 const auto num_to_tensor_ir = R"IR(
3020 graph(%1 : bool):
3021 %2 : NoneType = prim::Constant()
3022 %3 : Tensor = prim::NumToTensor(%1)
3023 %4 : Tensor = aten::clone(%3, %2)
3024 return (%4)
3025 )IR";
3026
3027 IValue arg{false};
3028 std::vector<IValue> args = {arg};
3029 testStaticRuntime(num_to_tensor_ir, args);
3030 }
3031
TEST(StaticRuntime,NumToTensorTrue)3032 TEST(StaticRuntime, NumToTensorTrue) {
3033 const auto num_to_tensor_ir = R"IR(
3034 graph(%1 : bool):
3035 %2 : NoneType = prim::Constant()
3036 %3 : Tensor = prim::NumToTensor(%1)
3037 %4 : Tensor = aten::clone(%3, %2)
3038 return (%4)
3039 )IR";
3040
3041 IValue arg{true};
3042 std::vector<IValue> args = {arg};
3043 testStaticRuntime(num_to_tensor_ir, args);
3044 }
3045
TEST(StaticRuntime,Split)3046 TEST(StaticRuntime, Split) {
3047 const auto src = R"JIT(
3048 def forward(self, inp, split_size: int, dim: int):
3049 return inp.split(split_size, dim)
3050 )JIT";
3051
3052 const auto a = at::randn({2, 2});
3053 const auto b = at::randn({2, 2, 2});
3054
3055 testStaticRuntime(src, {a, 1, 0});
3056 testStaticRuntime(src, {a, 1, 1});
3057 testStaticRuntime(src, {a, 2, -1}, {b, 2, 2});
3058 }
3059
TEST(StaticRuntime,SplitWithSizes)3060 TEST(StaticRuntime, SplitWithSizes) {
3061 const auto src = R"JIT(
3062 def forward(self, inp, split_sizes: List[int], dim: int):
3063 return inp.split(split_sizes, dim)
3064 )JIT";
3065
3066 const auto a = at::randn({2, 2});
3067 const auto b = at::randn({2, 2, 2});
3068 const auto split_sizes = c10::List<int64_t>{1, 1};
3069
3070 testStaticRuntime(src, {a, split_sizes, 0});
3071 testStaticRuntime(src, {a, split_sizes, 1});
3072 testStaticRuntime(src, {a, split_sizes, -1}, {b, split_sizes, 2});
3073 }
3074
3075 namespace {
3076
maybe_throw(bool should_throw)3077 void maybe_throw(bool should_throw) {
3078 if (should_throw) {
3079 throw std::runtime_error("test exception");
3080 }
3081 }
3082
TORCH_LIBRARY(static_runtime_tests,m)3083 TORCH_LIBRARY(static_runtime_tests, m) {
3084 // Conservative so this op doesn't get deleted by dead
3085 // code elimination
3086 m.def(torch::schema(
3087 "static_runtime_tests::maybe_throw(bool throw) -> ()",
3088 at::AliasAnalysisKind::CONSERVATIVE));
3089 m.impl("maybe_throw", maybe_throw);
3090 }
3091
3092 } // namespace
3093
TEST(StaticRuntime,ModelCrashOnFirstRun)3094 TEST(StaticRuntime, ModelCrashOnFirstRun) {
3095 const auto src = R"JIT(
3096 graph(%0: Tensor, %throw: bool):
3097 %1: Tensor = aten::mul(%0, %0)
3098 static_runtime_tests::maybe_throw(%throw)
3099 %2: Tensor = aten::mul(%1, %1)
3100 %3: Tensor = aten::mul(%2, %2)
3101 return (%3)
3102 )JIT";
3103
3104 auto graph = getGraphFromIR(src);
3105 auto static_module = StaticModule(graph);
3106 auto& runtime = static_module.runtime();
3107
3108 std::vector<IValue> args_crash{at::randn({1}), true};
3109 std::vector<IValue> args_no_crash{at::randn({1}), false};
3110 EXPECT_THROW(runtime(args_crash, {}), std::runtime_error);
3111
3112 // The run didn't finish, we didn't allocate the memory planner
3113 EXPECT_EQ(runtime.get_memory_planner(), nullptr);
3114 runtime.check_for_memory_leak();
3115
3116 // We guarantee that the runtime is still usable after the crash.
3117 // Run again to verify this.
3118 compareResultsWithJIT(runtime, graph, args_no_crash);
3119 EXPECT_NE(runtime.get_memory_planner(), nullptr);
3120 }
3121
TEST(StaticRuntime,ModelCrashOnSecondRun)3122 TEST(StaticRuntime, ModelCrashOnSecondRun) {
3123 const auto src = R"JIT(
3124 graph(%0: Tensor, %throw: bool):
3125 %1: Tensor = aten::mul(%0, %0)
3126 static_runtime_tests::maybe_throw(%throw)
3127 %2: Tensor = aten::mul(%1, %1)
3128 %3: Tensor = aten::mul(%2, %2)
3129 return (%3)
3130 )JIT";
3131
3132 auto graph = getGraphFromIR(src);
3133 auto static_module = StaticModule(graph);
3134 auto& runtime = static_module.runtime();
3135
3136 std::vector<IValue> args_crash{at::randn({1}), true};
3137 std::vector<IValue> args_no_crash{at::randn({1}), false};
3138 runtime(args_no_crash, {});
3139 EXPECT_NE(runtime.get_memory_planner(), nullptr);
3140 runtime.check_for_memory_leak();
3141
3142 EXPECT_THROW(runtime(args_crash, {}), std::runtime_error);
3143 runtime.check_for_memory_leak();
3144
3145 // We guarantee that the runtime is still usable after the crash.
3146 // Run again to verify this.
3147 compareResultsWithJIT(runtime, graph, args_no_crash);
3148 }
3149
TEST(StaticRuntime,ModelCrashOnFirstRunWithBorrows)3150 TEST(StaticRuntime, ModelCrashOnFirstRunWithBorrows) {
3151 const auto src = R"JIT(
3152 graph(%0: Tensor):
3153 %1: Tensor = aten::mul(%0, %0)
3154 %2: Tensor = aten::mul(%1, %1)
3155 %3: bool = prim::Constant[value=1]()
3156 %4: Tensor = static_runtime::select_tensor(%1, %2, %3)
3157 static_runtime_tests::maybe_throw(%3)
3158 return (%4)
3159 )JIT";
3160 auto graph = getGraphFromIR(src);
3161 auto static_module = StaticModule(graph);
3162 auto& runtime = static_module.runtime();
3163
3164 std::vector<IValue> args{at::randn({1})};
3165 EXPECT_THROW(runtime(args), std::runtime_error);
3166 }
3167
TEST(StaticRuntime,ModelCrashOnFirstRunWithBorrowedInputs)3168 TEST(StaticRuntime, ModelCrashOnFirstRunWithBorrowedInputs) {
3169 const auto src = R"JIT(
3170 graph(%0: Tensor, %1: Tensor):
3171 %2: bool = prim::Constant[value=1]()
3172 %3: Tensor = static_runtime::select_tensor(%0, %1, %2)
3173 static_runtime_tests::maybe_throw(%2)
3174 return (%3)
3175 )JIT";
3176 auto graph = getGraphFromIR(src);
3177 auto static_module = StaticModule(graph);
3178 auto& runtime = static_module.runtime();
3179
3180 std::vector<IValue> args{at::randn({1}), at::randn({1})};
3181 EXPECT_THROW(runtime(std::move(args)), std::runtime_error);
3182 }
3183
TEST(StaticRuntime,ReplaceWithMaybeCopy)3184 TEST(StaticRuntime, ReplaceWithMaybeCopy) {
3185 const std::string to = R"IR(
3186 graph(%0 : Tensor):
3187 %1: int = prim::Constant[value=4]()
3188 %2: bool = prim::Constant[value=0]()
3189 %3: None = prim::Constant()
3190 %res : Tensor = aten::to(%0, %1, %2, %2, %3)
3191 return (%res)
3192 )IR";
3193
3194 at::Tensor a = at::tensor({1.1, 2.2, 3.3, 4.0}, at::ScalarType::Float);
3195 std::vector<IValue> args{a};
3196 auto g = std::make_shared<torch::jit::Graph>();
3197 torch::jit::parseIR(to, g.get());
3198
3199 // Jit Interpreter.
3200 Stack stack(args);
3201 torch::jit::GraphExecutor graph_exec(g, "");
3202 graph_exec.run(stack);
3203 ASSERT_EQ(stack.size(), 1);
3204 auto expected = stack[0].toTensor();
3205
3206 // Static Runtime.
3207 torch::jit::StaticModule smodule(g);
3208 auto actual = smodule(args, {}).toTensor();
3209 smodule.runtime().check_for_memory_leak();
3210
3211 EXPECT_TRUE(expected.equal(actual));
3212
3213 // Make a fresh graph to ensure the pass works in isolation
3214 auto new_graph = std::make_shared<torch::jit::Graph>();
3215 torch::jit::parseIR(to, new_graph.get());
3216 ReplaceWithMaybeCopy(new_graph);
3217 EXPECT_FALSE(hasNodeWithKind(new_graph, "aten::to"));
3218 EXPECT_TRUE(
3219 hasNodeWithKind(new_graph, "static_runtime::to_maybe_copy_out"));
3220 }
3221
TEST(StaticRuntime,Int)3222 TEST(StaticRuntime, Int) {
3223 const auto src = R"JIT(
3224 def forward(self, x):
3225 return int(x) + int(x)
3226 )JIT";
3227 std::vector<IValue> args{at::tensor({3.14})};
3228 testStaticRuntime(src, args);
3229 }
3230
TEST(StaticRuntime,ReturnConstant)3231 TEST(StaticRuntime, ReturnConstant) {
3232 const auto src = R"JIT(
3233 def forward(self):
3234 return 1
3235 )JIT";
3236
3237 testStaticRuntime(src, {});
3238 }
3239
TEST(StaticRuntime,SimpleIf)3240 TEST(StaticRuntime, SimpleIf) {
3241 const auto src = R"JIT(
3242 def forward(self, cond: bool, x):
3243 if cond:
3244 return torch.mul(x, 42).clone()
3245 else:
3246 return x.clone()
3247 )JIT";
3248
3249 std::vector<IValue> args_false{false, at::randn({1})};
3250 std::vector<IValue> args_true{true, at::randn({1})};
3251 std::vector<IValue> args_big_tensor{true, at::randn({3, 3, 3})};
3252
3253 testStaticRuntime(src, args_false);
3254 testStaticRuntime(src, args_true);
3255 testStaticRuntime(src, args_true, args_big_tensor);
3256 }
3257
TEST(StaticRuntime,NestedIf)3258 TEST(StaticRuntime, NestedIf) {
3259 const auto src = R"JIT(
3260 def forward(self, cond1: bool, cond2: bool, x):
3261 y = x * 42
3262 if cond1:
3263 y = y * y
3264 if cond2:
3265 y += x
3266 else:
3267 if cond2:
3268 return x.clone()
3269
3270 return y.clone()
3271 )JIT";
3272
3273 for (auto cond1 : {true, false}) {
3274 for (auto cond2 : {true, false}) {
3275 std::vector<IValue> args1{cond1, cond2, at::randn({1})};
3276 std::vector<IValue> args2{cond1, cond2, at::randn({3, 3, 3})};
3277 testStaticRuntime(src, args1, args2);
3278 }
3279 }
3280 }
3281
TEST(StaticRuntime,DeeplyNestedIf)3282 TEST(StaticRuntime, DeeplyNestedIf) {
3283 const auto src = R"JIT(
3284 def forward(self, cond1: bool, cond2: bool, cond3: bool, x):
3285 y = x * 42
3286 if cond1:
3287 y = y * y
3288 if cond2:
3289 y += x
3290
3291 if cond2 and cond3:
3292 y += 1
3293
3294 if cond2:
3295 if cond3:
3296 y += 2
3297 else:
3298 y = y * y
3299 y += 4
3300 else:
3301 if cond2:
3302 return x.clone()
3303 if cond3 or cond2:
3304 y += 42
3305
3306 return y.clone()
3307 )JIT";
3308
3309 for (auto cond1 : {true, false}) {
3310 for (auto cond2 : {true, false}) {
3311 for (auto cond3 : {true, false}) {
3312 std::vector<IValue> args1{cond1, cond2, cond3, at::randn({1})};
3313 std::vector<IValue> args2{cond1, cond2, cond3, at::randn({3, 3, 3})};
3314 testStaticRuntime(src, args1, args2);
3315 }
3316 }
3317 }
3318 }
3319
TEST(StaticRuntime,BasicForLoop)3320 TEST(StaticRuntime, BasicForLoop) {
3321 const auto src = R"JIT(
3322 def forward(self, x, loop_max: int):
3323 y = x.clone()
3324 for i in range(loop_max):
3325 y += 1
3326 return y
3327 )JIT";
3328
3329 std::vector<IValue> args1{at::randn({1}), 10};
3330 std::vector<IValue> args2{at::randn({3, 3, 3}), 10};
3331
3332 testStaticRuntime(src, args1, args2);
3333 }
3334
TEST(StaticRuntime,BasicWhileLoop)3335 TEST(StaticRuntime, BasicWhileLoop) {
3336 const auto src = R"JIT(
3337 def forward(self, x, loop_max: int):
3338 y = x.clone()
3339 loop_count = 0
3340 while loop_count < loop_max:
3341 y += 1
3342 loop_count += 1
3343 return y
3344 )JIT";
3345
3346 std::vector<IValue> args1{at::randn({1}), 10};
3347 std::vector<IValue> args2{at::randn({3, 3, 3}), 10};
3348
3349 testStaticRuntime(src, args1, args2);
3350 }
3351
TEST(StaticRuntime,NestedLoops)3352 TEST(StaticRuntime, NestedLoops) {
3353 const auto src = R"JIT(
3354 def forward(self, x, loop_max: int):
3355 y = x.clone()
3356 even: List[int] = []
3357 odd: List[int] = []
3358
3359 for i in range(loop_max):
3360 if i % 2:
3361 odd.append(i)
3362 else:
3363 even.append(i)
3364
3365 for j in range(i):
3366 y += 1
3367
3368 return y, even, odd
3369 )JIT";
3370
3371 std::vector<IValue> args1{at::randn({1}), 10};
3372 std::vector<IValue> args2{at::randn({3, 3, 3}), 10};
3373
3374 testStaticRuntime(src, args1, args2);
3375 }
3376
TEST(StaticRuntime,TupleIndex)3377 TEST(StaticRuntime, TupleIndex) {
3378 const auto src = R"JIT(
3379 def forward(self, idx: int, tup: Tuple[int, int]):
3380 a = tup[idx]
3381 return a * a
3382 )JIT";
3383 const auto tuple = c10::ivalue::Tuple::create({1, 2});
3384 testStaticRuntime(src, {1, tuple}, {-1, tuple});
3385
3386 torch::jit::Module mod("module");
3387 mod.define(src);
3388 StaticModule smod(mod);
3389 EXPECT_THROW(smod({100, tuple}), std::out_of_range);
3390 }
3391
TEST(StaticRuntime,RaiseException)3392 TEST(StaticRuntime, RaiseException) {
3393 const auto src = R"IR(
3394 graph(%str: str):
3395 %none: NoneType = prim::Constant()
3396 prim::RaiseException(%str, %none)
3397 return (%none)
3398 )IR";
3399 auto graph = getGraphFromIR(src);
3400 StaticModule smod(graph);
3401 const auto msg = "exception message";
3402 EXPECT_THROW(
3403 {
3404 try {
3405 smod({msg});
3406 } catch (const std::runtime_error& e) {
3407 EXPECT_STREQ(msg, e.what());
3408 throw;
3409 }
3410 },
3411 std::runtime_error);
3412 }
3413
TEST(StaticRuntime,Uninitialized)3414 TEST(StaticRuntime, Uninitialized) {
3415 const auto src = R"IR(
3416 graph():
3417 %0: int = prim::Uninitialized()
3418 return (%0)
3419 )IR";
3420 auto graph = getGraphFromIR(src);
3421 StaticModule smod(graph);
3422 const auto ret = smod({});
3423 // If a and b are both uninitialized, then a != b. So just check that the type
3424 // is Any
3425 EXPECT_EQ(ret.type()->kind(), c10::TypeKind::AnyType);
3426 }
3427
TEST(StaticRuntime,Format)3428 TEST(StaticRuntime, Format) {
3429 const auto src = R"JIT(
3430 def forward(self, arg1: int, arg2: Tensor, arg3: str):
3431 a = "arg1: {}, arg2: {}, arg3: {}".format(arg1, arg2, arg3)
3432 return a[::]
3433 )JIT";
3434 testStaticRuntime(src, {1, at::randn({3}), "str"});
3435 }
3436
TEST(StaticRuntime,Device)3437 TEST(StaticRuntime, Device) {
3438 const auto src = R"JIT(
3439 def forward(self, x):
3440 return x.device, x.device
3441 )JIT";
3442 testStaticRuntime(src, {at::tensor({1})});
3443 }
3444
TEST(StaticRuntime,Dtype)3445 TEST(StaticRuntime, Dtype) {
3446 const auto src = R"JIT(
3447 def forward(self, x, y):
3448 return x.dtype, y.dtype
3449 )JIT";
3450 testStaticRuntime(
3451 src, {at::tensor({1}, at::kLong), at::tensor({1}, at::kFloat)});
3452 }
3453
TEST(StaticRuntime,Dim)3454 TEST(StaticRuntime, Dim) {
3455 const auto src = R"JIT(
3456 def forward(self, x, y):
3457 return x.dim(), y.dim()
3458 )JIT";
3459 testStaticRuntime(src, {at::randn({2, 2}), at::randn({1})});
3460 }
3461
TEST(StaticRuntime,Not)3462 TEST(StaticRuntime, Not) {
3463 const auto src = R"JIT(
3464 def forward(self, x: bool, y: bool):
3465 return not x, not y
3466 )JIT";
3467 testStaticRuntime(src, {true, false});
3468 }
3469
TEST(StaticRuntime,Bool)3470 TEST(StaticRuntime, Bool) {
3471 const auto src = R"JIT(
3472 def forward(self, x: Tensor, y: int, z: float):
3473 return bool(x), bool(y), bool(z)
3474 )JIT";
3475 testStaticRuntime(src, {at::randn({1}), 0, 1.151}, {at::zeros({1}), 1, 0.0});
3476 }
3477
TEST(StaticRuntime,IsCuda)3478 TEST(StaticRuntime, IsCuda) {
3479 const auto src = R"JIT(
3480 def forward(self, x: Tensor, y: Tensor):
3481 return x.is_cuda, y.is_cuda
3482 )JIT";
3483 testStaticRuntime(src, {at::randn({1}), at::randn({1})});
3484 }
3485
TEST(StaticRuntime,ToList)3486 TEST(StaticRuntime, ToList) {
3487 const auto src = R"JIT(
3488 graph(%x: Tensor):
3489 %type: int = prim::Constant[value=1]()
3490 %dim: int = aten::dim(%x)
3491 %ret: float[] = prim::tolist(%x, %dim, %type)
3492 return (%ret)
3493 )JIT";
3494 testStaticRuntime(src, {at::randn({2, 2})});
3495 }
3496
TEST(StaticRuntime,IfThenElse)3497 TEST(StaticRuntime, IfThenElse) {
3498 const auto src = R"IR(
3499 graph(%cond: bool, %a: Tensor, %b: Tensor):
3500 %none: NoneType = prim::Constant()
3501 %c: Tensor = prim::IfThenElse(%cond, %a, %b)
3502 %d: Tensor = aten::clone(%c, %none)
3503 return (%d)
3504 )IR";
3505
3506 std::vector<IValue> args1{true, at::randn({1}), at::randn({1})};
3507 std::vector<IValue> args2{false, at::randn({1}), at::randn({1})};
3508
3509 testStaticRuntime(src, args1);
3510 testStaticRuntime(src, args2);
3511 }
3512
TEST(StaticRuntime,EmptyIfBlock)3513 TEST(StaticRuntime, EmptyIfBlock) {
3514 const auto src =
3515 R"JIT(
3516 def forward(self, cond: bool, a: Tensor, b: Tensor):
3517 l = []
3518 if cond:
3519 l.append((a + b).clone())
3520 return l
3521 )JIT";
3522
3523 testStaticRuntime(src, {true, at::rand(1), at::rand({1, 2})});
3524 testStaticRuntime(src, {false, at::rand(1), at::rand({1, 2})});
3525 }
3526
TEST(StaticRuntime,EmptyNestedIfBlock)3527 TEST(StaticRuntime, EmptyNestedIfBlock) {
3528 const auto src =
3529 R"JIT(
3530 def forward(self, cond: bool, a: Tensor, b: Tensor):
3531 l = []
3532 if cond:
3533 if cond:
3534 l.append((a + b).clone())
3535 return l
3536 )JIT";
3537
3538 testStaticRuntime(src, {true, at::rand(1), at::rand({1, 2})});
3539 testStaticRuntime(src, {false, at::rand(1), at::rand({1, 2})});
3540 }
3541
TEST(StaticRuntime,StackEmpty)3542 TEST(StaticRuntime, StackEmpty) {
3543 const auto src = R"JIT(
3544 def forward(self):
3545 x = torch.stack([])
3546 return x
3547 )JIT";
3548
3549 torch::jit::Module mod("mod");
3550 mod.define(src);
3551
3552 torch::jit::StaticModule smod(mod);
3553 EXPECT_THROW(smod({}), c10::Error);
3554 }
3555
TEST(StaticRuntime,ConcatEmpty)3556 TEST(StaticRuntime, ConcatEmpty) {
3557 const auto src = R"JIT(
3558 def forward(self):
3559 x = torch.concat([])
3560 return x
3561 )JIT";
3562
3563 torch::jit::Module mod("mod");
3564 mod.define(src);
3565
3566 torch::jit::StaticModule smod(mod);
3567 EXPECT_THROW(smod({}), c10::Error);
3568 }
3569
TEST(StaticRuntime,IntImplicit)3570 TEST(StaticRuntime, IntImplicit) {
3571 const auto src = R"IR(
3572 graph(%a: Tensor):
3573 %y: int = aten::IntImplicit(%a)
3574 return (%y)
3575 )IR";
3576 testStaticRuntime(src, {at::tensor({1}, at::kInt).squeeze()});
3577 }
3578
TEST(StaticRuntime,IntImplicit_ThrowOnBadInputs)3579 TEST(StaticRuntime, IntImplicit_ThrowOnBadInputs) {
3580 const auto src = R"IR(
3581 graph(%a: Tensor):
3582 %y: int = aten::IntImplicit(%a)
3583 return (%y)
3584 )IR";
3585 auto graph = getGraphFromIR(src);
3586 torch::jit::StaticModule smod(graph);
3587 // Not 0D tensor
3588 EXPECT_THROW(smod({at::tensor({1, 2}, at::kInt)}), std::runtime_error);
3589 // Wrong dtype
3590 EXPECT_THROW(
3591 smod({at::tensor({1}, at::kFloat).squeeze()}), std::runtime_error);
3592 }
3593
TEST(StaticRuntime,Select)3594 TEST(StaticRuntime, Select) {
3595 const auto src = R"IR(
3596 graph(%a: Tensor, %dim: int, %index: int):
3597 %none: NoneType = prim::Constant()
3598 %b: Tensor = aten::select(%a, %dim, %index)
3599 %c: Tensor = aten::clone(%b, %none)
3600 return (%c)
3601 )IR";
3602 testStaticRuntime(src, {at::randn({2, 2}), 0, 1});
3603 }
3604
TEST(StaticRuntime,ReshapeAs)3605 TEST(StaticRuntime, ReshapeAs) {
3606 const auto src = R"JIT(
3607 def forward(self, a, b):
3608 return a.reshape_as(b).clone()
3609 )JIT";
3610 testStaticRuntime(src, {at::randn({2, 2}), at::randn({4})});
3611 }
3612
TEST(StaticRuntime,MoveCtor)3613 TEST(StaticRuntime, MoveCtor) {
3614 auto mod = getDeepAndWideSciptModel();
3615 std::vector<IValue> args{
3616 at::randn({1, 1, 32}), at::randn({1, 1, 32}), at::randn({1, 50})};
3617
3618 torch::jit::StaticModule smod(mod);
3619
3620 torch::jit::StaticRuntime runtime(smod);
3621 auto expected = runtime(args);
3622
3623 torch::jit::StaticRuntime new_runtime(std::move(runtime));
3624 auto actual = new_runtime(args);
3625 compareResults(expected, actual);
3626 }
3627
TEST(StaticRuntime,SingleBlockIfReturnList)3628 TEST(StaticRuntime, SingleBlockIfReturnList) {
3629 const auto src = R"JIT(
3630 def forward(self, a, b, cond: bool):
3631 lst = []
3632 if cond:
3633 lst.append(a + b)
3634 return lst
3635 )JIT";
3636 std::vector<IValue> args1{at::randn({1}), at::randn({1}), true};
3637 std::vector<IValue> args2{at::randn({42, 42}), at::randn({42, 42}), false};
3638 testStaticRuntime(src, args1, args2);
3639 }
3640
TEST(StaticRuntime,NestedBlockIfReturnList)3641 TEST(StaticRuntime, NestedBlockIfReturnList) {
3642 const auto src = R"JIT(
3643 def forward(self, a, b, cond1: bool, cond2: bool):
3644 if cond1:
3645 lst = []
3646 if cond2:
3647 lst.append(a + b)
3648 lst.append(a * b)
3649 return lst
3650 return []
3651 )JIT";
3652 std::vector<IValue> args1{at::randn({1}), at::randn({1}), true, true};
3653 std::vector<IValue> args2{
3654 at::randn({42, 42}), at::randn({42, 42}), true, false};
3655 testStaticRuntime(src, args1, args2);
3656 }
3657
TEST(StaticRuntime,ClampNaNToNum)3658 TEST(StaticRuntime, ClampNaNToNum) {
3659 const auto src1 = R"JIT(
3660 def forward(self, a):
3661 return torch.clamp(a, min=1.0, max=2.0).nan_to_num().clone()
3662 )JIT";
3663
3664 const auto src2 = R"JIT(
3665 def forward(self, a, nan: float):
3666 return torch.clamp(a, min=-1.0, max=2.0).nan_to_num(nan=nan).clone()
3667 )JIT";
3668
3669 const auto src3 = R"JIT(
3670 def forward(self, a):
3671 return torch.clamp(a, min=1.0, max=-1.0).nan_to_num().clone()
3672 )JIT";
3673
3674 auto a = at::tensor({
3675 std::numeric_limits<float>::quiet_NaN(),
3676 std::numeric_limits<float>::infinity(),
3677 -std::numeric_limits<float>::infinity(),
3678 0.0f,
3679 3.0f
3680 });
3681 auto b = a.repeat({10, 5});
3682
3683 // Have to use_allclose even though all NaNs will be replaced - testStaticRuntime
3684 // also checks inputs at the end to make sure they're not changed
3685 testStaticRuntime(src1, {a}, {}, /*use_allclose=*/true, /*use_equalnan=*/true);
3686 testStaticRuntime(src1, {a}, {b}, /*use_allclose=*/true, /*use_equalnan=*/true);
3687
3688 testStaticRuntime(src2, {a, 42.0}, {}, /*use_allclose=*/true, /*use_equalnan=*/true);
3689 testStaticRuntime(src2, {a, 2.0}, {b, 1.0}, /*use_allclose=*/true, /*use_equalnan=*/true);
3690
3691 testStaticRuntime(src3, {a}, {}, /*use_allclose=*/true, /*use_equalnan=*/true);
3692 testStaticRuntime(src3, {a}, {b}, /*use_allclose=*/true, /*use_equalnan=*/true);
3693
3694 // Non-NNC path
3695 testStaticRuntime(src1, {a.to(at::kDouble)}, {}, /*use_allclose=*/true, /*use_equalnan=*/true);
3696 testStaticRuntime(src1, {a.to(at::kDouble)}, {b.to(at::kDouble)}, /*use_allclose=*/true, /*use_equalnan=*/true);
3697 }
3698
TEST(StaticRuntime,IfReturningTuple)3699 TEST(StaticRuntime, IfReturningTuple) {
3700 const auto src = R"JIT(
3701 def forward(self, x, y, cond: bool, idx: int):
3702 if cond:
3703 tup = (x, y)
3704 else:
3705 tup = (x, x)
3706 return tup[idx]
3707 )JIT";
3708
3709 std::vector<IValue> args{at::randn({3}), at::randn({3}), true, 0};
3710 testStaticRuntime(src, args);
3711 }
3712