1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport inspect 4*da0073e9SAndroid Build Coastguard Workerimport io 5*da0073e9SAndroid Build Coastguard Workerimport os 6*da0073e9SAndroid Build Coastguard Workerimport tempfile 7*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import patch 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerimport torch 10*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.test_case import run_tests, TestCase 11*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import CompileCounter 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workerclass ToyModel(torch.nn.Module): 15*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 16*da0073e9SAndroid Build Coastguard Worker super().__init__() 17*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(10, 10) 18*da0073e9SAndroid Build Coastguard Worker self.relu = torch.nn.ReLU() 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 21*da0073e9SAndroid Build Coastguard Worker return self.relu(self.linear(x)) 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Workerclass InPlaceCompilationTests(TestCase): 25*da0073e9SAndroid Build Coastguard Worker def test_compilation(self): 26*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 27*da0073e9SAndroid Build Coastguard Worker model = ToyModel() 28*da0073e9SAndroid Build Coastguard Worker cnt = CompileCounter() 29*da0073e9SAndroid Build Coastguard Worker model.compile(backend=cnt) 30*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10) 31*da0073e9SAndroid Build Coastguard Worker model(x) 32*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker def test_overwrite_call_impl(self): 35*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 36*da0073e9SAndroid Build Coastguard Worker model = ToyModel() 37*da0073e9SAndroid Build Coastguard Worker self.assertTrue(model._compiled_call_impl is None) 38*da0073e9SAndroid Build Coastguard Worker model.compile() 39*da0073e9SAndroid Build Coastguard Worker self.assertTrue(model._compiled_call_impl is not None) 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker def test_save(self): 42*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 43*da0073e9SAndroid Build Coastguard Worker model = ToyModel() 44*da0073e9SAndroid Build Coastguard Worker model.compile() 45*da0073e9SAndroid Build Coastguard Worker model(torch.randn(1, 10)) 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker with tempfile.TemporaryDirectory() as tmpdirname: 48*da0073e9SAndroid Build Coastguard Worker torch.save(model, os.path.join(tmpdirname, "model.pt")) 49*da0073e9SAndroid Build Coastguard Worker loaded_model = torch.load(os.path.join(tmpdirname, "model.pt")) 50*da0073e9SAndroid Build Coastguard Worker loaded_model(torch.randn(1, 10)) 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker def test_state_dict_save(self): 53*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 54*da0073e9SAndroid Build Coastguard Worker model = ToyModel() 55*da0073e9SAndroid Build Coastguard Worker model.compile() 56*da0073e9SAndroid Build Coastguard Worker model(torch.randn(1, 10)) 57*da0073e9SAndroid Build Coastguard Worker with tempfile.TemporaryDirectory() as tmpdirname: 58*da0073e9SAndroid Build Coastguard Worker torch.save(model.state_dict(), os.path.join(tmpdirname, "model.pt")) 59*da0073e9SAndroid Build Coastguard Worker loaded_model = ToyModel() 60*da0073e9SAndroid Build Coastguard Worker loaded_model.load_state_dict( 61*da0073e9SAndroid Build Coastguard Worker torch.load(os.path.join(tmpdirname, "model.pt")) 62*da0073e9SAndroid Build Coastguard Worker ) 63*da0073e9SAndroid Build Coastguard Worker loaded_model(torch.randn(1, 10)) 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker def test_jit_save(self): 66*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 67*da0073e9SAndroid Build Coastguard Worker model = ToyModel() 68*da0073e9SAndroid Build Coastguard Worker model.compile() 69*da0073e9SAndroid Build Coastguard Worker model(torch.randn(1, 10)) 70*da0073e9SAndroid Build Coastguard Worker scripted_model = torch.jit.script(model) 71*da0073e9SAndroid Build Coastguard Worker with tempfile.TemporaryDirectory() as tmpdirname: 72*da0073e9SAndroid Build Coastguard Worker torch.jit.save(scripted_model, os.path.join(tmpdirname, "model.pt")) 73*da0073e9SAndroid Build Coastguard Worker loaded_model = torch.jit.load(os.path.join(tmpdirname, "model.pt")) 74*da0073e9SAndroid Build Coastguard Worker loaded_model(torch.randn(1, 10)) 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker def test_compilation_callback(self): 77*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.on_compile_start 80*da0073e9SAndroid Build Coastguard Worker def start_callback(): 81*da0073e9SAndroid Build Coastguard Worker print("Compilation started.") 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.on_compile_end 84*da0073e9SAndroid Build Coastguard Worker def end_callback(): 85*da0073e9SAndroid Build Coastguard Worker print("Compilation ended.") 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker mod = ToyModel() 88*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10) 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: 91*da0073e9SAndroid Build Coastguard Worker opt_mod = torch.compile(backend="eager", fullgraph=True)(mod) 92*da0073e9SAndroid Build Coastguard Worker opt_mod(x) 93*da0073e9SAndroid Build Coastguard Worker printed_output = mock_stdout.getvalue().strip() 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker self.assertEqual(printed_output, "Compilation started.\nCompilation ended.") 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker def test_compile_eager_options(self): 98*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", options={"foo": 2}) 99*da0073e9SAndroid Build Coastguard Worker def f(x): 100*da0073e9SAndroid Build Coastguard Worker return x + x 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker f(torch.randn(3)) 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="aot_eager", options={"foo": 2}) 105*da0073e9SAndroid Build Coastguard Worker def g(x): 106*da0073e9SAndroid Build Coastguard Worker return x + x 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker g(torch.randn(3)) 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker def test_compilation_callback_with_graph_break(self): 111*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 112*da0073e9SAndroid Build Coastguard Worker counter = 0 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.on_compile_start 115*da0073e9SAndroid Build Coastguard Worker def start_callback(): 116*da0073e9SAndroid Build Coastguard Worker nonlocal counter 117*da0073e9SAndroid Build Coastguard Worker counter += 1 118*da0073e9SAndroid Build Coastguard Worker print(f"Counter = {counter}") 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.on_compile_end 121*da0073e9SAndroid Build Coastguard Worker def end_callback(): 122*da0073e9SAndroid Build Coastguard Worker nonlocal counter 123*da0073e9SAndroid Build Coastguard Worker counter += 1 124*da0073e9SAndroid Build Coastguard Worker print(f"Counter = {counter}") 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 127*da0073e9SAndroid Build Coastguard Worker def fn(x): 128*da0073e9SAndroid Build Coastguard Worker x = x + 1 129*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 130*da0073e9SAndroid Build Coastguard Worker return torch.sin(x) 131*da0073e9SAndroid Build Coastguard Worker 132*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10) 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: 135*da0073e9SAndroid Build Coastguard Worker fn(x) 136*da0073e9SAndroid Build Coastguard Worker printed_output = mock_stdout.getvalue().strip() 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 139*da0073e9SAndroid Build Coastguard Worker printed_output, "Counter = 1\nCounter = 2\nCounter = 3\nCounter = 4" 140*da0073e9SAndroid Build Coastguard Worker ) 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker# The private variants of the below functions are extensively tested 144*da0073e9SAndroid Build Coastguard Worker# So as long as the signatures match we're good 145*da0073e9SAndroid Build Coastguard Workerclass PublicTorchCompilerTests(TestCase): 146*da0073e9SAndroid Build Coastguard Worker def check_signature(self, public_fn_name, private_fn_name, private_namespace): 147*da0073e9SAndroid Build Coastguard Worker public_fn = getattr(torch.compiler, public_fn_name) 148*da0073e9SAndroid Build Coastguard Worker private_fn = getattr(private_namespace, private_fn_name) 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker public_sig = inspect.signature(public_fn) 151*da0073e9SAndroid Build Coastguard Worker private_sig = inspect.signature(private_fn) 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 154*da0073e9SAndroid Build Coastguard Worker public_sig, 155*da0073e9SAndroid Build Coastguard Worker private_sig, 156*da0073e9SAndroid Build Coastguard Worker f"Signatures do not match for function {public_fn_name}() \n Public: {public_sig} \n Private: {private_sig}", 157*da0073e9SAndroid Build Coastguard Worker ) 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker def test_dynamo_signatures(self): 160*da0073e9SAndroid Build Coastguard Worker function_names = [ 161*da0073e9SAndroid Build Coastguard Worker "reset", 162*da0073e9SAndroid Build Coastguard Worker "allow_in_graph", 163*da0073e9SAndroid Build Coastguard Worker "list_backends", 164*da0073e9SAndroid Build Coastguard Worker "assume_constant_result", 165*da0073e9SAndroid Build Coastguard Worker "disable", 166*da0073e9SAndroid Build Coastguard Worker ] 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker for fn_name in function_names: 169*da0073e9SAndroid Build Coastguard Worker self.check_signature(fn_name, fn_name, torch._dynamo) 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 173*da0073e9SAndroid Build Coastguard Worker run_tests() 174