xref: /aosp_15_r20/external/pytorch/test/test_jit_disabled.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport sys
4*da0073e9SAndroid Build Coastguard Workerimport os
5*da0073e9SAndroid Build Coastguard Workerimport contextlib
6*da0073e9SAndroid Build Coastguard Workerimport subprocess
7*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TestCase, run_tests, TemporaryFileName
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager
11*da0073e9SAndroid Build Coastguard Workerdef _jit_disabled():
12*da0073e9SAndroid Build Coastguard Worker    cur_env = os.environ.get("PYTORCH_JIT", "1")
13*da0073e9SAndroid Build Coastguard Worker    os.environ["PYTORCH_JIT"] = "0"
14*da0073e9SAndroid Build Coastguard Worker    try:
15*da0073e9SAndroid Build Coastguard Worker        yield
16*da0073e9SAndroid Build Coastguard Worker    finally:
17*da0073e9SAndroid Build Coastguard Worker        os.environ["PYTORCH_JIT"] = cur_env
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Workerclass TestJitDisabled(TestCase):
21*da0073e9SAndroid Build Coastguard Worker    """
22*da0073e9SAndroid Build Coastguard Worker    These tests are separate from the rest of the JIT tests because we need
23*da0073e9SAndroid Build Coastguard Worker    run a new subprocess and `import torch` with the correct environment
24*da0073e9SAndroid Build Coastguard Worker    variables set.
25*da0073e9SAndroid Build Coastguard Worker    """
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker    def compare_enabled_disabled(self, src):
28*da0073e9SAndroid Build Coastguard Worker        """
29*da0073e9SAndroid Build Coastguard Worker        Runs the script in `src` with PYTORCH_JIT enabled and disabled and
30*da0073e9SAndroid Build Coastguard Worker        compares their stdout for equality.
31*da0073e9SAndroid Build Coastguard Worker        """
32*da0073e9SAndroid Build Coastguard Worker        # Write `src` out to a temporary so our source inspection logic works
33*da0073e9SAndroid Build Coastguard Worker        # correctly.
34*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName() as fname:
35*da0073e9SAndroid Build Coastguard Worker            with open(fname, 'w') as f:
36*da0073e9SAndroid Build Coastguard Worker                f.write(src)
37*da0073e9SAndroid Build Coastguard Worker                with _jit_disabled():
38*da0073e9SAndroid Build Coastguard Worker                    out_disabled = subprocess.check_output([
39*da0073e9SAndroid Build Coastguard Worker                        sys.executable,
40*da0073e9SAndroid Build Coastguard Worker                        fname])
41*da0073e9SAndroid Build Coastguard Worker                out_enabled = subprocess.check_output([
42*da0073e9SAndroid Build Coastguard Worker                    sys.executable,
43*da0073e9SAndroid Build Coastguard Worker                    fname])
44*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out_disabled, out_enabled)
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker    def test_attribute(self):
47*da0073e9SAndroid Build Coastguard Worker        _program_string = """
48*da0073e9SAndroid Build Coastguard Workerimport torch
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Workerclass Foo(torch.jit.ScriptModule):
51*da0073e9SAndroid Build Coastguard Worker    def __init__(self, x):
52*da0073e9SAndroid Build Coastguard Worker        super().__init__()
53*da0073e9SAndroid Build Coastguard Worker        self.x = torch.jit.Attribute(x, torch.Tensor)
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker    def forward(self, input):
56*da0073e9SAndroid Build Coastguard Worker        return input
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Workers = Foo(torch.ones(2, 3))
59*da0073e9SAndroid Build Coastguard Workerprint(s.x)
60*da0073e9SAndroid Build Coastguard Worker"""
61*da0073e9SAndroid Build Coastguard Worker        self.compare_enabled_disabled(_program_string)
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker    def test_script_module_construction(self):
64*da0073e9SAndroid Build Coastguard Worker        _program_string = """
65*da0073e9SAndroid Build Coastguard Workerimport torch
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Workerclass AModule(torch.jit.ScriptModule):
68*da0073e9SAndroid Build Coastguard Worker    @torch.jit.script_method
69*da0073e9SAndroid Build Coastguard Worker    def forward(self, input):
70*da0073e9SAndroid Build Coastguard Worker        pass
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard WorkerAModule()
73*da0073e9SAndroid Build Coastguard Workerprint("Didn't throw exception")
74*da0073e9SAndroid Build Coastguard Worker"""
75*da0073e9SAndroid Build Coastguard Worker        self.compare_enabled_disabled(_program_string)
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker    def test_recursive_script(self):
78*da0073e9SAndroid Build Coastguard Worker        _program_string = """
79*da0073e9SAndroid Build Coastguard Workerimport torch
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Workerclass AModule(torch.nn.Module):
82*da0073e9SAndroid Build Coastguard Worker    def forward(self, input):
83*da0073e9SAndroid Build Coastguard Worker        pass
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Workersm = torch.jit.script(AModule())
86*da0073e9SAndroid Build Coastguard Workerprint("Didn't throw exception")
87*da0073e9SAndroid Build Coastguard Worker"""
88*da0073e9SAndroid Build Coastguard Worker        self.compare_enabled_disabled(_program_string)
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
91*da0073e9SAndroid Build Coastguard Worker    run_tests()
92