1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.backends.xnnpack.test.tester import Quantize, Tester 11from transformers import MobileBertConfig, MobileBertModel # @manual 12 13 14class TestMobilebert(unittest.TestCase): 15 # pyre-ignore 16 mobilebert = MobileBertModel(MobileBertConfig()).eval() 17 example_inputs = (torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]]),) 18 supported_ops = { 19 "executorch_exir_dialects_edge__ops_aten_addmm_default", 20 "executorch_exir_dialects_edge__ops_aten_add_Tensor", 21 "executorch_exir_dialects_edge__ops_aten_mul_Tensor", 22 "executorch_exir_dialects_edge__ops_aten_sub_Tensor", 23 "executorch_exir_dialects_edge__ops_aten_div_Tensor", 24 "executorch_exir_dialects_edge__ops_aten_cat_default", 25 "executorch_exir_dialects_edge__ops_aten_relu_default", 26 "executorch_exir_dialects_edge__ops_aten_permute_copy_default", 27 "executorch_exir_dialects_edge__ops_aten__softmax_default", 28 "executorch_exir_dialects_edge__ops_aten_constant_pad_nd_default", 29 } 30 31 def test_fp32_mobilebert(self): 32 ( 33 Tester(self.mobilebert, self.example_inputs) 34 .export() 35 .to_edge_transform_and_lower() 36 .check_not(list(self.supported_ops)) 37 .to_executorch() 38 .serialize() 39 .run_method_and_compare_outputs(inputs=self.example_inputs) 40 ) 41 42 def test_qs8_mobilebert(self): 43 ( 44 Tester(self.mobilebert, self.example_inputs) 45 .quantize(Quantize(calibrate=False)) 46 .export() 47 .to_edge_transform_and_lower() 48 .check_not(list(self.supported_ops)) 49 .to_executorch() 50 .serialize() 51 .run_method_and_compare_outputs(inputs=self.example_inputs) 52 ) 53