1# Owner(s): ["oncall: export"] 2import unittest 3 4import torch 5from functorch.experimental import control_flow 6from torch import Tensor 7from torch._dynamo.eval_frame import is_dynamo_supported 8from torch._export.verifier import SpecViolationError, Verifier 9from torch.export import export 10from torch.export.exported_program import InputKind, InputSpec, TensorArgument 11from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase 12 13 14@unittest.skipIf(not is_dynamo_supported(), "dynamo isn't supported") 15class TestVerifier(TestCase): 16 def test_verifier_basic(self) -> None: 17 class Foo(torch.nn.Module): 18 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 19 return x + y 20 21 f = Foo() 22 23 ep = export(f, (torch.randn(100), torch.randn(100))) 24 25 verifier = Verifier() 26 verifier.check(ep) 27 28 def test_verifier_call_module(self) -> None: 29 class M(torch.nn.Module): 30 def __init__(self) -> None: 31 super().__init__() 32 self.linear = torch.nn.Linear(10, 10) 33 34 def forward(self, x: Tensor) -> Tensor: 35 return self.linear(x) 36 37 gm = torch.fx.symbolic_trace(M()) 38 39 verifier = Verifier() 40 with self.assertRaises(SpecViolationError): 41 verifier._check_graph_module(gm) 42 43 def test_verifier_no_functional(self) -> None: 44 class Foo(torch.nn.Module): 45 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 46 return x + y 47 48 f = Foo() 49 50 ep = export(f, (torch.randn(100), torch.randn(100))) 51 for node in ep.graph.nodes: 52 if node.target == torch.ops.aten.add.Tensor: 53 node.target = torch.ops.aten.add_.Tensor 54 55 verifier = Verifier() 56 with self.assertRaises(SpecViolationError): 57 verifier.check(ep) 58 59 @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") 60 def test_verifier_higher_order(self) -> None: 61 class Foo(torch.nn.Module): 62 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 63 def true_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 64 return x + y 65 66 def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 67 return x - y 68 69 return control_flow.cond(x.sum() > 2, true_fn, false_fn, [x, y]) 70 71 f = Foo() 72 73 ep = export(f, (torch.randn(3, 3), torch.randn(3, 3))) 74 75 verifier = Verifier() 76 verifier.check(ep) 77 78 @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") 79 def test_verifier_nested_invalid_module(self) -> None: 80 class Foo(torch.nn.Module): 81 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 82 def true_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 83 return x + y 84 85 def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 86 return x - y 87 88 return control_flow.cond(x.sum() > 2, true_fn, false_fn, [x, y]) 89 90 f = Foo() 91 92 ep = export(f, (torch.randn(3, 3), torch.randn(3, 3))) 93 for node in ep.graph_module.true_graph_0.graph.nodes: 94 if node.target == torch.ops.aten.add.Tensor: 95 node.target = torch.ops.aten.add_.Tensor 96 97 verifier = Verifier() 98 with self.assertRaises(SpecViolationError): 99 verifier.check(ep) 100 101 def test_ep_verifier_basic(self) -> None: 102 class M(torch.nn.Module): 103 def __init__(self) -> None: 104 super().__init__() 105 self.linear = torch.nn.Linear(10, 10) 106 107 def forward(self, x: Tensor) -> Tensor: 108 return self.linear(x) 109 110 ep = export(M(), (torch.randn(10, 10),)) 111 ep.validate() 112 113 def test_ep_verifier_invalid_param(self) -> None: 114 class M(torch.nn.Module): 115 def __init__(self) -> None: 116 super().__init__() 117 self.register_parameter( 118 name="a", param=torch.nn.Parameter(torch.randn(100)) 119 ) 120 121 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 122 return x + y + self.a 123 124 ep = export(M(), (torch.randn(100), torch.randn(100))) 125 126 # Parameter doesn't exist in the state dict 127 ep.graph_signature.input_specs[0] = InputSpec( 128 kind=InputKind.PARAMETER, arg=TensorArgument(name="p_a"), target="bad_param" 129 ) 130 with self.assertRaisesRegex(SpecViolationError, "not in the state dict"): 131 ep.validate() 132 133 # Add non-torch.nn.Parameter parameter to the state dict 134 ep.state_dict["bad_param"] = torch.randn(100) 135 with self.assertRaisesRegex( 136 SpecViolationError, "not an instance of torch.nn.Parameter" 137 ): 138 ep.validate() 139 140 def test_ep_verifier_invalid_buffer(self) -> None: 141 class M(torch.nn.Module): 142 def __init__(self) -> None: 143 super().__init__() 144 self.a = torch.tensor(3.0) 145 146 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 147 return x + y + self.a 148 149 ep = export(M(), (torch.randn(100), torch.randn(100))) 150 151 # Buffer doesn't exist in the state dict 152 ep.graph_signature.input_specs[0] = InputSpec( 153 kind=InputKind.BUFFER, 154 arg=TensorArgument(name="c_a"), 155 target="bad_buffer", 156 persistent=True, 157 ) 158 with self.assertRaisesRegex(SpecViolationError, "not in the state dict"): 159 ep.validate() 160 161 def test_ep_verifier_buffer_mutate(self) -> None: 162 class M(torch.nn.Module): 163 def __init__(self) -> None: 164 super().__init__() 165 166 self.my_parameter = torch.nn.Parameter(torch.tensor(2.0)) 167 168 self.my_buffer1 = torch.nn.Buffer(torch.tensor(3.0)) 169 self.my_buffer2 = torch.nn.Buffer(torch.tensor(4.0)) 170 171 def forward(self, x1, x2): 172 # Use the parameter, buffers, and both inputs in the forward method 173 output = ( 174 x1 + self.my_parameter 175 ) * self.my_buffer1 + x2 * self.my_buffer2 176 177 # Mutate one of the buffers (e.g., increment it by 1) 178 self.my_buffer2.add_(1.0) 179 return output 180 181 ep = export(M(), (torch.tensor(5.0), torch.tensor(6.0))) 182 ep.validate() 183 184 def test_ep_verifier_invalid_output(self) -> None: 185 class M(torch.nn.Module): 186 def __init__(self) -> None: 187 super().__init__() 188 189 self.my_parameter = torch.nn.Parameter(torch.tensor(2.0)) 190 191 self.my_buffer1 = torch.nn.Buffer(torch.tensor(3.0)) 192 self.my_buffer2 = torch.nn.Buffer(torch.tensor(4.0)) 193 194 def forward(self, x1, x2): 195 # Use the parameter, buffers, and both inputs in the forward method 196 output = ( 197 x1 + self.my_parameter 198 ) * self.my_buffer1 + x2 * self.my_buffer2 199 200 # Mutate one of the buffers (e.g., increment it by 1) 201 self.my_buffer2.add_(1.0) 202 return output 203 204 ep = export(M(), (torch.tensor(5.0), torch.tensor(6.0))) 205 206 output_node = list(ep.graph.nodes)[-1] 207 output_node.args = ( 208 ( 209 output_node.args[0][0], 210 next(iter(ep.graph.nodes)), 211 output_node.args[0][1], 212 ), 213 ) 214 215 with self.assertRaisesRegex(SpecViolationError, "Number of output nodes"): 216 ep.validate() 217 218 219if __name__ == "__main__": 220 run_tests() 221