1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-unsafe 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport unittest 10*523fa7a6SAndroid Build Coastguard Worker 11*523fa7a6SAndroid Build Coastguard Workerimport torch 12*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import to_edge 13*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.const_prop_pass import ConstPropPass 14*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.schema import Tensor, TensorList 15*523fa7a6SAndroid Build Coastguard Worker 16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.verification.interpreter import Interpreter 17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.verification.verifier import EXIREdgeDialectVerifier 18*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.verifier import SpecViolationError 19*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import export 20*523fa7a6SAndroid Build Coastguard Worker 21*523fa7a6SAndroid Build Coastguard Worker 22*523fa7a6SAndroid Build Coastguard Workerclass WrapperModule(torch.nn.Module): 23*523fa7a6SAndroid Build Coastguard Worker def __init__(self, fn): 24*523fa7a6SAndroid Build Coastguard Worker super().__init__() 25*523fa7a6SAndroid Build Coastguard Worker self.fn = fn 26*523fa7a6SAndroid Build Coastguard Worker 27*523fa7a6SAndroid Build Coastguard Worker def forward(self, *args, **kwargs): 28*523fa7a6SAndroid Build Coastguard Worker return self.fn(*args, **kwargs) 29*523fa7a6SAndroid Build Coastguard Worker 30*523fa7a6SAndroid Build Coastguard Worker 31*523fa7a6SAndroid Build Coastguard Workerclass TestVerification(unittest.TestCase): 32*523fa7a6SAndroid Build Coastguard Worker def test_constant_buffer(self) -> None: 33*523fa7a6SAndroid Build Coastguard Worker def f(x: torch.Tensor) -> torch.Tensor: 34*523fa7a6SAndroid Build Coastguard Worker return torch.ones(2) + x + torch.ones(2) 35*523fa7a6SAndroid Build Coastguard Worker 36*523fa7a6SAndroid Build Coastguard Worker # Generate program 37*523fa7a6SAndroid Build Coastguard Worker program = ( 38*523fa7a6SAndroid Build Coastguard Worker to_edge(export(WrapperModule(f), (torch.randn(2),))) 39*523fa7a6SAndroid Build Coastguard Worker .transform( 40*523fa7a6SAndroid Build Coastguard Worker [ 41*523fa7a6SAndroid Build Coastguard Worker ConstPropPass(), 42*523fa7a6SAndroid Build Coastguard Worker ] 43*523fa7a6SAndroid Build Coastguard Worker ) 44*523fa7a6SAndroid Build Coastguard Worker .to_executorch() 45*523fa7a6SAndroid Build Coastguard Worker ._emitter_output.program 46*523fa7a6SAndroid Build Coastguard Worker ) 47*523fa7a6SAndroid Build Coastguard Worker 48*523fa7a6SAndroid Build Coastguard Worker test = Interpreter(program) 49*523fa7a6SAndroid Build Coastguard Worker for val_idx in range(len(test.execution_plan.values)): 50*523fa7a6SAndroid Build Coastguard Worker val = test.execution_plan.values[val_idx].val 51*523fa7a6SAndroid Build Coastguard Worker if not ( 52*523fa7a6SAndroid Build Coastguard Worker isinstance(val, Tensor) and val.data_buffer_idx == 0 53*523fa7a6SAndroid Build Coastguard Worker ) and not isinstance(val, TensorList): 54*523fa7a6SAndroid Build Coastguard Worker test.load_value(val_idx) 55*523fa7a6SAndroid Build Coastguard Worker vlist = test.get_value_list() 56*523fa7a6SAndroid Build Coastguard Worker for e in vlist: 57*523fa7a6SAndroid Build Coastguard Worker if isinstance(e, torch.Tensor): 58*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(e, torch.ones(2))) 59*523fa7a6SAndroid Build Coastguard Worker 60*523fa7a6SAndroid Build Coastguard Worker # asserting only 2 constant Tensors exist in value list 61*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(len([e for e in vlist if isinstance(e, torch.Tensor)]), 2) 62*523fa7a6SAndroid Build Coastguard Worker 63*523fa7a6SAndroid Build Coastguard Worker def test_operator_list(self) -> None: 64*523fa7a6SAndroid Build Coastguard Worker class Op1(torch.nn.Module): 65*523fa7a6SAndroid Build Coastguard Worker def __init__(self) -> None: 66*523fa7a6SAndroid Build Coastguard Worker super().__init__() 67*523fa7a6SAndroid Build Coastguard Worker self.a = torch.ones(2, 2) 68*523fa7a6SAndroid Build Coastguard Worker self.b = 2 * torch.ones(2, 2) 69*523fa7a6SAndroid Build Coastguard Worker 70*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 71*523fa7a6SAndroid Build Coastguard Worker for _ in range(10): 72*523fa7a6SAndroid Build Coastguard Worker z = self.a * x # mul 73*523fa7a6SAndroid Build Coastguard Worker y = z - self.b # sub 74*523fa7a6SAndroid Build Coastguard Worker return y 75*523fa7a6SAndroid Build Coastguard Worker 76*523fa7a6SAndroid Build Coastguard Worker class Op2(torch.nn.Module): 77*523fa7a6SAndroid Build Coastguard Worker def __init__(self) -> None: 78*523fa7a6SAndroid Build Coastguard Worker super().__init__() 79*523fa7a6SAndroid Build Coastguard Worker self.a = torch.ones(2, 2) 80*523fa7a6SAndroid Build Coastguard Worker self.b = 2 * torch.ones(2, 2) 81*523fa7a6SAndroid Build Coastguard Worker 82*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 83*523fa7a6SAndroid Build Coastguard Worker for _ in range(10): 84*523fa7a6SAndroid Build Coastguard Worker z = self.a % x # remainder 85*523fa7a6SAndroid Build Coastguard Worker y = z / self.b # div 86*523fa7a6SAndroid Build Coastguard Worker z = z + z # add 87*523fa7a6SAndroid Build Coastguard Worker return y + z 88*523fa7a6SAndroid Build Coastguard Worker 89*523fa7a6SAndroid Build Coastguard Worker # Generate a program with Op1's operations (mul, sub) 90*523fa7a6SAndroid Build Coastguard Worker model1 = Op1() 91*523fa7a6SAndroid Build Coastguard Worker inputs = (torch.ones(2, 2),) 92*523fa7a6SAndroid Build Coastguard Worker program = ( 93*523fa7a6SAndroid Build Coastguard Worker to_edge(export(model1, inputs)).to_executorch()._emitter_output.program 94*523fa7a6SAndroid Build Coastguard Worker ) 95*523fa7a6SAndroid Build Coastguard Worker 96*523fa7a6SAndroid Build Coastguard Worker # Initialize and test Interpreter -- assert that the operators are same as above 97*523fa7a6SAndroid Build Coastguard Worker test = Interpreter(program) 98*523fa7a6SAndroid Build Coastguard Worker self.assertEqual( 99*523fa7a6SAndroid Build Coastguard Worker set(test.get_operators_list()), 100*523fa7a6SAndroid Build Coastguard Worker {torch.ops.aten.mul.out, torch.ops.aten.sub.out}, 101*523fa7a6SAndroid Build Coastguard Worker ) 102*523fa7a6SAndroid Build Coastguard Worker 103*523fa7a6SAndroid Build Coastguard Worker # Generate a program with Op2's operations (remainder, div, add_, add) 104*523fa7a6SAndroid Build Coastguard Worker model2 = Op2() 105*523fa7a6SAndroid Build Coastguard Worker inputs = (torch.ones(2, 2),) 106*523fa7a6SAndroid Build Coastguard Worker program = ( 107*523fa7a6SAndroid Build Coastguard Worker to_edge(export(model2, inputs)).to_executorch()._emitter_output.program 108*523fa7a6SAndroid Build Coastguard Worker ) 109*523fa7a6SAndroid Build Coastguard Worker 110*523fa7a6SAndroid Build Coastguard Worker # Initialize and test Interpreter -- assert that the operators are same as above 111*523fa7a6SAndroid Build Coastguard Worker test = Interpreter(program) 112*523fa7a6SAndroid Build Coastguard Worker self.assertEqual( 113*523fa7a6SAndroid Build Coastguard Worker set(test.get_operators_list()), 114*523fa7a6SAndroid Build Coastguard Worker { 115*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.remainder.Tensor_out, 116*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.div.out, 117*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.add.out, 118*523fa7a6SAndroid Build Coastguard Worker }, 119*523fa7a6SAndroid Build Coastguard Worker ) 120*523fa7a6SAndroid Build Coastguard Worker 121*523fa7a6SAndroid Build Coastguard Worker def test_verification(self) -> None: 122*523fa7a6SAndroid Build Coastguard Worker class Op2(torch.nn.Module): 123*523fa7a6SAndroid Build Coastguard Worker def __init__(self) -> None: 124*523fa7a6SAndroid Build Coastguard Worker super().__init__() 125*523fa7a6SAndroid Build Coastguard Worker self.a = torch.ones(2, 2) 126*523fa7a6SAndroid Build Coastguard Worker self.b = 2 * torch.ones(2, 2) 127*523fa7a6SAndroid Build Coastguard Worker 128*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 129*523fa7a6SAndroid Build Coastguard Worker for _ in range(10): 130*523fa7a6SAndroid Build Coastguard Worker z = self.a % x # remainder 131*523fa7a6SAndroid Build Coastguard Worker y = z / self.b # div 132*523fa7a6SAndroid Build Coastguard Worker z = z + z # add 133*523fa7a6SAndroid Build Coastguard Worker return y + z 134*523fa7a6SAndroid Build Coastguard Worker 135*523fa7a6SAndroid Build Coastguard Worker # Generate a program with Op2's operations (remainder, div, add) 136*523fa7a6SAndroid Build Coastguard Worker model2 = Op2() 137*523fa7a6SAndroid Build Coastguard Worker inputs = torch.ones(2, 2) 138*523fa7a6SAndroid Build Coastguard Worker exec_prog = to_edge(export(model2, (inputs,))).to_executorch() 139*523fa7a6SAndroid Build Coastguard Worker 140*523fa7a6SAndroid Build Coastguard Worker exported_prog = exec_prog.exported_program() 141*523fa7a6SAndroid Build Coastguard Worker res = exported_prog.module()(inputs)[0] # noqa 142*523fa7a6SAndroid Build Coastguard Worker # Verifiers are run internally in to_edge, export, and to_executorch. 143*523fa7a6SAndroid Build Coastguard Worker # If we make it this far then no errors were thrown in verification 144*523fa7a6SAndroid Build Coastguard Worker 145*523fa7a6SAndroid Build Coastguard Worker 146*523fa7a6SAndroid Build Coastguard Workerclass TestEdgeVerification(unittest.TestCase): 147*523fa7a6SAndroid Build Coastguard Worker def test_edge_happy(self) -> None: 148*523fa7a6SAndroid Build Coastguard Worker class TestModel(torch.nn.Module): 149*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 150*523fa7a6SAndroid Build Coastguard Worker super().__init__() 151*523fa7a6SAndroid Build Coastguard Worker self.register_buffer("a", torch.randn(1, 3, 100, 100)) 152*523fa7a6SAndroid Build Coastguard Worker 153*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 154*523fa7a6SAndroid Build Coastguard Worker b = self.a + x 155*523fa7a6SAndroid Build Coastguard Worker return torch._to_cpu([b, x]) 156*523fa7a6SAndroid Build Coastguard Worker 157*523fa7a6SAndroid Build Coastguard Worker m = TestModel() 158*523fa7a6SAndroid Build Coastguard Worker egm = ( 159*523fa7a6SAndroid Build Coastguard Worker to_edge( 160*523fa7a6SAndroid Build Coastguard Worker export( 161*523fa7a6SAndroid Build Coastguard Worker m, 162*523fa7a6SAndroid Build Coastguard Worker (torch.randn(1, 3, 100, 100).to(dtype=torch.int),), 163*523fa7a6SAndroid Build Coastguard Worker ) 164*523fa7a6SAndroid Build Coastguard Worker ) 165*523fa7a6SAndroid Build Coastguard Worker .exported_program() 166*523fa7a6SAndroid Build Coastguard Worker .graph_module 167*523fa7a6SAndroid Build Coastguard Worker ) 168*523fa7a6SAndroid Build Coastguard Worker verifier = EXIREdgeDialectVerifier() 169*523fa7a6SAndroid Build Coastguard Worker verifier(egm) 170*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(verifier.is_valid(egm)) 171*523fa7a6SAndroid Build Coastguard Worker 172*523fa7a6SAndroid Build Coastguard Worker def test_edge_happy_with_optional_tensor_input(self) -> None: 173*523fa7a6SAndroid Build Coastguard Worker class TestModel(torch.nn.Module): 174*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 175*523fa7a6SAndroid Build Coastguard Worker super().__init__() 176*523fa7a6SAndroid Build Coastguard Worker 177*523fa7a6SAndroid Build Coastguard Worker def forward(self, x, weight, bias): 178*523fa7a6SAndroid Build Coastguard Worker # weight and bias here are optional tensor inputs. 179*523fa7a6SAndroid Build Coastguard Worker return torch.group_norm(x, 4, weight, bias) 180*523fa7a6SAndroid Build Coastguard Worker 181*523fa7a6SAndroid Build Coastguard Worker m = TestModel() 182*523fa7a6SAndroid Build Coastguard Worker egm = ( 183*523fa7a6SAndroid Build Coastguard Worker to_edge( 184*523fa7a6SAndroid Build Coastguard Worker export( 185*523fa7a6SAndroid Build Coastguard Worker m, 186*523fa7a6SAndroid Build Coastguard Worker (torch.rand(16, 8, 32, 32), torch.rand(8), torch.rand(8)), 187*523fa7a6SAndroid Build Coastguard Worker ) 188*523fa7a6SAndroid Build Coastguard Worker ) 189*523fa7a6SAndroid Build Coastguard Worker .exported_program() 190*523fa7a6SAndroid Build Coastguard Worker .graph_module 191*523fa7a6SAndroid Build Coastguard Worker ) 192*523fa7a6SAndroid Build Coastguard Worker verifier = EXIREdgeDialectVerifier() 193*523fa7a6SAndroid Build Coastguard Worker verifier(egm) 194*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(verifier.is_valid(egm)) 195*523fa7a6SAndroid Build Coastguard Worker 196*523fa7a6SAndroid Build Coastguard Worker def test_edge_happy_with_empty_tensorlist_input(self) -> None: 197*523fa7a6SAndroid Build Coastguard Worker class TestModel(torch.nn.Module): 198*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 199*523fa7a6SAndroid Build Coastguard Worker super().__init__() 200*523fa7a6SAndroid Build Coastguard Worker 201*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 202*523fa7a6SAndroid Build Coastguard Worker return torch._to_cpu(x) 203*523fa7a6SAndroid Build Coastguard Worker 204*523fa7a6SAndroid Build Coastguard Worker m = TestModel() 205*523fa7a6SAndroid Build Coastguard Worker egm = ( 206*523fa7a6SAndroid Build Coastguard Worker to_edge( 207*523fa7a6SAndroid Build Coastguard Worker export( 208*523fa7a6SAndroid Build Coastguard Worker m, 209*523fa7a6SAndroid Build Coastguard Worker ([],), 210*523fa7a6SAndroid Build Coastguard Worker ) 211*523fa7a6SAndroid Build Coastguard Worker ) 212*523fa7a6SAndroid Build Coastguard Worker .exported_program() 213*523fa7a6SAndroid Build Coastguard Worker .graph_module 214*523fa7a6SAndroid Build Coastguard Worker ) 215*523fa7a6SAndroid Build Coastguard Worker verifier = EXIREdgeDialectVerifier() 216*523fa7a6SAndroid Build Coastguard Worker verifier(egm) 217*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(verifier.is_valid(egm)) 218*523fa7a6SAndroid Build Coastguard Worker 219*523fa7a6SAndroid Build Coastguard Worker def test_edge_sad(self) -> None: 220*523fa7a6SAndroid Build Coastguard Worker class TestModel(torch.nn.Module): 221*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 222*523fa7a6SAndroid Build Coastguard Worker super().__init__() 223*523fa7a6SAndroid Build Coastguard Worker self.register_buffer("a", torch.randn(1, 3, 100, 100)) 224*523fa7a6SAndroid Build Coastguard Worker 225*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 226*523fa7a6SAndroid Build Coastguard Worker b = self.a + x 227*523fa7a6SAndroid Build Coastguard Worker return torch._to_cpu([b, x]) 228*523fa7a6SAndroid Build Coastguard Worker 229*523fa7a6SAndroid Build Coastguard Worker m = TestModel() 230*523fa7a6SAndroid Build Coastguard Worker egm = export( 231*523fa7a6SAndroid Build Coastguard Worker m, 232*523fa7a6SAndroid Build Coastguard Worker (torch.randn(1, 3, 100, 100).to(dtype=torch.int),), 233*523fa7a6SAndroid Build Coastguard Worker ).graph_module 234*523fa7a6SAndroid Build Coastguard Worker verifier = EXIREdgeDialectVerifier() 235*523fa7a6SAndroid Build Coastguard Worker with self.assertRaises(SpecViolationError): 236*523fa7a6SAndroid Build Coastguard Worker verifier(egm) 237*523fa7a6SAndroid Build Coastguard Worker 238*523fa7a6SAndroid Build Coastguard Worker def test_edge_happy_with_edge_ops(self) -> None: 239*523fa7a6SAndroid Build Coastguard Worker class TestModel(torch.nn.Module): 240*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 241*523fa7a6SAndroid Build Coastguard Worker super().__init__() 242*523fa7a6SAndroid Build Coastguard Worker 243*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 244*523fa7a6SAndroid Build Coastguard Worker return x + x 245*523fa7a6SAndroid Build Coastguard Worker 246*523fa7a6SAndroid Build Coastguard Worker m = TestModel() 247*523fa7a6SAndroid Build Coastguard Worker egm = ( 248*523fa7a6SAndroid Build Coastguard Worker to_edge( 249*523fa7a6SAndroid Build Coastguard Worker export( 250*523fa7a6SAndroid Build Coastguard Worker m, 251*523fa7a6SAndroid Build Coastguard Worker (torch.randn(1, 3, 100, 100).to(dtype=torch.int),), 252*523fa7a6SAndroid Build Coastguard Worker ) 253*523fa7a6SAndroid Build Coastguard Worker ) 254*523fa7a6SAndroid Build Coastguard Worker .exported_program() 255*523fa7a6SAndroid Build Coastguard Worker .graph_module 256*523fa7a6SAndroid Build Coastguard Worker ) 257*523fa7a6SAndroid Build Coastguard Worker verifier = EXIREdgeDialectVerifier() 258*523fa7a6SAndroid Build Coastguard Worker verifier(egm) 259*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(verifier.is_valid(egm)) 260*523fa7a6SAndroid Build Coastguard Worker 261*523fa7a6SAndroid Build Coastguard Worker def test_edge_sad_with_edge_ops(self) -> None: 262*523fa7a6SAndroid Build Coastguard Worker # log_softmax only takes float or double Tensor 263*523fa7a6SAndroid Build Coastguard Worker m = torch.nn.LogSoftmax(dim=1) 264*523fa7a6SAndroid Build Coastguard Worker with self.assertRaises(SpecViolationError): 265*523fa7a6SAndroid Build Coastguard Worker _ = ( 266*523fa7a6SAndroid Build Coastguard Worker to_edge( 267*523fa7a6SAndroid Build Coastguard Worker export( 268*523fa7a6SAndroid Build Coastguard Worker m, 269*523fa7a6SAndroid Build Coastguard Worker (torch.randn(1, 3, 100, 100).to(dtype=torch.bfloat16),), 270*523fa7a6SAndroid Build Coastguard Worker ) 271*523fa7a6SAndroid Build Coastguard Worker ) 272*523fa7a6SAndroid Build Coastguard Worker .exported_program() 273*523fa7a6SAndroid Build Coastguard Worker .graph_module 274*523fa7a6SAndroid Build Coastguard Worker ) 275