1# mypy: allow-untyped-defs 2import json 3import logging 4import os 5import struct 6 7from typing import Any, List, Optional 8 9import torch 10import numpy as np 11 12from google.protobuf import struct_pb2 13 14from tensorboard.compat.proto.summary_pb2 import ( 15 HistogramProto, 16 Summary, 17 SummaryMetadata, 18) 19from tensorboard.compat.proto.tensor_pb2 import TensorProto 20from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto 21from tensorboard.plugins.custom_scalar import layout_pb2 22from tensorboard.plugins.pr_curve.plugin_data_pb2 import PrCurvePluginData 23from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData 24 25from ._convert_np import make_np 26from ._utils import _prepare_video, convert_to_HWC 27 28__all__ = [ 29 "half_to_int", 30 "int_to_half", 31 "hparams", 32 "scalar", 33 "histogram_raw", 34 "histogram", 35 "make_histogram", 36 "image", 37 "image_boxes", 38 "draw_boxes", 39 "make_image", 40 "video", 41 "make_video", 42 "audio", 43 "custom_scalars", 44 "text", 45 "tensor_proto", 46 "pr_curve_raw", 47 "pr_curve", 48 "compute_curve", 49 "mesh", 50] 51 52logger = logging.getLogger(__name__) 53 54def half_to_int(f: float) -> int: 55 """Casts a half-precision float value into an integer. 56 57 Converts a half precision floating point value, such as `torch.half` or 58 `torch.bfloat16`, into an integer value which can be written into the 59 half_val field of a TensorProto for storage. 60 61 To undo the effects of this conversion, use int_to_half(). 62 63 """ 64 buf = struct.pack("f", f) 65 return struct.unpack("i", buf)[0] 66 67def int_to_half(i: int) -> float: 68 """Casts an integer value to a half-precision float. 69 70 Converts an integer value obtained from half_to_int back into a floating 71 point value. 72 73 """ 74 buf = struct.pack("i", i) 75 return struct.unpack("f", buf)[0] 76 77def _tensor_to_half_val(t: torch.Tensor) -> List[int]: 78 return [half_to_int(x) for x in t.flatten().tolist()] 79 80def _tensor_to_complex_val(t: torch.Tensor) -> List[float]: 81 return torch.view_as_real(t).flatten().tolist() 82 83def _tensor_to_list(t: torch.Tensor) -> List[Any]: 84 return t.flatten().tolist() 85 86# type maps: torch.Tensor type -> (protobuf type, protobuf val field) 87_TENSOR_TYPE_MAP = { 88 torch.half: ("DT_HALF", "half_val", _tensor_to_half_val), 89 torch.float16: ("DT_HALF", "half_val", _tensor_to_half_val), 90 torch.bfloat16: ("DT_BFLOAT16", "half_val", _tensor_to_half_val), 91 torch.float32: ("DT_FLOAT", "float_val", _tensor_to_list), 92 torch.float: ("DT_FLOAT", "float_val", _tensor_to_list), 93 torch.float64: ("DT_DOUBLE", "double_val", _tensor_to_list), 94 torch.double: ("DT_DOUBLE", "double_val", _tensor_to_list), 95 torch.int8: ("DT_INT8", "int_val", _tensor_to_list), 96 torch.uint8: ("DT_UINT8", "int_val", _tensor_to_list), 97 torch.qint8: ("DT_UINT8", "int_val", _tensor_to_list), 98 torch.int16: ("DT_INT16", "int_val", _tensor_to_list), 99 torch.short: ("DT_INT16", "int_val", _tensor_to_list), 100 torch.int: ("DT_INT32", "int_val", _tensor_to_list), 101 torch.int32: ("DT_INT32", "int_val", _tensor_to_list), 102 torch.qint32: ("DT_INT32", "int_val", _tensor_to_list), 103 torch.int64: ("DT_INT64", "int64_val", _tensor_to_list), 104 torch.complex32: ("DT_COMPLEX32", "scomplex_val", _tensor_to_complex_val), 105 torch.chalf: ("DT_COMPLEX32", "scomplex_val", _tensor_to_complex_val), 106 torch.complex64: ("DT_COMPLEX64", "scomplex_val", _tensor_to_complex_val), 107 torch.cfloat: ("DT_COMPLEX64", "scomplex_val", _tensor_to_complex_val), 108 torch.bool: ("DT_BOOL", "bool_val", _tensor_to_list), 109 torch.complex128: ("DT_COMPLEX128", "dcomplex_val", _tensor_to_complex_val), 110 torch.cdouble: ("DT_COMPLEX128", "dcomplex_val", _tensor_to_complex_val), 111 torch.uint8: ("DT_UINT8", "uint32_val", _tensor_to_list), 112 torch.quint8: ("DT_UINT8", "uint32_val", _tensor_to_list), 113 torch.quint4x2: ("DT_UINT8", "uint32_val", _tensor_to_list), 114} 115 116 117def _calc_scale_factor(tensor): 118 converted = tensor.numpy() if not isinstance(tensor, np.ndarray) else tensor 119 return 1 if converted.dtype == np.uint8 else 255 120 121 122def _draw_single_box( 123 image, 124 xmin, 125 ymin, 126 xmax, 127 ymax, 128 display_str, 129 color="black", 130 color_text="black", 131 thickness=2, 132): 133 from PIL import ImageDraw, ImageFont 134 135 font = ImageFont.load_default() 136 draw = ImageDraw.Draw(image) 137 (left, right, top, bottom) = (xmin, xmax, ymin, ymax) 138 draw.line( 139 [(left, top), (left, bottom), (right, bottom), (right, top), (left, top)], 140 width=thickness, 141 fill=color, 142 ) 143 if display_str: 144 text_bottom = bottom 145 # Reverse list and print from bottom to top. 146 _left, _top, _right, _bottom = font.getbbox(display_str) 147 text_width, text_height = _right - _left, _bottom - _top 148 margin = np.ceil(0.05 * text_height) 149 draw.rectangle( 150 [ 151 (left, text_bottom - text_height - 2 * margin), 152 (left + text_width, text_bottom), 153 ], 154 fill=color, 155 ) 156 draw.text( 157 (left + margin, text_bottom - text_height - margin), 158 display_str, 159 fill=color_text, 160 font=font, 161 ) 162 return image 163 164 165def hparams(hparam_dict=None, metric_dict=None, hparam_domain_discrete=None): 166 """Output three `Summary` protocol buffers needed by hparams plugin. 167 168 `Experiment` keeps the metadata of an experiment, such as the name of the 169 hyperparameters and the name of the metrics. 170 `SessionStartInfo` keeps key-value pairs of the hyperparameters 171 `SessionEndInfo` describes status of the experiment e.g. STATUS_SUCCESS 172 173 Args: 174 hparam_dict: A dictionary that contains names of the hyperparameters 175 and their values. 176 metric_dict: A dictionary that contains names of the metrics 177 and their values. 178 hparam_domain_discrete: (Optional[Dict[str, List[Any]]]) A dictionary that 179 contains names of the hyperparameters and all discrete values they can hold 180 181 Returns: 182 The `Summary` protobufs for Experiment, SessionStartInfo and 183 SessionEndInfo 184 """ 185 import torch 186 from tensorboard.plugins.hparams.api_pb2 import ( 187 DataType, 188 Experiment, 189 HParamInfo, 190 MetricInfo, 191 MetricName, 192 Status, 193 ) 194 from tensorboard.plugins.hparams.metadata import ( 195 EXPERIMENT_TAG, 196 PLUGIN_DATA_VERSION, 197 PLUGIN_NAME, 198 SESSION_END_INFO_TAG, 199 SESSION_START_INFO_TAG, 200 ) 201 from tensorboard.plugins.hparams.plugin_data_pb2 import ( 202 HParamsPluginData, 203 SessionEndInfo, 204 SessionStartInfo, 205 ) 206 207 # TODO: expose other parameters in the future. 208 # hp = HParamInfo(name='lr',display_name='learning rate', 209 # type=DataType.DATA_TYPE_FLOAT64, domain_interval=Interval(min_value=10, 210 # max_value=100)) 211 # mt = MetricInfo(name=MetricName(tag='accuracy'), display_name='accuracy', 212 # description='', dataset_type=DatasetType.DATASET_VALIDATION) 213 # exp = Experiment(name='123', description='456', time_created_secs=100.0, 214 # hparam_infos=[hp], metric_infos=[mt], user='tw') 215 216 if not isinstance(hparam_dict, dict): 217 logger.warning("parameter: hparam_dict should be a dictionary, nothing logged.") 218 raise TypeError( 219 "parameter: hparam_dict should be a dictionary, nothing logged." 220 ) 221 if not isinstance(metric_dict, dict): 222 logger.warning("parameter: metric_dict should be a dictionary, nothing logged.") 223 raise TypeError( 224 "parameter: metric_dict should be a dictionary, nothing logged." 225 ) 226 227 hparam_domain_discrete = hparam_domain_discrete or {} 228 if not isinstance(hparam_domain_discrete, dict): 229 raise TypeError( 230 "parameter: hparam_domain_discrete should be a dictionary, nothing logged." 231 ) 232 for k, v in hparam_domain_discrete.items(): 233 if ( 234 k not in hparam_dict 235 or not isinstance(v, list) 236 or not all(isinstance(d, type(hparam_dict[k])) for d in v) 237 ): 238 raise TypeError( 239 f"parameter: hparam_domain_discrete[{k}] should be a list of same type as hparam_dict[{k}]." 240 ) 241 hps = [] 242 243 ssi = SessionStartInfo() 244 for k, v in hparam_dict.items(): 245 if v is None: 246 continue 247 if isinstance(v, (int, float)): 248 ssi.hparams[k].number_value = v 249 250 if k in hparam_domain_discrete: 251 domain_discrete: Optional[struct_pb2.ListValue] = struct_pb2.ListValue( 252 values=[ 253 struct_pb2.Value(number_value=d) 254 for d in hparam_domain_discrete[k] 255 ] 256 ) 257 else: 258 domain_discrete = None 259 260 hps.append( 261 HParamInfo( 262 name=k, 263 type=DataType.Value("DATA_TYPE_FLOAT64"), 264 domain_discrete=domain_discrete, 265 ) 266 ) 267 continue 268 269 if isinstance(v, str): 270 ssi.hparams[k].string_value = v 271 272 if k in hparam_domain_discrete: 273 domain_discrete = struct_pb2.ListValue( 274 values=[ 275 struct_pb2.Value(string_value=d) 276 for d in hparam_domain_discrete[k] 277 ] 278 ) 279 else: 280 domain_discrete = None 281 282 hps.append( 283 HParamInfo( 284 name=k, 285 type=DataType.Value("DATA_TYPE_STRING"), 286 domain_discrete=domain_discrete, 287 ) 288 ) 289 continue 290 291 if isinstance(v, bool): 292 ssi.hparams[k].bool_value = v 293 294 if k in hparam_domain_discrete: 295 domain_discrete = struct_pb2.ListValue( 296 values=[ 297 struct_pb2.Value(bool_value=d) 298 for d in hparam_domain_discrete[k] 299 ] 300 ) 301 else: 302 domain_discrete = None 303 304 hps.append( 305 HParamInfo( 306 name=k, 307 type=DataType.Value("DATA_TYPE_BOOL"), 308 domain_discrete=domain_discrete, 309 ) 310 ) 311 continue 312 313 if isinstance(v, torch.Tensor): 314 v = make_np(v)[0] 315 ssi.hparams[k].number_value = v 316 hps.append(HParamInfo(name=k, type=DataType.Value("DATA_TYPE_FLOAT64"))) 317 continue 318 raise ValueError( 319 "value should be one of int, float, str, bool, or torch.Tensor" 320 ) 321 322 content = HParamsPluginData(session_start_info=ssi, version=PLUGIN_DATA_VERSION) 323 smd = SummaryMetadata( 324 plugin_data=SummaryMetadata.PluginData( 325 plugin_name=PLUGIN_NAME, content=content.SerializeToString() 326 ) 327 ) 328 ssi = Summary(value=[Summary.Value(tag=SESSION_START_INFO_TAG, metadata=smd)]) 329 330 mts = [MetricInfo(name=MetricName(tag=k)) for k in metric_dict.keys()] 331 332 exp = Experiment(hparam_infos=hps, metric_infos=mts) 333 334 content = HParamsPluginData(experiment=exp, version=PLUGIN_DATA_VERSION) 335 smd = SummaryMetadata( 336 plugin_data=SummaryMetadata.PluginData( 337 plugin_name=PLUGIN_NAME, content=content.SerializeToString() 338 ) 339 ) 340 exp = Summary(value=[Summary.Value(tag=EXPERIMENT_TAG, metadata=smd)]) 341 342 sei = SessionEndInfo(status=Status.Value("STATUS_SUCCESS")) 343 content = HParamsPluginData(session_end_info=sei, version=PLUGIN_DATA_VERSION) 344 smd = SummaryMetadata( 345 plugin_data=SummaryMetadata.PluginData( 346 plugin_name=PLUGIN_NAME, content=content.SerializeToString() 347 ) 348 ) 349 sei = Summary(value=[Summary.Value(tag=SESSION_END_INFO_TAG, metadata=smd)]) 350 351 return exp, ssi, sei 352 353 354def scalar(name, tensor, collections=None, new_style=False, double_precision=False): 355 """Output a `Summary` protocol buffer containing a single scalar value. 356 357 The generated Summary has a Tensor.proto containing the input Tensor. 358 Args: 359 name: A name for the generated node. Will also serve as the series name in 360 TensorBoard. 361 tensor: A real numeric Tensor containing a single value. 362 collections: Optional list of graph collections keys. The new summary op is 363 added to these collections. Defaults to `[GraphKeys.SUMMARIES]`. 364 new_style: Whether to use new style (tensor field) or old style (simple_value 365 field). New style could lead to faster data loading. 366 Returns: 367 A scalar `Tensor` of type `string`. Which contains a `Summary` protobuf. 368 Raises: 369 ValueError: If tensor has the wrong shape or type. 370 """ 371 tensor = make_np(tensor).squeeze() 372 assert ( 373 tensor.ndim == 0 374 ), f"Tensor should contain one element (0 dimensions). Was given size: {tensor.size} and {tensor.ndim} dimensions." 375 # python float is double precision in numpy 376 scalar = float(tensor) 377 if new_style: 378 tensor_proto = TensorProto(float_val=[scalar], dtype="DT_FLOAT") 379 if double_precision: 380 tensor_proto = TensorProto(double_val=[scalar], dtype="DT_DOUBLE") 381 382 plugin_data = SummaryMetadata.PluginData(plugin_name="scalars") 383 smd = SummaryMetadata(plugin_data=plugin_data) 384 return Summary( 385 value=[ 386 Summary.Value( 387 tag=name, 388 tensor=tensor_proto, 389 metadata=smd, 390 ) 391 ] 392 ) 393 else: 394 return Summary(value=[Summary.Value(tag=name, simple_value=scalar)]) 395 396 397def tensor_proto(tag, tensor): 398 """Outputs a `Summary` protocol buffer containing the full tensor. 399 The generated Summary has a Tensor.proto containing the input Tensor. 400 Args: 401 name: A name for the generated node. Will also serve as the series name in 402 TensorBoard. 403 tensor: Tensor to be converted to protobuf 404 Returns: 405 A tensor protobuf in a `Summary` protobuf. 406 Raises: 407 ValueError: If tensor is too big to be converted to protobuf, or 408 tensor data type is not supported 409 """ 410 if tensor.numel() * tensor.itemsize >= (1 << 31): 411 raise ValueError( 412 "tensor is bigger than protocol buffer's hard limit of 2GB in size" 413 ) 414 415 if tensor.dtype in _TENSOR_TYPE_MAP: 416 dtype, field_name, conversion_fn = _TENSOR_TYPE_MAP[tensor.dtype] 417 tensor_proto = TensorProto( 418 **{ 419 "dtype": dtype, 420 "tensor_shape": TensorShapeProto( 421 dim=[TensorShapeProto.Dim(size=x) for x in tensor.shape] 422 ), 423 field_name: conversion_fn(tensor), 424 }, 425 ) 426 else: 427 raise ValueError(f"{tag} has unsupported tensor dtype {tensor.dtype}") 428 429 plugin_data = SummaryMetadata.PluginData(plugin_name="tensor") 430 smd = SummaryMetadata(plugin_data=plugin_data) 431 return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor_proto)]) 432 433 434def histogram_raw(name, min, max, num, sum, sum_squares, bucket_limits, bucket_counts): 435 # pylint: disable=line-too-long 436 """Output a `Summary` protocol buffer with a histogram. 437 438 The generated 439 [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) 440 has one summary value containing a histogram for `values`. 441 Args: 442 name: A name for the generated node. Will also serve as a series name in 443 TensorBoard. 444 min: A float or int min value 445 max: A float or int max value 446 num: Int number of values 447 sum: Float or int sum of all values 448 sum_squares: Float or int sum of squares for all values 449 bucket_limits: A numeric `Tensor` with upper value per bucket 450 bucket_counts: A numeric `Tensor` with number of values per bucket 451 Returns: 452 A scalar `Tensor` of type `string`. The serialized `Summary` protocol 453 buffer. 454 """ 455 hist = HistogramProto( 456 min=min, 457 max=max, 458 num=num, 459 sum=sum, 460 sum_squares=sum_squares, 461 bucket_limit=bucket_limits, 462 bucket=bucket_counts, 463 ) 464 return Summary(value=[Summary.Value(tag=name, histo=hist)]) 465 466 467def histogram(name, values, bins, max_bins=None): 468 # pylint: disable=line-too-long 469 """Output a `Summary` protocol buffer with a histogram. 470 471 The generated 472 [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) 473 has one summary value containing a histogram for `values`. 474 This op reports an `InvalidArgument` error if any value is not finite. 475 Args: 476 name: A name for the generated node. Will also serve as a series name in 477 TensorBoard. 478 values: A real numeric `Tensor`. Any shape. Values to use to 479 build the histogram. 480 Returns: 481 A scalar `Tensor` of type `string`. The serialized `Summary` protocol 482 buffer. 483 """ 484 values = make_np(values) 485 hist = make_histogram(values.astype(float), bins, max_bins) 486 return Summary(value=[Summary.Value(tag=name, histo=hist)]) 487 488 489def make_histogram(values, bins, max_bins=None): 490 """Convert values into a histogram proto using logic from histogram.cc.""" 491 if values.size == 0: 492 raise ValueError("The input has no element.") 493 values = values.reshape(-1) 494 counts, limits = np.histogram(values, bins=bins) 495 num_bins = len(counts) 496 if max_bins is not None and num_bins > max_bins: 497 subsampling = num_bins // max_bins 498 subsampling_remainder = num_bins % subsampling 499 if subsampling_remainder != 0: 500 counts = np.pad( 501 counts, 502 pad_width=[[0, subsampling - subsampling_remainder]], 503 mode="constant", 504 constant_values=0, 505 ) 506 counts = counts.reshape(-1, subsampling).sum(axis=-1) 507 new_limits = np.empty((counts.size + 1,), limits.dtype) 508 new_limits[:-1] = limits[:-1:subsampling] 509 new_limits[-1] = limits[-1] 510 limits = new_limits 511 512 # Find the first and the last bin defining the support of the histogram: 513 514 cum_counts = np.cumsum(np.greater(counts, 0)) 515 start, end = np.searchsorted(cum_counts, [0, cum_counts[-1] - 1], side="right") 516 start = int(start) 517 end = int(end) + 1 518 del cum_counts 519 520 # TensorBoard only includes the right bin limits. To still have the leftmost limit 521 # included, we include an empty bin left. 522 # If start == 0, we need to add an empty one left, otherwise we can just include the bin left to the 523 # first nonzero-count bin: 524 counts = ( 525 counts[start - 1 : end] if start > 0 else np.concatenate([[0], counts[:end]]) 526 ) 527 limits = limits[start : end + 1] 528 529 if counts.size == 0 or limits.size == 0: 530 raise ValueError("The histogram is empty, please file a bug report.") 531 532 sum_sq = values.dot(values) 533 return HistogramProto( 534 min=values.min(), 535 max=values.max(), 536 num=len(values), 537 sum=values.sum(), 538 sum_squares=sum_sq, 539 bucket_limit=limits.tolist(), 540 bucket=counts.tolist(), 541 ) 542 543 544def image(tag, tensor, rescale=1, dataformats="NCHW"): 545 """Output a `Summary` protocol buffer with images. 546 547 The summary has up to `max_images` summary values containing images. The 548 images are built from `tensor` which must be 3-D with shape `[height, width, 549 channels]` and where `channels` can be: 550 * 1: `tensor` is interpreted as Grayscale. 551 * 3: `tensor` is interpreted as RGB. 552 * 4: `tensor` is interpreted as RGBA. 553 The `name` in the outputted Summary.Value protobufs is generated based on the 554 name, with a suffix depending on the max_outputs setting: 555 * If `max_outputs` is 1, the summary value tag is '*name*/image'. 556 * If `max_outputs` is greater than 1, the summary value tags are 557 generated sequentially as '*name*/image/0', '*name*/image/1', etc. 558 Args: 559 tag: A name for the generated node. Will also serve as a series name in 560 TensorBoard. 561 tensor: A 3-D `uint8` or `float32` `Tensor` of shape `[height, width, 562 channels]` where `channels` is 1, 3, or 4. 563 'tensor' can either have values in [0, 1] (float32) or [0, 255] (uint8). 564 The image() function will scale the image values to [0, 255] by applying 565 a scale factor of either 1 (uint8) or 255 (float32). Out-of-range values 566 will be clipped. 567 Returns: 568 A scalar `Tensor` of type `string`. The serialized `Summary` protocol 569 buffer. 570 """ 571 tensor = make_np(tensor) 572 tensor = convert_to_HWC(tensor, dataformats) 573 # Do not assume that user passes in values in [0, 255], use data type to detect 574 scale_factor = _calc_scale_factor(tensor) 575 tensor = tensor.astype(np.float32) 576 tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8) 577 image = make_image(tensor, rescale=rescale) 578 return Summary(value=[Summary.Value(tag=tag, image=image)]) 579 580 581def image_boxes( 582 tag, tensor_image, tensor_boxes, rescale=1, dataformats="CHW", labels=None 583): 584 """Output a `Summary` protocol buffer with images.""" 585 tensor_image = make_np(tensor_image) 586 tensor_image = convert_to_HWC(tensor_image, dataformats) 587 tensor_boxes = make_np(tensor_boxes) 588 tensor_image = tensor_image.astype(np.float32) * _calc_scale_factor(tensor_image) 589 image = make_image( 590 tensor_image.clip(0, 255).astype(np.uint8), 591 rescale=rescale, 592 rois=tensor_boxes, 593 labels=labels, 594 ) 595 return Summary(value=[Summary.Value(tag=tag, image=image)]) 596 597 598def draw_boxes(disp_image, boxes, labels=None): 599 # xyxy format 600 num_boxes = boxes.shape[0] 601 list_gt = range(num_boxes) 602 for i in list_gt: 603 disp_image = _draw_single_box( 604 disp_image, 605 boxes[i, 0], 606 boxes[i, 1], 607 boxes[i, 2], 608 boxes[i, 3], 609 display_str=None if labels is None else labels[i], 610 color="Red", 611 ) 612 return disp_image 613 614 615def make_image(tensor, rescale=1, rois=None, labels=None): 616 """Convert a numpy representation of an image to Image protobuf.""" 617 from PIL import Image 618 619 height, width, channel = tensor.shape 620 scaled_height = int(height * rescale) 621 scaled_width = int(width * rescale) 622 image = Image.fromarray(tensor) 623 if rois is not None: 624 image = draw_boxes(image, rois, labels=labels) 625 ANTIALIAS = Image.Resampling.LANCZOS 626 image = image.resize((scaled_width, scaled_height), ANTIALIAS) 627 import io 628 629 output = io.BytesIO() 630 image.save(output, format="PNG") 631 image_string = output.getvalue() 632 output.close() 633 return Summary.Image( 634 height=height, 635 width=width, 636 colorspace=channel, 637 encoded_image_string=image_string, 638 ) 639 640 641def video(tag, tensor, fps=4): 642 tensor = make_np(tensor) 643 tensor = _prepare_video(tensor) 644 # If user passes in uint8, then we don't need to rescale by 255 645 scale_factor = _calc_scale_factor(tensor) 646 tensor = tensor.astype(np.float32) 647 tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8) 648 video = make_video(tensor, fps) 649 return Summary(value=[Summary.Value(tag=tag, image=video)]) 650 651 652def make_video(tensor, fps): 653 try: 654 import moviepy # noqa: F401 655 except ImportError: 656 print("add_video needs package moviepy") 657 return 658 try: 659 from moviepy import editor as mpy 660 except ImportError: 661 print( 662 "moviepy is installed, but can't import moviepy.editor.", 663 "Some packages could be missing [imageio, requests]", 664 ) 665 return 666 import tempfile 667 668 t, h, w, c = tensor.shape 669 670 # encode sequence of images into gif string 671 clip = mpy.ImageSequenceClip(list(tensor), fps=fps) 672 673 filename = tempfile.NamedTemporaryFile(suffix=".gif", delete=False).name 674 try: # newer version of moviepy use logger instead of progress_bar argument. 675 clip.write_gif(filename, verbose=False, logger=None) 676 except TypeError: 677 try: # older version of moviepy does not support progress_bar argument. 678 clip.write_gif(filename, verbose=False, progress_bar=False) 679 except TypeError: 680 clip.write_gif(filename, verbose=False) 681 682 with open(filename, "rb") as f: 683 tensor_string = f.read() 684 685 try: 686 os.remove(filename) 687 except OSError: 688 logger.warning("The temporary file used by moviepy cannot be deleted.") 689 690 return Summary.Image( 691 height=h, width=w, colorspace=c, encoded_image_string=tensor_string 692 ) 693 694 695def audio(tag, tensor, sample_rate=44100): 696 array = make_np(tensor) 697 array = array.squeeze() 698 if abs(array).max() > 1: 699 print("warning: audio amplitude out of range, auto clipped.") 700 array = array.clip(-1, 1) 701 assert array.ndim == 1, "input tensor should be 1 dimensional." 702 array = (array * np.iinfo(np.int16).max).astype("<i2") 703 704 import io 705 import wave 706 707 fio = io.BytesIO() 708 with wave.open(fio, "wb") as wave_write: 709 wave_write.setnchannels(1) 710 wave_write.setsampwidth(2) 711 wave_write.setframerate(sample_rate) 712 wave_write.writeframes(array.data) 713 audio_string = fio.getvalue() 714 fio.close() 715 audio = Summary.Audio( 716 sample_rate=sample_rate, 717 num_channels=1, 718 length_frames=array.shape[-1], 719 encoded_audio_string=audio_string, 720 content_type="audio/wav", 721 ) 722 return Summary(value=[Summary.Value(tag=tag, audio=audio)]) 723 724 725def custom_scalars(layout): 726 categories = [] 727 for k, v in layout.items(): 728 charts = [] 729 for chart_name, chart_meatadata in v.items(): 730 tags = chart_meatadata[1] 731 if chart_meatadata[0] == "Margin": 732 assert len(tags) == 3 733 mgcc = layout_pb2.MarginChartContent( 734 series=[ 735 layout_pb2.MarginChartContent.Series( 736 value=tags[0], lower=tags[1], upper=tags[2] 737 ) 738 ] 739 ) 740 chart = layout_pb2.Chart(title=chart_name, margin=mgcc) 741 else: 742 mlcc = layout_pb2.MultilineChartContent(tag=tags) 743 chart = layout_pb2.Chart(title=chart_name, multiline=mlcc) 744 charts.append(chart) 745 categories.append(layout_pb2.Category(title=k, chart=charts)) 746 747 layout = layout_pb2.Layout(category=categories) 748 plugin_data = SummaryMetadata.PluginData(plugin_name="custom_scalars") 749 smd = SummaryMetadata(plugin_data=plugin_data) 750 tensor = TensorProto( 751 dtype="DT_STRING", 752 string_val=[layout.SerializeToString()], 753 tensor_shape=TensorShapeProto(), 754 ) 755 return Summary( 756 value=[ 757 Summary.Value(tag="custom_scalars__config__", tensor=tensor, metadata=smd) 758 ] 759 ) 760 761 762def text(tag, text): 763 plugin_data = SummaryMetadata.PluginData( 764 plugin_name="text", content=TextPluginData(version=0).SerializeToString() 765 ) 766 smd = SummaryMetadata(plugin_data=plugin_data) 767 tensor = TensorProto( 768 dtype="DT_STRING", 769 string_val=[text.encode(encoding="utf_8")], 770 tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)]), 771 ) 772 return Summary( 773 value=[Summary.Value(tag=tag + "/text_summary", metadata=smd, tensor=tensor)] 774 ) 775 776 777def pr_curve_raw( 778 tag, tp, fp, tn, fn, precision, recall, num_thresholds=127, weights=None 779): 780 if num_thresholds > 127: # weird, value > 127 breaks protobuf 781 num_thresholds = 127 782 data = np.stack((tp, fp, tn, fn, precision, recall)) 783 pr_curve_plugin_data = PrCurvePluginData( 784 version=0, num_thresholds=num_thresholds 785 ).SerializeToString() 786 plugin_data = SummaryMetadata.PluginData( 787 plugin_name="pr_curves", content=pr_curve_plugin_data 788 ) 789 smd = SummaryMetadata(plugin_data=plugin_data) 790 tensor = TensorProto( 791 dtype="DT_FLOAT", 792 float_val=data.reshape(-1).tolist(), 793 tensor_shape=TensorShapeProto( 794 dim=[ 795 TensorShapeProto.Dim(size=data.shape[0]), 796 TensorShapeProto.Dim(size=data.shape[1]), 797 ] 798 ), 799 ) 800 return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)]) 801 802 803def pr_curve(tag, labels, predictions, num_thresholds=127, weights=None): 804 # weird, value > 127 breaks protobuf 805 num_thresholds = min(num_thresholds, 127) 806 data = compute_curve( 807 labels, predictions, num_thresholds=num_thresholds, weights=weights 808 ) 809 pr_curve_plugin_data = PrCurvePluginData( 810 version=0, num_thresholds=num_thresholds 811 ).SerializeToString() 812 plugin_data = SummaryMetadata.PluginData( 813 plugin_name="pr_curves", content=pr_curve_plugin_data 814 ) 815 smd = SummaryMetadata(plugin_data=plugin_data) 816 tensor = TensorProto( 817 dtype="DT_FLOAT", 818 float_val=data.reshape(-1).tolist(), 819 tensor_shape=TensorShapeProto( 820 dim=[ 821 TensorShapeProto.Dim(size=data.shape[0]), 822 TensorShapeProto.Dim(size=data.shape[1]), 823 ] 824 ), 825 ) 826 return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)]) 827 828 829# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/pr_curve/summary.py 830def compute_curve(labels, predictions, num_thresholds=None, weights=None): 831 _MINIMUM_COUNT = 1e-7 832 833 if weights is None: 834 weights = 1.0 835 836 # Compute bins of true positives and false positives. 837 bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1))) 838 float_labels = labels.astype(np.float64) 839 histogram_range = (0, num_thresholds - 1) 840 tp_buckets, _ = np.histogram( 841 bucket_indices, 842 bins=num_thresholds, 843 range=histogram_range, 844 weights=float_labels * weights, 845 ) 846 fp_buckets, _ = np.histogram( 847 bucket_indices, 848 bins=num_thresholds, 849 range=histogram_range, 850 weights=(1.0 - float_labels) * weights, 851 ) 852 853 # Obtain the reverse cumulative sum. 854 tp = np.cumsum(tp_buckets[::-1])[::-1] 855 fp = np.cumsum(fp_buckets[::-1])[::-1] 856 tn = fp[0] - fp 857 fn = tp[0] - tp 858 precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp) 859 recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn) 860 return np.stack((tp, fp, tn, fn, precision, recall)) 861 862 863def _get_tensor_summary( 864 name, display_name, description, tensor, content_type, components, json_config 865): 866 """Create a tensor summary with summary metadata. 867 868 Args: 869 name: Uniquely identifiable name of the summary op. Could be replaced by 870 combination of name and type to make it unique even outside of this 871 summary. 872 display_name: Will be used as the display name in TensorBoard. 873 Defaults to `name`. 874 description: A longform readable description of the summary data. Markdown 875 is supported. 876 tensor: Tensor to display in summary. 877 content_type: Type of content inside the Tensor. 878 components: Bitmask representing present parts (vertices, colors, etc.) that 879 belong to the summary. 880 json_config: A string, JSON-serialized dictionary of ThreeJS classes 881 configuration. 882 883 Returns: 884 Tensor summary with metadata. 885 """ 886 import torch 887 from tensorboard.plugins.mesh import metadata 888 889 tensor = torch.as_tensor(tensor) 890 891 tensor_metadata = metadata.create_summary_metadata( 892 name, 893 display_name, 894 content_type, 895 components, 896 tensor.shape, 897 description, 898 json_config=json_config, 899 ) 900 901 tensor = TensorProto( 902 dtype="DT_FLOAT", 903 float_val=tensor.reshape(-1).tolist(), 904 tensor_shape=TensorShapeProto( 905 dim=[ 906 TensorShapeProto.Dim(size=tensor.shape[0]), 907 TensorShapeProto.Dim(size=tensor.shape[1]), 908 TensorShapeProto.Dim(size=tensor.shape[2]), 909 ] 910 ), 911 ) 912 913 tensor_summary = Summary.Value( 914 tag=metadata.get_instance_name(name, content_type), 915 tensor=tensor, 916 metadata=tensor_metadata, 917 ) 918 919 return tensor_summary 920 921 922def _get_json_config(config_dict): 923 """Parse and returns JSON string from python dictionary.""" 924 json_config = "{}" 925 if config_dict is not None: 926 json_config = json.dumps(config_dict, sort_keys=True) 927 return json_config 928 929 930# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/mesh/summary.py 931def mesh( 932 tag, vertices, colors, faces, config_dict, display_name=None, description=None 933): 934 """Output a merged `Summary` protocol buffer with a mesh/point cloud. 935 936 Args: 937 tag: A name for this summary operation. 938 vertices: Tensor of shape `[dim_1, ..., dim_n, 3]` representing the 3D 939 coordinates of vertices. 940 faces: Tensor of shape `[dim_1, ..., dim_n, 3]` containing indices of 941 vertices within each triangle. 942 colors: Tensor of shape `[dim_1, ..., dim_n, 3]` containing colors for each 943 vertex. 944 display_name: If set, will be used as the display name in TensorBoard. 945 Defaults to `name`. 946 description: A longform readable description of the summary data. Markdown 947 is supported. 948 config_dict: Dictionary with ThreeJS classes names and configuration. 949 950 Returns: 951 Merged summary for mesh/point cloud representation. 952 """ 953 from tensorboard.plugins.mesh import metadata 954 from tensorboard.plugins.mesh.plugin_data_pb2 import MeshPluginData 955 956 json_config = _get_json_config(config_dict) 957 958 summaries = [] 959 tensors = [ 960 (vertices, MeshPluginData.VERTEX), 961 (faces, MeshPluginData.FACE), 962 (colors, MeshPluginData.COLOR), 963 ] 964 tensors = [tensor for tensor in tensors if tensor[0] is not None] 965 components = metadata.get_components_bitmask( 966 [content_type for (tensor, content_type) in tensors] 967 ) 968 969 for tensor, content_type in tensors: 970 summaries.append( 971 _get_tensor_summary( 972 tag, 973 display_name, 974 description, 975 tensor, 976 content_type, 977 components, 978 json_config, 979 ) 980 ) 981 982 return Summary(value=summaries) 983