xref: /aosp_15_r20/external/pytorch/torch/utils/tensorboard/summary.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import json
3import logging
4import os
5import struct
6
7from typing import Any, List, Optional
8
9import torch
10import numpy as np
11
12from google.protobuf import struct_pb2
13
14from tensorboard.compat.proto.summary_pb2 import (
15    HistogramProto,
16    Summary,
17    SummaryMetadata,
18)
19from tensorboard.compat.proto.tensor_pb2 import TensorProto
20from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
21from tensorboard.plugins.custom_scalar import layout_pb2
22from tensorboard.plugins.pr_curve.plugin_data_pb2 import PrCurvePluginData
23from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData
24
25from ._convert_np import make_np
26from ._utils import _prepare_video, convert_to_HWC
27
28__all__ = [
29    "half_to_int",
30    "int_to_half",
31    "hparams",
32    "scalar",
33    "histogram_raw",
34    "histogram",
35    "make_histogram",
36    "image",
37    "image_boxes",
38    "draw_boxes",
39    "make_image",
40    "video",
41    "make_video",
42    "audio",
43    "custom_scalars",
44    "text",
45    "tensor_proto",
46    "pr_curve_raw",
47    "pr_curve",
48    "compute_curve",
49    "mesh",
50]
51
52logger = logging.getLogger(__name__)
53
54def half_to_int(f: float) -> int:
55    """Casts a half-precision float value into an integer.
56
57    Converts a half precision floating point value, such as `torch.half` or
58    `torch.bfloat16`, into an integer value which can be written into the
59    half_val field of a TensorProto for storage.
60
61    To undo the effects of this conversion, use int_to_half().
62
63    """
64    buf = struct.pack("f", f)
65    return struct.unpack("i", buf)[0]
66
67def int_to_half(i: int) -> float:
68    """Casts an integer value to a half-precision float.
69
70    Converts an integer value obtained from half_to_int back into a floating
71    point value.
72
73    """
74    buf = struct.pack("i", i)
75    return struct.unpack("f", buf)[0]
76
77def _tensor_to_half_val(t: torch.Tensor) -> List[int]:
78    return [half_to_int(x) for x in t.flatten().tolist()]
79
80def _tensor_to_complex_val(t: torch.Tensor) -> List[float]:
81    return torch.view_as_real(t).flatten().tolist()
82
83def _tensor_to_list(t: torch.Tensor) -> List[Any]:
84    return t.flatten().tolist()
85
86# type maps: torch.Tensor type -> (protobuf type, protobuf val field)
87_TENSOR_TYPE_MAP = {
88    torch.half: ("DT_HALF", "half_val", _tensor_to_half_val),
89    torch.float16: ("DT_HALF", "half_val", _tensor_to_half_val),
90    torch.bfloat16: ("DT_BFLOAT16", "half_val", _tensor_to_half_val),
91    torch.float32: ("DT_FLOAT", "float_val", _tensor_to_list),
92    torch.float: ("DT_FLOAT", "float_val", _tensor_to_list),
93    torch.float64: ("DT_DOUBLE", "double_val", _tensor_to_list),
94    torch.double: ("DT_DOUBLE", "double_val", _tensor_to_list),
95    torch.int8: ("DT_INT8", "int_val", _tensor_to_list),
96    torch.uint8: ("DT_UINT8", "int_val", _tensor_to_list),
97    torch.qint8: ("DT_UINT8", "int_val", _tensor_to_list),
98    torch.int16: ("DT_INT16", "int_val", _tensor_to_list),
99    torch.short: ("DT_INT16", "int_val", _tensor_to_list),
100    torch.int: ("DT_INT32", "int_val", _tensor_to_list),
101    torch.int32: ("DT_INT32", "int_val", _tensor_to_list),
102    torch.qint32: ("DT_INT32", "int_val", _tensor_to_list),
103    torch.int64: ("DT_INT64", "int64_val", _tensor_to_list),
104    torch.complex32: ("DT_COMPLEX32", "scomplex_val", _tensor_to_complex_val),
105    torch.chalf: ("DT_COMPLEX32", "scomplex_val", _tensor_to_complex_val),
106    torch.complex64: ("DT_COMPLEX64", "scomplex_val", _tensor_to_complex_val),
107    torch.cfloat: ("DT_COMPLEX64", "scomplex_val", _tensor_to_complex_val),
108    torch.bool: ("DT_BOOL", "bool_val", _tensor_to_list),
109    torch.complex128: ("DT_COMPLEX128", "dcomplex_val", _tensor_to_complex_val),
110    torch.cdouble: ("DT_COMPLEX128", "dcomplex_val", _tensor_to_complex_val),
111    torch.uint8: ("DT_UINT8", "uint32_val", _tensor_to_list),
112    torch.quint8: ("DT_UINT8", "uint32_val", _tensor_to_list),
113    torch.quint4x2: ("DT_UINT8", "uint32_val", _tensor_to_list),
114}
115
116
117def _calc_scale_factor(tensor):
118    converted = tensor.numpy() if not isinstance(tensor, np.ndarray) else tensor
119    return 1 if converted.dtype == np.uint8 else 255
120
121
122def _draw_single_box(
123    image,
124    xmin,
125    ymin,
126    xmax,
127    ymax,
128    display_str,
129    color="black",
130    color_text="black",
131    thickness=2,
132):
133    from PIL import ImageDraw, ImageFont
134
135    font = ImageFont.load_default()
136    draw = ImageDraw.Draw(image)
137    (left, right, top, bottom) = (xmin, xmax, ymin, ymax)
138    draw.line(
139        [(left, top), (left, bottom), (right, bottom), (right, top), (left, top)],
140        width=thickness,
141        fill=color,
142    )
143    if display_str:
144        text_bottom = bottom
145        # Reverse list and print from bottom to top.
146        _left, _top, _right, _bottom = font.getbbox(display_str)
147        text_width, text_height = _right - _left, _bottom - _top
148        margin = np.ceil(0.05 * text_height)
149        draw.rectangle(
150            [
151                (left, text_bottom - text_height - 2 * margin),
152                (left + text_width, text_bottom),
153            ],
154            fill=color,
155        )
156        draw.text(
157            (left + margin, text_bottom - text_height - margin),
158            display_str,
159            fill=color_text,
160            font=font,
161        )
162    return image
163
164
165def hparams(hparam_dict=None, metric_dict=None, hparam_domain_discrete=None):
166    """Output three `Summary` protocol buffers needed by hparams plugin.
167
168    `Experiment` keeps the metadata of an experiment, such as the name of the
169      hyperparameters and the name of the metrics.
170    `SessionStartInfo` keeps key-value pairs of the hyperparameters
171    `SessionEndInfo` describes status of the experiment e.g. STATUS_SUCCESS
172
173    Args:
174      hparam_dict: A dictionary that contains names of the hyperparameters
175        and their values.
176      metric_dict: A dictionary that contains names of the metrics
177        and their values.
178      hparam_domain_discrete: (Optional[Dict[str, List[Any]]]) A dictionary that
179        contains names of the hyperparameters and all discrete values they can hold
180
181    Returns:
182      The `Summary` protobufs for Experiment, SessionStartInfo and
183        SessionEndInfo
184    """
185    import torch
186    from tensorboard.plugins.hparams.api_pb2 import (
187        DataType,
188        Experiment,
189        HParamInfo,
190        MetricInfo,
191        MetricName,
192        Status,
193    )
194    from tensorboard.plugins.hparams.metadata import (
195        EXPERIMENT_TAG,
196        PLUGIN_DATA_VERSION,
197        PLUGIN_NAME,
198        SESSION_END_INFO_TAG,
199        SESSION_START_INFO_TAG,
200    )
201    from tensorboard.plugins.hparams.plugin_data_pb2 import (
202        HParamsPluginData,
203        SessionEndInfo,
204        SessionStartInfo,
205    )
206
207    # TODO: expose other parameters in the future.
208    # hp = HParamInfo(name='lr',display_name='learning rate',
209    # type=DataType.DATA_TYPE_FLOAT64, domain_interval=Interval(min_value=10,
210    # max_value=100))
211    # mt = MetricInfo(name=MetricName(tag='accuracy'), display_name='accuracy',
212    # description='', dataset_type=DatasetType.DATASET_VALIDATION)
213    # exp = Experiment(name='123', description='456', time_created_secs=100.0,
214    # hparam_infos=[hp], metric_infos=[mt], user='tw')
215
216    if not isinstance(hparam_dict, dict):
217        logger.warning("parameter: hparam_dict should be a dictionary, nothing logged.")
218        raise TypeError(
219            "parameter: hparam_dict should be a dictionary, nothing logged."
220        )
221    if not isinstance(metric_dict, dict):
222        logger.warning("parameter: metric_dict should be a dictionary, nothing logged.")
223        raise TypeError(
224            "parameter: metric_dict should be a dictionary, nothing logged."
225        )
226
227    hparam_domain_discrete = hparam_domain_discrete or {}
228    if not isinstance(hparam_domain_discrete, dict):
229        raise TypeError(
230            "parameter: hparam_domain_discrete should be a dictionary, nothing logged."
231        )
232    for k, v in hparam_domain_discrete.items():
233        if (
234            k not in hparam_dict
235            or not isinstance(v, list)
236            or not all(isinstance(d, type(hparam_dict[k])) for d in v)
237        ):
238            raise TypeError(
239                f"parameter: hparam_domain_discrete[{k}] should be a list of same type as hparam_dict[{k}]."
240            )
241    hps = []
242
243    ssi = SessionStartInfo()
244    for k, v in hparam_dict.items():
245        if v is None:
246            continue
247        if isinstance(v, (int, float)):
248            ssi.hparams[k].number_value = v
249
250            if k in hparam_domain_discrete:
251                domain_discrete: Optional[struct_pb2.ListValue] = struct_pb2.ListValue(
252                    values=[
253                        struct_pb2.Value(number_value=d)
254                        for d in hparam_domain_discrete[k]
255                    ]
256                )
257            else:
258                domain_discrete = None
259
260            hps.append(
261                HParamInfo(
262                    name=k,
263                    type=DataType.Value("DATA_TYPE_FLOAT64"),
264                    domain_discrete=domain_discrete,
265                )
266            )
267            continue
268
269        if isinstance(v, str):
270            ssi.hparams[k].string_value = v
271
272            if k in hparam_domain_discrete:
273                domain_discrete = struct_pb2.ListValue(
274                    values=[
275                        struct_pb2.Value(string_value=d)
276                        for d in hparam_domain_discrete[k]
277                    ]
278                )
279            else:
280                domain_discrete = None
281
282            hps.append(
283                HParamInfo(
284                    name=k,
285                    type=DataType.Value("DATA_TYPE_STRING"),
286                    domain_discrete=domain_discrete,
287                )
288            )
289            continue
290
291        if isinstance(v, bool):
292            ssi.hparams[k].bool_value = v
293
294            if k in hparam_domain_discrete:
295                domain_discrete = struct_pb2.ListValue(
296                    values=[
297                        struct_pb2.Value(bool_value=d)
298                        for d in hparam_domain_discrete[k]
299                    ]
300                )
301            else:
302                domain_discrete = None
303
304            hps.append(
305                HParamInfo(
306                    name=k,
307                    type=DataType.Value("DATA_TYPE_BOOL"),
308                    domain_discrete=domain_discrete,
309                )
310            )
311            continue
312
313        if isinstance(v, torch.Tensor):
314            v = make_np(v)[0]
315            ssi.hparams[k].number_value = v
316            hps.append(HParamInfo(name=k, type=DataType.Value("DATA_TYPE_FLOAT64")))
317            continue
318        raise ValueError(
319            "value should be one of int, float, str, bool, or torch.Tensor"
320        )
321
322    content = HParamsPluginData(session_start_info=ssi, version=PLUGIN_DATA_VERSION)
323    smd = SummaryMetadata(
324        plugin_data=SummaryMetadata.PluginData(
325            plugin_name=PLUGIN_NAME, content=content.SerializeToString()
326        )
327    )
328    ssi = Summary(value=[Summary.Value(tag=SESSION_START_INFO_TAG, metadata=smd)])
329
330    mts = [MetricInfo(name=MetricName(tag=k)) for k in metric_dict.keys()]
331
332    exp = Experiment(hparam_infos=hps, metric_infos=mts)
333
334    content = HParamsPluginData(experiment=exp, version=PLUGIN_DATA_VERSION)
335    smd = SummaryMetadata(
336        plugin_data=SummaryMetadata.PluginData(
337            plugin_name=PLUGIN_NAME, content=content.SerializeToString()
338        )
339    )
340    exp = Summary(value=[Summary.Value(tag=EXPERIMENT_TAG, metadata=smd)])
341
342    sei = SessionEndInfo(status=Status.Value("STATUS_SUCCESS"))
343    content = HParamsPluginData(session_end_info=sei, version=PLUGIN_DATA_VERSION)
344    smd = SummaryMetadata(
345        plugin_data=SummaryMetadata.PluginData(
346            plugin_name=PLUGIN_NAME, content=content.SerializeToString()
347        )
348    )
349    sei = Summary(value=[Summary.Value(tag=SESSION_END_INFO_TAG, metadata=smd)])
350
351    return exp, ssi, sei
352
353
354def scalar(name, tensor, collections=None, new_style=False, double_precision=False):
355    """Output a `Summary` protocol buffer containing a single scalar value.
356
357    The generated Summary has a Tensor.proto containing the input Tensor.
358    Args:
359      name: A name for the generated node. Will also serve as the series name in
360        TensorBoard.
361      tensor: A real numeric Tensor containing a single value.
362      collections: Optional list of graph collections keys. The new summary op is
363        added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
364      new_style: Whether to use new style (tensor field) or old style (simple_value
365        field). New style could lead to faster data loading.
366    Returns:
367      A scalar `Tensor` of type `string`. Which contains a `Summary` protobuf.
368    Raises:
369      ValueError: If tensor has the wrong shape or type.
370    """
371    tensor = make_np(tensor).squeeze()
372    assert (
373        tensor.ndim == 0
374    ), f"Tensor should contain one element (0 dimensions). Was given size: {tensor.size} and {tensor.ndim} dimensions."
375    # python float is double precision in numpy
376    scalar = float(tensor)
377    if new_style:
378        tensor_proto = TensorProto(float_val=[scalar], dtype="DT_FLOAT")
379        if double_precision:
380            tensor_proto = TensorProto(double_val=[scalar], dtype="DT_DOUBLE")
381
382        plugin_data = SummaryMetadata.PluginData(plugin_name="scalars")
383        smd = SummaryMetadata(plugin_data=plugin_data)
384        return Summary(
385            value=[
386                Summary.Value(
387                    tag=name,
388                    tensor=tensor_proto,
389                    metadata=smd,
390                )
391            ]
392        )
393    else:
394        return Summary(value=[Summary.Value(tag=name, simple_value=scalar)])
395
396
397def tensor_proto(tag, tensor):
398    """Outputs a `Summary` protocol buffer containing the full tensor.
399    The generated Summary has a Tensor.proto containing the input Tensor.
400    Args:
401      name: A name for the generated node. Will also serve as the series name in
402        TensorBoard.
403      tensor: Tensor to be converted to protobuf
404    Returns:
405      A tensor protobuf in a `Summary` protobuf.
406    Raises:
407      ValueError: If tensor is too big to be converted to protobuf, or
408                     tensor data type is not supported
409    """
410    if tensor.numel() * tensor.itemsize >= (1 << 31):
411        raise ValueError(
412            "tensor is bigger than protocol buffer's hard limit of 2GB in size"
413        )
414
415    if tensor.dtype in _TENSOR_TYPE_MAP:
416        dtype, field_name, conversion_fn = _TENSOR_TYPE_MAP[tensor.dtype]
417        tensor_proto = TensorProto(
418            **{
419                "dtype": dtype,
420                "tensor_shape": TensorShapeProto(
421                    dim=[TensorShapeProto.Dim(size=x) for x in tensor.shape]
422                ),
423                field_name: conversion_fn(tensor),
424            },
425        )
426    else:
427        raise ValueError(f"{tag} has unsupported tensor dtype {tensor.dtype}")
428
429    plugin_data = SummaryMetadata.PluginData(plugin_name="tensor")
430    smd = SummaryMetadata(plugin_data=plugin_data)
431    return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor_proto)])
432
433
434def histogram_raw(name, min, max, num, sum, sum_squares, bucket_limits, bucket_counts):
435    # pylint: disable=line-too-long
436    """Output a `Summary` protocol buffer with a histogram.
437
438    The generated
439    [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
440    has one summary value containing a histogram for `values`.
441    Args:
442      name: A name for the generated node. Will also serve as a series name in
443        TensorBoard.
444      min: A float or int min value
445      max: A float or int max value
446      num: Int number of values
447      sum: Float or int sum of all values
448      sum_squares: Float or int sum of squares for all values
449      bucket_limits: A numeric `Tensor` with upper value per bucket
450      bucket_counts: A numeric `Tensor` with number of values per bucket
451    Returns:
452      A scalar `Tensor` of type `string`. The serialized `Summary` protocol
453      buffer.
454    """
455    hist = HistogramProto(
456        min=min,
457        max=max,
458        num=num,
459        sum=sum,
460        sum_squares=sum_squares,
461        bucket_limit=bucket_limits,
462        bucket=bucket_counts,
463    )
464    return Summary(value=[Summary.Value(tag=name, histo=hist)])
465
466
467def histogram(name, values, bins, max_bins=None):
468    # pylint: disable=line-too-long
469    """Output a `Summary` protocol buffer with a histogram.
470
471    The generated
472    [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
473    has one summary value containing a histogram for `values`.
474    This op reports an `InvalidArgument` error if any value is not finite.
475    Args:
476      name: A name for the generated node. Will also serve as a series name in
477        TensorBoard.
478      values: A real numeric `Tensor`. Any shape. Values to use to
479        build the histogram.
480    Returns:
481      A scalar `Tensor` of type `string`. The serialized `Summary` protocol
482      buffer.
483    """
484    values = make_np(values)
485    hist = make_histogram(values.astype(float), bins, max_bins)
486    return Summary(value=[Summary.Value(tag=name, histo=hist)])
487
488
489def make_histogram(values, bins, max_bins=None):
490    """Convert values into a histogram proto using logic from histogram.cc."""
491    if values.size == 0:
492        raise ValueError("The input has no element.")
493    values = values.reshape(-1)
494    counts, limits = np.histogram(values, bins=bins)
495    num_bins = len(counts)
496    if max_bins is not None and num_bins > max_bins:
497        subsampling = num_bins // max_bins
498        subsampling_remainder = num_bins % subsampling
499        if subsampling_remainder != 0:
500            counts = np.pad(
501                counts,
502                pad_width=[[0, subsampling - subsampling_remainder]],
503                mode="constant",
504                constant_values=0,
505            )
506        counts = counts.reshape(-1, subsampling).sum(axis=-1)
507        new_limits = np.empty((counts.size + 1,), limits.dtype)
508        new_limits[:-1] = limits[:-1:subsampling]
509        new_limits[-1] = limits[-1]
510        limits = new_limits
511
512    # Find the first and the last bin defining the support of the histogram:
513
514    cum_counts = np.cumsum(np.greater(counts, 0))
515    start, end = np.searchsorted(cum_counts, [0, cum_counts[-1] - 1], side="right")
516    start = int(start)
517    end = int(end) + 1
518    del cum_counts
519
520    # TensorBoard only includes the right bin limits. To still have the leftmost limit
521    # included, we include an empty bin left.
522    # If start == 0, we need to add an empty one left, otherwise we can just include the bin left to the
523    # first nonzero-count bin:
524    counts = (
525        counts[start - 1 : end] if start > 0 else np.concatenate([[0], counts[:end]])
526    )
527    limits = limits[start : end + 1]
528
529    if counts.size == 0 or limits.size == 0:
530        raise ValueError("The histogram is empty, please file a bug report.")
531
532    sum_sq = values.dot(values)
533    return HistogramProto(
534        min=values.min(),
535        max=values.max(),
536        num=len(values),
537        sum=values.sum(),
538        sum_squares=sum_sq,
539        bucket_limit=limits.tolist(),
540        bucket=counts.tolist(),
541    )
542
543
544def image(tag, tensor, rescale=1, dataformats="NCHW"):
545    """Output a `Summary` protocol buffer with images.
546
547    The summary has up to `max_images` summary values containing images. The
548    images are built from `tensor` which must be 3-D with shape `[height, width,
549    channels]` and where `channels` can be:
550    *  1: `tensor` is interpreted as Grayscale.
551    *  3: `tensor` is interpreted as RGB.
552    *  4: `tensor` is interpreted as RGBA.
553    The `name` in the outputted Summary.Value protobufs is generated based on the
554    name, with a suffix depending on the max_outputs setting:
555    *  If `max_outputs` is 1, the summary value tag is '*name*/image'.
556    *  If `max_outputs` is greater than 1, the summary value tags are
557       generated sequentially as '*name*/image/0', '*name*/image/1', etc.
558    Args:
559      tag: A name for the generated node. Will also serve as a series name in
560        TensorBoard.
561      tensor: A 3-D `uint8` or `float32` `Tensor` of shape `[height, width,
562        channels]` where `channels` is 1, 3, or 4.
563        'tensor' can either have values in [0, 1] (float32) or [0, 255] (uint8).
564        The image() function will scale the image values to [0, 255] by applying
565        a scale factor of either 1 (uint8) or 255 (float32). Out-of-range values
566        will be clipped.
567    Returns:
568      A scalar `Tensor` of type `string`. The serialized `Summary` protocol
569      buffer.
570    """
571    tensor = make_np(tensor)
572    tensor = convert_to_HWC(tensor, dataformats)
573    # Do not assume that user passes in values in [0, 255], use data type to detect
574    scale_factor = _calc_scale_factor(tensor)
575    tensor = tensor.astype(np.float32)
576    tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8)
577    image = make_image(tensor, rescale=rescale)
578    return Summary(value=[Summary.Value(tag=tag, image=image)])
579
580
581def image_boxes(
582    tag, tensor_image, tensor_boxes, rescale=1, dataformats="CHW", labels=None
583):
584    """Output a `Summary` protocol buffer with images."""
585    tensor_image = make_np(tensor_image)
586    tensor_image = convert_to_HWC(tensor_image, dataformats)
587    tensor_boxes = make_np(tensor_boxes)
588    tensor_image = tensor_image.astype(np.float32) * _calc_scale_factor(tensor_image)
589    image = make_image(
590        tensor_image.clip(0, 255).astype(np.uint8),
591        rescale=rescale,
592        rois=tensor_boxes,
593        labels=labels,
594    )
595    return Summary(value=[Summary.Value(tag=tag, image=image)])
596
597
598def draw_boxes(disp_image, boxes, labels=None):
599    # xyxy format
600    num_boxes = boxes.shape[0]
601    list_gt = range(num_boxes)
602    for i in list_gt:
603        disp_image = _draw_single_box(
604            disp_image,
605            boxes[i, 0],
606            boxes[i, 1],
607            boxes[i, 2],
608            boxes[i, 3],
609            display_str=None if labels is None else labels[i],
610            color="Red",
611        )
612    return disp_image
613
614
615def make_image(tensor, rescale=1, rois=None, labels=None):
616    """Convert a numpy representation of an image to Image protobuf."""
617    from PIL import Image
618
619    height, width, channel = tensor.shape
620    scaled_height = int(height * rescale)
621    scaled_width = int(width * rescale)
622    image = Image.fromarray(tensor)
623    if rois is not None:
624        image = draw_boxes(image, rois, labels=labels)
625    ANTIALIAS = Image.Resampling.LANCZOS
626    image = image.resize((scaled_width, scaled_height), ANTIALIAS)
627    import io
628
629    output = io.BytesIO()
630    image.save(output, format="PNG")
631    image_string = output.getvalue()
632    output.close()
633    return Summary.Image(
634        height=height,
635        width=width,
636        colorspace=channel,
637        encoded_image_string=image_string,
638    )
639
640
641def video(tag, tensor, fps=4):
642    tensor = make_np(tensor)
643    tensor = _prepare_video(tensor)
644    # If user passes in uint8, then we don't need to rescale by 255
645    scale_factor = _calc_scale_factor(tensor)
646    tensor = tensor.astype(np.float32)
647    tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8)
648    video = make_video(tensor, fps)
649    return Summary(value=[Summary.Value(tag=tag, image=video)])
650
651
652def make_video(tensor, fps):
653    try:
654        import moviepy  # noqa: F401
655    except ImportError:
656        print("add_video needs package moviepy")
657        return
658    try:
659        from moviepy import editor as mpy
660    except ImportError:
661        print(
662            "moviepy is installed, but can't import moviepy.editor.",
663            "Some packages could be missing [imageio, requests]",
664        )
665        return
666    import tempfile
667
668    t, h, w, c = tensor.shape
669
670    # encode sequence of images into gif string
671    clip = mpy.ImageSequenceClip(list(tensor), fps=fps)
672
673    filename = tempfile.NamedTemporaryFile(suffix=".gif", delete=False).name
674    try:  # newer version of moviepy use logger instead of progress_bar argument.
675        clip.write_gif(filename, verbose=False, logger=None)
676    except TypeError:
677        try:  # older version of moviepy does not support progress_bar argument.
678            clip.write_gif(filename, verbose=False, progress_bar=False)
679        except TypeError:
680            clip.write_gif(filename, verbose=False)
681
682    with open(filename, "rb") as f:
683        tensor_string = f.read()
684
685    try:
686        os.remove(filename)
687    except OSError:
688        logger.warning("The temporary file used by moviepy cannot be deleted.")
689
690    return Summary.Image(
691        height=h, width=w, colorspace=c, encoded_image_string=tensor_string
692    )
693
694
695def audio(tag, tensor, sample_rate=44100):
696    array = make_np(tensor)
697    array = array.squeeze()
698    if abs(array).max() > 1:
699        print("warning: audio amplitude out of range, auto clipped.")
700        array = array.clip(-1, 1)
701    assert array.ndim == 1, "input tensor should be 1 dimensional."
702    array = (array * np.iinfo(np.int16).max).astype("<i2")
703
704    import io
705    import wave
706
707    fio = io.BytesIO()
708    with wave.open(fio, "wb") as wave_write:
709        wave_write.setnchannels(1)
710        wave_write.setsampwidth(2)
711        wave_write.setframerate(sample_rate)
712        wave_write.writeframes(array.data)
713    audio_string = fio.getvalue()
714    fio.close()
715    audio = Summary.Audio(
716        sample_rate=sample_rate,
717        num_channels=1,
718        length_frames=array.shape[-1],
719        encoded_audio_string=audio_string,
720        content_type="audio/wav",
721    )
722    return Summary(value=[Summary.Value(tag=tag, audio=audio)])
723
724
725def custom_scalars(layout):
726    categories = []
727    for k, v in layout.items():
728        charts = []
729        for chart_name, chart_meatadata in v.items():
730            tags = chart_meatadata[1]
731            if chart_meatadata[0] == "Margin":
732                assert len(tags) == 3
733                mgcc = layout_pb2.MarginChartContent(
734                    series=[
735                        layout_pb2.MarginChartContent.Series(
736                            value=tags[0], lower=tags[1], upper=tags[2]
737                        )
738                    ]
739                )
740                chart = layout_pb2.Chart(title=chart_name, margin=mgcc)
741            else:
742                mlcc = layout_pb2.MultilineChartContent(tag=tags)
743                chart = layout_pb2.Chart(title=chart_name, multiline=mlcc)
744            charts.append(chart)
745        categories.append(layout_pb2.Category(title=k, chart=charts))
746
747    layout = layout_pb2.Layout(category=categories)
748    plugin_data = SummaryMetadata.PluginData(plugin_name="custom_scalars")
749    smd = SummaryMetadata(plugin_data=plugin_data)
750    tensor = TensorProto(
751        dtype="DT_STRING",
752        string_val=[layout.SerializeToString()],
753        tensor_shape=TensorShapeProto(),
754    )
755    return Summary(
756        value=[
757            Summary.Value(tag="custom_scalars__config__", tensor=tensor, metadata=smd)
758        ]
759    )
760
761
762def text(tag, text):
763    plugin_data = SummaryMetadata.PluginData(
764        plugin_name="text", content=TextPluginData(version=0).SerializeToString()
765    )
766    smd = SummaryMetadata(plugin_data=plugin_data)
767    tensor = TensorProto(
768        dtype="DT_STRING",
769        string_val=[text.encode(encoding="utf_8")],
770        tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)]),
771    )
772    return Summary(
773        value=[Summary.Value(tag=tag + "/text_summary", metadata=smd, tensor=tensor)]
774    )
775
776
777def pr_curve_raw(
778    tag, tp, fp, tn, fn, precision, recall, num_thresholds=127, weights=None
779):
780    if num_thresholds > 127:  # weird, value > 127 breaks protobuf
781        num_thresholds = 127
782    data = np.stack((tp, fp, tn, fn, precision, recall))
783    pr_curve_plugin_data = PrCurvePluginData(
784        version=0, num_thresholds=num_thresholds
785    ).SerializeToString()
786    plugin_data = SummaryMetadata.PluginData(
787        plugin_name="pr_curves", content=pr_curve_plugin_data
788    )
789    smd = SummaryMetadata(plugin_data=plugin_data)
790    tensor = TensorProto(
791        dtype="DT_FLOAT",
792        float_val=data.reshape(-1).tolist(),
793        tensor_shape=TensorShapeProto(
794            dim=[
795                TensorShapeProto.Dim(size=data.shape[0]),
796                TensorShapeProto.Dim(size=data.shape[1]),
797            ]
798        ),
799    )
800    return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
801
802
803def pr_curve(tag, labels, predictions, num_thresholds=127, weights=None):
804    # weird, value > 127 breaks protobuf
805    num_thresholds = min(num_thresholds, 127)
806    data = compute_curve(
807        labels, predictions, num_thresholds=num_thresholds, weights=weights
808    )
809    pr_curve_plugin_data = PrCurvePluginData(
810        version=0, num_thresholds=num_thresholds
811    ).SerializeToString()
812    plugin_data = SummaryMetadata.PluginData(
813        plugin_name="pr_curves", content=pr_curve_plugin_data
814    )
815    smd = SummaryMetadata(plugin_data=plugin_data)
816    tensor = TensorProto(
817        dtype="DT_FLOAT",
818        float_val=data.reshape(-1).tolist(),
819        tensor_shape=TensorShapeProto(
820            dim=[
821                TensorShapeProto.Dim(size=data.shape[0]),
822                TensorShapeProto.Dim(size=data.shape[1]),
823            ]
824        ),
825    )
826    return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
827
828
829# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/pr_curve/summary.py
830def compute_curve(labels, predictions, num_thresholds=None, weights=None):
831    _MINIMUM_COUNT = 1e-7
832
833    if weights is None:
834        weights = 1.0
835
836    # Compute bins of true positives and false positives.
837    bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1)))
838    float_labels = labels.astype(np.float64)
839    histogram_range = (0, num_thresholds - 1)
840    tp_buckets, _ = np.histogram(
841        bucket_indices,
842        bins=num_thresholds,
843        range=histogram_range,
844        weights=float_labels * weights,
845    )
846    fp_buckets, _ = np.histogram(
847        bucket_indices,
848        bins=num_thresholds,
849        range=histogram_range,
850        weights=(1.0 - float_labels) * weights,
851    )
852
853    # Obtain the reverse cumulative sum.
854    tp = np.cumsum(tp_buckets[::-1])[::-1]
855    fp = np.cumsum(fp_buckets[::-1])[::-1]
856    tn = fp[0] - fp
857    fn = tp[0] - tp
858    precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp)
859    recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn)
860    return np.stack((tp, fp, tn, fn, precision, recall))
861
862
863def _get_tensor_summary(
864    name, display_name, description, tensor, content_type, components, json_config
865):
866    """Create a tensor summary with summary metadata.
867
868    Args:
869      name: Uniquely identifiable name of the summary op. Could be replaced by
870        combination of name and type to make it unique even outside of this
871        summary.
872      display_name: Will be used as the display name in TensorBoard.
873        Defaults to `name`.
874      description: A longform readable description of the summary data. Markdown
875        is supported.
876      tensor: Tensor to display in summary.
877      content_type: Type of content inside the Tensor.
878      components: Bitmask representing present parts (vertices, colors, etc.) that
879        belong to the summary.
880      json_config: A string, JSON-serialized dictionary of ThreeJS classes
881        configuration.
882
883    Returns:
884      Tensor summary with metadata.
885    """
886    import torch
887    from tensorboard.plugins.mesh import metadata
888
889    tensor = torch.as_tensor(tensor)
890
891    tensor_metadata = metadata.create_summary_metadata(
892        name,
893        display_name,
894        content_type,
895        components,
896        tensor.shape,
897        description,
898        json_config=json_config,
899    )
900
901    tensor = TensorProto(
902        dtype="DT_FLOAT",
903        float_val=tensor.reshape(-1).tolist(),
904        tensor_shape=TensorShapeProto(
905            dim=[
906                TensorShapeProto.Dim(size=tensor.shape[0]),
907                TensorShapeProto.Dim(size=tensor.shape[1]),
908                TensorShapeProto.Dim(size=tensor.shape[2]),
909            ]
910        ),
911    )
912
913    tensor_summary = Summary.Value(
914        tag=metadata.get_instance_name(name, content_type),
915        tensor=tensor,
916        metadata=tensor_metadata,
917    )
918
919    return tensor_summary
920
921
922def _get_json_config(config_dict):
923    """Parse and returns JSON string from python dictionary."""
924    json_config = "{}"
925    if config_dict is not None:
926        json_config = json.dumps(config_dict, sort_keys=True)
927    return json_config
928
929
930# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/mesh/summary.py
931def mesh(
932    tag, vertices, colors, faces, config_dict, display_name=None, description=None
933):
934    """Output a merged `Summary` protocol buffer with a mesh/point cloud.
935
936    Args:
937      tag: A name for this summary operation.
938      vertices: Tensor of shape `[dim_1, ..., dim_n, 3]` representing the 3D
939        coordinates of vertices.
940      faces: Tensor of shape `[dim_1, ..., dim_n, 3]` containing indices of
941        vertices within each triangle.
942      colors: Tensor of shape `[dim_1, ..., dim_n, 3]` containing colors for each
943        vertex.
944      display_name: If set, will be used as the display name in TensorBoard.
945        Defaults to `name`.
946      description: A longform readable description of the summary data. Markdown
947        is supported.
948      config_dict: Dictionary with ThreeJS classes names and configuration.
949
950    Returns:
951      Merged summary for mesh/point cloud representation.
952    """
953    from tensorboard.plugins.mesh import metadata
954    from tensorboard.plugins.mesh.plugin_data_pb2 import MeshPluginData
955
956    json_config = _get_json_config(config_dict)
957
958    summaries = []
959    tensors = [
960        (vertices, MeshPluginData.VERTEX),
961        (faces, MeshPluginData.FACE),
962        (colors, MeshPluginData.COLOR),
963    ]
964    tensors = [tensor for tensor in tensors if tensor[0] is not None]
965    components = metadata.get_components_bitmask(
966        [content_type for (tensor, content_type) in tensors]
967    )
968
969    for tensor, content_type in tensors:
970        summaries.append(
971            _get_tensor_summary(
972                tag,
973                display_name,
974                description,
975                tensor,
976                content_type,
977                components,
978                json_config,
979            )
980        )
981
982    return Summary(value=summaries)
983