1""" module for inspecting models during inference """ 2 3import os 4 5import yaml 6import matplotlib.pyplot as plt 7import matplotlib.animation as animation 8 9import torch 10import numpy as np 11 12# stores entries {key : {'fid' : fid, 'fs' : fs, 'dim' : dim, 'dtype' : dtype}} 13_state = dict() 14_folder = 'endoscopy' 15 16def get_gru_gates(gru, input, state): 17 hidden_size = gru.hidden_size 18 19 direct = torch.matmul(gru.weight_ih_l0, input.squeeze()) 20 recurrent = torch.matmul(gru.weight_hh_l0, state.squeeze()) 21 22 # reset gate 23 start, stop = 0 * hidden_size, 1 * hidden_size 24 reset_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop]) 25 26 # update gate 27 start, stop = 1 * hidden_size, 2 * hidden_size 28 update_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop]) 29 30 # new gate 31 start, stop = 2 * hidden_size, 3 * hidden_size 32 new_gate = torch.tanh(direct[start : stop] + gru.bias_ih_l0[start : stop] + reset_gate * (recurrent[start : stop] + gru.bias_hh_l0[start : stop])) 33 34 return {'reset_gate' : reset_gate, 'update_gate' : update_gate, 'new_gate' : new_gate} 35 36 37def init(folder='endoscopy'): 38 """ sets up output folder for endoscopy data """ 39 40 global _folder 41 _folder = folder 42 43 if not os.path.exists(folder): 44 os.makedirs(folder) 45 else: 46 print(f"warning: endoscopy folder {folder} exists. Content may be lost or inconsistent results may occur.") 47 48def write_data(key, data, fs): 49 """ appends data to previous data written under key """ 50 51 global _state 52 53 # convert to numpy if torch.Tensor is given 54 if isinstance(data, torch.Tensor): 55 data = data.detach().numpy() 56 57 if not key in _state: 58 _state[key] = { 59 'fid' : open(os.path.join(_folder, key + '.bin'), 'wb'), 60 'fs' : fs, 61 'dim' : tuple(data.shape), 62 'dtype' : str(data.dtype) 63 } 64 65 with open(os.path.join(_folder, key + '.yml'), 'w') as f: 66 f.write(yaml.dump({'fs' : fs, 'dim' : tuple(data.shape), 'dtype' : str(data.dtype).split('.')[-1]})) 67 else: 68 if _state[key]['fs'] != fs: 69 raise ValueError(f"fs changed for key {key}: {_state[key]['fs']} vs. {fs}") 70 if _state[key]['dtype'] != str(data.dtype): 71 raise ValueError(f"dtype changed for key {key}: {_state[key]['dtype']} vs. {str(data.dtype)}") 72 if _state[key]['dim'] != tuple(data.shape): 73 raise ValueError(f"dim changed for key {key}: {_state[key]['dim']} vs. {tuple(data.shape)}") 74 75 _state[key]['fid'].write(data.tobytes()) 76 77def close(folder='endoscopy'): 78 """ clean up """ 79 for key in _state.keys(): 80 _state[key]['fid'].close() 81 82 83def read_data(folder='endoscopy'): 84 """ retrieves written data as numpy arrays """ 85 86 87 keys = [name[:-4] for name in os.listdir(folder) if name.endswith('.yml')] 88 89 return_dict = dict() 90 91 for key in keys: 92 with open(os.path.join(folder, key + '.yml'), 'r') as f: 93 value = yaml.load(f.read(), yaml.FullLoader) 94 95 with open(os.path.join(folder, key + '.bin'), 'rb') as f: 96 data = np.frombuffer(f.read(), dtype=value['dtype']) 97 98 value['data'] = data.reshape((-1,) + value['dim']) 99 100 return_dict[key] = value 101 102 return return_dict 103 104def get_best_reshape(shape, target_ratio=1): 105 """ calculated the best 2d reshape of shape given the target ratio (rows/cols)""" 106 107 if len(shape) > 1: 108 pixel_count = 1 109 for s in shape: 110 pixel_count *= s 111 else: 112 pixel_count = shape[0] 113 114 if pixel_count == 1: 115 return (1,) 116 117 num_columns = int((pixel_count / target_ratio)**.5) 118 119 while (pixel_count % num_columns): 120 num_columns -= 1 121 122 num_rows = pixel_count // num_columns 123 124 return (num_rows, num_columns) 125 126def get_type_and_shape(shape): 127 128 # can happen if data is one dimensional 129 if len(shape) == 0: 130 shape = (1,) 131 132 # calculate pixel count 133 if len(shape) > 1: 134 pixel_count = 1 135 for s in shape: 136 pixel_count *= s 137 else: 138 pixel_count = shape[0] 139 140 if pixel_count == 1: 141 return 'plot', (1, ) 142 143 # stay with shape if already 2-dimensional 144 if len(shape) == 2: 145 if (shape[0] != pixel_count) or (shape[1] != pixel_count): 146 return 'image', shape 147 148 return 'image', get_best_reshape(shape) 149 150def make_animation(data, filename, start_index=80, stop_index=-80, interval=20, half_signal_window_length=80): 151 152 # determine plot setup 153 num_keys = len(data.keys()) 154 155 num_rows = int((num_keys * 3/4) ** .5) 156 157 num_cols = (num_keys + num_rows - 1) // num_rows 158 159 fig, axs = plt.subplots(num_rows, num_cols) 160 fig.set_size_inches(num_cols * 5, num_rows * 5) 161 162 display = dict() 163 164 fs_max = max([val['fs'] for val in data.values()]) 165 166 num_samples = max([val['data'].shape[0] for val in data.values()]) 167 168 keys = sorted(data.keys()) 169 170 # inspect data 171 for i, key in enumerate(keys): 172 axs[i // num_cols, i % num_cols].title.set_text(key) 173 174 display[key] = dict() 175 176 display[key]['type'], display[key]['shape'] = get_type_and_shape(data[key]['dim']) 177 display[key]['down_factor'] = data[key]['fs'] / fs_max 178 179 start_index = max(start_index, half_signal_window_length) 180 while stop_index < 0: 181 stop_index += num_samples 182 183 stop_index = min(stop_index, num_samples - half_signal_window_length) 184 185 # actual plotting 186 frames = [] 187 for index in range(start_index, stop_index): 188 ims = [] 189 for i, key in enumerate(keys): 190 feature_index = int(round(index * display[key]['down_factor'])) 191 192 if display[key]['type'] == 'plot': 193 ims.append(axs[i // num_cols, i % num_cols].plot(data[key]['data'][index - half_signal_window_length : index + half_signal_window_length], marker='P', markevery=[half_signal_window_length], animated=True, color='blue')[0]) 194 195 elif display[key]['type'] == 'image': 196 ims.append(axs[i // num_cols, i % num_cols].imshow(data[key]['data'][index].reshape(display[key]['shape']), animated=True)) 197 198 frames.append(ims) 199 200 ani = animation.ArtistAnimation(fig, frames, interval=interval, blit=True, repeat_delay=1000) 201 202 if not filename.endswith('.mp4'): 203 filename += '.mp4' 204 205 ani.save(filename)