xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/models/mobilenet_v2.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack.test.tester import Tester
11from executorch.backends.xnnpack.test.tester.tester import Quantize
12from torchvision import models
13from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
14
15
16class TestMobileNetV2(unittest.TestCase):
17    mv2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights)
18    mv2 = mv2.eval()
19    model_inputs = (torch.randn(1, 3, 224, 224),)
20
21    all_operators = {
22        "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default",
23        "executorch_exir_dialects_edge__ops_aten_add_Tensor",
24        "executorch_exir_dialects_edge__ops_aten_permute_copy_default",
25        "executorch_exir_dialects_edge__ops_aten_addmm_default",
26        "executorch_exir_dialects_edge__ops_aten_mean_dim",
27        "executorch_exir_dialects_edge__ops_aten_hardtanh_default",
28        "executorch_exir_dialects_edge__ops_aten_convolution_default",
29    }
30
31    def test_fp32_mv2(self):
32        dynamic_shapes = (
33            {
34                2: torch.export.Dim("height", min=224, max=455),
35                3: torch.export.Dim("width", min=224, max=455),
36            },
37        )
38
39        (
40            Tester(self.mv2, self.model_inputs, dynamic_shapes=dynamic_shapes)
41            .export()
42            .to_edge_transform_and_lower()
43            .check(["torch.ops.higher_order.executorch_call_delegate"])
44            .check_not(list(self.all_operators))
45            .to_executorch()
46            .serialize()
47            .run_method_and_compare_outputs(num_runs=10)
48        )
49
50    @unittest.skip("T187799178: Debugging Numerical Issues with Calibration")
51    def _test_qs8_mv2(self):
52        # Quantization fuses away batchnorm, so it is no longer in the graph
53        ops_after_quantization = self.all_operators - {
54            "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default",
55        }
56
57        dynamic_shapes = (
58            {
59                2: torch.export.Dim("height", min=224, max=455),
60                3: torch.export.Dim("width", min=224, max=455),
61            },
62        )
63
64        (
65            Tester(self.mv2, self.model_inputs, dynamic_shapes=dynamic_shapes)
66            .quantize()
67            .export()
68            .to_edge_transform_and_lower()
69            .check(["torch.ops.higher_order.executorch_call_delegate"])
70            .check_not(list(ops_after_quantization))
71            .to_executorch()
72            .serialize()
73            .run_method_and_compare_outputs(num_runs=10)
74        )
75
76    # TODO: Delete and only used calibrated test after T187799178
77    def test_qs8_mv2_no_calibration(self):
78        # Quantization fuses away batchnorm, so it is no longer in the graph
79        ops_after_quantization = self.all_operators - {
80            "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default",
81        }
82
83        dynamic_shapes = (
84            {
85                2: torch.export.Dim("height", min=224, max=455),
86                3: torch.export.Dim("width", min=224, max=455),
87            },
88        )
89
90        (
91            Tester(self.mv2, self.model_inputs, dynamic_shapes=dynamic_shapes)
92            .quantize(Quantize(calibrate=False))
93            .export()
94            .to_edge_transform_and_lower()
95            .check(["torch.ops.higher_order.executorch_call_delegate"])
96            .check_not(list(ops_after_quantization))
97            .to_executorch()
98            .serialize()
99            .run_method_and_compare_outputs(num_runs=10)
100        )
101