xref: /aosp_15_r20/external/executorch/examples/qualcomm/scripts/edsr.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import json
8import os
9import re
10from multiprocessing.connection import Client
11
12import numpy as np
13import piq
14import torch
15from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
16from executorch.examples.models.edsr import EdsrModel
17from executorch.examples.qualcomm.utils import (
18    build_executorch_binary,
19    make_output_dir,
20    parse_skip_delegation_node,
21    setup_common_args_and_variables,
22    SimpleADB,
23)
24
25from PIL import Image
26from torch.utils.data import Dataset
27from torchsr.datasets import B100
28from torchvision.transforms.functional import to_pil_image, to_tensor
29
30
31class SrDataset(Dataset):
32    def __init__(self, hr_dir: str, lr_dir: str):
33        self.input_size = np.asanyarray([224, 224])
34        self.hr = []
35        self.lr = []
36
37        for file in sorted(os.listdir(hr_dir)):
38            self.hr.append(self._resize_img(os.path.join(hr_dir, file), 2))
39
40        for file in sorted(os.listdir(lr_dir)):
41            self.lr.append(self._resize_img(os.path.join(lr_dir, file), 1))
42
43        if len(self.hr) != len(self.lr):
44            raise AssertionError(
45                "The number of high resolution pics is not equal to low "
46                "resolution pics"
47            )
48
49    def __getitem__(self, idx: int):
50        return self.hr[idx], self.lr[idx]
51
52    def __len__(self):
53        return len(self.lr)
54
55    def _resize_img(self, file: str, scale: int):
56        with Image.open(file) as img:
57            return to_tensor(img.resize(tuple(self.input_size * scale))).unsqueeze(0)
58
59    def get_input_list(self):
60        input_list = ""
61        for i in range(len(self.lr)):
62            input_list += f"input_{i}_0.raw\n"
63        return input_list
64
65
66def get_b100(
67    dataset_dir: str,
68):
69    hr_dir = f"{dataset_dir}/sr_bm_dataset/SRBenchmarks/benchmark/B100/HR"
70    lr_dir = f"{dataset_dir}/sr_bm_dataset/SRBenchmarks/benchmark/B100/LR_bicubic/X2"
71
72    if not os.path.exists(hr_dir) or not os.path.exists(lr_dir):
73        B100(root=f"{dataset_dir}/sr_bm_dataset", scale=2, download=True)
74
75    return SrDataset(hr_dir, lr_dir)
76
77
78def get_dataset(hr_dir: str, lr_dir: str, default_dataset: str, dataset_dir: str):
79    if not (lr_dir and hr_dir) and not default_dataset:
80        raise RuntimeError(
81            "Nither custom dataset is provided nor using default dataset."
82        )
83
84    if (lr_dir and hr_dir) and default_dataset:
85        raise RuntimeError("Either use custom dataset, or use default dataset.")
86
87    if default_dataset:
88        return get_b100(dataset_dir)
89
90    return SrDataset(hr_dir, lr_dir)
91
92
93def main(args):
94    skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
95
96    # ensure the working directory exist.
97    os.makedirs(args.artifact, exist_ok=True)
98
99    if not args.compile_only and args.device is None:
100        raise RuntimeError(
101            "device serial is required if not compile only. "
102            "Please specify a device serial by -s/--device argument."
103        )
104
105    dataset = get_dataset(
106        args.hr_ref_dir, args.lr_dir, args.default_dataset, args.artifact
107    )
108
109    inputs, targets, input_list = dataset.lr, dataset.hr, dataset.get_input_list()
110    pte_filename = "edsr_qnn_q8"
111    instance = EdsrModel()
112
113    build_executorch_binary(
114        instance.get_eager_model().eval(),
115        (inputs[0],),
116        args.model,
117        f"{args.artifact}/{pte_filename}",
118        [(input,) for input in inputs],
119        skip_node_id_set=skip_node_id_set,
120        skip_node_op_set=skip_node_op_set,
121        quant_dtype=QuantDtype.use_8a8w,
122        shared_buffer=args.shared_buffer,
123    )
124
125    if args.compile_only:
126        return
127
128    adb = SimpleADB(
129        qnn_sdk=os.getenv("QNN_SDK_ROOT"),
130        build_path=f"{args.build_folder}",
131        pte_path=f"{args.artifact}/{pte_filename}.pte",
132        workspace=f"/data/local/tmp/executorch/{pte_filename}",
133        device_id=args.device,
134        host_id=args.host,
135        soc_model=args.model,
136        shared_buffer=args.shared_buffer,
137    )
138    adb.push(inputs=inputs, input_list=input_list)
139    adb.execute()
140
141    # collect output data
142    output_data_folder = f"{args.artifact}/outputs"
143    output_pic_folder = f"{args.artifact}/output_pics"
144    make_output_dir(output_data_folder)
145    make_output_dir(output_pic_folder)
146
147    output_raws = []
148
149    def post_process():
150        cnt = 0
151        output_shape = tuple(targets[0].size())
152        for f in sorted(
153            os.listdir(output_data_folder), key=lambda f: int(f.split("_")[1])
154        ):
155            filename = os.path.join(output_data_folder, f)
156            if re.match(r"^output_[0-9]+_[1-9].raw$", f):
157                os.remove(filename)
158            else:
159                output = np.fromfile(filename, dtype=np.float32)
160                output = torch.tensor(output).reshape(output_shape).clamp(0, 1)
161                output_raws.append(output)
162                to_pil_image(output.squeeze(0)).save(
163                    os.path.join(output_pic_folder, str(cnt) + ".png")
164                )
165                cnt += 1
166
167    adb.pull(output_path=args.artifact, callback=post_process)
168
169    psnr_list = []
170    ssim_list = []
171    for i, hr in enumerate(targets):
172        psnr_list.append(piq.psnr(hr, output_raws[i]))
173        ssim_list.append(piq.ssim(hr, output_raws[i]))
174
175    avg_PSNR = sum(psnr_list).item() / len(psnr_list)
176    avg_SSIM = sum(ssim_list).item() / len(ssim_list)
177    if args.ip and args.port != -1:
178        with Client((args.ip, args.port)) as conn:
179            conn.send(json.dumps({"PSNR": avg_PSNR, "SSIM": avg_SSIM}))
180    else:
181        print(f"Average of PNSR is: {avg_PSNR}")
182        print(f"Average of SSIM is: {avg_SSIM}")
183
184
185if __name__ == "__main__":
186    parser = setup_common_args_and_variables()
187
188    parser.add_argument(
189        "-a",
190        "--artifact",
191        help="path for storing generated artifacts by this example. Default ./edsr",
192        default="./edsr",
193        type=str,
194    )
195
196    parser.add_argument(
197        "-r",
198        "--hr_ref_dir",
199        help="Path to the high resolution images",
200        default="",
201        type=str,
202    )
203
204    parser.add_argument(
205        "-l",
206        "--lr_dir",
207        help="Path to the low resolution image inputs",
208        default="",
209        type=str,
210    )
211
212    parser.add_argument(
213        "-d",
214        "--default_dataset",
215        help="If specified, download and use B100 dataset by torchSR API",
216        action="store_true",
217        default=False,
218    )
219
220    args = parser.parse_args()
221    try:
222        main(args)
223    except Exception as e:
224        if args.ip and args.port != -1:
225            with Client((args.ip, args.port)) as conn:
226                conn.send(json.dumps({"Error": str(e)}))
227        else:
228            raise Exception(e)
229