1# Owner(s): ["module: fx"] 2 3from __future__ import annotations # type: ignore[attr-defined] 4 5import typing 6 7import torch 8from torch.fx import symbolic_trace 9 10 11class A: 12 def __call__(self, x: torch.Tensor): 13 return torch.add(x, x) 14 15 16# No forward references 17class M1(torch.nn.Module): 18 def forward(self, x: torch.Tensor, a: A) -> torch.Tensor: 19 return a(x) 20 21 22# Forward references 23class M2(torch.nn.Module): 24 def forward(self, x: torch.Tensor, a: A) -> torch.Tensor: 25 return a(x) 26 27 28# Non-torch annotation with no internal forward references 29class M3(torch.nn.Module): 30 def forward(self, x: typing.List[torch.Tensor], a: A) -> torch.Tensor: 31 return a(x[0]) 32 33 34# Non-torch annotation with internal forward references 35class M4(torch.nn.Module): 36 def forward(self, x: typing.List[torch.Tensor], a: A) -> torch.Tensor: 37 return a(x[0]) 38 39 40x = torch.rand(2, 3) 41 42ref = torch.add(x, x) 43 44traced1 = symbolic_trace(M1()) 45res1 = traced1(x, A()) 46assert torch.all(torch.eq(ref, res1)) 47 48traced2 = symbolic_trace(M2()) 49res2 = traced2(x, A()) 50assert torch.all(torch.eq(ref, res2)) 51 52traced3 = symbolic_trace(M3()) 53res3 = traced3([x], A()) 54assert torch.all(torch.eq(ref, res3)) 55 56traced4 = symbolic_trace(M4()) 57res4 = traced4([x], A()) 58assert torch.all(torch.eq(ref, res4)) 59