1# Copyright (c) MediaTek Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import argparse 8import os 9 10import torch 11from executorch.backends.mediatek import Precision 12from executorch.examples.mediatek.aot_utils.oss_utils.utils import ( 13 build_executorch_binary, 14) 15from executorch.examples.models.resnet import ResNet50Model 16 17 18class NhwcWrappedModel(torch.nn.Module): 19 def __init__(self): 20 super(NhwcWrappedModel, self).__init__() 21 self.resnet = ResNet50Model().get_eager_model() 22 23 def forward(self, input1): 24 nchw_input1 = input1.permute(0, 3, 1, 2) 25 output = self.resnet(nchw_input1) 26 return output 27 28 29def get_dataset(dataset_path, data_size): 30 from torchvision import datasets, transforms 31 32 def get_data_loader(): 33 preprocess = transforms.Compose( 34 [ 35 transforms.Resize(256), 36 transforms.CenterCrop(224), 37 transforms.ToTensor(), 38 transforms.Normalize( 39 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 40 ), 41 ] 42 ) 43 imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess) 44 return torch.utils.data.DataLoader( 45 imagenet_data, 46 shuffle=True, 47 ) 48 49 # prepare input data 50 inputs, targets, input_list = [], [], "" 51 data_loader = get_data_loader() 52 for index, data in enumerate(data_loader): 53 if index >= data_size: 54 break 55 feature, target = data 56 feature = feature.permute(0, 2, 3, 1) # NHWC 57 inputs.append((feature,)) 58 targets.append(target) 59 input_list += f"input_{index}_0.bin\n" 60 61 return inputs, targets, input_list 62 63 64if __name__ == "__main__": 65 parser = argparse.ArgumentParser() 66 67 parser.add_argument( 68 "-d", 69 "--dataset", 70 help=( 71 "path to the validation folder of ImageNet dataset. " 72 "e.g. --dataset imagenet-mini/val " 73 "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" 74 ), 75 type=str, 76 required=True, 77 ) 78 79 parser.add_argument( 80 "-a", 81 "--artifact", 82 help="path for storing generated artifacts by this example. " 83 "Default ./resnet50", 84 default="./resnet50", 85 type=str, 86 ) 87 88 args = parser.parse_args() 89 90 # ensure the working directory exist. 91 os.makedirs(args.artifact, exist_ok=True) 92 93 data_num = 100 94 inputs, targets, input_list = get_dataset( 95 dataset_path=f"{args.dataset}", 96 data_size=data_num, 97 ) 98 99 # save data to inference on device 100 input_list_file = f"{args.artifact}/input_list.txt" 101 with open(input_list_file, "w") as f: 102 f.write(input_list) 103 f.flush() 104 for idx, data in enumerate(inputs): 105 for i, d in enumerate(data): 106 file_name = f"{args.artifact}/input_{idx}_{i}.bin" 107 d.detach().numpy().tofile(file_name) 108 for idx, data in enumerate(targets): 109 file_name = f"{args.artifact}/golden_{idx}_0.bin" 110 data.detach().numpy().tofile(file_name) 111 112 # compile to pte 113 pte_filename = "resnet50_mtk" 114 instance = NhwcWrappedModel() 115 build_executorch_binary( 116 instance.eval(), 117 (torch.randn(1, 224, 224, 3),), 118 f"{args.artifact}/{pte_filename}", 119 inputs, 120 quant_dtype=Precision.A8W8, 121 ) 122