1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: fx"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations # type: ignore[attr-defined] 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport typing 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport torch 8*da0073e9SAndroid Build Coastguard Workerfrom torch.fx import symbolic_trace 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Workerclass A: 12*da0073e9SAndroid Build Coastguard Worker def __call__(self, x: torch.Tensor): 13*da0073e9SAndroid Build Coastguard Worker return torch.add(x, x) 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker# No forward references 17*da0073e9SAndroid Build Coastguard Workerclass M1(torch.nn.Module): 18*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor, a: A) -> torch.Tensor: 19*da0073e9SAndroid Build Coastguard Worker return a(x) 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker# Forward references 23*da0073e9SAndroid Build Coastguard Workerclass M2(torch.nn.Module): 24*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor, a: A) -> torch.Tensor: 25*da0073e9SAndroid Build Coastguard Worker return a(x) 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker# Non-torch annotation with no internal forward references 29*da0073e9SAndroid Build Coastguard Workerclass M3(torch.nn.Module): 30*da0073e9SAndroid Build Coastguard Worker def forward(self, x: typing.List[torch.Tensor], a: A) -> torch.Tensor: 31*da0073e9SAndroid Build Coastguard Worker return a(x[0]) 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker# Non-torch annotation with internal forward references 35*da0073e9SAndroid Build Coastguard Workerclass M4(torch.nn.Module): 36*da0073e9SAndroid Build Coastguard Worker def forward(self, x: typing.List[torch.Tensor], a: A) -> torch.Tensor: 37*da0073e9SAndroid Build Coastguard Worker return a(x[0]) 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Workerx = torch.rand(2, 3) 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Workerref = torch.add(x, x) 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Workertraced1 = symbolic_trace(M1()) 45*da0073e9SAndroid Build Coastguard Workerres1 = traced1(x, A()) 46*da0073e9SAndroid Build Coastguard Workerassert torch.all(torch.eq(ref, res1)) 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Workertraced2 = symbolic_trace(M2()) 49*da0073e9SAndroid Build Coastguard Workerres2 = traced2(x, A()) 50*da0073e9SAndroid Build Coastguard Workerassert torch.all(torch.eq(ref, res2)) 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Workertraced3 = symbolic_trace(M3()) 53*da0073e9SAndroid Build Coastguard Workerres3 = traced3([x], A()) 54*da0073e9SAndroid Build Coastguard Workerassert torch.all(torch.eq(ref, res3)) 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Workertraced4 = symbolic_trace(M4()) 57*da0073e9SAndroid Build Coastguard Workerres4 = traced4([x], A()) 58*da0073e9SAndroid Build Coastguard Workerassert torch.all(torch.eq(ref, res4)) 59