xref: /aosp_15_r20/external/pytorch/torch/utils/tensorboard/_embedding.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import math
3import numpy as np
4from ._convert_np import make_np
5from ._utils import make_grid
6from tensorboard.compat import tf
7from tensorboard.plugins.projector.projector_config_pb2 import EmbeddingInfo
8
9
10_HAS_GFILE_JOIN = hasattr(tf.io.gfile, "join")
11
12
13def _gfile_join(a, b):
14    # The join API is different between tensorboard's TF stub and TF:
15    # https://github.com/tensorflow/tensorboard/issues/6080
16    # We need to try both because `tf` may point to either the stub or the real TF.
17    if _HAS_GFILE_JOIN:
18        return tf.io.gfile.join(a, b)
19    else:
20        fs = tf.io.gfile.get_filesystem(a)
21        return fs.join(a, b)
22
23
24def make_tsv(metadata, save_path, metadata_header=None):
25    if not metadata_header:
26        metadata = [str(x) for x in metadata]
27    else:
28        assert len(metadata_header) == len(
29            metadata[0]
30        ), "len of header must be equal to the number of columns in metadata"
31        metadata = ["\t".join(str(e) for e in l) for l in [metadata_header] + metadata]
32
33    metadata_bytes = tf.compat.as_bytes("\n".join(metadata) + "\n")
34    with tf.io.gfile.GFile(_gfile_join(save_path, "metadata.tsv"), "wb") as f:
35        f.write(metadata_bytes)
36
37
38# https://github.com/tensorflow/tensorboard/issues/44 image label will be squared
39def make_sprite(label_img, save_path):
40    from PIL import Image
41    from io import BytesIO
42
43    # this ensures the sprite image has correct dimension as described in
44    # https://www.tensorflow.org/get_started/embedding_viz
45    nrow = int(math.ceil((label_img.size(0)) ** 0.5))
46    arranged_img_CHW = make_grid(make_np(label_img), ncols=nrow)
47
48    # augment images so that #images equals nrow*nrow
49    arranged_augment_square_HWC = np.zeros(
50        (arranged_img_CHW.shape[2], arranged_img_CHW.shape[2], 3)
51    )
52    arranged_img_HWC = arranged_img_CHW.transpose(1, 2, 0)  # chw -> hwc
53    arranged_augment_square_HWC[: arranged_img_HWC.shape[0], :, :] = arranged_img_HWC
54    im = Image.fromarray(np.uint8((arranged_augment_square_HWC * 255).clip(0, 255)))
55
56    with BytesIO() as buf:
57        im.save(buf, format="PNG")
58        im_bytes = buf.getvalue()
59
60    with tf.io.gfile.GFile(_gfile_join(save_path, "sprite.png"), "wb") as f:
61        f.write(im_bytes)
62
63
64def get_embedding_info(metadata, label_img, subdir, global_step, tag):
65    info = EmbeddingInfo()
66    info.tensor_name = f"{tag}:{str(global_step).zfill(5)}"
67    info.tensor_path = _gfile_join(subdir, "tensors.tsv")
68    if metadata is not None:
69        info.metadata_path = _gfile_join(subdir, "metadata.tsv")
70    if label_img is not None:
71        info.sprite.image_path = _gfile_join(subdir, "sprite.png")
72        info.sprite.single_image_dim.extend([label_img.size(3), label_img.size(2)])
73    return info
74
75
76def write_pbtxt(save_path, contents):
77    config_path = _gfile_join(save_path, "projector_config.pbtxt")
78    with tf.io.gfile.GFile(config_path, "wb") as f:
79        f.write(tf.compat.as_bytes(contents))
80
81
82def make_mat(matlist, save_path):
83    with tf.io.gfile.GFile(_gfile_join(save_path, "tensors.tsv"), "wb") as f:
84        for x in matlist:
85            x = [str(i.item()) for i in x]
86            f.write(tf.compat.as_bytes("\t".join(x) + "\n"))
87