xref: /aosp_15_r20/external/pytorch/test/fx/test_future.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: fx"]
2
3from __future__ import annotations  # type: ignore[attr-defined]
4
5import typing
6
7import torch
8from torch.fx import symbolic_trace
9
10
11class A:
12    def __call__(self, x: torch.Tensor):
13        return torch.add(x, x)
14
15
16# No forward references
17class M1(torch.nn.Module):
18    def forward(self, x: torch.Tensor, a: A) -> torch.Tensor:
19        return a(x)
20
21
22# Forward references
23class M2(torch.nn.Module):
24    def forward(self, x: torch.Tensor, a: A) -> torch.Tensor:
25        return a(x)
26
27
28# Non-torch annotation with no internal forward references
29class M3(torch.nn.Module):
30    def forward(self, x: typing.List[torch.Tensor], a: A) -> torch.Tensor:
31        return a(x[0])
32
33
34# Non-torch annotation with internal forward references
35class M4(torch.nn.Module):
36    def forward(self, x: typing.List[torch.Tensor], a: A) -> torch.Tensor:
37        return a(x[0])
38
39
40x = torch.rand(2, 3)
41
42ref = torch.add(x, x)
43
44traced1 = symbolic_trace(M1())
45res1 = traced1(x, A())
46assert torch.all(torch.eq(ref, res1))
47
48traced2 = symbolic_trace(M2())
49res2 = traced2(x, A())
50assert torch.all(torch.eq(ref, res2))
51
52traced3 = symbolic_trace(M3())
53res3 = traced3([x], A())
54assert torch.all(torch.eq(ref, res3))
55
56traced4 = symbolic_trace(M4())
57res4 = traced4([x], A())
58assert torch.all(torch.eq(ref, res4))
59