1# Copyright (c) MediaTek Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import os 8from typing import Optional 9 10import torch 11from executorch import exir 12from executorch.backends.mediatek import ( 13 NeuropilotPartitioner, 14 NeuropilotQuantizer, 15 Precision, 16) 17from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e 18 19 20def build_executorch_binary( 21 model, 22 inputs, 23 file_name, 24 dataset, 25 quant_dtype: Optional[Precision] = None, 26): 27 if quant_dtype is not None: 28 quantizer = NeuropilotQuantizer() 29 quantizer.setup_precision(quant_dtype) 30 if quant_dtype not in Precision: 31 raise AssertionError(f"No support for Precision {quant_dtype}.") 32 33 captured_model = torch.export.export_for_training(model, inputs).module() 34 annotated_model = prepare_pt2e(captured_model, quantizer) 35 print("Quantizing the model...") 36 # calibration 37 for data in dataset: 38 annotated_model(*data) 39 quantized_model = convert_pt2e(annotated_model, fold_quantize=False) 40 aten_dialect = torch.export.export(quantized_model, inputs) 41 else: 42 aten_dialect = torch.export.export(model, inputs) 43 44 from executorch.exir.program._program import to_edge_transform_and_lower 45 46 edge_compile_config = exir.EdgeCompileConfig(_check_ir_validity=False) 47 # skipped op names are used for deeplabV3 model 48 neuro_partitioner = NeuropilotPartitioner( 49 [], 50 op_names_to_skip={ 51 "aten_convolution_default_106", 52 "aten_convolution_default_107", 53 }, 54 ) 55 edge_prog = to_edge_transform_and_lower( 56 aten_dialect, 57 compile_config=edge_compile_config, 58 partitioner=[neuro_partitioner], 59 ) 60 61 exec_prog = edge_prog.to_executorch(config=exir.ExecutorchBackendConfig()) 62 63 with open(f"{file_name}.pte", "wb") as file: 64 file.write(exec_prog.buffer) 65 66 67def make_output_dir(path: str): 68 if os.path.exists(path): 69 for f in os.listdir(path): 70 os.remove(os.path.join(path, f)) 71 os.removedirs(path) 72 os.makedirs(path) 73