xref: /aosp_15_r20/external/libopus/dnn/torch/osce/utils/endoscopy.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1""" module for inspecting models during inference """
2
3import os
4
5import yaml
6import matplotlib.pyplot as plt
7import matplotlib.animation as animation
8
9import torch
10import numpy as np
11
12# stores entries {key : {'fid' : fid, 'fs' : fs, 'dim' : dim, 'dtype' : dtype}}
13_state = dict()
14_folder = 'endoscopy'
15
16def get_gru_gates(gru, input, state):
17    hidden_size = gru.hidden_size
18
19    direct = torch.matmul(gru.weight_ih_l0, input.squeeze())
20    recurrent = torch.matmul(gru.weight_hh_l0, state.squeeze())
21
22    # reset gate
23    start, stop = 0 * hidden_size, 1 * hidden_size
24    reset_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
25
26    # update gate
27    start, stop = 1 * hidden_size, 2 * hidden_size
28    update_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
29
30    # new gate
31    start, stop = 2 * hidden_size, 3 * hidden_size
32    new_gate = torch.tanh(direct[start : stop] + gru.bias_ih_l0[start : stop] + reset_gate * (recurrent[start : stop] +  gru.bias_hh_l0[start : stop]))
33
34    return {'reset_gate' : reset_gate, 'update_gate' : update_gate, 'new_gate' : new_gate}
35
36
37def init(folder='endoscopy'):
38    """ sets up output folder for endoscopy data """
39
40    global _folder
41    _folder = folder
42
43    if not os.path.exists(folder):
44        os.makedirs(folder)
45    else:
46        print(f"warning: endoscopy folder {folder} exists. Content may be lost or inconsistent results may occur.")
47
48def write_data(key, data, fs):
49    """ appends data to previous data written under key """
50
51    global _state
52
53    # convert to numpy if torch.Tensor is given
54    if isinstance(data, torch.Tensor):
55        data = data.detach().numpy()
56
57    if not key in _state:
58        _state[key] = {
59            'fid'   : open(os.path.join(_folder, key + '.bin'), 'wb'),
60            'fs'    : fs,
61            'dim'   : tuple(data.shape),
62            'dtype' : str(data.dtype)
63        }
64
65        with open(os.path.join(_folder, key + '.yml'), 'w') as f:
66            f.write(yaml.dump({'fs' : fs, 'dim' : tuple(data.shape), 'dtype' : str(data.dtype).split('.')[-1]}))
67    else:
68        if _state[key]['fs'] != fs:
69            raise ValueError(f"fs changed for key {key}: {_state[key]['fs']} vs. {fs}")
70        if _state[key]['dtype'] != str(data.dtype):
71            raise ValueError(f"dtype changed for key {key}: {_state[key]['dtype']} vs. {str(data.dtype)}")
72        if _state[key]['dim'] != tuple(data.shape):
73            raise ValueError(f"dim changed for key {key}: {_state[key]['dim']} vs. {tuple(data.shape)}")
74
75    _state[key]['fid'].write(data.tobytes())
76
77def close(folder='endoscopy'):
78    """ clean up """
79    for key in _state.keys():
80        _state[key]['fid'].close()
81
82
83def read_data(folder='endoscopy'):
84    """ retrieves written data as numpy arrays """
85
86
87    keys = [name[:-4] for name in os.listdir(folder) if name.endswith('.yml')]
88
89    return_dict = dict()
90
91    for key in keys:
92        with open(os.path.join(folder, key + '.yml'), 'r') as f:
93            value = yaml.load(f.read(), yaml.FullLoader)
94
95        with open(os.path.join(folder, key + '.bin'), 'rb') as f:
96            data = np.frombuffer(f.read(), dtype=value['dtype'])
97
98        value['data'] = data.reshape((-1,) + value['dim'])
99
100        return_dict[key] = value
101
102    return return_dict
103
104def get_best_reshape(shape, target_ratio=1):
105    """ calculated the best 2d reshape of shape given the target ratio (rows/cols)"""
106
107    if len(shape) > 1:
108        pixel_count = 1
109        for s in shape:
110            pixel_count *= s
111    else:
112        pixel_count = shape[0]
113
114    if pixel_count == 1:
115        return (1,)
116
117    num_columns = int((pixel_count / target_ratio)**.5)
118
119    while (pixel_count % num_columns):
120        num_columns -= 1
121
122    num_rows = pixel_count // num_columns
123
124    return (num_rows, num_columns)
125
126def get_type_and_shape(shape):
127
128    # can happen if data is one dimensional
129    if len(shape) == 0:
130        shape = (1,)
131
132    # calculate pixel count
133    if len(shape) > 1:
134        pixel_count = 1
135        for s in shape:
136            pixel_count *= s
137    else:
138        pixel_count = shape[0]
139
140    if pixel_count == 1:
141        return 'plot', (1, )
142
143    # stay with shape if already 2-dimensional
144    if len(shape) == 2:
145        if (shape[0] != pixel_count) or (shape[1] != pixel_count):
146            return 'image', shape
147
148    return 'image', get_best_reshape(shape)
149
150def make_animation(data, filename, start_index=80, stop_index=-80, interval=20, half_signal_window_length=80):
151
152    # determine plot setup
153    num_keys = len(data.keys())
154
155    num_rows = int((num_keys * 3/4) ** .5)
156
157    num_cols = (num_keys + num_rows - 1) // num_rows
158
159    fig, axs = plt.subplots(num_rows, num_cols)
160    fig.set_size_inches(num_cols * 5, num_rows * 5)
161
162    display = dict()
163
164    fs_max = max([val['fs'] for val in data.values()])
165
166    num_samples = max([val['data'].shape[0] for val in data.values()])
167
168    keys = sorted(data.keys())
169
170    # inspect data
171    for i, key in enumerate(keys):
172        axs[i // num_cols, i % num_cols].title.set_text(key)
173
174        display[key] = dict()
175
176        display[key]['type'], display[key]['shape'] = get_type_and_shape(data[key]['dim'])
177        display[key]['down_factor'] = data[key]['fs'] / fs_max
178
179    start_index = max(start_index, half_signal_window_length)
180    while stop_index < 0:
181        stop_index += num_samples
182
183    stop_index = min(stop_index, num_samples - half_signal_window_length)
184
185    # actual plotting
186    frames = []
187    for index in range(start_index, stop_index):
188        ims = []
189        for i, key in enumerate(keys):
190            feature_index = int(round(index * display[key]['down_factor']))
191
192            if display[key]['type'] == 'plot':
193                ims.append(axs[i // num_cols, i % num_cols].plot(data[key]['data'][index - half_signal_window_length : index + half_signal_window_length], marker='P', markevery=[half_signal_window_length], animated=True, color='blue')[0])
194
195            elif display[key]['type'] == 'image':
196                ims.append(axs[i // num_cols, i % num_cols].imshow(data[key]['data'][index].reshape(display[key]['shape']), animated=True))
197
198        frames.append(ims)
199
200    ani = animation.ArtistAnimation(fig, frames, interval=interval, blit=True, repeat_delay=1000)
201
202    if not filename.endswith('.mp4'):
203        filename += '.mp4'
204
205    ani.save(filename)