1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.backends.xnnpack.test.tester import Tester 11from timm.models import inception_v4 12 13 14class TestInceptionV4(unittest.TestCase): 15 ic4 = inception_v4(pretrained=False).eval() 16 model_inputs = (torch.randn(3, 299, 299).unsqueeze(0),) 17 18 all_operators = { 19 "executorch_exir_dialects_edge__ops_aten_addmm_default", 20 # "executorch.exir.dialects.edge._ops.aten.avg_pool2d.default", Currently do not have avg_pool2d partitioned 21 "executorch_exir_dialects_edge__ops_aten_cat_default", 22 "executorch_exir_dialects_edge__ops_aten_convolution_default", 23 "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default", 24 "executorch_exir_dialects_edge__ops_aten_mean_dim", 25 "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default", 26 "executorch_exir_dialects_edge__ops_aten_permute_copy_default", 27 "executorch_exir_dialects_edge__ops_aten_relu_default", 28 } 29 30 def test_fp32_ic4(self): 31 32 ( 33 Tester(self.ic4, self.model_inputs) 34 .export() 35 .to_edge_transform_and_lower() 36 .check(["torch.ops.higher_order.executorch_call_delegate"]) 37 .check_not(list(self.all_operators)) 38 .to_executorch() 39 .serialize() 40 .run_method_and_compare_outputs() 41 ) 42 43 def test_qs8_ic4(self): 44 # Quantization fuses away batchnorm, so it is no longer in the graph 45 ops_after_quantization = self.all_operators - { 46 "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default", 47 } 48 49 ( 50 Tester(self.ic4, self.model_inputs) 51 .quantize() 52 .export() 53 .to_edge_transform_and_lower() 54 .check(["torch.ops.higher_order.executorch_call_delegate"]) 55 .check_not(list(ops_after_quantization)) 56 .to_executorch() 57 .serialize() 58 .run_method_and_compare_outputs() 59 ) 60