xref: /aosp_15_r20/external/pytorch/test/jit/test_warn.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport io
4*da0073e9SAndroid Build Coastguard Workerimport os
5*da0073e9SAndroid Build Coastguard Workerimport sys
6*da0073e9SAndroid Build Coastguard Workerimport warnings
7*da0073e9SAndroid Build Coastguard Workerfrom contextlib import redirect_stderr
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerimport torch
10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable
14*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
15*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir)
16*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
20*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError(
21*da0073e9SAndroid Build Coastguard Worker        "This test file is not meant to be run directly, use:\n\n"
22*da0073e9SAndroid Build Coastguard Worker        "\tpython test/test_jit.py TESTNAME\n\n"
23*da0073e9SAndroid Build Coastguard Worker        "instead."
24*da0073e9SAndroid Build Coastguard Worker    )
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Workerclass TestWarn(JitTestCase):
28*da0073e9SAndroid Build Coastguard Worker    def test_warn(self):
29*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
30*da0073e9SAndroid Build Coastguard Worker        def fn():
31*da0073e9SAndroid Build Coastguard Worker            warnings.warn("I am warning you")
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker        f = io.StringIO()
34*da0073e9SAndroid Build Coastguard Worker        with redirect_stderr(f):
35*da0073e9SAndroid Build Coastguard Worker            fn()
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count(
38*da0073e9SAndroid Build Coastguard Worker            str="UserWarning: I am warning you", count=1, exactly=True
39*da0073e9SAndroid Build Coastguard Worker        ).run(f.getvalue())
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker    def test_warn_only_once(self):
42*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
43*da0073e9SAndroid Build Coastguard Worker        def fn():
44*da0073e9SAndroid Build Coastguard Worker            for _ in range(10):
45*da0073e9SAndroid Build Coastguard Worker                warnings.warn("I am warning you")
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker        f = io.StringIO()
48*da0073e9SAndroid Build Coastguard Worker        with redirect_stderr(f):
49*da0073e9SAndroid Build Coastguard Worker            fn()
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count(
52*da0073e9SAndroid Build Coastguard Worker            str="UserWarning: I am warning you", count=1, exactly=True
53*da0073e9SAndroid Build Coastguard Worker        ).run(f.getvalue())
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker    def test_warn_only_once_in_loop_func(self):
56*da0073e9SAndroid Build Coastguard Worker        def w():
57*da0073e9SAndroid Build Coastguard Worker            warnings.warn("I am warning you")
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
60*da0073e9SAndroid Build Coastguard Worker        def fn():
61*da0073e9SAndroid Build Coastguard Worker            for _ in range(10):
62*da0073e9SAndroid Build Coastguard Worker                w()
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker        f = io.StringIO()
65*da0073e9SAndroid Build Coastguard Worker        with redirect_stderr(f):
66*da0073e9SAndroid Build Coastguard Worker            fn()
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count(
69*da0073e9SAndroid Build Coastguard Worker            str="UserWarning: I am warning you", count=1, exactly=True
70*da0073e9SAndroid Build Coastguard Worker        ).run(f.getvalue())
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker    def test_warn_once_per_func(self):
73*da0073e9SAndroid Build Coastguard Worker        def w1():
74*da0073e9SAndroid Build Coastguard Worker            warnings.warn("I am warning you")
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker        def w2():
77*da0073e9SAndroid Build Coastguard Worker            warnings.warn("I am warning you")
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
80*da0073e9SAndroid Build Coastguard Worker        def fn():
81*da0073e9SAndroid Build Coastguard Worker            w1()
82*da0073e9SAndroid Build Coastguard Worker            w2()
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker        f = io.StringIO()
85*da0073e9SAndroid Build Coastguard Worker        with redirect_stderr(f):
86*da0073e9SAndroid Build Coastguard Worker            fn()
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count(
89*da0073e9SAndroid Build Coastguard Worker            str="UserWarning: I am warning you", count=2, exactly=True
90*da0073e9SAndroid Build Coastguard Worker        ).run(f.getvalue())
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker    def test_warn_once_per_func_in_loop(self):
93*da0073e9SAndroid Build Coastguard Worker        def w1():
94*da0073e9SAndroid Build Coastguard Worker            warnings.warn("I am warning you")
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker        def w2():
97*da0073e9SAndroid Build Coastguard Worker            warnings.warn("I am warning you")
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
100*da0073e9SAndroid Build Coastguard Worker        def fn():
101*da0073e9SAndroid Build Coastguard Worker            for _ in range(10):
102*da0073e9SAndroid Build Coastguard Worker                w1()
103*da0073e9SAndroid Build Coastguard Worker                w2()
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker        f = io.StringIO()
106*da0073e9SAndroid Build Coastguard Worker        with redirect_stderr(f):
107*da0073e9SAndroid Build Coastguard Worker            fn()
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count(
110*da0073e9SAndroid Build Coastguard Worker            str="UserWarning: I am warning you", count=2, exactly=True
111*da0073e9SAndroid Build Coastguard Worker        ).run(f.getvalue())
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker    def test_warn_multiple_calls_multiple_warnings(self):
114*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
115*da0073e9SAndroid Build Coastguard Worker        def fn():
116*da0073e9SAndroid Build Coastguard Worker            warnings.warn("I am warning you")
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker        f = io.StringIO()
119*da0073e9SAndroid Build Coastguard Worker        with redirect_stderr(f):
120*da0073e9SAndroid Build Coastguard Worker            fn()
121*da0073e9SAndroid Build Coastguard Worker            fn()
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count(
124*da0073e9SAndroid Build Coastguard Worker            str="UserWarning: I am warning you", count=2, exactly=True
125*da0073e9SAndroid Build Coastguard Worker        ).run(f.getvalue())
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker    def test_warn_multiple_calls_same_func_diff_stack(self):
128*da0073e9SAndroid Build Coastguard Worker        def warn(caller: str):
129*da0073e9SAndroid Build Coastguard Worker            warnings.warn("I am warning you from " + caller)
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
132*da0073e9SAndroid Build Coastguard Worker        def foo():
133*da0073e9SAndroid Build Coastguard Worker            warn("foo")
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
136*da0073e9SAndroid Build Coastguard Worker        def bar():
137*da0073e9SAndroid Build Coastguard Worker            warn("bar")
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Worker        f = io.StringIO()
140*da0073e9SAndroid Build Coastguard Worker        with redirect_stderr(f):
141*da0073e9SAndroid Build Coastguard Worker            foo()
142*da0073e9SAndroid Build Coastguard Worker            bar()
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count(
145*da0073e9SAndroid Build Coastguard Worker            str="UserWarning: I am warning you from foo", count=1, exactly=True
146*da0073e9SAndroid Build Coastguard Worker        ).check_count(
147*da0073e9SAndroid Build Coastguard Worker            str="UserWarning: I am warning you from bar", count=1, exactly=True
148*da0073e9SAndroid Build Coastguard Worker        ).run(
149*da0073e9SAndroid Build Coastguard Worker            f.getvalue()
150*da0073e9SAndroid Build Coastguard Worker        )
151