xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/models/inception_v4.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack.test.tester import Tester
11from timm.models import inception_v4
12
13
14class TestInceptionV4(unittest.TestCase):
15    ic4 = inception_v4(pretrained=False).eval()
16    model_inputs = (torch.randn(3, 299, 299).unsqueeze(0),)
17
18    all_operators = {
19        "executorch_exir_dialects_edge__ops_aten_addmm_default",
20        # "executorch.exir.dialects.edge._ops.aten.avg_pool2d.default", Currently do not have avg_pool2d partitioned
21        "executorch_exir_dialects_edge__ops_aten_cat_default",
22        "executorch_exir_dialects_edge__ops_aten_convolution_default",
23        "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default",
24        "executorch_exir_dialects_edge__ops_aten_mean_dim",
25        "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default",
26        "executorch_exir_dialects_edge__ops_aten_permute_copy_default",
27        "executorch_exir_dialects_edge__ops_aten_relu_default",
28    }
29
30    def test_fp32_ic4(self):
31
32        (
33            Tester(self.ic4, self.model_inputs)
34            .export()
35            .to_edge_transform_and_lower()
36            .check(["torch.ops.higher_order.executorch_call_delegate"])
37            .check_not(list(self.all_operators))
38            .to_executorch()
39            .serialize()
40            .run_method_and_compare_outputs()
41        )
42
43    def test_qs8_ic4(self):
44        # Quantization fuses away batchnorm, so it is no longer in the graph
45        ops_after_quantization = self.all_operators - {
46            "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default",
47        }
48
49        (
50            Tester(self.ic4, self.model_inputs)
51            .quantize()
52            .export()
53            .to_edge_transform_and_lower()
54            .check(["torch.ops.higher_order.executorch_call_delegate"])
55            .check_not(list(ops_after_quantization))
56            .to_executorch()
57            .serialize()
58            .run_method_and_compare_outputs()
59        )
60