1# mypy: allow-untyped-defs 2"""This module converts objects into numpy array.""" 3 4import numpy as np 5 6import torch 7 8 9def make_np(x): 10 """ 11 Convert an object into numpy array. 12 13 Args: 14 x: An instance of torch tensor 15 16 Returns: 17 numpy.array: Numpy array 18 """ 19 if isinstance(x, np.ndarray): 20 return x 21 if np.isscalar(x): 22 return np.array([x]) 23 if isinstance(x, torch.Tensor): 24 return _prepare_pytorch(x) 25 raise NotImplementedError( 26 f"Got {type(x)}, but numpy array or torch tensor are expected." 27 ) 28 29 30def _prepare_pytorch(x): 31 if x.dtype == torch.bfloat16: 32 x = x.to(torch.float16) 33 x = x.detach().cpu().numpy() 34 return x 35