1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport io 4*da0073e9SAndroid Build Coastguard Workerimport os 5*da0073e9SAndroid Build Coastguard Workerimport sys 6*da0073e9SAndroid Build Coastguard Workerimport warnings 7*da0073e9SAndroid Build Coastguard Workerfrom contextlib import redirect_stderr 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerimport torch 10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable 14*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 15*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir) 16*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 20*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 21*da0073e9SAndroid Build Coastguard Worker "This test file is not meant to be run directly, use:\n\n" 22*da0073e9SAndroid Build Coastguard Worker "\tpython test/test_jit.py TESTNAME\n\n" 23*da0073e9SAndroid Build Coastguard Worker "instead." 24*da0073e9SAndroid Build Coastguard Worker ) 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Workerclass TestWarn(JitTestCase): 28*da0073e9SAndroid Build Coastguard Worker def test_warn(self): 29*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 30*da0073e9SAndroid Build Coastguard Worker def fn(): 31*da0073e9SAndroid Build Coastguard Worker warnings.warn("I am warning you") 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker f = io.StringIO() 34*da0073e9SAndroid Build Coastguard Worker with redirect_stderr(f): 35*da0073e9SAndroid Build Coastguard Worker fn() 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count( 38*da0073e9SAndroid Build Coastguard Worker str="UserWarning: I am warning you", count=1, exactly=True 39*da0073e9SAndroid Build Coastguard Worker ).run(f.getvalue()) 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker def test_warn_only_once(self): 42*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 43*da0073e9SAndroid Build Coastguard Worker def fn(): 44*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 45*da0073e9SAndroid Build Coastguard Worker warnings.warn("I am warning you") 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker f = io.StringIO() 48*da0073e9SAndroid Build Coastguard Worker with redirect_stderr(f): 49*da0073e9SAndroid Build Coastguard Worker fn() 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count( 52*da0073e9SAndroid Build Coastguard Worker str="UserWarning: I am warning you", count=1, exactly=True 53*da0073e9SAndroid Build Coastguard Worker ).run(f.getvalue()) 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker def test_warn_only_once_in_loop_func(self): 56*da0073e9SAndroid Build Coastguard Worker def w(): 57*da0073e9SAndroid Build Coastguard Worker warnings.warn("I am warning you") 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 60*da0073e9SAndroid Build Coastguard Worker def fn(): 61*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 62*da0073e9SAndroid Build Coastguard Worker w() 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker f = io.StringIO() 65*da0073e9SAndroid Build Coastguard Worker with redirect_stderr(f): 66*da0073e9SAndroid Build Coastguard Worker fn() 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count( 69*da0073e9SAndroid Build Coastguard Worker str="UserWarning: I am warning you", count=1, exactly=True 70*da0073e9SAndroid Build Coastguard Worker ).run(f.getvalue()) 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker def test_warn_once_per_func(self): 73*da0073e9SAndroid Build Coastguard Worker def w1(): 74*da0073e9SAndroid Build Coastguard Worker warnings.warn("I am warning you") 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker def w2(): 77*da0073e9SAndroid Build Coastguard Worker warnings.warn("I am warning you") 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 80*da0073e9SAndroid Build Coastguard Worker def fn(): 81*da0073e9SAndroid Build Coastguard Worker w1() 82*da0073e9SAndroid Build Coastguard Worker w2() 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker f = io.StringIO() 85*da0073e9SAndroid Build Coastguard Worker with redirect_stderr(f): 86*da0073e9SAndroid Build Coastguard Worker fn() 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count( 89*da0073e9SAndroid Build Coastguard Worker str="UserWarning: I am warning you", count=2, exactly=True 90*da0073e9SAndroid Build Coastguard Worker ).run(f.getvalue()) 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker def test_warn_once_per_func_in_loop(self): 93*da0073e9SAndroid Build Coastguard Worker def w1(): 94*da0073e9SAndroid Build Coastguard Worker warnings.warn("I am warning you") 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker def w2(): 97*da0073e9SAndroid Build Coastguard Worker warnings.warn("I am warning you") 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 100*da0073e9SAndroid Build Coastguard Worker def fn(): 101*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 102*da0073e9SAndroid Build Coastguard Worker w1() 103*da0073e9SAndroid Build Coastguard Worker w2() 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker f = io.StringIO() 106*da0073e9SAndroid Build Coastguard Worker with redirect_stderr(f): 107*da0073e9SAndroid Build Coastguard Worker fn() 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count( 110*da0073e9SAndroid Build Coastguard Worker str="UserWarning: I am warning you", count=2, exactly=True 111*da0073e9SAndroid Build Coastguard Worker ).run(f.getvalue()) 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker def test_warn_multiple_calls_multiple_warnings(self): 114*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 115*da0073e9SAndroid Build Coastguard Worker def fn(): 116*da0073e9SAndroid Build Coastguard Worker warnings.warn("I am warning you") 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker f = io.StringIO() 119*da0073e9SAndroid Build Coastguard Worker with redirect_stderr(f): 120*da0073e9SAndroid Build Coastguard Worker fn() 121*da0073e9SAndroid Build Coastguard Worker fn() 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count( 124*da0073e9SAndroid Build Coastguard Worker str="UserWarning: I am warning you", count=2, exactly=True 125*da0073e9SAndroid Build Coastguard Worker ).run(f.getvalue()) 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker def test_warn_multiple_calls_same_func_diff_stack(self): 128*da0073e9SAndroid Build Coastguard Worker def warn(caller: str): 129*da0073e9SAndroid Build Coastguard Worker warnings.warn("I am warning you from " + caller) 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 132*da0073e9SAndroid Build Coastguard Worker def foo(): 133*da0073e9SAndroid Build Coastguard Worker warn("foo") 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 136*da0073e9SAndroid Build Coastguard Worker def bar(): 137*da0073e9SAndroid Build Coastguard Worker warn("bar") 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Worker f = io.StringIO() 140*da0073e9SAndroid Build Coastguard Worker with redirect_stderr(f): 141*da0073e9SAndroid Build Coastguard Worker foo() 142*da0073e9SAndroid Build Coastguard Worker bar() 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count( 145*da0073e9SAndroid Build Coastguard Worker str="UserWarning: I am warning you from foo", count=1, exactly=True 146*da0073e9SAndroid Build Coastguard Worker ).check_count( 147*da0073e9SAndroid Build Coastguard Worker str="UserWarning: I am warning you from bar", count=1, exactly=True 148*da0073e9SAndroid Build Coastguard Worker ).run( 149*da0073e9SAndroid Build Coastguard Worker f.getvalue() 150*da0073e9SAndroid Build Coastguard Worker ) 151