xref: /aosp_15_r20/external/executorch/examples/mediatek/model_export_scripts/resnet50.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) MediaTek Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import os
9
10import torch
11from executorch.backends.mediatek import Precision
12from executorch.examples.mediatek.aot_utils.oss_utils.utils import (
13    build_executorch_binary,
14)
15from executorch.examples.models.resnet import ResNet50Model
16
17
18class NhwcWrappedModel(torch.nn.Module):
19    def __init__(self):
20        super(NhwcWrappedModel, self).__init__()
21        self.resnet = ResNet50Model().get_eager_model()
22
23    def forward(self, input1):
24        nchw_input1 = input1.permute(0, 3, 1, 2)
25        output = self.resnet(nchw_input1)
26        return output
27
28
29def get_dataset(dataset_path, data_size):
30    from torchvision import datasets, transforms
31
32    def get_data_loader():
33        preprocess = transforms.Compose(
34            [
35                transforms.Resize(256),
36                transforms.CenterCrop(224),
37                transforms.ToTensor(),
38                transforms.Normalize(
39                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
40                ),
41            ]
42        )
43        imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess)
44        return torch.utils.data.DataLoader(
45            imagenet_data,
46            shuffle=True,
47        )
48
49    # prepare input data
50    inputs, targets, input_list = [], [], ""
51    data_loader = get_data_loader()
52    for index, data in enumerate(data_loader):
53        if index >= data_size:
54            break
55        feature, target = data
56        feature = feature.permute(0, 2, 3, 1)  # NHWC
57        inputs.append((feature,))
58        targets.append(target)
59        input_list += f"input_{index}_0.bin\n"
60
61    return inputs, targets, input_list
62
63
64if __name__ == "__main__":
65    parser = argparse.ArgumentParser()
66
67    parser.add_argument(
68        "-d",
69        "--dataset",
70        help=(
71            "path to the validation folder of ImageNet dataset. "
72            "e.g. --dataset imagenet-mini/val "
73            "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)"
74        ),
75        type=str,
76        required=True,
77    )
78
79    parser.add_argument(
80        "-a",
81        "--artifact",
82        help="path for storing generated artifacts by this example. "
83        "Default ./resnet50",
84        default="./resnet50",
85        type=str,
86    )
87
88    args = parser.parse_args()
89
90    # ensure the working directory exist.
91    os.makedirs(args.artifact, exist_ok=True)
92
93    data_num = 100
94    inputs, targets, input_list = get_dataset(
95        dataset_path=f"{args.dataset}",
96        data_size=data_num,
97    )
98
99    # save data to inference on device
100    input_list_file = f"{args.artifact}/input_list.txt"
101    with open(input_list_file, "w") as f:
102        f.write(input_list)
103        f.flush()
104    for idx, data in enumerate(inputs):
105        for i, d in enumerate(data):
106            file_name = f"{args.artifact}/input_{idx}_{i}.bin"
107            d.detach().numpy().tofile(file_name)
108    for idx, data in enumerate(targets):
109        file_name = f"{args.artifact}/golden_{idx}_0.bin"
110        data.detach().numpy().tofile(file_name)
111
112    # compile to pte
113    pte_filename = "resnet50_mtk"
114    instance = NhwcWrappedModel()
115    build_executorch_binary(
116        instance.eval(),
117        (torch.randn(1, 224, 224, 3),),
118        f"{args.artifact}/{pte_filename}",
119        inputs,
120        quant_dtype=Precision.A8W8,
121    )
122