xref: /aosp_15_r20/external/pytorch/test/export/test_converter.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: export"]
2
3import unittest
4from collections import OrderedDict
5from typing import Any, Dict, List, Optional, Tuple
6
7import torch
8import torch.utils._pytree as pytree
9from torch._dynamo.test_case import TestCase
10from torch._export.converter import TS2EPConverter
11from torch.export import ExportedProgram
12from torch.testing._internal.common_quantized import override_quantized_engine
13from torch.testing._internal.common_utils import IS_WINDOWS, run_tests
14from torch.testing._internal.torchbind_impls import (
15    _empty_tensor_queue,
16    init_torchbind_implementations,
17)
18
19
20requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "requires cuda")
21
22
23class TestConverter(TestCase):
24    def setUp(self):
25        init_torchbind_implementations()
26
27        @torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
28        class FakeTensorQueue:
29            def __init__(self, queue):
30                self.queue = queue
31
32            @classmethod
33            def __obj_unflatten__(cls, flattened_ctx):
34                return cls(**dict(flattened_ctx))
35
36            def push(self, x):
37                self.queue.append(x)
38
39            def pop(self):
40                if self.is_empty():
41                    return torch.empty([])
42                return self.queue.pop(0)
43
44            def size(self):
45                return len(self.queue)
46
47            def is_empty(self):
48                return len(self.queue) == 0
49
50            def float_size(self):
51                return float(len(self.queue))
52
53        self.torch_bind_ops = [
54            torch.ops._TorchScriptTesting.queue_pop,
55            torch.ops._TorchScriptTesting.queue_push,
56            torch.ops._TorchScriptTesting.queue_size,
57        ]
58
59    def tearDown(self):
60        torch._library.fake_class_registry.deregister_fake_class(
61            "_TorchScriptTesting::_TensorQueue"
62        )
63
64    def _check_equal_ts_ep_converter(
65        self,
66        M,
67        inp,
68        option: Optional[List[str]] = None,
69        check_persistent=False,
70        lifted_tensor_constants=None,
71    ) -> List[ExportedProgram]:
72        # By default, it tests both jit.trace and jit.script.
73        if option is None:
74            option = ["trace", "script"]
75
76        if check_persistent:
77            num_iterations = 10
78        else:
79            num_iterations = 1
80
81        ep_list = []
82        for opt in option:
83            if opt == "script":
84                # Separate two models for testing non-functional effects
85                if check_persistent:
86                    original_ts_model = torch.jit.script(M())
87                    ts_model = torch.jit.script(M())
88                    eager_model = M()
89                else:
90                    original_ts_model = torch.jit.script(M)
91                    ts_model = torch.jit.script(M)
92                    eager_model = M
93            elif opt == "trace":
94                if check_persistent:
95                    original_ts_model = torch.jit.trace(M(), inp)
96                    ts_model = torch.jit.trace(M(), inp)
97                    eager_model = M()
98                else:
99                    original_ts_model = torch.jit.trace(M, inp)
100                    ts_model = torch.jit.trace(M, inp)
101                    eager_model = M
102            else:
103                raise RuntimeError(f"Unrecognized mode for torch.jit: {opt}")
104
105            converter = TS2EPConverter(ts_model, inp)
106            ep = converter.convert()
107            ep_list.append(ep)
108
109            for _ in range(num_iterations):
110                orig_out, _ = pytree.tree_flatten(original_ts_model(*inp))
111                ep_out, _ = pytree.tree_flatten(ep.module()(*inp))
112
113                # Check module.
114                if isinstance(eager_model, torch.nn.Module):
115                    expected_state_dict = OrderedDict()
116                    expected_state_dict.update(ts_model.state_dict())
117                    if lifted_tensor_constants:
118                        expected_state_dict.update(lifted_tensor_constants)
119                    self.assertEqual(
120                        ep.state_dict.keys(),
121                        expected_state_dict.keys(),
122                    )
123
124                # Check results
125                self._check_tensor_list_equal(ep_out, orig_out)
126        return ep_list
127
128    def _check_tensor_list_equal(self, xs: List[torch.Tensor], ys: List[torch.Tensor]):
129        self.assertEqual(len(xs), len(ys))
130        for x, y in zip(xs, ys):
131            if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
132                self.assertEqual(x.shape, y.shape)
133                self.assertTrue(torch.allclose(x, y))
134            else:
135                self.assertEqual(type(x), type(y))
136                self.assertEqual(x, y)
137
138    def test_ts2ep_converter_basic(self):
139        class MSingle(torch.nn.Module):
140            def forward(self, x, y):
141                return x + y
142
143        class MMulti(torch.nn.Module):
144            def forward(self, x, y):
145                x = x.cos() + 1
146                y = y.sin() - 1
147                return x, y
148
149        inp = (torch.ones(1, 3), torch.ones(1, 3))
150        self._check_equal_ts_ep_converter(MSingle(), inp)
151        self._check_equal_ts_ep_converter(MMulti(), inp)
152
153    def test_ts2ep_converter_container_output(self):
154        # Output is a List.
155        class MOutputList(torch.nn.Module):
156            def forward(self, x: torch.Tensor, y: torch.Tensor):
157                a = x * x
158                b = y + y
159                return [a, b]
160
161        # Output is a Tuple.
162        class MOutputTuple(torch.nn.Module):
163            def forward(self, x: torch.Tensor, y: torch.Tensor):
164                a = x * x
165                b = y + y
166                return (a, b)
167
168        # Output is a Dict.
169        class MOutputDict(torch.nn.Module):
170            def forward(self, x: torch.Tensor, y: torch.Tensor):
171                a = x * x
172                b = y + y
173                return {"data": {"mul": a, "add": b}}
174
175        inp = (torch.tensor(4), torch.tensor(4))
176
177        # Traced function must use immutable structure as output.
178        self._check_equal_ts_ep_converter(MOutputList(), inp, ["script"])
179        self._check_equal_ts_ep_converter(MOutputTuple(), inp)
180        self._check_equal_ts_ep_converter(MOutputDict(), inp, ["script"])
181
182    def test_aten_dim(self):
183        class Module(torch.nn.Module):
184            def forward(self, x):
185                num_dim = x.dim()
186                return torch.ones(num_dim)
187
188        inp = (torch.ones(1, 3),)
189        self._check_equal_ts_ep_converter(Module(), inp)
190
191    def test_aten_len(self):
192        class Module(torch.nn.Module):
193            def forward(self, x: torch.Tensor):
194                length = len(x)
195                return torch.ones(length)
196
197        # aten::len.Tensor
198        inp = (torch.ones(2, 3),)
199        self._check_equal_ts_ep_converter(Module(), inp)
200
201        class Module(torch.nn.Module):
202            def forward(self, x: List[int]):
203                length = len(x)
204                return torch.ones(length)
205
206        # aten::len.t
207        inp = ([1, 2, 3],)
208        self._check_equal_ts_ep_converter(Module(), inp, ["script"])
209
210        class Module(torch.nn.Module):
211            def forward(self, x: Dict[int, str]):
212                length = len(x)
213                return torch.ones(length)
214
215        # aten::len.Dict_int
216        inp = ({1: "a", 2: "b", 3: "c"},)
217        self._check_equal_ts_ep_converter(Module(), inp, ["script"])
218
219        class Module(torch.nn.Module):
220            def forward(self, x: Dict[bool, str]):
221                length = len(x)
222                return torch.ones(length)
223
224        # aten::len.Dict_bool
225        inp = ({True: "a", False: "b"},)
226        self._check_equal_ts_ep_converter(Module(), inp, ["script"])
227
228        class Module(torch.nn.Module):
229            def forward(self, x: Dict[float, str]):
230                length = len(x)
231                return torch.ones(length)
232
233        # aten::len.Dict_float
234        inp = ({1.2: "a", 3.4: "b"},)
235        self._check_equal_ts_ep_converter(Module(), inp, ["script"])
236
237        class Module(torch.nn.Module):
238            def forward(self, x: Dict[torch.Tensor, str]):
239                length = len(x)
240                return torch.ones(length)
241
242        # aten::len.Dict_Tensor
243        inp = ({torch.zeros(2, 3): "a", torch.ones(2, 3): "b"},)
244        self._check_equal_ts_ep_converter(Module(), inp, ["script"])
245
246        # aten::len.str and aten::len.Dict_str are not supported
247        # since torch._C._jit_flatten does not support str
248        # inp = ("abcdefg",)
249        # self._check_equal_ts_ep_converter(Module(), inp)
250        # inp = ({"a": 1, "b": 2},)
251        # self._check_equal_ts_ep_converter(Module(), inp)
252
253    def test_aten_add_t(self):
254        # python list append
255        class Module(torch.nn.Module):
256            def forward(self, x: List[torch.Tensor]):
257                out = []
258                out = out + x
259                a = torch.cat(out)
260                out = out + x
261                b = torch.cat(out)
262                return a, b
263
264        inp = ([torch.ones(2, 3), torch.ones(2, 3)],)
265        self._check_equal_ts_ep_converter(Module(), inp, ["script"])
266
267    def test_aten_to_dtype_with_mutating_storage(self):
268        class Module(torch.nn.Module):
269            def forward(self, x: torch.Tensor, y: torch.Tensor):
270                x = x.to(y.dtype)
271                torch.ops.aten.index_put_(x, [torch.tensor([0])], y)
272                return x
273
274        inp = (torch.ones(2, 3), torch.tensor([0, 0, 0]))
275        self._check_equal_ts_ep_converter(Module(), inp)
276
277    def test_prim_min(self):
278        class Module(torch.nn.Module):
279            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
280                x_len = len(x)
281                y_len = len(y)
282
283                # prim::min.int
284                len_int = min(x_len, y_len)
285
286                # prim::min.float
287                len_float = int(min(x_len * 2.0, y_len * 2.0))
288
289                # prim::min.self_int
290                len_self_int = min([x_len, y_len])
291
292                # prim::min.self_float
293                len_self_float = int(min([x_len * 2.0, y_len * 2.0]))
294
295                # prim::min.float_int
296                len_float_int = int(min(x_len * 2.0, y_len))
297
298                # prim::min.int_float
299                len_int_float = int(min(x_len, y_len * 2.0))
300
301                return torch.ones(
302                    len_int
303                    + len_float
304                    + len_self_int
305                    + len_self_float
306                    + len_float_int
307                    + len_int_float
308                )
309
310        inp = (torch.randn(10, 2), torch.randn(5))
311        self._check_equal_ts_ep_converter(Module(), inp)
312
313    def test_prim_max(self):
314        class Module(torch.nn.Module):
315            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
316                x_len = len(x)
317                y_len = len(y)
318
319                # prim::max.int
320                len_int = max(x_len, y_len)
321
322                # prim::max.float
323                len_float = int(max(x_len * 2.0, y_len * 2.0))
324
325                # prim::max.self_int
326                len_self_int = max([x_len, y_len])
327
328                # prim::max.self_float
329                len_self_float = int(max([x_len * 2.0, y_len * 2.0]))
330
331                # prim::max.float_int
332                len_float_int = int(max(x_len * 2.0, y_len))
333
334                # prim::max.int_float
335                len_int_float = int(max(x_len, y_len * 2.0))
336
337                return torch.ones(
338                    len_int
339                    + len_float
340                    + len_self_int
341                    + len_self_float
342                    + len_float_int
343                    + len_int_float
344                )
345
346        inp = (torch.randn(10, 2), torch.randn(5))
347        self._check_equal_ts_ep_converter(Module(), inp)
348
349    def test_aten___getitem___list(self):
350        class Module(torch.nn.Module):
351            def forward(self, x):
352                y = torch.split(x, 2)
353                return y[0]
354
355        inp = (torch.rand((3, 2)),)
356        self._check_equal_ts_ep_converter(Module(), inp)
357
358    def test_aten___getitem___dict(self):
359        class Module(torch.nn.Module):
360            def forward(self, x):
361                y = torch.split(x, 2)
362                d_int = {0: y[0], 1: y[1]}
363                d_str = {"0": y[0], "1": y[1]}
364                d_bool = {True: y[0], False: y[1]}
365                d_float = {0.1: y[0], 2.3: y[1]}
366                return d_int[0], d_str["0"], d_bool[True], d_float[0.1]
367
368        inp = (torch.rand((3, 2)),)
369        self._check_equal_ts_ep_converter(Module(), inp)
370
371    def test_prim_device(self):
372        class Module(torch.nn.Module):
373            def forward(self, x):
374                device = x.device
375                return torch.ones(2, 3, device=device)
376
377        inp = (torch.rand(3, 4),)
378        self._check_equal_ts_ep_converter(Module(), inp)
379
380    @requires_cuda
381    def test_prim_device_cuda(self):
382        class Module(torch.nn.Module):
383            def forward(self, x):
384                device = x.device
385                return torch.ones(2, 3, device=device)
386
387        inp = (torch.rand((3, 4), device="cuda:0"),)
388        self._check_equal_ts_ep_converter(Module(), inp)
389
390    def test_prim_dtype(self):
391        class Module(torch.nn.Module):
392            def forward(self, x):
393                dtype = x.dtype
394                return torch.ones(2, 3, dtype=dtype)
395
396        for dtype in [
397            torch.float32,
398            torch.double,
399        ]:
400            inp = (torch.rand((3, 4), dtype=dtype),)
401            self._check_equal_ts_ep_converter(Module(), inp)
402
403        for dtype in [
404            torch.uint8,
405            torch.int8,
406            torch.int32,
407        ]:
408            inp = (torch.randint(high=128, size=(3, 4), dtype=dtype),)
409            self._check_equal_ts_ep_converter(Module(), inp)
410
411    def test_convert_if_basic(self):
412        class M(torch.nn.Module):
413            def forward(self, x: torch.Tensor, y: torch.Tensor):
414                if x:
415                    return y * y
416                else:
417                    return y + y
418
419        inp = (torch.tensor(True), torch.tensor(4))
420        ep_list = self._check_equal_ts_ep_converter(M(), inp)
421
422        for ep in ep_list[1:]:
423            torch.testing.assert_close(
424                ep.module()(torch.tensor(False), torch.tensor(4)),
425                M()(torch.tensor(False), torch.tensor(4)),
426            )
427
428    def test_convert_if_tuple_out(self):
429        class M(torch.nn.Module):
430            def true_fn(self, y, z):
431                return (z * z, z + z)
432
433            def false_fn(self, y, z):
434                return (y * y * y, y + y)
435
436            def forward(self, x: torch.Tensor, y: torch.Tensor):
437                z = y * y
438
439                if x:
440                    res = self.true_fn(y, z)
441                else:
442                    res = self.false_fn(y, z)
443
444                return res[0] + res[1]
445
446        inp = (torch.tensor(True), torch.tensor(4))
447        ep_list = self._check_equal_ts_ep_converter(M(), inp)
448
449        for ep in ep_list[1:]:
450            torch.testing.assert_close(
451                ep.module()(torch.tensor(False), torch.tensor(4)),
452                M()(torch.tensor(False), torch.tensor(4)),
453            )
454
455    def test_convert_if_multiple_out(self):
456        class M(torch.nn.Module):
457            def true_fn(self, y, z):
458                return z * z
459
460            def false_fn(self, y, z):
461                return y * y * y
462
463            def forward(self, x: torch.Tensor, y: torch.Tensor):
464                z = y * y
465
466                if x:
467                    res1 = self.true_fn(y, z)
468                    res2 = y
469                else:
470                    res1 = z
471                    res2 = self.false_fn(y, z)
472
473                return res1 + res2
474
475        inp = (torch.tensor(True), torch.tensor(4))
476        ep_list = self._check_equal_ts_ep_converter(M(), inp)
477
478        for ep in ep_list[1:]:
479            torch.testing.assert_close(
480                ep.module()(torch.tensor(False), torch.tensor(4)),
481                M()(torch.tensor(False), torch.tensor(4)),
482            )
483
484    def test_profiler__record_function(self):
485        class Module(torch.nn.Module):
486            def forward(self, x: torch.Tensor) -> torch.Tensor:
487                handle = torch.ops.profiler._record_function_enter_new("foo", None)
488                y = x * 2 + 4
489                torch.ops.profiler._record_function_exit(handle)
490                return y
491
492        x = torch.randn(10, 10)
493        self._check_equal_ts_ep_converter(Module(), (x,))
494
495    def test_aten_floordiv(self):
496        class Module(torch.nn.Module):
497            def forward(self, x: torch.Tensor) -> torch.Tensor:
498                return x // 2
499
500        x = torch.randn(10, 10)
501        self._check_equal_ts_ep_converter(Module(), (x,))
502
503    def test_aten___is__(self):
504        class Module(torch.nn.Module):
505            def forward(
506                self, x: torch.Tensor, y: torch.Tensor
507            ) -> Tuple[bool, torch.Tensor]:
508                z = x + 1
509                return x is y, z
510
511        # Traced function must return output that has tensors.
512        inp = (torch.randn(10, 10), torch.rand(10, 10))
513        self._check_equal_ts_ep_converter(Module(), inp, ["script"])
514
515    def test_aten___isnot__(self):
516        class Module(torch.nn.Module):
517            def forward(
518                self, x: torch.Tensor, y: torch.Tensor
519            ) -> Tuple[bool, torch.Tensor]:
520                z = x + 1
521                return x is not y, z
522
523        # Traced function must return output that has tensors.
524        inp = (torch.randn(10, 10), torch.rand(10, 10))
525        self._check_equal_ts_ep_converter(Module(), inp, ["script"])
526
527    def test_aten___not__(self):
528        class Module(torch.nn.Module):
529            def forward(
530                self, x: torch.Tensor, y: torch.Tensor
531            ) -> Tuple[bool, torch.Tensor]:
532                z = x + 1
533                return not (x is not y), z
534
535        # Traced function must return output that has tensors.
536        inp = (torch.randn(10, 10), torch.rand(10, 10))
537        self._check_equal_ts_ep_converter(Module(), inp, ["script"])
538
539    def test_ts2ep_converter_unpack(self):
540        class MUnpackList(torch.nn.Module):
541            def forward(self, x):
542                x, y = torch.split(x, 2)
543                return x + y
544
545        class MUnpackTuple(torch.nn.Module):
546            def forward(self, x_tuple: Tuple[torch.Tensor, torch.Tensor]):
547                x, y = x_tuple
548                x = x.cos()
549                return x + y
550
551        inp = (torch.ones(4),)
552        self._check_equal_ts_ep_converter(MUnpackList(), inp)
553        inp = ((torch.zeros(1, 4), torch.ones(1, 4)),)
554        self._check_equal_ts_ep_converter(MUnpackTuple(), inp)
555
556    @unittest.skipIf(
557        IS_WINDOWS,
558        "torch.cond doesn't go through torch.compile on windows"
559        "causing output not normalized as list",
560    )
561    def test_convert_retrace_nested_scripted_modules(self):
562        class Wrapper(torch.nn.Module):
563            def __init__(self, mod) -> None:
564                super().__init__()
565                self.mod = mod
566
567            def forward(self, x, y):
568                return self.mod(x, y)
569
570        class LinearM(torch.nn.Module):
571            def __init__(self, dim: int) -> None:
572                super().__init__()
573                self.linear = torch.nn.Linear(dim, dim)
574
575            def forward(self, x, y):
576                return self.linear(y)
577
578        class M(torch.nn.Module):
579            def __init__(self, dim: int) -> None:
580                super().__init__()
581                m = LinearM(dim)
582                m = torch.jit.script(m)
583                self.mod1 = m
584                self.mod2 = Wrapper(m)
585
586            def forward(self, x: torch.Tensor, y: torch.Tensor):
587                if x:
588                    return -self.mod1(x, y) - self.mod2(x, y)
589                else:
590                    return -self.mod1(x, y) + self.mod2(x, y)
591
592        class NestedM(torch.nn.Module):
593            def __init__(self, dim: int) -> None:
594                super().__init__()
595                m = M(dim)
596                m = torch.jit.script(m)
597                self.mod1 = m
598                self.mod2 = Wrapper(m)
599
600            def forward(self, x: torch.Tensor, y: torch.Tensor):
601                if x:
602                    return self.mod1(x, y) + self.mod2(x, y)
603                else:
604                    return self.mod1(x, y) - self.mod2(x, y)
605
606        inp = (
607            torch.tensor(True),
608            torch.randn([3, 3]),
609        )
610        self._check_equal_ts_ep_converter(NestedM(3), inp)
611
612    def test_convert_nn_module_with_nested_param(self):
613        class M(torch.nn.Module):
614            def __init__(self, dim: int) -> None:
615                super().__init__()
616                self.linear = torch.nn.Linear(dim, dim)
617
618            def forward(self, x: torch.Tensor):
619                return self.linear(x)
620
621        class NestedM(torch.nn.Module):
622            def __init__(self, dim: int) -> None:
623                super().__init__()
624                self.linear = torch.nn.Linear(dim, dim)
625                self.m = M(dim)
626
627            def forward(self, x: torch.Tensor):
628                return self.linear(self.m(x))
629
630        class SuperNestedM(torch.nn.Module):
631            def __init__(self, dim: int) -> None:
632                super().__init__()
633                self.linear = torch.nn.Linear(dim, dim)
634                self.m = NestedM(dim)
635
636            def forward(self, x: torch.Tensor):
637                return self.linear(self.m(x))
638
639        inp = (torch.ones(3),)
640        orig_m = NestedM(3)
641        self._check_equal_ts_ep_converter(orig_m, inp)
642        orig_m = SuperNestedM(3)
643        self._check_equal_ts_ep_converter(orig_m, inp)
644
645    def test_convert_nn_module_with_nested_buffer(self):
646        class M(torch.nn.Module):
647            def __init__(self) -> None:
648                super().__init__()
649                self.w = torch.nn.Buffer(torch.randn(1))
650
651            def forward(self, x: torch.Tensor):
652                return self.w + x
653
654        class NestedM(torch.nn.Module):
655            def __init__(self) -> None:
656                super().__init__()
657                self.m = M()
658                self.w = torch.nn.Buffer(torch.randn(1))
659
660            def forward(self, x: torch.Tensor):
661                return self.w + self.m(x)
662
663        class SuperNestedM(torch.nn.Module):
664            def __init__(self) -> None:
665                super().__init__()
666                self.m = NestedM()
667                self.w = torch.nn.Buffer(torch.randn(1))
668
669            def forward(self, x: torch.Tensor):
670                return self.w + self.m(x)
671
672        inp = (torch.ones(1),)
673        orig_m = NestedM()
674        self._check_equal_ts_ep_converter(orig_m, inp)
675        orig_m = SuperNestedM()
676        self._check_equal_ts_ep_converter(orig_m, inp)
677
678    def test_convert_nn_module_with_nested_if_and_buffer(self):
679        class M(torch.nn.Module):
680            def __init__(self) -> None:
681                super().__init__()
682                self.w = torch.nn.Buffer(torch.randn(1))
683                self.count = 1
684
685            def forward(self, x: torch.Tensor):
686                return self.w + x + self.count
687
688        class NestedM(torch.nn.Module):
689            def __init__(self) -> None:
690                super().__init__()
691                self.m1 = M()
692                self.m2 = M()
693                self.w = torch.nn.Buffer(torch.randn(1))
694
695            def forward(self, x: torch.Tensor):
696                if torch.sum(x) > 1:
697                    return self.w + self.m1(x)
698                else:
699                    return self.w + self.m2(x)
700
701        # Super nested, parameters neeed to lifted
702        # multiple times.
703        class SuperNestedM(torch.nn.Module):
704            def __init__(self) -> None:
705                super().__init__()
706                self.m1 = NestedM()
707                self.m2 = NestedM()
708                self.w = torch.nn.Buffer(torch.randn(1))
709
710            def forward(self, x: torch.Tensor):
711                if torch.max(x) > 1:
712                    return self.w + self.m1(x)
713                else:
714                    return self.w + self.m2(x)
715
716        # Super nested module testing.
717        inp = (torch.ones(1),)
718        orig_m = SuperNestedM()
719        ep_list = self._check_equal_ts_ep_converter(orig_m, inp)
720
721        t = inp[0]
722        t -= 1
723        for ep in ep_list:
724            torch.testing.assert_close(
725                ep.module()(*inp),
726                orig_m(*inp),
727            )
728
729    @unittest.skipIf(
730        IS_WINDOWS,
731        "torch.cond doesn't go through torch.compile on windows"
732        "causing output not normalized as list",
733    )
734    def test_convert_nn_module_with_nested_if_and_param(self):
735        class M(torch.nn.Module):
736            def __init__(self, dim: int) -> None:
737                super().__init__()
738                self.linear = torch.nn.Linear(dim, dim)
739
740            def forward(self, x: torch.Tensor):
741                return self.linear(x)
742
743        class NestedM(torch.nn.Module):
744            def __init__(self, dim: int) -> None:
745                super().__init__()
746                self.m1 = M(dim)
747                self.m2 = M(dim)
748                self.linear = torch.nn.Linear(dim, dim)
749
750            def forward(self, x: torch.Tensor):
751                if torch.sum(x) > 1:
752                    return self.linear(self.m1(x))
753                else:
754                    return self.linear(self.m2(x))
755
756        # Super nested, parameters neeed to lifted
757        # multiple times.
758        class SuperNestedM1(torch.nn.Module):
759            def __init__(self, dim: int) -> None:
760                super().__init__()
761                self.m1 = NestedM(dim)
762                self.m2 = NestedM(dim)
763                self.linear = torch.nn.Linear(dim, dim)
764
765            def forward(self, x: torch.Tensor):
766                if torch.max(x) > 1:
767                    return self.linear(self.m1(x))
768                else:
769                    return self.linear(self.m2(x))
770
771        # Super nested, even the input needs to be
772        # lifted recursively due to value propogation optimiztaion.
773        class SuperNestedM2(torch.nn.Module):
774            def __init__(self, dim: int) -> None:
775                super().__init__()
776                self.m1 = NestedM(dim)
777                self.m2 = NestedM(dim)
778                self.linear = torch.nn.Linear(dim, dim)
779
780            def forward(self, x: torch.Tensor):
781                if torch.sum(x) > 1:
782                    return self.linear(self.m1(x))
783                else:
784                    return self.linear(self.m2(x))
785
786        # Basic module testing.
787        inp = (torch.ones(3),)
788        orig_m = M(3)
789        ep_list = self._check_equal_ts_ep_converter(orig_m, inp)
790
791        t = inp[0]
792        t -= 0.8
793        for ep in ep_list[1:]:
794            torch.testing.assert_close(
795                ep.module()(*inp),
796                orig_m(*inp),
797            )
798
799        # Nested module testing.
800        inp = (torch.ones(3),)
801        orig_m = NestedM(3)
802        ep_list = self._check_equal_ts_ep_converter(orig_m, inp)
803
804        t = inp[0]
805        t -= 0.8
806        # Skip jit.traced because it specializes on one path.
807        for ep in ep_list[1:]:
808            torch.testing.assert_close(
809                ep.module()(*inp),
810                orig_m(*inp),
811            )
812
813        # Super nested module testing.
814        inp = (torch.ones(3),)
815        orig_m = SuperNestedM1(3)
816        ep_list = self._check_equal_ts_ep_converter(orig_m, inp)
817
818        t = inp[0]
819        t -= 0.8
820        # Skip jit.traced because it specializes on one path.
821        for ep in ep_list[1:]:
822            torch.testing.assert_close(
823                ep.module()(*inp),
824                orig_m(*inp),
825            )
826
827        # Super nested module testing.
828        inp = (torch.ones(3),)
829        orig_m = SuperNestedM2(3)
830        ep_list = self._check_equal_ts_ep_converter(orig_m, inp)
831
832        t = inp[0]
833        t -= 0.8
834        # Skip jit.traced because it specializes on one path.
835        for ep in ep_list[1:]:
836            torch.testing.assert_close(
837                ep.module()(*inp),
838                orig_m(*inp),
839            )
840
841    def test_ts2ep_converter_contains(self):
842        class MIn(torch.nn.Module):
843            def forward(self, x: torch.Tensor):
844                return x.dtype in [torch.float32, torch.float64]
845
846        class MNotIn(torch.nn.Module):
847            def forward(self, x: torch.Tensor):
848                return x.dtype in [torch.int8]
849
850        class MTensorIn(torch.nn.Module):
851            def forward(self, x: torch.Tensor, x_dict: Dict[torch.Tensor, str]):
852                return x in x_dict
853
854        # Traced function must return output that has tensors.
855        inp = (torch.tensor(4),)
856        self._check_equal_ts_ep_converter(MIn(), inp, ["script"])
857        self._check_equal_ts_ep_converter(MNotIn(), inp, ["script"])
858
859        # TODO: update test to use reference for in.
860        inp = (torch.tensor(4), {torch.tensor(4): "foo"})
861        self._check_equal_ts_ep_converter(MTensorIn(), inp, ["script"])
862        inp = (torch.tensor(1), {torch.tensor(4): "foo"})
863        self._check_equal_ts_ep_converter(MTensorIn(), inp, ["script"])
864
865    def test_ts2ep_converter_custom_op(self):
866        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
867            torch._dynamo.config.capture_scalar_outputs = True
868            torch._dynamo.config.capture_dynamic_output_shape_ops = True
869
870            torch.library.define(
871                "mylib::foo",
872                "(Tensor x) -> Tensor",
873                lib=lib,
874            )
875
876            # PyTorch custorm op implementation
877            @torch.library.impl(
878                "mylib::foo",
879                "CompositeExplicitAutograd",
880                lib=lib,
881            )
882            def foo_impl(x):
883                return x + x
884
885            # Meta function of the custom op.
886            @torch.library.impl_abstract(
887                "mylib::foo",
888                lib=lib,
889            )
890            def foo_meta(x):
891                return x + x
892
893            class M(torch.nn.Module):
894                def forward(self, x):
895                    return torch.ops.mylib.foo(x)
896
897            inp = (torch.randn(3, 3),)
898            m = M()
899            self._check_equal_ts_ep_converter(m, inp)
900
901    def test_convert_func_without_param(self):
902        def func1(x, y):
903            return x + y
904
905        def func2(x, y):
906            if x.sum() > 0:
907                return x + y
908            else:
909                return x - y
910
911        inp = (
912            torch.tensor(1),
913            torch.tensor(1),
914        )
915        self._check_equal_ts_ep_converter(func1, inp)
916
917        ep_list = self._check_equal_ts_ep_converter(func2, inp)
918
919        t = inp[0]
920        t -= 1
921        for ep in ep_list[1:]:
922            torch.testing.assert_close(
923                ep.module()(*inp),
924                func2(*inp),
925            )
926
927    def test_implicit_constant_to_tensor_handling(self):
928        def func1(x):
929            return x + 2
930
931        def func2(x, y):
932            return x * y / (x - 2 * y) + y
933
934        def func3(x):
935            return x + torch.tensor([3])
936
937        def func4():
938            val = torch.tensor(float("inf"))
939            return torch.full((10, 10), val)
940
941        def func5():
942            x = -1
943            return x * torch.ones(1, dtype=torch.float), torch.zeros(
944                1, dtype=torch.float
945            )
946
947        def func6(x1, x2, x3, x4):
948            return (
949                x1.numel(),
950                x1.size(),
951                x2.numel(),
952                x2.size(),
953                x3.numel(),
954                x3.size(),
955                x4.numel(),
956                x4.size(),
957                torch.ones(x1.numel()),  # Just make sure downstream ops still work.
958                torch.ones(x1.size()),  # Just make sure downstream ops still work.
959            )
960
961        class M1(torch.nn.Module):
962            def __init__(self, value):
963                super().__init__()
964                self.x = torch.tensor(value)
965
966            def forward(self):
967                return self.x.clone()
968
969        class M2(torch.nn.Module):
970            def forward(self, x):
971                return torch.tensor(4) + x
972
973        inp = (torch.randn([2, 2]),)
974        self._check_equal_ts_ep_converter(func1, inp)
975        inp = (torch.randn([2, 2]), torch.randn([2, 2]))
976        self._check_equal_ts_ep_converter(func2, inp)
977
978        inp = (torch.randn([2, 2]),)
979        self._check_equal_ts_ep_converter(func3, inp)
980
981        self._check_equal_ts_ep_converter(func4, ())
982        self._check_equal_ts_ep_converter(M1(5), ())
983
984        inp = (torch.randn(2),)
985        self._check_equal_ts_ep_converter(M2(), inp)
986
987        self._check_equal_ts_ep_converter(func5, ())
988        inp = (
989            torch.randn([2, 3, 4]).to(torch.int8),
990            torch.randn([2, 3, 4]).to(torch.int32),
991            torch.randn([2, 3, 4]).to(torch.float32),
992            torch.randn([2, 3, 4]).to(torch.float64),
993        )
994        ep_list = self._check_equal_ts_ep_converter(func6, inp)
995
996        # TODO: Additional check once dynamic shape is supported.
997        # for ep in ep_list:
998        #     self.assertEqual(
999        #         ep.module()(
1000        #             torch.randn([1, 1, 1]).to(torch.int8),
1001        #             torch.randn([1, 1, 1]).to(torch.int32),
1002        #             torch.randn([1, 1, 1]).to(torch.float32),
1003        #             torch.randn([1, 1, 1]).to(torch.float64),
1004        #         )[0], 1
1005        #     )
1006
1007    def test_aten_tensor_dtype_int(self):
1008        class M(torch.nn.Module):
1009            def forward(self, x):
1010                y = torch.tensor(1, dtype=torch.int32)
1011                return y + x
1012
1013        ep_list = self._check_equal_ts_ep_converter(M(), (torch.tensor(1),))
1014        for ep in ep_list:
1015            self.assertEqual(len(ep.constants), 1)
1016
1017    def test_aten_tensor_prim_dtype(self):
1018        class M(torch.nn.Module):
1019            def forward(self, x):
1020                y = torch.tensor(1, dtype=x.dtype)
1021                return y + x
1022
1023        ep_list = self._check_equal_ts_ep_converter(M(), (torch.tensor(1),))
1024        for ep in ep_list:
1025            self.assertEqual(len(ep.constants), 1)
1026
1027    def test_aten_tensor_dynamic(self):
1028        class M(torch.nn.Module):
1029            def forward(self, x):
1030                s = x.shape[0]
1031                y = torch.tensor(s)
1032                return y
1033
1034        ep_list = self._check_equal_ts_ep_converter(M(), (torch.ones(3),))
1035        for ep in ep_list:
1036            self.assertEqual(len(ep.constants), 0)
1037
1038        # TODO: Additional check once dynamic shape is supported.
1039        # for ep in ep_list:
1040        #     torch.testing.assert_close(
1041        #         ep.module()(torch.ones(4)),
1042        #         M()(torch.ones(4)),
1043        #     )
1044
1045        class M(torch.nn.Module):
1046            def forward(self, x):
1047                s = x.shape[0]
1048                y = torch.tensor([s, s * 2, 1])
1049                return y
1050
1051        ep_list = self._check_equal_ts_ep_converter(M(), (torch.ones(3),))
1052        # Trace directly inline a tensor constant.
1053        for ep in ep_list[1:]:
1054            self.assertEqual(len(ep.constants), 0)
1055
1056        # TODO: Additional check once dynamic shape is supported.
1057        # for ep in ep_list:
1058        #     torch.testing.assert_close(
1059        #         ep.module()(torch.ones(4)),
1060        #         M()(torch.ones(4)),
1061        #     )
1062
1063    def test_prim_tolist(self):
1064        class Module(torch.nn.Module):
1065            def forward(self, x: torch.Tensor) -> List[int]:
1066                return x.tolist()
1067
1068        inp = (torch.tensor([1, 2, 3]),)
1069        self._check_equal_ts_ep_converter(Module(), inp, ["script"])
1070
1071        class Module(torch.nn.Module):
1072            def forward(self, x: torch.Tensor) -> List[List[int]]:
1073                return x.tolist()
1074
1075        inp = (torch.tensor([[1, 2, 3], [4, 5, 6]]),)
1076        self._check_equal_ts_ep_converter(Module(), inp, ["script"])
1077
1078    def test_get_tensor_constants(self):
1079        # Since self.data is only read but not written, it is lifted as
1080        # constant tensors.
1081        class Foo(torch.nn.Module):
1082            def __init__(self) -> None:
1083                super().__init__()
1084                self.data = torch.randn(3, 2)
1085
1086            def forward(self, x: torch.Tensor) -> torch.Tensor:
1087                return x + self.data
1088
1089        class Goo(torch.nn.Module):
1090            def __init__(self) -> None:
1091                super().__init__()
1092                self.data = torch.randn(3, 2)
1093                self.foo = Foo()
1094
1095            def forward(self, x: torch.Tensor) -> torch.Tensor:
1096                return x + self.data + self.foo.data + self.foo(x)
1097
1098        inp = (torch.randn(3, 2),)
1099        goo = Goo()
1100        self._check_equal_ts_ep_converter(goo, inp)
1101
1102    def test_prim_SetAttr(self):
1103        class Module(torch.nn.Module):
1104            def __init__(self) -> None:
1105                super().__init__()
1106                self.data = torch.nn.Buffer(torch.ones(3, 2))
1107
1108            def forward(self, x: torch.Tensor) -> torch.Tensor:
1109                self.data = self.data + x
1110                return x + x
1111
1112        inp = (torch.ones(3, 2),)
1113        self._check_equal_ts_ep_converter(
1114            Module, inp, ["script"], check_persistent=True
1115        )
1116
1117        class Module(torch.nn.Module):
1118            def __init__(self) -> None:
1119                super().__init__()
1120                self.data = torch.nn.Buffer(torch.ones(3, 2))
1121
1122            def forward(self, x: torch.Tensor) -> torch.Tensor:
1123                self.data = self.data + x
1124                return x + self.data
1125
1126        inp = (torch.ones(3, 2),)
1127        self._check_equal_ts_ep_converter(
1128            Module, inp, ["script"], check_persistent=True
1129        )
1130
1131        # export lifts a tensor constant (self.data) as an input if it is not assigned.
1132        # If it is assigned, export will error and ask users to register it as a buffer.
1133        # In converter, we change tensor constants that are assigned as a buffer automatically,
1134        # since it might be hard to manually register them as buffers.
1135        class Module(torch.nn.Module):
1136            def __init__(self) -> None:
1137                super().__init__()
1138                self.data = torch.ones(3, 2)
1139
1140            def forward(self, x: torch.Tensor) -> torch.Tensor:
1141                self.data = self.data + x
1142                return x + self.data
1143
1144        inp = (torch.ones(3, 2),)
1145        self._check_equal_ts_ep_converter(
1146            Module,
1147            inp,
1148            ["script"],
1149            check_persistent=True,
1150            lifted_tensor_constants=OrderedDict([("data", torch.ones(3, 2))]),
1151        )
1152
1153        class Module(torch.nn.Module):
1154            def __init__(self) -> None:
1155                super().__init__()
1156                self.count = 0
1157
1158            def forward(self, x: torch.Tensor) -> torch.Tensor:
1159                self.count += 1
1160                return x + self.count
1161
1162        # check_persistent is False since export specializes on non-tensor constants
1163        inp = (torch.ones(3, 2),)
1164        self._check_equal_ts_ep_converter(
1165            Module(), inp, ["script"], check_persistent=False
1166        )
1167
1168        class M(torch.nn.Module):
1169            def __init__(self) -> None:
1170                super().__init__()
1171                self.count = 0
1172
1173            def forward(self, x):
1174                count1 = self.count
1175                self.count += 1
1176                count2 = self.count
1177                self.count += 1
1178                count3 = self.count
1179                return x + count1 + count2 + count3
1180
1181        inp = (torch.ones(1),)
1182        self._check_equal_ts_ep_converter(M(), inp, ["script"], check_persistent=False)
1183
1184        class M(torch.nn.Module):
1185            def __init__(self) -> None:
1186                super().__init__()
1187                self.w2 = torch.nn.Buffer(torch.ones(1))
1188
1189            def forward(self, x: torch.Tensor):
1190                self.w2 += 1
1191                return self.w2
1192
1193        inp = (torch.ones(1),)
1194        self._check_equal_ts_ep_converter(M, inp, ["script"], check_persistent=True)
1195
1196    def test_raise_exception(self):
1197        class Module(torch.nn.Module):
1198            def forward(self, x: torch.Tensor, y: int) -> torch.Tensor:
1199                if y > 0:
1200                    raise RuntimeError("test")
1201                return x + y
1202
1203        # match non-strict export behavior that errors when the given input leads to
1204        # RaiseException.
1205        with self.assertRaisesRegex(torch.jit.Error, "builtins.RuntimeError"):
1206            inp = (torch.randn(3, 2), 1)
1207            self._check_equal_ts_ep_converter(Module(), inp, ["script"])
1208
1209        # Matching non-strict export behavior that only executes 1 if-branch according
1210        # to the given input.
1211        inp = (torch.randn(3, 2), 0)
1212        self._check_equal_ts_ep_converter(Module(), inp, ["script"])
1213
1214        class Module(torch.nn.Module):
1215            def forward(self, x: torch.Tensor, y: int) -> torch.Tensor:
1216                z = x
1217                if y > 0:
1218                    raise RuntimeError("test")
1219                    # z = x
1220                else:
1221                    z = x + y
1222                return x + y + z
1223
1224        # match non-strict export behavior that errors when the given input leads to
1225        # RaiseException.
1226        with self.assertRaisesRegex(torch.jit.Error, "builtins.RuntimeError"):
1227            inp = (torch.randn(3, 2), 1)
1228            self._check_equal_ts_ep_converter(Module(), inp, ["script"])
1229
1230        # Matching non-strict export behavior that only executes 1 if-branch according
1231        # to the given input.
1232        inp = (torch.randn(3, 2), 0)
1233        self._check_equal_ts_ep_converter(Module(), inp, ["script"])
1234
1235    def test_context_manager(self):
1236        class ContextManager:
1237            def __init__(self) -> None:
1238                self.count = 0
1239                return
1240
1241            def __enter__(self):
1242                self.count += 1
1243                return
1244
1245            def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
1246                self.count -= 1
1247                return
1248
1249        class M(torch.nn.Module):
1250            def forward(self, x, y):
1251                with ContextManager():
1252                    res = x + y
1253                return res
1254
1255        inp = (torch.ones(3, 3), torch.ones(3, 3))
1256        self._check_equal_ts_ep_converter(M(), inp)
1257
1258    def test_hidden_input_name(self):
1259        @torch.jit.script
1260        def func1(x):
1261            return x + 1
1262
1263        def func2(*args):
1264            v = torch.cat(args, dim=1)
1265            return v * v
1266
1267        inp = (torch.randn([1, 1]),)
1268        self._check_equal_ts_ep_converter(func1, inp)
1269
1270        inp = (torch.ones(5, 5),)
1271        # Cannot script again.
1272        self._check_equal_ts_ep_converter(torch.ops.aten.relu, inp, ["trace"])
1273
1274        M = 2
1275        Ns = [4, 2, 1]
1276        empty = torch.tensor([], dtype=torch.double)
1277        values = [empty] + [torch.randn(M, N) for N in Ns]
1278        # Cannot script variable length inputs.
1279        self._check_equal_ts_ep_converter(func2, tuple(values), ["trace"])
1280
1281    def test_ts2ep_multi_outputs_on_call_ops(self):
1282        class M(torch.nn.Module):
1283            def __init__(self) -> None:
1284                super().__init__()
1285                self.pool = torch.nn.AdaptiveMaxPool2d((2, 2), return_indices=True)
1286
1287            def forward(self, x: torch.Tensor, y: torch.Tensor):
1288                return (
1289                    torch.max(x, dim=0),
1290                    torch.topk(x, 3),
1291                    torch.sort(x, dim=0),
1292                    self.pool(y),
1293                )
1294
1295        inp = (torch.randn([4, 4]), torch.randn([1, 1, 10, 10]))
1296        self._check_equal_ts_ep_converter(M(), inp)
1297
1298    def test_aten_append_t(self):
1299        class M(torch.nn.Module):
1300            def forward(self, x: List[torch.Tensor]):
1301                out = []
1302                out.append(x[0] + x[1])
1303                out.append(x[0] - x[1])
1304                out1 = torch.cat(out)
1305                out.append(x[0] * x[1])
1306                out2 = torch.cat(out)
1307                return out, out1, out2
1308
1309        inp = ([torch.ones(2, 3), torch.ones(2, 3)],)
1310        # Trace already unrolls the list.
1311        self._check_equal_ts_ep_converter(M(), inp, ["script"])
1312
1313    def test_convert_script_object(self):
1314        class M1(torch.nn.Module):
1315            def __init__(self):
1316                super().__init__()
1317                self.tq = _empty_tensor_queue()
1318
1319            def forward(self, x: torch.Tensor):
1320                self.tq.push(x)
1321                torch.ops._TorchScriptTesting.queue_push(self.tq, x.cos())
1322                return torch.ops._TorchScriptTesting.queue_pop(self.tq), self.tq.pop()
1323
1324        inp = (torch.randn(2, 3),)
1325        self._check_equal_ts_ep_converter(M1(), inp, ["script"])
1326
1327    def test_ts2ep_with_loop(self):
1328        def func1(x, x_list: List[torch.Tensor]):
1329            a, b, c = x, x, x
1330            for i in range(1, 5, 2):
1331                for k in range(5):
1332                    a = a + a + k
1333                    b = b + b - k
1334                    x_list.append(x_list[k] + x_list[k + 1])
1335                for k in range(5):
1336                    b = b + b - k
1337                    c = c + c * k
1338                    x_list.append(x_list[k] + x_list[k + 1] - x_list[k + 2])
1339            return x, x_list
1340
1341        def func2(x):
1342            for i in range(x.size(0)):
1343                x = x * x * i
1344            return x
1345
1346        def func3(x):
1347            while x.sum() < 10:
1348                x += x.sin()
1349            return x
1350
1351        inp = (
1352            torch.tensor(1),
1353            [torch.ones([2, 2]), torch.ones([2, 2]) * 2],
1354        )
1355        # Trace unrolls the loop.
1356        self._check_equal_ts_ep_converter(func1, inp, ["script"])
1357
1358        # TODO: (2/N)
1359        # Trace unrolls the loop.
1360        # self._check_equal_ts_ep_converter(func2, inp, ["script"])
1361
1362        # TODO: (3/N)
1363        # Trace unrolls the loop.
1364        # self._check_equal_ts_ep_converter(func3, inp, ["script"])
1365
1366    @unittest.skipIf(
1367        IS_WINDOWS,
1368        "Windows does not support qnnpack",
1369    )
1370    def test_ts2ep_convert_quantized_model(self):
1371        class Standalone(torch.nn.Module):
1372            def __init__(self):
1373                super().__init__()
1374                self.quant = torch.ao.quantization.QuantStub()
1375                self.conv1 = torch.nn.Conv2d(1, 1, 1)
1376                self.conv2 = torch.nn.Conv2d(1, 1, 1)
1377                self.relu = torch.nn.ReLU()
1378                self.dequant = torch.ao.quantization.DeQuantStub()
1379
1380            def forward(self, x):
1381                x = self.quant(x)
1382                x = self.conv1(x)
1383                x = self.conv2(x)
1384                x = self.relu(x)
1385                x = self.dequant(x)
1386                return x
1387
1388            def fuse_model(self):
1389                torch.ao.quantization.fuse_modules(
1390                    self, [["conv2", "relu"]], inplace=True
1391                )
1392
1393        with override_quantized_engine("qnnpack"):
1394            model = Standalone()
1395            model.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
1396            model.fuse_model()
1397            torch.ao.quantization.prepare(model, inplace=True)
1398            model(torch.randn(4, 1, 4, 4))
1399            torch.ao.quantization.convert(model, inplace=True)
1400
1401            # Use customized checking here, because state_dict of quantization will be
1402            # modified by the quantization pass.
1403            inp = (torch.randn(4, 1, 4, 4),)
1404            original_ts_model = torch.jit.script(model)
1405            ts_model = torch.jit.script(model)
1406            converter = TS2EPConverter(ts_model, inp)
1407            ep = converter.convert()
1408
1409            orig_out, _ = pytree.tree_flatten(original_ts_model(*inp))
1410            ep_out, _ = pytree.tree_flatten(ep.module()(*inp))
1411            self._check_tensor_list_equal(orig_out, ep_out)
1412
1413    def test_ts2ep_convert_quantized_model_with_opcontext(self):
1414        class M(torch.nn.Module):
1415            def __init__(self, linear_op):
1416                super().__init__()
1417                self.linear_op = linear_op
1418
1419            def forward(self, x):
1420                x = torch.ops.prepacked.linear_clamp_run(x, self.linear_op)
1421                return x
1422
1423        linear_op = torch.ops.prepacked.linear_clamp_prepack(
1424            torch.randn(10, 10), torch.randn(10)
1425        )
1426        m = M(linear_op)
1427        inp = (torch.randn(1, 10),)
1428        self._check_equal_ts_ep_converter(m, inp, ["script"])
1429
1430
1431if __name__ == "__main__":
1432    run_tests()
1433