xref: /aosp_15_r20/external/executorch/examples/mediatek/model_export_scripts/deeplab_v3.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) MediaTek Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import os
9import random
10
11import numpy as np
12
13import torch
14from executorch.backends.mediatek import Precision
15from executorch.examples.mediatek.aot_utils.oss_utils.utils import (
16    build_executorch_binary,
17)
18from executorch.examples.models.deeplab_v3 import DeepLabV3ResNet101Model
19
20
21class NhwcWrappedModel(torch.nn.Module):
22    def __init__(self):
23        super(NhwcWrappedModel, self).__init__()
24        self.deeplabv3 = DeepLabV3ResNet101Model().get_eager_model()
25
26    def forward(self, input1):
27        nchw_input1 = input1.permute(0, 3, 1, 2)
28        nchw_output = self.deeplabv3(nchw_input1)
29        return nchw_output.permute(0, 2, 3, 1)
30
31
32def get_dataset(data_size, dataset_dir, download):
33    from torchvision import datasets, transforms
34
35    input_size = (224, 224)
36    preprocess = transforms.Compose(
37        [
38            transforms.Resize(input_size),
39            transforms.ToTensor(),
40            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
41        ]
42    )
43    dataset = list(
44        datasets.VOCSegmentation(
45            root=os.path.join(dataset_dir, "voc_image"),
46            year="2009",
47            image_set="val",
48            transform=preprocess,
49            download=download,
50        )
51    )
52
53    # prepare input data
54    random.shuffle(dataset)
55    inputs, targets, input_list = [], [], ""
56    for index, data in enumerate(dataset):
57        if index >= data_size:
58            break
59        image, target = data
60        inputs.append((image.unsqueeze(0).permute(0, 2, 3, 1),))
61        targets.append(np.array(target.resize(input_size)))
62        input_list += f"input_{index}_0.bin\n"
63
64    return inputs, targets, input_list
65
66
67if __name__ == "__main__":
68    parser = argparse.ArgumentParser()
69
70    parser.add_argument(
71        "-a",
72        "--artifact",
73        help="path for storing generated artifacts by this example. Default ./deeplab_v3",
74        default="./deeplab_v3",
75        type=str,
76    )
77
78    parser.add_argument(
79        "-d",
80        "--download",
81        help="If specified, download VOCSegmentation dataset by torchvision API",
82        action="store_true",
83        default=False,
84    )
85
86    args = parser.parse_args()
87
88    # ensure the working directory exist.
89    os.makedirs(args.artifact, exist_ok=True)
90
91    data_num = 100
92    inputs, targets, input_list = get_dataset(
93        data_size=data_num, dataset_dir=args.artifact, download=args.download
94    )
95
96    # save data to inference on device
97    input_list_file = f"{args.artifact}/input_list.txt"
98    with open(input_list_file, "w") as f:
99        f.write(input_list)
100        f.flush()
101    for idx, data in enumerate(inputs):
102        for i, d in enumerate(data):
103            file_name = f"{args.artifact}/input_{idx}_{i}.bin"
104            d.detach().numpy().tofile(file_name)
105            if idx == 0:
106                print("inp shape: ", d.detach().numpy().shape)
107                print("inp type: ", d.detach().numpy().dtype)
108    for idx, data in enumerate(targets):
109        file_name = f"{args.artifact}/golden_{idx}_0.bin"
110        data.tofile(file_name)
111        if idx == 0:
112            print("golden shape: ", data.shape)
113            print("golden type: ", data.dtype)
114
115    # build pte
116    pte_filename = "deeplabV3Resnet101_mtk"
117    instance = NhwcWrappedModel()
118    build_executorch_binary(
119        instance.eval(),
120        (torch.randn(1, 224, 224, 3),),
121        f"{args.artifact}/{pte_filename}",
122        inputs,
123        quant_dtype=Precision.A8W8,
124    )
125