1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerfrom typing import NamedTuple, Tuple 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport torch 6*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck 7*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 11*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 12*da0073e9SAndroid Build Coastguard Worker "This test file is not meant to be run directly, use:\n\n" 13*da0073e9SAndroid Build Coastguard Worker "\tpython test/test_jit.py TESTNAME\n\n" 14*da0073e9SAndroid Build Coastguard Worker "instead." 15*da0073e9SAndroid Build Coastguard Worker ) 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Workerclass TestGetDefaultAttr(JitTestCase): 19*da0073e9SAndroid Build Coastguard Worker def test_getattr_with_default(self): 20*da0073e9SAndroid Build Coastguard Worker class A(torch.nn.Module): 21*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 22*da0073e9SAndroid Build Coastguard Worker super().__init__() 23*da0073e9SAndroid Build Coastguard Worker self.init_attr_val = 1.0 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 26*da0073e9SAndroid Build Coastguard Worker y = getattr(self, "init_attr_val") # noqa: B009 27*da0073e9SAndroid Build Coastguard Worker w: list[float] = [1.0] 28*da0073e9SAndroid Build Coastguard Worker z = getattr(self, "missing", w) # noqa: B009 29*da0073e9SAndroid Build Coastguard Worker z.append(y) 30*da0073e9SAndroid Build Coastguard Worker return z 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker result = A().forward(0.0) 33*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2, len(result)) 34*da0073e9SAndroid Build Coastguard Worker graph = torch.jit.script(A()).graph 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker # The "init_attr_val" attribute exists 37*da0073e9SAndroid Build Coastguard Worker FileCheck().check('prim::GetAttr[name="init_attr_val"]').run(graph) 38*da0073e9SAndroid Build Coastguard Worker # The "missing" attribute does not exist, so there should be no corresponding GetAttr in AST 39*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("missing").run(graph) 40*da0073e9SAndroid Build Coastguard Worker # instead the getattr call will emit the default value, which is a list with one float element 41*da0073e9SAndroid Build Coastguard Worker FileCheck().check("float[] = prim::ListConstruct").run(graph) 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker def test_getattr_named_tuple(self): 44*da0073e9SAndroid Build Coastguard Worker global MyTuple 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker class MyTuple(NamedTuple): 47*da0073e9SAndroid Build Coastguard Worker x: str 48*da0073e9SAndroid Build Coastguard Worker y: torch.Tensor 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker def fn(x: MyTuple) -> Tuple[str, torch.Tensor, int]: 51*da0073e9SAndroid Build Coastguard Worker return ( 52*da0073e9SAndroid Build Coastguard Worker getattr(x, "x", "fdsa"), 53*da0073e9SAndroid Build Coastguard Worker getattr(x, "y", torch.ones((3, 3))), 54*da0073e9SAndroid Build Coastguard Worker getattr(x, "z", 7), 55*da0073e9SAndroid Build Coastguard Worker ) 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker inp = MyTuple(x="test", y=torch.ones(3, 3) * 2) 58*da0073e9SAndroid Build Coastguard Worker ref = fn(inp) 59*da0073e9SAndroid Build Coastguard Worker fn_s = torch.jit.script(fn) 60*da0073e9SAndroid Build Coastguard Worker res = fn_s(inp) 61*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, ref) 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker def test_getattr_tuple(self): 64*da0073e9SAndroid Build Coastguard Worker def fn(x: Tuple[str, int]) -> int: 65*da0073e9SAndroid Build Coastguard Worker return getattr(x, "x", 2) 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "but got a normal Tuple"): 68*da0073e9SAndroid Build Coastguard Worker torch.jit.script(fn) 69