xref: /aosp_15_r20/external/pytorch/test/fx/test_future.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: fx"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations  # type: ignore[attr-defined]
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport typing
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport torch
8*da0073e9SAndroid Build Coastguard Workerfrom torch.fx import symbolic_trace
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Workerclass A:
12*da0073e9SAndroid Build Coastguard Worker    def __call__(self, x: torch.Tensor):
13*da0073e9SAndroid Build Coastguard Worker        return torch.add(x, x)
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker# No forward references
17*da0073e9SAndroid Build Coastguard Workerclass M1(torch.nn.Module):
18*da0073e9SAndroid Build Coastguard Worker    def forward(self, x: torch.Tensor, a: A) -> torch.Tensor:
19*da0073e9SAndroid Build Coastguard Worker        return a(x)
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker# Forward references
23*da0073e9SAndroid Build Coastguard Workerclass M2(torch.nn.Module):
24*da0073e9SAndroid Build Coastguard Worker    def forward(self, x: torch.Tensor, a: A) -> torch.Tensor:
25*da0073e9SAndroid Build Coastguard Worker        return a(x)
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker# Non-torch annotation with no internal forward references
29*da0073e9SAndroid Build Coastguard Workerclass M3(torch.nn.Module):
30*da0073e9SAndroid Build Coastguard Worker    def forward(self, x: typing.List[torch.Tensor], a: A) -> torch.Tensor:
31*da0073e9SAndroid Build Coastguard Worker        return a(x[0])
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker# Non-torch annotation with internal forward references
35*da0073e9SAndroid Build Coastguard Workerclass M4(torch.nn.Module):
36*da0073e9SAndroid Build Coastguard Worker    def forward(self, x: typing.List[torch.Tensor], a: A) -> torch.Tensor:
37*da0073e9SAndroid Build Coastguard Worker        return a(x[0])
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Workerx = torch.rand(2, 3)
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Workerref = torch.add(x, x)
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Workertraced1 = symbolic_trace(M1())
45*da0073e9SAndroid Build Coastguard Workerres1 = traced1(x, A())
46*da0073e9SAndroid Build Coastguard Workerassert torch.all(torch.eq(ref, res1))
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Workertraced2 = symbolic_trace(M2())
49*da0073e9SAndroid Build Coastguard Workerres2 = traced2(x, A())
50*da0073e9SAndroid Build Coastguard Workerassert torch.all(torch.eq(ref, res2))
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Workertraced3 = symbolic_trace(M3())
53*da0073e9SAndroid Build Coastguard Workerres3 = traced3([x], A())
54*da0073e9SAndroid Build Coastguard Workerassert torch.all(torch.eq(ref, res3))
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Workertraced4 = symbolic_trace(M4())
57*da0073e9SAndroid Build Coastguard Workerres4 = traced4([x], A())
58*da0073e9SAndroid Build Coastguard Workerassert torch.all(torch.eq(ref, res4))
59