xref: /aosp_15_r20/external/executorch/examples/mediatek/eval_utils/eval_oss_result.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) MediaTek Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import json
9import os
10
11import numpy as np
12import piq
13import torch
14
15
16def check_data(target_f, predict_f):
17    target_files = os.listdir(target_f)
18    predict_files = os.listdir(predict_f)
19    if len(target_files) != len(predict_files):
20        raise RuntimeError(
21            "Data number in target folder and prediction folder must be same"
22        )
23
24    predict_set = set(predict_files)
25    for f in target_files:
26        # target file naming rule is golden_sampleId_outId.bin
27        # predict file naming rule is output_sampleId_outId.bin
28        pred_name = f.replace("golden", "output")
29        try:
30            predict_set.remove(pred_name)
31        except KeyError:
32            raise RuntimeError(f"Cannot find {pred_name} in {predict_f}")
33
34    if predict_set:
35        target_name = next(predict_set).replace("output", "golden")
36        raise RuntimeError(f"Cannot find {target_name} in {target_f}")
37
38
39def eval_topk(target_f, predict_f):
40    def solve(prob, target, k):
41        _, indices = torch.topk(prob, k=k, sorted=True)
42        golden = torch.reshape(target, [-1, 1])
43        correct = golden == indices
44        if torch.any(correct):
45            return 1
46        else:
47            return 0
48
49    target_files = os.listdir(target_f)
50
51    cnt10 = 0
52    cnt50 = 0
53    for target_name in target_files:
54        pred_name = target_name.replace("golden", "output")
55
56        pred_npy = np.fromfile(os.path.join(predict_f, pred_name), dtype=np.float32)
57        target_npy = np.fromfile(os.path.join(target_f, target_name), dtype=np.int64)[0]
58        cnt10 += solve(torch.from_numpy(pred_npy), torch.from_numpy(target_npy), 10)
59        cnt50 += solve(torch.from_numpy(pred_npy), torch.from_numpy(target_npy), 50)
60
61    print("Top10 acc:", cnt10 * 100.0 / len(target_files))
62    print("Top50 acc:", cnt50 * 100.0 / len(target_files))
63
64
65def eval_piq(target_f, predict_f):
66    target_files = os.listdir(target_f)
67
68    psnr_list = []
69    ssim_list = []
70    for target_name in target_files:
71        pred_name = target_name.replace("golden", "output")
72        hr = np.fromfile(os.path.join(target_f, target_name), dtype=np.float32)
73        hr = hr.reshape((1, 448, 448, 3))
74        hr = np.moveaxis(hr, 3, 1)
75        hr = torch.from_numpy(hr)
76
77        sr = np.fromfile(os.path.join(predict_f, pred_name), dtype=np.float32)
78        sr = sr.reshape((1, 448, 448, 3))
79        sr = np.moveaxis(sr, 3, 1)
80        sr = torch.from_numpy(sr).clamp(0, 1)
81
82        psnr_list.append(piq.psnr(hr, sr))
83        ssim_list.append(piq.ssim(hr, sr))
84
85    avg_psnr = sum(psnr_list).item() / len(psnr_list)
86    avg_ssim = sum(ssim_list).item() / len(ssim_list)
87
88    print(f"Avg of PSNR is: {avg_psnr}")
89    print(f"Avg of SSIM is: {avg_ssim}")
90
91
92def eval_segmentation(target_f, predict_f):
93    classes = [
94        "Backround",
95        "Aeroplane",
96        "Bicycle",
97        "Bird",
98        "Boat",
99        "Bottle",
100        "Bus",
101        "Car",
102        "Cat",
103        "Chair",
104        "Cow",
105        "DiningTable",
106        "Dog",
107        "Horse",
108        "MotorBike",
109        "Person",
110        "PottedPlant",
111        "Sheep",
112        "Sofa",
113        "Train",
114        "TvMonitor",
115    ]
116
117    target_files = os.listdir(target_f)
118
119    def make_confusion(goldens, predictions, num_classes):
120        def histogram(golden, predict):
121            mask = golden < num_classes
122            hist = np.bincount(
123                num_classes * golden[mask].astype(int) + predict[mask],
124                minlength=num_classes**2,
125            ).reshape(num_classes, num_classes)
126            return hist
127
128        confusion = np.zeros((num_classes, num_classes))
129        for g, p in zip(goldens, predictions):
130            confusion += histogram(g.flatten(), p.flatten())
131
132        return confusion
133
134    pred_list = []
135    target_list = []
136    for target_name in target_files:
137        pred_name = target_name.replace("golden", "output")
138        target_npy = np.fromfile(os.path.join(target_f, target_name), dtype=np.uint8)
139        target_npy = target_npy.reshape((224, 224))
140        target_list.append(target_npy)
141
142        pred_npy = np.fromfile(os.path.join(predict_f, pred_name), dtype=np.float32)
143        pred_npy = pred_npy.reshape((224, 224, len(classes)))
144        pred_npy = pred_npy.argmax(2).astype(np.uint8)
145        pred_list.append(pred_npy)
146
147    eps = 1e-6
148    confusion = make_confusion(target_list, pred_list, len(classes))
149
150    pa = np.diag(confusion).sum() / (confusion.sum() + eps)
151    mpa = np.mean(np.diag(confusion) / (confusion.sum(axis=1) + eps))
152    iou = np.diag(confusion) / (
153        confusion.sum(axis=1) + confusion.sum(axis=0) - np.diag(confusion) + eps
154    )
155    miou = np.mean(iou)
156    cls_iou = dict(zip(classes, iou))
157
158    print(f"PA   : {pa}")
159    print(f"MPA  : {mpa}")
160    print(f"MIoU : {miou}")
161    print(f"CIoU : \n{json.dumps(cls_iou, indent=2)}")
162
163
164if __name__ == "__main__":
165    parser = argparse.ArgumentParser()
166
167    parser.add_argument(
168        "--target_f",
169        help="folder of target data",
170        type=str,
171        required=True,
172    )
173
174    parser.add_argument(
175        "--out_f",
176        help="folder of model prediction data",
177        type=str,
178        required=True,
179    )
180
181    parser.add_argument(
182        "--eval_type",
183        help="Choose eval type from: topk, piq, segmentation",
184        type=str,
185        choices=["topk", "piq", "segmentation"],
186        required=True,
187    )
188
189    args = parser.parse_args()
190
191    check_data(args.target_f, args.out_f)
192
193    if args.eval_type == "topk":
194        eval_topk(args.target_f, args.out_f)
195    elif args.eval_type == "piq":
196        eval_piq(args.target_f, args.out_f)
197    elif args.eval_type == "segmentation":
198        eval_segmentation(args.target_f, args.out_f)
199