1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import json 8import os 9import re 10from multiprocessing.connection import Client 11 12import numpy as np 13import piq 14import torch 15from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype 16from executorch.examples.models.edsr import EdsrModel 17from executorch.examples.qualcomm.utils import ( 18 build_executorch_binary, 19 make_output_dir, 20 parse_skip_delegation_node, 21 setup_common_args_and_variables, 22 SimpleADB, 23) 24 25from PIL import Image 26from torch.utils.data import Dataset 27from torchsr.datasets import B100 28from torchvision.transforms.functional import to_pil_image, to_tensor 29 30 31class SrDataset(Dataset): 32 def __init__(self, hr_dir: str, lr_dir: str): 33 self.input_size = np.asanyarray([224, 224]) 34 self.hr = [] 35 self.lr = [] 36 37 for file in sorted(os.listdir(hr_dir)): 38 self.hr.append(self._resize_img(os.path.join(hr_dir, file), 2)) 39 40 for file in sorted(os.listdir(lr_dir)): 41 self.lr.append(self._resize_img(os.path.join(lr_dir, file), 1)) 42 43 if len(self.hr) != len(self.lr): 44 raise AssertionError( 45 "The number of high resolution pics is not equal to low " 46 "resolution pics" 47 ) 48 49 def __getitem__(self, idx: int): 50 return self.hr[idx], self.lr[idx] 51 52 def __len__(self): 53 return len(self.lr) 54 55 def _resize_img(self, file: str, scale: int): 56 with Image.open(file) as img: 57 return to_tensor(img.resize(tuple(self.input_size * scale))).unsqueeze(0) 58 59 def get_input_list(self): 60 input_list = "" 61 for i in range(len(self.lr)): 62 input_list += f"input_{i}_0.raw\n" 63 return input_list 64 65 66def get_b100( 67 dataset_dir: str, 68): 69 hr_dir = f"{dataset_dir}/sr_bm_dataset/SRBenchmarks/benchmark/B100/HR" 70 lr_dir = f"{dataset_dir}/sr_bm_dataset/SRBenchmarks/benchmark/B100/LR_bicubic/X2" 71 72 if not os.path.exists(hr_dir) or not os.path.exists(lr_dir): 73 B100(root=f"{dataset_dir}/sr_bm_dataset", scale=2, download=True) 74 75 return SrDataset(hr_dir, lr_dir) 76 77 78def get_dataset(hr_dir: str, lr_dir: str, default_dataset: str, dataset_dir: str): 79 if not (lr_dir and hr_dir) and not default_dataset: 80 raise RuntimeError( 81 "Nither custom dataset is provided nor using default dataset." 82 ) 83 84 if (lr_dir and hr_dir) and default_dataset: 85 raise RuntimeError("Either use custom dataset, or use default dataset.") 86 87 if default_dataset: 88 return get_b100(dataset_dir) 89 90 return SrDataset(hr_dir, lr_dir) 91 92 93def main(args): 94 skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) 95 96 # ensure the working directory exist. 97 os.makedirs(args.artifact, exist_ok=True) 98 99 if not args.compile_only and args.device is None: 100 raise RuntimeError( 101 "device serial is required if not compile only. " 102 "Please specify a device serial by -s/--device argument." 103 ) 104 105 dataset = get_dataset( 106 args.hr_ref_dir, args.lr_dir, args.default_dataset, args.artifact 107 ) 108 109 inputs, targets, input_list = dataset.lr, dataset.hr, dataset.get_input_list() 110 pte_filename = "edsr_qnn_q8" 111 instance = EdsrModel() 112 113 build_executorch_binary( 114 instance.get_eager_model().eval(), 115 (inputs[0],), 116 args.model, 117 f"{args.artifact}/{pte_filename}", 118 [(input,) for input in inputs], 119 skip_node_id_set=skip_node_id_set, 120 skip_node_op_set=skip_node_op_set, 121 quant_dtype=QuantDtype.use_8a8w, 122 shared_buffer=args.shared_buffer, 123 ) 124 125 if args.compile_only: 126 return 127 128 adb = SimpleADB( 129 qnn_sdk=os.getenv("QNN_SDK_ROOT"), 130 build_path=f"{args.build_folder}", 131 pte_path=f"{args.artifact}/{pte_filename}.pte", 132 workspace=f"/data/local/tmp/executorch/{pte_filename}", 133 device_id=args.device, 134 host_id=args.host, 135 soc_model=args.model, 136 shared_buffer=args.shared_buffer, 137 ) 138 adb.push(inputs=inputs, input_list=input_list) 139 adb.execute() 140 141 # collect output data 142 output_data_folder = f"{args.artifact}/outputs" 143 output_pic_folder = f"{args.artifact}/output_pics" 144 make_output_dir(output_data_folder) 145 make_output_dir(output_pic_folder) 146 147 output_raws = [] 148 149 def post_process(): 150 cnt = 0 151 output_shape = tuple(targets[0].size()) 152 for f in sorted( 153 os.listdir(output_data_folder), key=lambda f: int(f.split("_")[1]) 154 ): 155 filename = os.path.join(output_data_folder, f) 156 if re.match(r"^output_[0-9]+_[1-9].raw$", f): 157 os.remove(filename) 158 else: 159 output = np.fromfile(filename, dtype=np.float32) 160 output = torch.tensor(output).reshape(output_shape).clamp(0, 1) 161 output_raws.append(output) 162 to_pil_image(output.squeeze(0)).save( 163 os.path.join(output_pic_folder, str(cnt) + ".png") 164 ) 165 cnt += 1 166 167 adb.pull(output_path=args.artifact, callback=post_process) 168 169 psnr_list = [] 170 ssim_list = [] 171 for i, hr in enumerate(targets): 172 psnr_list.append(piq.psnr(hr, output_raws[i])) 173 ssim_list.append(piq.ssim(hr, output_raws[i])) 174 175 avg_PSNR = sum(psnr_list).item() / len(psnr_list) 176 avg_SSIM = sum(ssim_list).item() / len(ssim_list) 177 if args.ip and args.port != -1: 178 with Client((args.ip, args.port)) as conn: 179 conn.send(json.dumps({"PSNR": avg_PSNR, "SSIM": avg_SSIM})) 180 else: 181 print(f"Average of PNSR is: {avg_PSNR}") 182 print(f"Average of SSIM is: {avg_SSIM}") 183 184 185if __name__ == "__main__": 186 parser = setup_common_args_and_variables() 187 188 parser.add_argument( 189 "-a", 190 "--artifact", 191 help="path for storing generated artifacts by this example. Default ./edsr", 192 default="./edsr", 193 type=str, 194 ) 195 196 parser.add_argument( 197 "-r", 198 "--hr_ref_dir", 199 help="Path to the high resolution images", 200 default="", 201 type=str, 202 ) 203 204 parser.add_argument( 205 "-l", 206 "--lr_dir", 207 help="Path to the low resolution image inputs", 208 default="", 209 type=str, 210 ) 211 212 parser.add_argument( 213 "-d", 214 "--default_dataset", 215 help="If specified, download and use B100 dataset by torchSR API", 216 action="store_true", 217 default=False, 218 ) 219 220 args = parser.parse_args() 221 try: 222 main(args) 223 except Exception as e: 224 if args.ip and args.port != -1: 225 with Client((args.ip, args.port)) as conn: 226 conn.send(json.dumps({"Error": str(e)})) 227 else: 228 raise Exception(e) 229