1# mypy: allow-untyped-defs 2import numpy as np 3 4 5# Functions for converting 6def figure_to_image(figures, close=True): 7 """Render matplotlib figure to numpy format. 8 9 Note that this requires the ``matplotlib`` package. 10 11 Args: 12 figures (matplotlib.pyplot.figure or list of figures): figure or a list of figures 13 close (bool): Flag to automatically close the figure 14 15 Returns: 16 numpy.array: image in [CHW] order 17 """ 18 import matplotlib.pyplot as plt 19 import matplotlib.backends.backend_agg as plt_backend_agg 20 21 def render_to_rgb(figure): 22 canvas = plt_backend_agg.FigureCanvasAgg(figure) 23 canvas.draw() 24 data: np.ndarray = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8) 25 w, h = figure.canvas.get_width_height() 26 image_hwc = data.reshape([h, w, 4])[:, :, 0:3] 27 image_chw = np.moveaxis(image_hwc, source=2, destination=0) 28 if close: 29 plt.close(figure) 30 return image_chw 31 32 if isinstance(figures, list): 33 images = [render_to_rgb(figure) for figure in figures] 34 return np.stack(images) 35 else: 36 image = render_to_rgb(figures) 37 return image 38 39 40def _prepare_video(V): 41 """ 42 Convert a 5D tensor into 4D tensor. 43 44 Convesrion is done from [batchsize, time(frame), channel(color), height, width] (5D tensor) 45 to [time(frame), new_width, new_height, channel] (4D tensor). 46 47 A batch of images are spreaded to a grid, which forms a frame. 48 e.g. Video with batchsize 16 will have a 4x4 grid. 49 """ 50 b, t, c, h, w = V.shape 51 52 if V.dtype == np.uint8: 53 V = np.float32(V) / 255.0 54 55 def is_power2(num): 56 return num != 0 and ((num & (num - 1)) == 0) 57 58 # pad to nearest power of 2, all at once 59 if not is_power2(V.shape[0]): 60 len_addition = int(2 ** V.shape[0].bit_length() - V.shape[0]) 61 V = np.concatenate((V, np.zeros(shape=(len_addition, t, c, h, w))), axis=0) 62 63 n_rows = 2 ** ((b.bit_length() - 1) // 2) 64 n_cols = V.shape[0] // n_rows 65 66 V = np.reshape(V, newshape=(n_rows, n_cols, t, c, h, w)) 67 V = np.transpose(V, axes=(2, 0, 4, 1, 5, 3)) 68 V = np.reshape(V, newshape=(t, n_rows * h, n_cols * w, c)) 69 70 return V 71 72 73def make_grid(I, ncols=8): 74 # I: N1HW or N3HW 75 assert isinstance(I, np.ndarray), "plugin error, should pass numpy array here" 76 if I.shape[1] == 1: 77 I = np.concatenate([I, I, I], 1) 78 assert I.ndim == 4 and I.shape[1] == 3 79 nimg = I.shape[0] 80 H = I.shape[2] 81 W = I.shape[3] 82 ncols = min(nimg, ncols) 83 nrows = int(np.ceil(float(nimg) / ncols)) 84 canvas = np.zeros((3, H * nrows, W * ncols), dtype=I.dtype) 85 i = 0 86 for y in range(nrows): 87 for x in range(ncols): 88 if i >= nimg: 89 break 90 canvas[:, y * H : (y + 1) * H, x * W : (x + 1) * W] = I[i] 91 i = i + 1 92 return canvas 93 94 # if modality == 'IMG': 95 # if x.dtype == np.uint8: 96 # x = x.astype(np.float32) / 255.0 97 98 99def convert_to_HWC(tensor, input_format): # tensor: numpy array 100 assert len(set(input_format)) == len( 101 input_format 102 ), f"You can not use the same dimension shordhand twice. input_format: {input_format}" 103 assert len(tensor.shape) == len( 104 input_format 105 ), f"size of input tensor and input format are different. \ 106 tensor shape: {tensor.shape}, input_format: {input_format}" 107 input_format = input_format.upper() 108 109 if len(input_format) == 4: 110 index = [input_format.find(c) for c in "NCHW"] 111 tensor_NCHW = tensor.transpose(index) 112 tensor_CHW = make_grid(tensor_NCHW) 113 return tensor_CHW.transpose(1, 2, 0) 114 115 if len(input_format) == 3: 116 index = [input_format.find(c) for c in "HWC"] 117 tensor_HWC = tensor.transpose(index) 118 if tensor_HWC.shape[2] == 1: 119 tensor_HWC = np.concatenate([tensor_HWC, tensor_HWC, tensor_HWC], 2) 120 return tensor_HWC 121 122 if len(input_format) == 2: 123 index = [input_format.find(c) for c in "HW"] 124 tensor = tensor.transpose(index) 125 tensor = np.stack([tensor, tensor, tensor], 2) 126 return tensor 127