xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/models/mobilenet_v3.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack.test.tester import Tester
11from executorch.backends.xnnpack.test.tester.tester import Quantize
12from torchvision import models
13
14
15class TestMobileNetV3(unittest.TestCase):
16    mv3 = models.mobilenetv3.mobilenet_v3_small(pretrained=True)
17    mv3 = mv3.eval()
18    model_inputs = (torch.randn(1, 3, 224, 224),)
19    dynamic_shapes = (
20        {
21            2: torch.export.Dim("height", min=224, max=455),
22            3: torch.export.Dim("width", min=224, max=455),
23        },
24    )
25
26    all_operators = {
27        "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default",
28        "executorch_exir_dialects_edge__ops_aten_clamp_default",
29        "executorch_exir_dialects_edge__ops_aten_permute_copy_default",
30        "executorch_exir_dialects_edge__ops_aten_addmm_default",
31        "executorch_exir_dialects_edge__ops_aten_convolution_default",
32        "executorch_exir_dialects_edge__ops_aten_relu_default",
33        "executorch_exir_dialects_edge__ops_aten_add_Tensor",
34        "executorch_exir_dialects_edge__ops_aten_mul_Tensor",
35        "executorch_exir_dialects_edge__ops_aten_div_Tensor",
36        "executorch_exir_dialects_edge__ops_aten_mean_dim",
37    }
38
39    def test_fp32_mv3(self):
40        (
41            Tester(self.mv3, self.model_inputs, dynamic_shapes=self.dynamic_shapes)
42            .export()
43            .to_edge_transform_and_lower()
44            .check(["torch.ops.higher_order.executorch_call_delegate"])
45            .check_not(list(self.all_operators))
46            .to_executorch()
47            .serialize()
48            .run_method_and_compare_outputs(num_runs=5)
49        )
50
51    @unittest.skip("T187799178: Debugging Numerical Issues with Calibration")
52    def _test_qs8_mv3(self):
53        ops_after_lowering = self.all_operators
54
55        (
56            Tester(self.mv3, self.model_inputs, dynamic_shapes=self.dynamic_shapes)
57            .quantize()
58            .export()
59            .to_edge_tranform_and_lower()
60            .check(["torch.ops.higher_order.executorch_call_delegate"])
61            .check_not(list(ops_after_lowering))
62            .to_executorch()
63            .serialize()
64            .run_method_and_compare_outputs(num_runs=5)
65        )
66
67    # TODO: Delete and only used calibrated test after T187799178
68    def test_qs8_mv3_no_calibration(self):
69        ops_after_lowering = self.all_operators
70
71        (
72            Tester(self.mv3, self.model_inputs, dynamic_shapes=self.dynamic_shapes)
73            .quantize(Quantize(calibrate=False))
74            .export()
75            .to_edge_transform_and_lower()
76            .check(["torch.ops.higher_order.executorch_call_delegate"])
77            .check_not(list(ops_after_lowering))
78            .to_executorch()
79            .serialize()
80            .run_method_and_compare_outputs(num_runs=5)
81        )
82