xref: /aosp_15_r20/external/executorch/examples/qualcomm/scripts/edsr.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Qualcomm Innovation Center, Inc.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerimport json
8*523fa7a6SAndroid Build Coastguard Workerimport os
9*523fa7a6SAndroid Build Coastguard Workerimport re
10*523fa7a6SAndroid Build Coastguard Workerfrom multiprocessing.connection import Client
11*523fa7a6SAndroid Build Coastguard Worker
12*523fa7a6SAndroid Build Coastguard Workerimport numpy as np
13*523fa7a6SAndroid Build Coastguard Workerimport piq
14*523fa7a6SAndroid Build Coastguard Workerimport torch
15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.examples.models.edsr import EdsrModel
17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.examples.qualcomm.utils import (
18*523fa7a6SAndroid Build Coastguard Worker    build_executorch_binary,
19*523fa7a6SAndroid Build Coastguard Worker    make_output_dir,
20*523fa7a6SAndroid Build Coastguard Worker    parse_skip_delegation_node,
21*523fa7a6SAndroid Build Coastguard Worker    setup_common_args_and_variables,
22*523fa7a6SAndroid Build Coastguard Worker    SimpleADB,
23*523fa7a6SAndroid Build Coastguard Worker)
24*523fa7a6SAndroid Build Coastguard Worker
25*523fa7a6SAndroid Build Coastguard Workerfrom PIL import Image
26*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils.data import Dataset
27*523fa7a6SAndroid Build Coastguard Workerfrom torchsr.datasets import B100
28*523fa7a6SAndroid Build Coastguard Workerfrom torchvision.transforms.functional import to_pil_image, to_tensor
29*523fa7a6SAndroid Build Coastguard Worker
30*523fa7a6SAndroid Build Coastguard Worker
31*523fa7a6SAndroid Build Coastguard Workerclass SrDataset(Dataset):
32*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, hr_dir: str, lr_dir: str):
33*523fa7a6SAndroid Build Coastguard Worker        self.input_size = np.asanyarray([224, 224])
34*523fa7a6SAndroid Build Coastguard Worker        self.hr = []
35*523fa7a6SAndroid Build Coastguard Worker        self.lr = []
36*523fa7a6SAndroid Build Coastguard Worker
37*523fa7a6SAndroid Build Coastguard Worker        for file in sorted(os.listdir(hr_dir)):
38*523fa7a6SAndroid Build Coastguard Worker            self.hr.append(self._resize_img(os.path.join(hr_dir, file), 2))
39*523fa7a6SAndroid Build Coastguard Worker
40*523fa7a6SAndroid Build Coastguard Worker        for file in sorted(os.listdir(lr_dir)):
41*523fa7a6SAndroid Build Coastguard Worker            self.lr.append(self._resize_img(os.path.join(lr_dir, file), 1))
42*523fa7a6SAndroid Build Coastguard Worker
43*523fa7a6SAndroid Build Coastguard Worker        if len(self.hr) != len(self.lr):
44*523fa7a6SAndroid Build Coastguard Worker            raise AssertionError(
45*523fa7a6SAndroid Build Coastguard Worker                "The number of high resolution pics is not equal to low "
46*523fa7a6SAndroid Build Coastguard Worker                "resolution pics"
47*523fa7a6SAndroid Build Coastguard Worker            )
48*523fa7a6SAndroid Build Coastguard Worker
49*523fa7a6SAndroid Build Coastguard Worker    def __getitem__(self, idx: int):
50*523fa7a6SAndroid Build Coastguard Worker        return self.hr[idx], self.lr[idx]
51*523fa7a6SAndroid Build Coastguard Worker
52*523fa7a6SAndroid Build Coastguard Worker    def __len__(self):
53*523fa7a6SAndroid Build Coastguard Worker        return len(self.lr)
54*523fa7a6SAndroid Build Coastguard Worker
55*523fa7a6SAndroid Build Coastguard Worker    def _resize_img(self, file: str, scale: int):
56*523fa7a6SAndroid Build Coastguard Worker        with Image.open(file) as img:
57*523fa7a6SAndroid Build Coastguard Worker            return to_tensor(img.resize(tuple(self.input_size * scale))).unsqueeze(0)
58*523fa7a6SAndroid Build Coastguard Worker
59*523fa7a6SAndroid Build Coastguard Worker    def get_input_list(self):
60*523fa7a6SAndroid Build Coastguard Worker        input_list = ""
61*523fa7a6SAndroid Build Coastguard Worker        for i in range(len(self.lr)):
62*523fa7a6SAndroid Build Coastguard Worker            input_list += f"input_{i}_0.raw\n"
63*523fa7a6SAndroid Build Coastguard Worker        return input_list
64*523fa7a6SAndroid Build Coastguard Worker
65*523fa7a6SAndroid Build Coastguard Worker
66*523fa7a6SAndroid Build Coastguard Workerdef get_b100(
67*523fa7a6SAndroid Build Coastguard Worker    dataset_dir: str,
68*523fa7a6SAndroid Build Coastguard Worker):
69*523fa7a6SAndroid Build Coastguard Worker    hr_dir = f"{dataset_dir}/sr_bm_dataset/SRBenchmarks/benchmark/B100/HR"
70*523fa7a6SAndroid Build Coastguard Worker    lr_dir = f"{dataset_dir}/sr_bm_dataset/SRBenchmarks/benchmark/B100/LR_bicubic/X2"
71*523fa7a6SAndroid Build Coastguard Worker
72*523fa7a6SAndroid Build Coastguard Worker    if not os.path.exists(hr_dir) or not os.path.exists(lr_dir):
73*523fa7a6SAndroid Build Coastguard Worker        B100(root=f"{dataset_dir}/sr_bm_dataset", scale=2, download=True)
74*523fa7a6SAndroid Build Coastguard Worker
75*523fa7a6SAndroid Build Coastguard Worker    return SrDataset(hr_dir, lr_dir)
76*523fa7a6SAndroid Build Coastguard Worker
77*523fa7a6SAndroid Build Coastguard Worker
78*523fa7a6SAndroid Build Coastguard Workerdef get_dataset(hr_dir: str, lr_dir: str, default_dataset: str, dataset_dir: str):
79*523fa7a6SAndroid Build Coastguard Worker    if not (lr_dir and hr_dir) and not default_dataset:
80*523fa7a6SAndroid Build Coastguard Worker        raise RuntimeError(
81*523fa7a6SAndroid Build Coastguard Worker            "Nither custom dataset is provided nor using default dataset."
82*523fa7a6SAndroid Build Coastguard Worker        )
83*523fa7a6SAndroid Build Coastguard Worker
84*523fa7a6SAndroid Build Coastguard Worker    if (lr_dir and hr_dir) and default_dataset:
85*523fa7a6SAndroid Build Coastguard Worker        raise RuntimeError("Either use custom dataset, or use default dataset.")
86*523fa7a6SAndroid Build Coastguard Worker
87*523fa7a6SAndroid Build Coastguard Worker    if default_dataset:
88*523fa7a6SAndroid Build Coastguard Worker        return get_b100(dataset_dir)
89*523fa7a6SAndroid Build Coastguard Worker
90*523fa7a6SAndroid Build Coastguard Worker    return SrDataset(hr_dir, lr_dir)
91*523fa7a6SAndroid Build Coastguard Worker
92*523fa7a6SAndroid Build Coastguard Worker
93*523fa7a6SAndroid Build Coastguard Workerdef main(args):
94*523fa7a6SAndroid Build Coastguard Worker    skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
95*523fa7a6SAndroid Build Coastguard Worker
96*523fa7a6SAndroid Build Coastguard Worker    # ensure the working directory exist.
97*523fa7a6SAndroid Build Coastguard Worker    os.makedirs(args.artifact, exist_ok=True)
98*523fa7a6SAndroid Build Coastguard Worker
99*523fa7a6SAndroid Build Coastguard Worker    if not args.compile_only and args.device is None:
100*523fa7a6SAndroid Build Coastguard Worker        raise RuntimeError(
101*523fa7a6SAndroid Build Coastguard Worker            "device serial is required if not compile only. "
102*523fa7a6SAndroid Build Coastguard Worker            "Please specify a device serial by -s/--device argument."
103*523fa7a6SAndroid Build Coastguard Worker        )
104*523fa7a6SAndroid Build Coastguard Worker
105*523fa7a6SAndroid Build Coastguard Worker    dataset = get_dataset(
106*523fa7a6SAndroid Build Coastguard Worker        args.hr_ref_dir, args.lr_dir, args.default_dataset, args.artifact
107*523fa7a6SAndroid Build Coastguard Worker    )
108*523fa7a6SAndroid Build Coastguard Worker
109*523fa7a6SAndroid Build Coastguard Worker    inputs, targets, input_list = dataset.lr, dataset.hr, dataset.get_input_list()
110*523fa7a6SAndroid Build Coastguard Worker    pte_filename = "edsr_qnn_q8"
111*523fa7a6SAndroid Build Coastguard Worker    instance = EdsrModel()
112*523fa7a6SAndroid Build Coastguard Worker
113*523fa7a6SAndroid Build Coastguard Worker    build_executorch_binary(
114*523fa7a6SAndroid Build Coastguard Worker        instance.get_eager_model().eval(),
115*523fa7a6SAndroid Build Coastguard Worker        (inputs[0],),
116*523fa7a6SAndroid Build Coastguard Worker        args.model,
117*523fa7a6SAndroid Build Coastguard Worker        f"{args.artifact}/{pte_filename}",
118*523fa7a6SAndroid Build Coastguard Worker        [(input,) for input in inputs],
119*523fa7a6SAndroid Build Coastguard Worker        skip_node_id_set=skip_node_id_set,
120*523fa7a6SAndroid Build Coastguard Worker        skip_node_op_set=skip_node_op_set,
121*523fa7a6SAndroid Build Coastguard Worker        quant_dtype=QuantDtype.use_8a8w,
122*523fa7a6SAndroid Build Coastguard Worker        shared_buffer=args.shared_buffer,
123*523fa7a6SAndroid Build Coastguard Worker    )
124*523fa7a6SAndroid Build Coastguard Worker
125*523fa7a6SAndroid Build Coastguard Worker    if args.compile_only:
126*523fa7a6SAndroid Build Coastguard Worker        return
127*523fa7a6SAndroid Build Coastguard Worker
128*523fa7a6SAndroid Build Coastguard Worker    adb = SimpleADB(
129*523fa7a6SAndroid Build Coastguard Worker        qnn_sdk=os.getenv("QNN_SDK_ROOT"),
130*523fa7a6SAndroid Build Coastguard Worker        build_path=f"{args.build_folder}",
131*523fa7a6SAndroid Build Coastguard Worker        pte_path=f"{args.artifact}/{pte_filename}.pte",
132*523fa7a6SAndroid Build Coastguard Worker        workspace=f"/data/local/tmp/executorch/{pte_filename}",
133*523fa7a6SAndroid Build Coastguard Worker        device_id=args.device,
134*523fa7a6SAndroid Build Coastguard Worker        host_id=args.host,
135*523fa7a6SAndroid Build Coastguard Worker        soc_model=args.model,
136*523fa7a6SAndroid Build Coastguard Worker        shared_buffer=args.shared_buffer,
137*523fa7a6SAndroid Build Coastguard Worker    )
138*523fa7a6SAndroid Build Coastguard Worker    adb.push(inputs=inputs, input_list=input_list)
139*523fa7a6SAndroid Build Coastguard Worker    adb.execute()
140*523fa7a6SAndroid Build Coastguard Worker
141*523fa7a6SAndroid Build Coastguard Worker    # collect output data
142*523fa7a6SAndroid Build Coastguard Worker    output_data_folder = f"{args.artifact}/outputs"
143*523fa7a6SAndroid Build Coastguard Worker    output_pic_folder = f"{args.artifact}/output_pics"
144*523fa7a6SAndroid Build Coastguard Worker    make_output_dir(output_data_folder)
145*523fa7a6SAndroid Build Coastguard Worker    make_output_dir(output_pic_folder)
146*523fa7a6SAndroid Build Coastguard Worker
147*523fa7a6SAndroid Build Coastguard Worker    output_raws = []
148*523fa7a6SAndroid Build Coastguard Worker
149*523fa7a6SAndroid Build Coastguard Worker    def post_process():
150*523fa7a6SAndroid Build Coastguard Worker        cnt = 0
151*523fa7a6SAndroid Build Coastguard Worker        output_shape = tuple(targets[0].size())
152*523fa7a6SAndroid Build Coastguard Worker        for f in sorted(
153*523fa7a6SAndroid Build Coastguard Worker            os.listdir(output_data_folder), key=lambda f: int(f.split("_")[1])
154*523fa7a6SAndroid Build Coastguard Worker        ):
155*523fa7a6SAndroid Build Coastguard Worker            filename = os.path.join(output_data_folder, f)
156*523fa7a6SAndroid Build Coastguard Worker            if re.match(r"^output_[0-9]+_[1-9].raw$", f):
157*523fa7a6SAndroid Build Coastguard Worker                os.remove(filename)
158*523fa7a6SAndroid Build Coastguard Worker            else:
159*523fa7a6SAndroid Build Coastguard Worker                output = np.fromfile(filename, dtype=np.float32)
160*523fa7a6SAndroid Build Coastguard Worker                output = torch.tensor(output).reshape(output_shape).clamp(0, 1)
161*523fa7a6SAndroid Build Coastguard Worker                output_raws.append(output)
162*523fa7a6SAndroid Build Coastguard Worker                to_pil_image(output.squeeze(0)).save(
163*523fa7a6SAndroid Build Coastguard Worker                    os.path.join(output_pic_folder, str(cnt) + ".png")
164*523fa7a6SAndroid Build Coastguard Worker                )
165*523fa7a6SAndroid Build Coastguard Worker                cnt += 1
166*523fa7a6SAndroid Build Coastguard Worker
167*523fa7a6SAndroid Build Coastguard Worker    adb.pull(output_path=args.artifact, callback=post_process)
168*523fa7a6SAndroid Build Coastguard Worker
169*523fa7a6SAndroid Build Coastguard Worker    psnr_list = []
170*523fa7a6SAndroid Build Coastguard Worker    ssim_list = []
171*523fa7a6SAndroid Build Coastguard Worker    for i, hr in enumerate(targets):
172*523fa7a6SAndroid Build Coastguard Worker        psnr_list.append(piq.psnr(hr, output_raws[i]))
173*523fa7a6SAndroid Build Coastguard Worker        ssim_list.append(piq.ssim(hr, output_raws[i]))
174*523fa7a6SAndroid Build Coastguard Worker
175*523fa7a6SAndroid Build Coastguard Worker    avg_PSNR = sum(psnr_list).item() / len(psnr_list)
176*523fa7a6SAndroid Build Coastguard Worker    avg_SSIM = sum(ssim_list).item() / len(ssim_list)
177*523fa7a6SAndroid Build Coastguard Worker    if args.ip and args.port != -1:
178*523fa7a6SAndroid Build Coastguard Worker        with Client((args.ip, args.port)) as conn:
179*523fa7a6SAndroid Build Coastguard Worker            conn.send(json.dumps({"PSNR": avg_PSNR, "SSIM": avg_SSIM}))
180*523fa7a6SAndroid Build Coastguard Worker    else:
181*523fa7a6SAndroid Build Coastguard Worker        print(f"Average of PNSR is: {avg_PSNR}")
182*523fa7a6SAndroid Build Coastguard Worker        print(f"Average of SSIM is: {avg_SSIM}")
183*523fa7a6SAndroid Build Coastguard Worker
184*523fa7a6SAndroid Build Coastguard Worker
185*523fa7a6SAndroid Build Coastguard Workerif __name__ == "__main__":
186*523fa7a6SAndroid Build Coastguard Worker    parser = setup_common_args_and_variables()
187*523fa7a6SAndroid Build Coastguard Worker
188*523fa7a6SAndroid Build Coastguard Worker    parser.add_argument(
189*523fa7a6SAndroid Build Coastguard Worker        "-a",
190*523fa7a6SAndroid Build Coastguard Worker        "--artifact",
191*523fa7a6SAndroid Build Coastguard Worker        help="path for storing generated artifacts by this example. Default ./edsr",
192*523fa7a6SAndroid Build Coastguard Worker        default="./edsr",
193*523fa7a6SAndroid Build Coastguard Worker        type=str,
194*523fa7a6SAndroid Build Coastguard Worker    )
195*523fa7a6SAndroid Build Coastguard Worker
196*523fa7a6SAndroid Build Coastguard Worker    parser.add_argument(
197*523fa7a6SAndroid Build Coastguard Worker        "-r",
198*523fa7a6SAndroid Build Coastguard Worker        "--hr_ref_dir",
199*523fa7a6SAndroid Build Coastguard Worker        help="Path to the high resolution images",
200*523fa7a6SAndroid Build Coastguard Worker        default="",
201*523fa7a6SAndroid Build Coastguard Worker        type=str,
202*523fa7a6SAndroid Build Coastguard Worker    )
203*523fa7a6SAndroid Build Coastguard Worker
204*523fa7a6SAndroid Build Coastguard Worker    parser.add_argument(
205*523fa7a6SAndroid Build Coastguard Worker        "-l",
206*523fa7a6SAndroid Build Coastguard Worker        "--lr_dir",
207*523fa7a6SAndroid Build Coastguard Worker        help="Path to the low resolution image inputs",
208*523fa7a6SAndroid Build Coastguard Worker        default="",
209*523fa7a6SAndroid Build Coastguard Worker        type=str,
210*523fa7a6SAndroid Build Coastguard Worker    )
211*523fa7a6SAndroid Build Coastguard Worker
212*523fa7a6SAndroid Build Coastguard Worker    parser.add_argument(
213*523fa7a6SAndroid Build Coastguard Worker        "-d",
214*523fa7a6SAndroid Build Coastguard Worker        "--default_dataset",
215*523fa7a6SAndroid Build Coastguard Worker        help="If specified, download and use B100 dataset by torchSR API",
216*523fa7a6SAndroid Build Coastguard Worker        action="store_true",
217*523fa7a6SAndroid Build Coastguard Worker        default=False,
218*523fa7a6SAndroid Build Coastguard Worker    )
219*523fa7a6SAndroid Build Coastguard Worker
220*523fa7a6SAndroid Build Coastguard Worker    args = parser.parse_args()
221*523fa7a6SAndroid Build Coastguard Worker    try:
222*523fa7a6SAndroid Build Coastguard Worker        main(args)
223*523fa7a6SAndroid Build Coastguard Worker    except Exception as e:
224*523fa7a6SAndroid Build Coastguard Worker        if args.ip and args.port != -1:
225*523fa7a6SAndroid Build Coastguard Worker            with Client((args.ip, args.port)) as conn:
226*523fa7a6SAndroid Build Coastguard Worker                conn.send(json.dumps({"Error": str(e)}))
227*523fa7a6SAndroid Build Coastguard Worker        else:
228*523fa7a6SAndroid Build Coastguard Worker            raise Exception(e)
229