xref: /aosp_15_r20/external/pytorch/test/jit/test_script_profile.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5
6import torch
7from torch import nn
8
9
10# Make the helper files in test/ importable
11pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
12sys.path.append(pytorch_test_dir)
13from torch.testing._internal.jit_utils import JitTestCase
14
15
16if __name__ == "__main__":
17    raise RuntimeError(
18        "This test file is not meant to be run directly, use:\n\n"
19        "\tpython test/test_jit.py TESTNAME\n\n"
20        "instead."
21    )
22
23
24class Sequence(nn.Module):
25    def __init__(self) -> None:
26        super().__init__()
27        self.lstm1 = nn.LSTMCell(1, 51)
28        self.lstm2 = nn.LSTMCell(51, 51)
29        self.linear = nn.Linear(51, 1)
30
31    def forward(self, input):
32        outputs = []
33        h_t = torch.zeros(input.size(0), 51)
34        c_t = torch.zeros(input.size(0), 51)
35        h_t2 = torch.zeros(input.size(0), 51)
36        c_t2 = torch.zeros(input.size(0), 51)
37
38        for input_t in input.split(1, dim=1):
39            h_t, c_t = self.lstm1(input_t, (h_t, c_t))
40            h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
41            output = self.linear(h_t2)
42            outputs += [output]
43        outputs = torch.cat(outputs, dim=1)
44        return outputs
45
46
47class TestScriptProfile(JitTestCase):
48    def test_basic(self):
49        seq = torch.jit.script(Sequence())
50        p = torch.jit._ScriptProfile()
51        p.enable()
52        seq(torch.rand((10, 100)))
53        p.disable()
54        self.assertNotEqual(p.dump_string(), "")
55
56    def test_script(self):
57        seq = Sequence()
58
59        p = torch.jit._ScriptProfile()
60        p.enable()
61
62        @torch.jit.script
63        def fn():
64            _ = seq(torch.rand((10, 100)))
65
66        fn()
67        p.disable()
68
69        self.assertNotEqual(p.dump_string(), "")
70
71    def test_multi(self):
72        seq = torch.jit.script(Sequence())
73        profiles = [torch.jit._ScriptProfile() for _ in range(5)]
74        for p in profiles:
75            p.enable()
76
77        last = None
78        while len(profiles) > 0:
79            seq(torch.rand((10, 10)))
80            p = profiles.pop()
81            p.disable()
82            stats = p.dump_string()
83            self.assertNotEqual(stats, "")
84            if last:
85                self.assertNotEqual(stats, last)
86            last = stats
87
88    def test_section(self):
89        seq = Sequence()
90
91        @torch.jit.script
92        def fn(max: int):
93            _ = seq(torch.rand((10, max)))
94
95        p = torch.jit._ScriptProfile()
96        p.enable()
97        fn(100)
98        p.disable()
99        s0 = p.dump_string()
100
101        fn(10)
102        p.disable()
103        s1 = p.dump_string()
104
105        p.enable()
106        fn(10)
107        p.disable()
108        s2 = p.dump_string()
109
110        self.assertEqual(s0, s1)
111        self.assertNotEqual(s1, s2)
112
113    def test_empty(self):
114        p = torch.jit._ScriptProfile()
115        p.enable()
116        p.disable()
117        self.assertEqual(p.dump_string(), "")
118