xref: /aosp_15_r20/external/pytorch/torch/utils/tensorboard/_convert_np.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""This module converts objects into numpy array."""
3
4import numpy as np
5
6import torch
7
8
9def make_np(x):
10    """
11    Convert an object into numpy array.
12
13    Args:
14      x: An instance of torch tensor
15
16    Returns:
17        numpy.array: Numpy array
18    """
19    if isinstance(x, np.ndarray):
20        return x
21    if np.isscalar(x):
22        return np.array([x])
23    if isinstance(x, torch.Tensor):
24        return _prepare_pytorch(x)
25    raise NotImplementedError(
26        f"Got {type(x)}, but numpy array or torch tensor are expected."
27    )
28
29
30def _prepare_pytorch(x):
31    if x.dtype == torch.bfloat16:
32        x = x.to(torch.float16)
33    x = x.detach().cpu().numpy()
34    return x
35