xref: /aosp_15_r20/external/pytorch/test/jit/test_attr.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerfrom typing import NamedTuple, Tuple
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport torch
6*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck
7*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
11*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError(
12*da0073e9SAndroid Build Coastguard Worker        "This test file is not meant to be run directly, use:\n\n"
13*da0073e9SAndroid Build Coastguard Worker        "\tpython test/test_jit.py TESTNAME\n\n"
14*da0073e9SAndroid Build Coastguard Worker        "instead."
15*da0073e9SAndroid Build Coastguard Worker    )
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Workerclass TestGetDefaultAttr(JitTestCase):
19*da0073e9SAndroid Build Coastguard Worker    def test_getattr_with_default(self):
20*da0073e9SAndroid Build Coastguard Worker        class A(torch.nn.Module):
21*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
22*da0073e9SAndroid Build Coastguard Worker                super().__init__()
23*da0073e9SAndroid Build Coastguard Worker                self.init_attr_val = 1.0
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
26*da0073e9SAndroid Build Coastguard Worker                y = getattr(self, "init_attr_val")  # noqa: B009
27*da0073e9SAndroid Build Coastguard Worker                w: list[float] = [1.0]
28*da0073e9SAndroid Build Coastguard Worker                z = getattr(self, "missing", w)  # noqa: B009
29*da0073e9SAndroid Build Coastguard Worker                z.append(y)
30*da0073e9SAndroid Build Coastguard Worker                return z
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker        result = A().forward(0.0)
33*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(2, len(result))
34*da0073e9SAndroid Build Coastguard Worker        graph = torch.jit.script(A()).graph
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker        # The "init_attr_val" attribute exists
37*da0073e9SAndroid Build Coastguard Worker        FileCheck().check('prim::GetAttr[name="init_attr_val"]').run(graph)
38*da0073e9SAndroid Build Coastguard Worker        # The "missing" attribute does not exist, so there should be no corresponding GetAttr in AST
39*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("missing").run(graph)
40*da0073e9SAndroid Build Coastguard Worker        # instead the getattr call will emit the default value, which is a list with one float element
41*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("float[] = prim::ListConstruct").run(graph)
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker    def test_getattr_named_tuple(self):
44*da0073e9SAndroid Build Coastguard Worker        global MyTuple
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker        class MyTuple(NamedTuple):
47*da0073e9SAndroid Build Coastguard Worker            x: str
48*da0073e9SAndroid Build Coastguard Worker            y: torch.Tensor
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker        def fn(x: MyTuple) -> Tuple[str, torch.Tensor, int]:
51*da0073e9SAndroid Build Coastguard Worker            return (
52*da0073e9SAndroid Build Coastguard Worker                getattr(x, "x", "fdsa"),
53*da0073e9SAndroid Build Coastguard Worker                getattr(x, "y", torch.ones((3, 3))),
54*da0073e9SAndroid Build Coastguard Worker                getattr(x, "z", 7),
55*da0073e9SAndroid Build Coastguard Worker            )
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker        inp = MyTuple(x="test", y=torch.ones(3, 3) * 2)
58*da0073e9SAndroid Build Coastguard Worker        ref = fn(inp)
59*da0073e9SAndroid Build Coastguard Worker        fn_s = torch.jit.script(fn)
60*da0073e9SAndroid Build Coastguard Worker        res = fn_s(inp)
61*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, ref)
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker    def test_getattr_tuple(self):
64*da0073e9SAndroid Build Coastguard Worker        def fn(x: Tuple[str, int]) -> int:
65*da0073e9SAndroid Build Coastguard Worker            return getattr(x, "x", 2)
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "but got a normal Tuple"):
68*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(fn)
69