1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport sys 4*da0073e9SAndroid Build Coastguard Workerimport os 5*da0073e9SAndroid Build Coastguard Workerimport contextlib 6*da0073e9SAndroid Build Coastguard Workerimport subprocess 7*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TestCase, run_tests, TemporaryFileName 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager 11*da0073e9SAndroid Build Coastguard Workerdef _jit_disabled(): 12*da0073e9SAndroid Build Coastguard Worker cur_env = os.environ.get("PYTORCH_JIT", "1") 13*da0073e9SAndroid Build Coastguard Worker os.environ["PYTORCH_JIT"] = "0" 14*da0073e9SAndroid Build Coastguard Worker try: 15*da0073e9SAndroid Build Coastguard Worker yield 16*da0073e9SAndroid Build Coastguard Worker finally: 17*da0073e9SAndroid Build Coastguard Worker os.environ["PYTORCH_JIT"] = cur_env 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Workerclass TestJitDisabled(TestCase): 21*da0073e9SAndroid Build Coastguard Worker """ 22*da0073e9SAndroid Build Coastguard Worker These tests are separate from the rest of the JIT tests because we need 23*da0073e9SAndroid Build Coastguard Worker run a new subprocess and `import torch` with the correct environment 24*da0073e9SAndroid Build Coastguard Worker variables set. 25*da0073e9SAndroid Build Coastguard Worker """ 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker def compare_enabled_disabled(self, src): 28*da0073e9SAndroid Build Coastguard Worker """ 29*da0073e9SAndroid Build Coastguard Worker Runs the script in `src` with PYTORCH_JIT enabled and disabled and 30*da0073e9SAndroid Build Coastguard Worker compares their stdout for equality. 31*da0073e9SAndroid Build Coastguard Worker """ 32*da0073e9SAndroid Build Coastguard Worker # Write `src` out to a temporary so our source inspection logic works 33*da0073e9SAndroid Build Coastguard Worker # correctly. 34*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName() as fname: 35*da0073e9SAndroid Build Coastguard Worker with open(fname, 'w') as f: 36*da0073e9SAndroid Build Coastguard Worker f.write(src) 37*da0073e9SAndroid Build Coastguard Worker with _jit_disabled(): 38*da0073e9SAndroid Build Coastguard Worker out_disabled = subprocess.check_output([ 39*da0073e9SAndroid Build Coastguard Worker sys.executable, 40*da0073e9SAndroid Build Coastguard Worker fname]) 41*da0073e9SAndroid Build Coastguard Worker out_enabled = subprocess.check_output([ 42*da0073e9SAndroid Build Coastguard Worker sys.executable, 43*da0073e9SAndroid Build Coastguard Worker fname]) 44*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_disabled, out_enabled) 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker def test_attribute(self): 47*da0073e9SAndroid Build Coastguard Worker _program_string = """ 48*da0073e9SAndroid Build Coastguard Workerimport torch 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Workerclass Foo(torch.jit.ScriptModule): 51*da0073e9SAndroid Build Coastguard Worker def __init__(self, x): 52*da0073e9SAndroid Build Coastguard Worker super().__init__() 53*da0073e9SAndroid Build Coastguard Worker self.x = torch.jit.Attribute(x, torch.Tensor) 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 56*da0073e9SAndroid Build Coastguard Worker return input 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Workers = Foo(torch.ones(2, 3)) 59*da0073e9SAndroid Build Coastguard Workerprint(s.x) 60*da0073e9SAndroid Build Coastguard Worker""" 61*da0073e9SAndroid Build Coastguard Worker self.compare_enabled_disabled(_program_string) 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker def test_script_module_construction(self): 64*da0073e9SAndroid Build Coastguard Worker _program_string = """ 65*da0073e9SAndroid Build Coastguard Workerimport torch 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Workerclass AModule(torch.jit.ScriptModule): 68*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 69*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 70*da0073e9SAndroid Build Coastguard Worker pass 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard WorkerAModule() 73*da0073e9SAndroid Build Coastguard Workerprint("Didn't throw exception") 74*da0073e9SAndroid Build Coastguard Worker""" 75*da0073e9SAndroid Build Coastguard Worker self.compare_enabled_disabled(_program_string) 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker def test_recursive_script(self): 78*da0073e9SAndroid Build Coastguard Worker _program_string = """ 79*da0073e9SAndroid Build Coastguard Workerimport torch 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Workerclass AModule(torch.nn.Module): 82*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 83*da0073e9SAndroid Build Coastguard Worker pass 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Workersm = torch.jit.script(AModule()) 86*da0073e9SAndroid Build Coastguard Workerprint("Didn't throw exception") 87*da0073e9SAndroid Build Coastguard Worker""" 88*da0073e9SAndroid Build Coastguard Worker self.compare_enabled_disabled(_program_string) 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 91*da0073e9SAndroid Build Coastguard Worker run_tests() 92