xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/models/mobilebert.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack.test.tester import Quantize, Tester
11from transformers import MobileBertConfig, MobileBertModel  # @manual
12
13
14class TestMobilebert(unittest.TestCase):
15    # pyre-ignore
16    mobilebert = MobileBertModel(MobileBertConfig()).eval()
17    example_inputs = (torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]]),)
18    supported_ops = {
19        "executorch_exir_dialects_edge__ops_aten_addmm_default",
20        "executorch_exir_dialects_edge__ops_aten_add_Tensor",
21        "executorch_exir_dialects_edge__ops_aten_mul_Tensor",
22        "executorch_exir_dialects_edge__ops_aten_sub_Tensor",
23        "executorch_exir_dialects_edge__ops_aten_div_Tensor",
24        "executorch_exir_dialects_edge__ops_aten_cat_default",
25        "executorch_exir_dialects_edge__ops_aten_relu_default",
26        "executorch_exir_dialects_edge__ops_aten_permute_copy_default",
27        "executorch_exir_dialects_edge__ops_aten__softmax_default",
28        "executorch_exir_dialects_edge__ops_aten_constant_pad_nd_default",
29    }
30
31    def test_fp32_mobilebert(self):
32        (
33            Tester(self.mobilebert, self.example_inputs)
34            .export()
35            .to_edge_transform_and_lower()
36            .check_not(list(self.supported_ops))
37            .to_executorch()
38            .serialize()
39            .run_method_and_compare_outputs(inputs=self.example_inputs)
40        )
41
42    def test_qs8_mobilebert(self):
43        (
44            Tester(self.mobilebert, self.example_inputs)
45            .quantize(Quantize(calibrate=False))
46            .export()
47            .to_edge_transform_and_lower()
48            .check_not(list(self.supported_ops))
49            .to_executorch()
50            .serialize()
51            .run_method_and_compare_outputs(inputs=self.example_inputs)
52        )
53