1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import unittest 10 11import torch 12from executorch.exir import to_edge 13from executorch.exir.passes.const_prop_pass import ConstPropPass 14from executorch.exir.schema import Tensor, TensorList 15 16from executorch.exir.verification.interpreter import Interpreter 17from executorch.exir.verification.verifier import EXIREdgeDialectVerifier 18from torch._export.verifier import SpecViolationError 19from torch.export import export 20 21 22class WrapperModule(torch.nn.Module): 23 def __init__(self, fn): 24 super().__init__() 25 self.fn = fn 26 27 def forward(self, *args, **kwargs): 28 return self.fn(*args, **kwargs) 29 30 31class TestVerification(unittest.TestCase): 32 def test_constant_buffer(self) -> None: 33 def f(x: torch.Tensor) -> torch.Tensor: 34 return torch.ones(2) + x + torch.ones(2) 35 36 # Generate program 37 program = ( 38 to_edge(export(WrapperModule(f), (torch.randn(2),))) 39 .transform( 40 [ 41 ConstPropPass(), 42 ] 43 ) 44 .to_executorch() 45 ._emitter_output.program 46 ) 47 48 test = Interpreter(program) 49 for val_idx in range(len(test.execution_plan.values)): 50 val = test.execution_plan.values[val_idx].val 51 if not ( 52 isinstance(val, Tensor) and val.data_buffer_idx == 0 53 ) and not isinstance(val, TensorList): 54 test.load_value(val_idx) 55 vlist = test.get_value_list() 56 for e in vlist: 57 if isinstance(e, torch.Tensor): 58 self.assertTrue(torch.allclose(e, torch.ones(2))) 59 60 # asserting only 2 constant Tensors exist in value list 61 self.assertEqual(len([e for e in vlist if isinstance(e, torch.Tensor)]), 2) 62 63 def test_operator_list(self) -> None: 64 class Op1(torch.nn.Module): 65 def __init__(self) -> None: 66 super().__init__() 67 self.a = torch.ones(2, 2) 68 self.b = 2 * torch.ones(2, 2) 69 70 def forward(self, x: torch.Tensor) -> torch.Tensor: 71 for _ in range(10): 72 z = self.a * x # mul 73 y = z - self.b # sub 74 return y 75 76 class Op2(torch.nn.Module): 77 def __init__(self) -> None: 78 super().__init__() 79 self.a = torch.ones(2, 2) 80 self.b = 2 * torch.ones(2, 2) 81 82 def forward(self, x: torch.Tensor) -> torch.Tensor: 83 for _ in range(10): 84 z = self.a % x # remainder 85 y = z / self.b # div 86 z = z + z # add 87 return y + z 88 89 # Generate a program with Op1's operations (mul, sub) 90 model1 = Op1() 91 inputs = (torch.ones(2, 2),) 92 program = ( 93 to_edge(export(model1, inputs)).to_executorch()._emitter_output.program 94 ) 95 96 # Initialize and test Interpreter -- assert that the operators are same as above 97 test = Interpreter(program) 98 self.assertEqual( 99 set(test.get_operators_list()), 100 {torch.ops.aten.mul.out, torch.ops.aten.sub.out}, 101 ) 102 103 # Generate a program with Op2's operations (remainder, div, add_, add) 104 model2 = Op2() 105 inputs = (torch.ones(2, 2),) 106 program = ( 107 to_edge(export(model2, inputs)).to_executorch()._emitter_output.program 108 ) 109 110 # Initialize and test Interpreter -- assert that the operators are same as above 111 test = Interpreter(program) 112 self.assertEqual( 113 set(test.get_operators_list()), 114 { 115 torch.ops.aten.remainder.Tensor_out, 116 torch.ops.aten.div.out, 117 torch.ops.aten.add.out, 118 }, 119 ) 120 121 def test_verification(self) -> None: 122 class Op2(torch.nn.Module): 123 def __init__(self) -> None: 124 super().__init__() 125 self.a = torch.ones(2, 2) 126 self.b = 2 * torch.ones(2, 2) 127 128 def forward(self, x: torch.Tensor) -> torch.Tensor: 129 for _ in range(10): 130 z = self.a % x # remainder 131 y = z / self.b # div 132 z = z + z # add 133 return y + z 134 135 # Generate a program with Op2's operations (remainder, div, add) 136 model2 = Op2() 137 inputs = torch.ones(2, 2) 138 exec_prog = to_edge(export(model2, (inputs,))).to_executorch() 139 140 exported_prog = exec_prog.exported_program() 141 res = exported_prog.module()(inputs)[0] # noqa 142 # Verifiers are run internally in to_edge, export, and to_executorch. 143 # If we make it this far then no errors were thrown in verification 144 145 146class TestEdgeVerification(unittest.TestCase): 147 def test_edge_happy(self) -> None: 148 class TestModel(torch.nn.Module): 149 def __init__(self): 150 super().__init__() 151 self.register_buffer("a", torch.randn(1, 3, 100, 100)) 152 153 def forward(self, x): 154 b = self.a + x 155 return torch._to_cpu([b, x]) 156 157 m = TestModel() 158 egm = ( 159 to_edge( 160 export( 161 m, 162 (torch.randn(1, 3, 100, 100).to(dtype=torch.int),), 163 ) 164 ) 165 .exported_program() 166 .graph_module 167 ) 168 verifier = EXIREdgeDialectVerifier() 169 verifier(egm) 170 self.assertTrue(verifier.is_valid(egm)) 171 172 def test_edge_happy_with_optional_tensor_input(self) -> None: 173 class TestModel(torch.nn.Module): 174 def __init__(self): 175 super().__init__() 176 177 def forward(self, x, weight, bias): 178 # weight and bias here are optional tensor inputs. 179 return torch.group_norm(x, 4, weight, bias) 180 181 m = TestModel() 182 egm = ( 183 to_edge( 184 export( 185 m, 186 (torch.rand(16, 8, 32, 32), torch.rand(8), torch.rand(8)), 187 ) 188 ) 189 .exported_program() 190 .graph_module 191 ) 192 verifier = EXIREdgeDialectVerifier() 193 verifier(egm) 194 self.assertTrue(verifier.is_valid(egm)) 195 196 def test_edge_happy_with_empty_tensorlist_input(self) -> None: 197 class TestModel(torch.nn.Module): 198 def __init__(self): 199 super().__init__() 200 201 def forward(self, x): 202 return torch._to_cpu(x) 203 204 m = TestModel() 205 egm = ( 206 to_edge( 207 export( 208 m, 209 ([],), 210 ) 211 ) 212 .exported_program() 213 .graph_module 214 ) 215 verifier = EXIREdgeDialectVerifier() 216 verifier(egm) 217 self.assertTrue(verifier.is_valid(egm)) 218 219 def test_edge_sad(self) -> None: 220 class TestModel(torch.nn.Module): 221 def __init__(self): 222 super().__init__() 223 self.register_buffer("a", torch.randn(1, 3, 100, 100)) 224 225 def forward(self, x): 226 b = self.a + x 227 return torch._to_cpu([b, x]) 228 229 m = TestModel() 230 egm = export( 231 m, 232 (torch.randn(1, 3, 100, 100).to(dtype=torch.int),), 233 ).graph_module 234 verifier = EXIREdgeDialectVerifier() 235 with self.assertRaises(SpecViolationError): 236 verifier(egm) 237 238 def test_edge_happy_with_edge_ops(self) -> None: 239 class TestModel(torch.nn.Module): 240 def __init__(self): 241 super().__init__() 242 243 def forward(self, x): 244 return x + x 245 246 m = TestModel() 247 egm = ( 248 to_edge( 249 export( 250 m, 251 (torch.randn(1, 3, 100, 100).to(dtype=torch.int),), 252 ) 253 ) 254 .exported_program() 255 .graph_module 256 ) 257 verifier = EXIREdgeDialectVerifier() 258 verifier(egm) 259 self.assertTrue(verifier.is_valid(egm)) 260 261 def test_edge_sad_with_edge_ops(self) -> None: 262 # log_softmax only takes float or double Tensor 263 m = torch.nn.LogSoftmax(dim=1) 264 with self.assertRaises(SpecViolationError): 265 _ = ( 266 to_edge( 267 export( 268 m, 269 (torch.randn(1, 3, 100, 100).to(dtype=torch.bfloat16),), 270 ) 271 ) 272 .exported_program() 273 .graph_module 274 ) 275