1# mypy: allow-untyped-defs 2import math 3import numpy as np 4from ._convert_np import make_np 5from ._utils import make_grid 6from tensorboard.compat import tf 7from tensorboard.plugins.projector.projector_config_pb2 import EmbeddingInfo 8 9 10_HAS_GFILE_JOIN = hasattr(tf.io.gfile, "join") 11 12 13def _gfile_join(a, b): 14 # The join API is different between tensorboard's TF stub and TF: 15 # https://github.com/tensorflow/tensorboard/issues/6080 16 # We need to try both because `tf` may point to either the stub or the real TF. 17 if _HAS_GFILE_JOIN: 18 return tf.io.gfile.join(a, b) 19 else: 20 fs = tf.io.gfile.get_filesystem(a) 21 return fs.join(a, b) 22 23 24def make_tsv(metadata, save_path, metadata_header=None): 25 if not metadata_header: 26 metadata = [str(x) for x in metadata] 27 else: 28 assert len(metadata_header) == len( 29 metadata[0] 30 ), "len of header must be equal to the number of columns in metadata" 31 metadata = ["\t".join(str(e) for e in l) for l in [metadata_header] + metadata] 32 33 metadata_bytes = tf.compat.as_bytes("\n".join(metadata) + "\n") 34 with tf.io.gfile.GFile(_gfile_join(save_path, "metadata.tsv"), "wb") as f: 35 f.write(metadata_bytes) 36 37 38# https://github.com/tensorflow/tensorboard/issues/44 image label will be squared 39def make_sprite(label_img, save_path): 40 from PIL import Image 41 from io import BytesIO 42 43 # this ensures the sprite image has correct dimension as described in 44 # https://www.tensorflow.org/get_started/embedding_viz 45 nrow = int(math.ceil((label_img.size(0)) ** 0.5)) 46 arranged_img_CHW = make_grid(make_np(label_img), ncols=nrow) 47 48 # augment images so that #images equals nrow*nrow 49 arranged_augment_square_HWC = np.zeros( 50 (arranged_img_CHW.shape[2], arranged_img_CHW.shape[2], 3) 51 ) 52 arranged_img_HWC = arranged_img_CHW.transpose(1, 2, 0) # chw -> hwc 53 arranged_augment_square_HWC[: arranged_img_HWC.shape[0], :, :] = arranged_img_HWC 54 im = Image.fromarray(np.uint8((arranged_augment_square_HWC * 255).clip(0, 255))) 55 56 with BytesIO() as buf: 57 im.save(buf, format="PNG") 58 im_bytes = buf.getvalue() 59 60 with tf.io.gfile.GFile(_gfile_join(save_path, "sprite.png"), "wb") as f: 61 f.write(im_bytes) 62 63 64def get_embedding_info(metadata, label_img, subdir, global_step, tag): 65 info = EmbeddingInfo() 66 info.tensor_name = f"{tag}:{str(global_step).zfill(5)}" 67 info.tensor_path = _gfile_join(subdir, "tensors.tsv") 68 if metadata is not None: 69 info.metadata_path = _gfile_join(subdir, "metadata.tsv") 70 if label_img is not None: 71 info.sprite.image_path = _gfile_join(subdir, "sprite.png") 72 info.sprite.single_image_dim.extend([label_img.size(3), label_img.size(2)]) 73 return info 74 75 76def write_pbtxt(save_path, contents): 77 config_path = _gfile_join(save_path, "projector_config.pbtxt") 78 with tf.io.gfile.GFile(config_path, "wb") as f: 79 f.write(tf.compat.as_bytes(contents)) 80 81 82def make_mat(matlist, save_path): 83 with tf.io.gfile.GFile(_gfile_join(save_path, "tensors.tsv"), "wb") as f: 84 for x in matlist: 85 x = [str(i.item()) for i in x] 86 f.write(tf.compat.as_bytes("\t".join(x) + "\n")) 87