1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import unittest 10from contextlib import contextmanager 11from typing import Any 12 13import torch 14from executorch.exir import EdgeCompileConfig, to_edge 15 16from executorch.exir.dialects._ops import ops 17from torch._export.verifier import SpecViolationError 18from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 19from torch.export import export 20 21from ..verifier import EXIREdgeDialectVerifier 22 23 24class TestEdgeDialectVerifier(unittest.TestCase): 25 @contextmanager 26 def assertNotRaises(self, exc_type: Any) -> Any: 27 try: 28 yield None 29 except exc_type: 30 raise self.failureException("{} raised".format(exc_type.__name__)) 31 32 def test_edge_verifier_check_valid_op_succeed_given_custom_op(self) -> None: 33 edge_op = ops.edge.quantized_decomposed.quantize_per_tensor.default 34 verifier = EXIREdgeDialectVerifier() 35 with self.assertNotRaises(SpecViolationError): 36 verifier.check_valid_edge_op(edge_op) 37 verifier.check_valid_op(edge_op) 38 39 def test_edge_verifier_enablement(self) -> None: 40 class M(torch.nn.Module): 41 def forward(self, x, y): 42 z = y.item() 43 torch._check(z > 0) 44 torch._check(z < 4) 45 return x[z : z + y.shape[0]] 46 47 ep = torch.export.export(M(), (torch.randn(10), torch.tensor([3]))) 48 49 compile_config_with_disable_ir_validity = EdgeCompileConfig( 50 _check_ir_validity=False 51 ) 52 edge_manager = to_edge( 53 ep, compile_config=compile_config_with_disable_ir_validity 54 ) 55 56 normal_verifier = EXIREdgeDialectVerifier() 57 disable_ir_validity_verifier = EXIREdgeDialectVerifier( 58 compile_config_with_disable_ir_validity 59 ) 60 61 # exported model can not pass normal verifier due to 62 # aten.sym_constrain_range.default is illegal to be edge op 63 with self.assertRaises(SpecViolationError): 64 normal_verifier(edge_manager.exported_program()) 65 66 # exported model can pass disable_ir_validity_verifier due to verifier 67 # is disabled by compile_config_with_disable_ir_validity 68 # (_check_ir_validity=False). Noted that this verifation has been done 69 # when calling `to_edge`. Explicitly calling verifier here just for better 70 # demonstration and is unnecessary in real world for ir verification. 71 disable_ir_validity_verifier(edge_manager.exported_program()) 72 73 def test_edge_verifier_check_edge_op(self) -> None: 74 class Model(torch.nn.Module): 75 def __init__(self): 76 super().__init__() 77 78 def forward(self, x: torch.Tensor) -> torch.Tensor: 79 return x.transpose(0, 1) 80 81 m = Model().eval() 82 83 example_input = (torch.zeros([2, 2]),) 84 85 export_model = export(m, example_input) 86 87 compile_config_without_edge_op = EdgeCompileConfig( 88 _use_edge_ops=False, _skip_dim_order=False 89 ) 90 91 edge_manager = to_edge( 92 export_model, compile_config=compile_config_without_edge_op 93 ) 94 95 normal_verifier = EXIREdgeDialectVerifier() 96 disable_edge_op_check_verifier = EXIREdgeDialectVerifier( 97 compile_config_without_edge_op 98 ) 99 100 # exported model can not pass normal verifier due to 101 # incontiguous memory layout tensor is not supported in ET 102 with self.assertRaises(SpecViolationError): 103 normal_verifier(edge_manager.exported_program()) 104 105 # exported model can pass disable_edge_op_check_verifier due to the 106 # incontiguous memory layout tensor verification is disabled by 107 # compile_config_without_edge_op (_use_edge_ops=False). Noted that this 108 # verifation has been done when calling `to_edge`. Explicitly calling 109 # verifier here just for better demonstration and is unnecessary 110 # in real world for ir verification. 111 disable_edge_op_check_verifier(edge_manager.exported_program()) 112 113 def test_edge_verifier_check_valid_dim_order_graph(self) -> None: 114 class Model(torch.nn.Module): 115 def __init__(self): 116 super().__init__() 117 118 def forward(self, x: torch.Tensor) -> torch.Tensor: 119 t1 = x.to(dtype=torch.double, memory_format=torch.channels_last) 120 t2 = t1 + t1 121 return t1 * t2 122 123 m = Model().eval() 124 125 example_input = ( 126 torch.rand_like( 127 torch.zeros([2, 2, 2, 2]), 128 dtype=torch.float32, 129 memory_format=torch.contiguous_format, 130 ), 131 ) 132 133 export_model = export(m, example_input) 134 135 compile_config_with_dim_order = EdgeCompileConfig(_skip_dim_order=False) 136 compile_config_with_stride = EdgeCompileConfig(_skip_dim_order=True) 137 138 dim_order_edge_model = to_edge( 139 export_model, compile_config=compile_config_with_dim_order 140 ) 141 stride_edge_model = to_edge( 142 export_model, compile_config=compile_config_with_stride 143 ) 144 145 dim_order_verifier = EXIREdgeDialectVerifier( 146 edge_compile_config=compile_config_with_dim_order 147 ) 148 stride_verifier = EXIREdgeDialectVerifier( 149 edge_compile_config=compile_config_with_stride 150 ) 151 152 dim_order_verifier(dim_order_edge_model.exported_program()) 153 stride_verifier(stride_edge_model.exported_program()) 154 155 with self.assertRaises(SpecViolationError): 156 dim_order_verifier(stride_edge_model.exported_program()) 157 with self.assertRaises(SpecViolationError): 158 stride_verifier(dim_order_edge_model.exported_program()) 159