1# Copyright (c) MediaTek Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import argparse 8import os 9import random 10 11import numpy as np 12 13import torch 14from executorch.backends.mediatek import Precision 15from executorch.examples.mediatek.aot_utils.oss_utils.utils import ( 16 build_executorch_binary, 17) 18from executorch.examples.models.deeplab_v3 import DeepLabV3ResNet101Model 19 20 21class NhwcWrappedModel(torch.nn.Module): 22 def __init__(self): 23 super(NhwcWrappedModel, self).__init__() 24 self.deeplabv3 = DeepLabV3ResNet101Model().get_eager_model() 25 26 def forward(self, input1): 27 nchw_input1 = input1.permute(0, 3, 1, 2) 28 nchw_output = self.deeplabv3(nchw_input1) 29 return nchw_output.permute(0, 2, 3, 1) 30 31 32def get_dataset(data_size, dataset_dir, download): 33 from torchvision import datasets, transforms 34 35 input_size = (224, 224) 36 preprocess = transforms.Compose( 37 [ 38 transforms.Resize(input_size), 39 transforms.ToTensor(), 40 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 41 ] 42 ) 43 dataset = list( 44 datasets.VOCSegmentation( 45 root=os.path.join(dataset_dir, "voc_image"), 46 year="2009", 47 image_set="val", 48 transform=preprocess, 49 download=download, 50 ) 51 ) 52 53 # prepare input data 54 random.shuffle(dataset) 55 inputs, targets, input_list = [], [], "" 56 for index, data in enumerate(dataset): 57 if index >= data_size: 58 break 59 image, target = data 60 inputs.append((image.unsqueeze(0).permute(0, 2, 3, 1),)) 61 targets.append(np.array(target.resize(input_size))) 62 input_list += f"input_{index}_0.bin\n" 63 64 return inputs, targets, input_list 65 66 67if __name__ == "__main__": 68 parser = argparse.ArgumentParser() 69 70 parser.add_argument( 71 "-a", 72 "--artifact", 73 help="path for storing generated artifacts by this example. Default ./deeplab_v3", 74 default="./deeplab_v3", 75 type=str, 76 ) 77 78 parser.add_argument( 79 "-d", 80 "--download", 81 help="If specified, download VOCSegmentation dataset by torchvision API", 82 action="store_true", 83 default=False, 84 ) 85 86 args = parser.parse_args() 87 88 # ensure the working directory exist. 89 os.makedirs(args.artifact, exist_ok=True) 90 91 data_num = 100 92 inputs, targets, input_list = get_dataset( 93 data_size=data_num, dataset_dir=args.artifact, download=args.download 94 ) 95 96 # save data to inference on device 97 input_list_file = f"{args.artifact}/input_list.txt" 98 with open(input_list_file, "w") as f: 99 f.write(input_list) 100 f.flush() 101 for idx, data in enumerate(inputs): 102 for i, d in enumerate(data): 103 file_name = f"{args.artifact}/input_{idx}_{i}.bin" 104 d.detach().numpy().tofile(file_name) 105 if idx == 0: 106 print("inp shape: ", d.detach().numpy().shape) 107 print("inp type: ", d.detach().numpy().dtype) 108 for idx, data in enumerate(targets): 109 file_name = f"{args.artifact}/golden_{idx}_0.bin" 110 data.tofile(file_name) 111 if idx == 0: 112 print("golden shape: ", data.shape) 113 print("golden type: ", data.dtype) 114 115 # build pte 116 pte_filename = "deeplabV3Resnet101_mtk" 117 instance = NhwcWrappedModel() 118 build_executorch_binary( 119 instance.eval(), 120 (torch.randn(1, 224, 224, 3),), 121 f"{args.artifact}/{pte_filename}", 122 inputs, 123 quant_dtype=Precision.A8W8, 124 ) 125