1# Copyright (c) MediaTek Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import argparse 8import json 9import os 10 11import numpy as np 12import piq 13import torch 14 15 16def check_data(target_f, predict_f): 17 target_files = os.listdir(target_f) 18 predict_files = os.listdir(predict_f) 19 if len(target_files) != len(predict_files): 20 raise RuntimeError( 21 "Data number in target folder and prediction folder must be same" 22 ) 23 24 predict_set = set(predict_files) 25 for f in target_files: 26 # target file naming rule is golden_sampleId_outId.bin 27 # predict file naming rule is output_sampleId_outId.bin 28 pred_name = f.replace("golden", "output") 29 try: 30 predict_set.remove(pred_name) 31 except KeyError: 32 raise RuntimeError(f"Cannot find {pred_name} in {predict_f}") 33 34 if predict_set: 35 target_name = next(predict_set).replace("output", "golden") 36 raise RuntimeError(f"Cannot find {target_name} in {target_f}") 37 38 39def eval_topk(target_f, predict_f): 40 def solve(prob, target, k): 41 _, indices = torch.topk(prob, k=k, sorted=True) 42 golden = torch.reshape(target, [-1, 1]) 43 correct = golden == indices 44 if torch.any(correct): 45 return 1 46 else: 47 return 0 48 49 target_files = os.listdir(target_f) 50 51 cnt10 = 0 52 cnt50 = 0 53 for target_name in target_files: 54 pred_name = target_name.replace("golden", "output") 55 56 pred_npy = np.fromfile(os.path.join(predict_f, pred_name), dtype=np.float32) 57 target_npy = np.fromfile(os.path.join(target_f, target_name), dtype=np.int64)[0] 58 cnt10 += solve(torch.from_numpy(pred_npy), torch.from_numpy(target_npy), 10) 59 cnt50 += solve(torch.from_numpy(pred_npy), torch.from_numpy(target_npy), 50) 60 61 print("Top10 acc:", cnt10 * 100.0 / len(target_files)) 62 print("Top50 acc:", cnt50 * 100.0 / len(target_files)) 63 64 65def eval_piq(target_f, predict_f): 66 target_files = os.listdir(target_f) 67 68 psnr_list = [] 69 ssim_list = [] 70 for target_name in target_files: 71 pred_name = target_name.replace("golden", "output") 72 hr = np.fromfile(os.path.join(target_f, target_name), dtype=np.float32) 73 hr = hr.reshape((1, 448, 448, 3)) 74 hr = np.moveaxis(hr, 3, 1) 75 hr = torch.from_numpy(hr) 76 77 sr = np.fromfile(os.path.join(predict_f, pred_name), dtype=np.float32) 78 sr = sr.reshape((1, 448, 448, 3)) 79 sr = np.moveaxis(sr, 3, 1) 80 sr = torch.from_numpy(sr).clamp(0, 1) 81 82 psnr_list.append(piq.psnr(hr, sr)) 83 ssim_list.append(piq.ssim(hr, sr)) 84 85 avg_psnr = sum(psnr_list).item() / len(psnr_list) 86 avg_ssim = sum(ssim_list).item() / len(ssim_list) 87 88 print(f"Avg of PSNR is: {avg_psnr}") 89 print(f"Avg of SSIM is: {avg_ssim}") 90 91 92def eval_segmentation(target_f, predict_f): 93 classes = [ 94 "Backround", 95 "Aeroplane", 96 "Bicycle", 97 "Bird", 98 "Boat", 99 "Bottle", 100 "Bus", 101 "Car", 102 "Cat", 103 "Chair", 104 "Cow", 105 "DiningTable", 106 "Dog", 107 "Horse", 108 "MotorBike", 109 "Person", 110 "PottedPlant", 111 "Sheep", 112 "Sofa", 113 "Train", 114 "TvMonitor", 115 ] 116 117 target_files = os.listdir(target_f) 118 119 def make_confusion(goldens, predictions, num_classes): 120 def histogram(golden, predict): 121 mask = golden < num_classes 122 hist = np.bincount( 123 num_classes * golden[mask].astype(int) + predict[mask], 124 minlength=num_classes**2, 125 ).reshape(num_classes, num_classes) 126 return hist 127 128 confusion = np.zeros((num_classes, num_classes)) 129 for g, p in zip(goldens, predictions): 130 confusion += histogram(g.flatten(), p.flatten()) 131 132 return confusion 133 134 pred_list = [] 135 target_list = [] 136 for target_name in target_files: 137 pred_name = target_name.replace("golden", "output") 138 target_npy = np.fromfile(os.path.join(target_f, target_name), dtype=np.uint8) 139 target_npy = target_npy.reshape((224, 224)) 140 target_list.append(target_npy) 141 142 pred_npy = np.fromfile(os.path.join(predict_f, pred_name), dtype=np.float32) 143 pred_npy = pred_npy.reshape((224, 224, len(classes))) 144 pred_npy = pred_npy.argmax(2).astype(np.uint8) 145 pred_list.append(pred_npy) 146 147 eps = 1e-6 148 confusion = make_confusion(target_list, pred_list, len(classes)) 149 150 pa = np.diag(confusion).sum() / (confusion.sum() + eps) 151 mpa = np.mean(np.diag(confusion) / (confusion.sum(axis=1) + eps)) 152 iou = np.diag(confusion) / ( 153 confusion.sum(axis=1) + confusion.sum(axis=0) - np.diag(confusion) + eps 154 ) 155 miou = np.mean(iou) 156 cls_iou = dict(zip(classes, iou)) 157 158 print(f"PA : {pa}") 159 print(f"MPA : {mpa}") 160 print(f"MIoU : {miou}") 161 print(f"CIoU : \n{json.dumps(cls_iou, indent=2)}") 162 163 164if __name__ == "__main__": 165 parser = argparse.ArgumentParser() 166 167 parser.add_argument( 168 "--target_f", 169 help="folder of target data", 170 type=str, 171 required=True, 172 ) 173 174 parser.add_argument( 175 "--out_f", 176 help="folder of model prediction data", 177 type=str, 178 required=True, 179 ) 180 181 parser.add_argument( 182 "--eval_type", 183 help="Choose eval type from: topk, piq, segmentation", 184 type=str, 185 choices=["topk", "piq", "segmentation"], 186 required=True, 187 ) 188 189 args = parser.parse_args() 190 191 check_data(args.target_f, args.out_f) 192 193 if args.eval_type == "topk": 194 eval_topk(args.target_f, args.out_f) 195 elif args.eval_type == "piq": 196 eval_piq(args.target_f, args.out_f) 197 elif args.eval_type == "segmentation": 198 eval_segmentation(args.target_f, args.out_f) 199