1# Owner(s): ["module: unknown"] 2 3import io 4import os 5import shutil 6import sys 7import tempfile 8import unittest 9from pathlib import Path 10 11import expecttest 12import numpy as np 13 14 15TEST_TENSORBOARD = True 16try: 17 import tensorboard.summary.writer.event_file_writer # noqa: F401 18 from tensorboard.compat.proto.summary_pb2 import Summary 19except ImportError: 20 TEST_TENSORBOARD = False 21 22HAS_TORCHVISION = True 23try: 24 import torchvision 25except ImportError: 26 HAS_TORCHVISION = False 27skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") 28 29TEST_MATPLOTLIB = True 30try: 31 import matplotlib 32 if os.environ.get('DISPLAY', '') == '': 33 matplotlib.use('Agg') 34 import matplotlib.pyplot as plt 35except ImportError: 36 TEST_MATPLOTLIB = False 37skipIfNoMatplotlib = unittest.skipIf(not TEST_MATPLOTLIB, "no matplotlib") 38 39import torch 40from torch.testing._internal.common_utils import ( 41 instantiate_parametrized_tests, 42 IS_MACOS, 43 IS_WINDOWS, 44 parametrize, 45 run_tests, 46 TEST_WITH_CROSSREF, 47 TestCase, 48 skipIfTorchDynamo, 49) 50 51 52def tensor_N(shape, dtype=float): 53 numel = np.prod(shape) 54 x = (np.arange(numel, dtype=dtype)).reshape(shape) 55 return x 56 57class BaseTestCase(TestCase): 58 """ Base class used for all TensorBoard tests """ 59 def setUp(self): 60 super().setUp() 61 if not TEST_TENSORBOARD: 62 return self.skipTest("Skip the test since TensorBoard is not installed") 63 if TEST_WITH_CROSSREF: 64 return self.skipTest("Don't run TensorBoard tests with crossref") 65 self.temp_dirs = [] 66 67 def createSummaryWriter(self): 68 # Just to get the name of the directory in a writable place. tearDown() 69 # is responsible for clean-ups. 70 temp_dir = tempfile.TemporaryDirectory(prefix="test_tensorboard").name 71 self.temp_dirs.append(temp_dir) 72 return SummaryWriter(temp_dir) 73 74 def tearDown(self): 75 super().tearDown() 76 # Remove directories created by SummaryWriter 77 for temp_dir in self.temp_dirs: 78 if os.path.exists(temp_dir): 79 shutil.rmtree(temp_dir) 80 81 82if TEST_TENSORBOARD: 83 from google.protobuf import text_format 84 from PIL import Image 85 from tensorboard.compat.proto.graph_pb2 import GraphDef 86 from tensorboard.compat.proto.types_pb2 import DataType 87 88 from torch.utils.tensorboard import summary, SummaryWriter 89 from torch.utils.tensorboard._convert_np import make_np 90 from torch.utils.tensorboard._pytorch_graph import graph 91 from torch.utils.tensorboard._utils import _prepare_video, convert_to_HWC 92 from torch.utils.tensorboard.summary import int_to_half, tensor_proto 93 94class TestTensorBoardPyTorchNumpy(BaseTestCase): 95 def test_pytorch_np(self): 96 tensors = [torch.rand(3, 10, 10), torch.rand(1), torch.rand(1, 2, 3, 4, 5)] 97 for tensor in tensors: 98 # regular tensor 99 self.assertIsInstance(make_np(tensor), np.ndarray) 100 101 # CUDA tensor 102 if torch.cuda.is_available(): 103 self.assertIsInstance(make_np(tensor.cuda()), np.ndarray) 104 105 # regular variable 106 self.assertIsInstance(make_np(torch.autograd.Variable(tensor)), np.ndarray) 107 108 # CUDA variable 109 if torch.cuda.is_available(): 110 self.assertIsInstance(make_np(torch.autograd.Variable(tensor).cuda()), np.ndarray) 111 112 # python primitive type 113 self.assertIsInstance(make_np(0), np.ndarray) 114 self.assertIsInstance(make_np(0.1), np.ndarray) 115 116 def test_pytorch_autograd_np(self): 117 x = torch.autograd.Variable(torch.empty(1)) 118 self.assertIsInstance(make_np(x), np.ndarray) 119 120 def test_pytorch_write(self): 121 with self.createSummaryWriter() as w: 122 w.add_scalar('scalar', torch.autograd.Variable(torch.rand(1)), 0) 123 124 def test_pytorch_histogram(self): 125 with self.createSummaryWriter() as w: 126 w.add_histogram('float histogram', torch.rand((50,))) 127 w.add_histogram('int histogram', torch.randint(0, 100, (50,))) 128 w.add_histogram('bfloat16 histogram', torch.rand(50, dtype=torch.bfloat16)) 129 130 def test_pytorch_histogram_raw(self): 131 with self.createSummaryWriter() as w: 132 num = 50 133 floats = make_np(torch.rand((num,))) 134 bins = [0.0, 0.25, 0.5, 0.75, 1.0] 135 counts, limits = np.histogram(floats, bins) 136 sum_sq = floats.dot(floats).item() 137 w.add_histogram_raw('float histogram raw', 138 min=floats.min().item(), 139 max=floats.max().item(), 140 num=num, 141 sum=floats.sum().item(), 142 sum_squares=sum_sq, 143 bucket_limits=limits[1:].tolist(), 144 bucket_counts=counts.tolist()) 145 146 ints = make_np(torch.randint(0, 100, (num,))) 147 bins = [0, 25, 50, 75, 100] 148 counts, limits = np.histogram(ints, bins) 149 sum_sq = ints.dot(ints).item() 150 w.add_histogram_raw('int histogram raw', 151 min=ints.min().item(), 152 max=ints.max().item(), 153 num=num, 154 sum=ints.sum().item(), 155 sum_squares=sum_sq, 156 bucket_limits=limits[1:].tolist(), 157 bucket_counts=counts.tolist()) 158 159 ints = torch.tensor(range(0, 100)).float() 160 nbins = 100 161 counts = torch.histc(ints, bins=nbins, min=0, max=99) 162 limits = torch.tensor(range(nbins)) 163 sum_sq = ints.dot(ints).item() 164 w.add_histogram_raw('int histogram raw', 165 min=ints.min().item(), 166 max=ints.max().item(), 167 num=num, 168 sum=ints.sum().item(), 169 sum_squares=sum_sq, 170 bucket_limits=limits.tolist(), 171 bucket_counts=counts.tolist()) 172 173class TestTensorBoardUtils(BaseTestCase): 174 def test_to_HWC(self): 175 test_image = np.random.randint(0, 256, size=(3, 32, 32), dtype=np.uint8) 176 converted = convert_to_HWC(test_image, 'chw') 177 self.assertEqual(converted.shape, (32, 32, 3)) 178 test_image = np.random.randint(0, 256, size=(16, 3, 32, 32), dtype=np.uint8) 179 converted = convert_to_HWC(test_image, 'nchw') 180 self.assertEqual(converted.shape, (64, 256, 3)) 181 test_image = np.random.randint(0, 256, size=(32, 32), dtype=np.uint8) 182 converted = convert_to_HWC(test_image, 'hw') 183 self.assertEqual(converted.shape, (32, 32, 3)) 184 185 def test_convert_to_HWC_dtype_remains_same(self): 186 # test to ensure convert_to_HWC restores the dtype of input np array and 187 # thus the scale_factor calculated for the image is 1 188 test_image = torch.tensor([[[[1, 2, 3], [4, 5, 6]]]], dtype=torch.uint8) 189 tensor = make_np(test_image) 190 tensor = convert_to_HWC(tensor, 'NCHW') 191 scale_factor = summary._calc_scale_factor(tensor) 192 self.assertEqual(scale_factor, 1, msg='Values are already in [0, 255], scale factor should be 1') 193 194 195 def test_prepare_video(self): 196 # At each timeframe, the sum over all other 197 # dimensions of the video should be the same. 198 shapes = [ 199 (16, 30, 3, 28, 28), 200 (36, 30, 3, 28, 28), 201 (19, 29, 3, 23, 19), 202 (3, 3, 3, 3, 3) 203 ] 204 for s in shapes: 205 V_input = np.random.random(s) 206 V_after = _prepare_video(np.copy(V_input)) 207 total_frame = s[1] 208 V_input = np.swapaxes(V_input, 0, 1) 209 for f in range(total_frame): 210 x = np.reshape(V_input[f], newshape=(-1)) 211 y = np.reshape(V_after[f], newshape=(-1)) 212 np.testing.assert_array_almost_equal(np.sum(x), np.sum(y)) 213 214 def test_numpy_vid_uint8(self): 215 V_input = np.random.randint(0, 256, (16, 30, 3, 28, 28)).astype(np.uint8) 216 V_after = _prepare_video(np.copy(V_input)) * 255 217 total_frame = V_input.shape[1] 218 V_input = np.swapaxes(V_input, 0, 1) 219 for f in range(total_frame): 220 x = np.reshape(V_input[f], newshape=(-1)) 221 y = np.reshape(V_after[f], newshape=(-1)) 222 np.testing.assert_array_almost_equal(np.sum(x), np.sum(y)) 223 224freqs = [262, 294, 330, 349, 392, 440, 440, 440, 440, 440, 440] 225 226true_positive_counts = [75, 64, 21, 5, 0] 227false_positive_counts = [150, 105, 18, 0, 0] 228true_negative_counts = [0, 45, 132, 150, 150] 229false_negative_counts = [0, 11, 54, 70, 75] 230precision = [0.3333333, 0.3786982, 0.5384616, 1.0, 0.0] 231recall = [1.0, 0.8533334, 0.28, 0.0666667, 0.0] 232 233class TestTensorBoardWriter(BaseTestCase): 234 def test_writer(self): 235 with self.createSummaryWriter() as writer: 236 sample_rate = 44100 237 238 n_iter = 0 239 writer.add_hparams( 240 {'lr': 0.1, 'bsize': 1}, 241 {'hparam/accuracy': 10, 'hparam/loss': 10} 242 ) 243 writer.add_scalar('data/scalar_systemtime', 0.1, n_iter) 244 writer.add_scalar('data/scalar_customtime', 0.2, n_iter, walltime=n_iter) 245 writer.add_scalar('data/new_style', 0.2, n_iter, new_style=True) 246 writer.add_scalars('data/scalar_group', { 247 "xsinx": n_iter * np.sin(n_iter), 248 "xcosx": n_iter * np.cos(n_iter), 249 "arctanx": np.arctan(n_iter) 250 }, n_iter) 251 x = np.zeros((32, 3, 64, 64)) # output from network 252 writer.add_images('Image', x, n_iter) # Tensor 253 writer.add_image_with_boxes('imagebox', 254 np.zeros((3, 64, 64)), 255 np.array([[10, 10, 40, 40], [40, 40, 60, 60]]), 256 n_iter) 257 x = np.zeros(sample_rate * 2) 258 259 writer.add_audio('myAudio', x, n_iter) 260 writer.add_video('myVideo', np.random.rand(16, 48, 1, 28, 28).astype(np.float32), n_iter) 261 writer.add_text('Text', 'text logged at step:' + str(n_iter), n_iter) 262 writer.add_text('markdown Text', '''a|b\n-|-\nc|d''', n_iter) 263 writer.add_histogram('hist', np.random.rand(100, 100), n_iter) 264 writer.add_pr_curve('xoxo', np.random.randint(2, size=100), np.random.rand( 265 100), n_iter) # needs tensorboard 0.4RC or later 266 writer.add_pr_curve_raw('prcurve with raw data', true_positive_counts, 267 false_positive_counts, 268 true_negative_counts, 269 false_negative_counts, 270 precision, 271 recall, n_iter) 272 273 v = np.array([[[1, 1, 1], [-1, -1, 1], [1, -1, -1], [-1, 1, -1]]], dtype=float) 274 c = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 0, 255]]], dtype=int) 275 f = np.array([[[0, 2, 3], [0, 3, 1], [0, 1, 2], [1, 3, 2]]], dtype=int) 276 writer.add_mesh('my_mesh', vertices=v, colors=c, faces=f) 277 278class TestTensorBoardSummaryWriter(BaseTestCase): 279 def test_summary_writer_ctx(self): 280 # after using a SummaryWriter as a ctx it should be closed 281 with self.createSummaryWriter() as writer: 282 writer.add_scalar('test', 1) 283 self.assertIs(writer.file_writer, None) 284 285 def test_summary_writer_close(self): 286 # Opening and closing SummaryWriter a lot should not run into 287 # OSError: [Errno 24] Too many open files 288 passed = True 289 try: 290 writer = self.createSummaryWriter() 291 writer.close() 292 except OSError: 293 passed = False 294 295 self.assertTrue(passed) 296 297 def test_pathlib(self): 298 with tempfile.TemporaryDirectory(prefix="test_tensorboard_pathlib") as d: 299 p = Path(d) 300 with SummaryWriter(p) as writer: 301 writer.add_scalar('test', 1) 302 303class TestTensorBoardEmbedding(BaseTestCase): 304 def test_embedding(self): 305 w = self.createSummaryWriter() 306 all_features = torch.tensor([[1., 2., 3.], [5., 4., 1.], [3., 7., 7.]]) 307 all_labels = torch.tensor([33., 44., 55.]) 308 all_images = torch.zeros(3, 3, 5, 5) 309 310 w.add_embedding(all_features, 311 metadata=all_labels, 312 label_img=all_images, 313 global_step=2) 314 315 dataset_label = ['test'] * 2 + ['train'] * 2 316 all_labels = list(zip(all_labels, dataset_label)) 317 w.add_embedding(all_features, 318 metadata=all_labels, 319 label_img=all_images, 320 metadata_header=['digit', 'dataset'], 321 global_step=2) 322 # assert... 323 324 def test_embedding_64(self): 325 w = self.createSummaryWriter() 326 all_features = torch.tensor([[1., 2., 3.], [5., 4., 1.], [3., 7., 7.]]) 327 all_labels = torch.tensor([33., 44., 55.]) 328 all_images = torch.zeros((3, 3, 5, 5), dtype=torch.float64) 329 330 w.add_embedding(all_features, 331 metadata=all_labels, 332 label_img=all_images, 333 global_step=2) 334 335 dataset_label = ['test'] * 2 + ['train'] * 2 336 all_labels = list(zip(all_labels, dataset_label)) 337 w.add_embedding(all_features, 338 metadata=all_labels, 339 label_img=all_images, 340 metadata_header=['digit', 'dataset'], 341 global_step=2) 342 343class TestTensorBoardSummary(BaseTestCase): 344 def test_uint8_image(self): 345 ''' 346 Tests that uint8 image (pixel values in [0, 255]) is not changed 347 ''' 348 test_image = np.random.randint(0, 256, size=(3, 32, 32), dtype=np.uint8) 349 scale_factor = summary._calc_scale_factor(test_image) 350 self.assertEqual(scale_factor, 1, msg='Values are already in [0, 255], scale factor should be 1') 351 352 def test_float32_image(self): 353 ''' 354 Tests that float32 image (pixel values in [0, 1]) are scaled correctly 355 to [0, 255] 356 ''' 357 test_image = np.random.rand(3, 32, 32).astype(np.float32) 358 scale_factor = summary._calc_scale_factor(test_image) 359 self.assertEqual(scale_factor, 255, msg='Values are in [0, 1], scale factor should be 255') 360 361 def test_list_input(self): 362 with self.assertRaises(Exception) as e_info: 363 summary.histogram('dummy', [1, 3, 4, 5, 6], 'tensorflow') 364 365 def test_empty_input(self): 366 with self.assertRaises(Exception) as e_info: 367 summary.histogram('dummy', np.ndarray(0), 'tensorflow') 368 369 def test_image_with_boxes(self): 370 self.assertTrue(compare_image_proto(summary.image_boxes('dummy', 371 tensor_N(shape=(3, 32, 32)), 372 np.array([[10, 10, 40, 40]])), 373 self)) 374 375 def test_image_with_one_channel(self): 376 self.assertTrue(compare_image_proto( 377 summary.image('dummy', 378 tensor_N(shape=(1, 8, 8)), 379 dataformats='CHW'), 380 self)) # noqa: E131 381 382 def test_image_with_one_channel_batched(self): 383 self.assertTrue(compare_image_proto( 384 summary.image('dummy', 385 tensor_N(shape=(2, 1, 8, 8)), 386 dataformats='NCHW'), 387 self)) # noqa: E131 388 389 def test_image_with_3_channel_batched(self): 390 self.assertTrue(compare_image_proto( 391 summary.image('dummy', 392 tensor_N(shape=(2, 3, 8, 8)), 393 dataformats='NCHW'), 394 self)) # noqa: E131 395 396 def test_image_without_channel(self): 397 self.assertTrue(compare_image_proto( 398 summary.image('dummy', 399 tensor_N(shape=(8, 8)), 400 dataformats='HW'), 401 self)) # noqa: E131 402 403 def test_video(self): 404 try: 405 import moviepy # noqa: F401 406 except ImportError: 407 return 408 self.assertTrue(compare_proto(summary.video('dummy', tensor_N(shape=(4, 3, 1, 8, 8))), self)) 409 summary.video('dummy', np.random.rand(16, 48, 1, 28, 28)) 410 summary.video('dummy', np.random.rand(20, 7, 1, 8, 8)) 411 412 @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ") 413 def test_audio(self): 414 self.assertTrue(compare_proto(summary.audio('dummy', tensor_N(shape=(42,))), self)) 415 416 @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ") 417 def test_text(self): 418 self.assertTrue(compare_proto(summary.text('dummy', 'text 123'), self)) 419 420 @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ") 421 def test_histogram_auto(self): 422 self.assertTrue(compare_proto(summary.histogram('dummy', tensor_N(shape=(1024,)), bins='auto', max_bins=5), self)) 423 424 @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ") 425 def test_histogram_fd(self): 426 self.assertTrue(compare_proto(summary.histogram('dummy', tensor_N(shape=(1024,)), bins='fd', max_bins=5), self)) 427 428 @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ") 429 def test_histogram_doane(self): 430 self.assertTrue(compare_proto(summary.histogram('dummy', tensor_N(shape=(1024,)), bins='doane', max_bins=5), self)) 431 432 def test_custom_scalars(self): 433 layout = { 434 'Taiwan': { 435 'twse': ['Multiline', ['twse/0050', 'twse/2330']] 436 }, 437 'USA': { 438 'dow': ['Margin', ['dow/aaa', 'dow/bbb', 'dow/ccc']], 439 'nasdaq': ['Margin', ['nasdaq/aaa', 'nasdaq/bbb', 'nasdaq/ccc']] 440 } 441 } 442 summary.custom_scalars(layout) # only smoke test. Because protobuf in python2/3 serialize dictionary differently. 443 444 445 @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ") 446 def test_mesh(self): 447 v = np.array([[[1, 1, 1], [-1, -1, 1], [1, -1, -1], [-1, 1, -1]]], dtype=float) 448 c = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 0, 255]]], dtype=int) 449 f = np.array([[[0, 2, 3], [0, 3, 1], [0, 1, 2], [1, 3, 2]]], dtype=int) 450 mesh = summary.mesh('my_mesh', vertices=v, colors=c, faces=f, config_dict=None) 451 self.assertTrue(compare_proto(mesh, self)) 452 453 @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ") 454 def test_scalar_new_style(self): 455 scalar = summary.scalar('test_scalar', 1.0, new_style=True) 456 self.assertTrue(compare_proto(scalar, self)) 457 with self.assertRaises(AssertionError): 458 summary.scalar('test_scalar2', torch.Tensor([1, 2, 3]), new_style=True) 459 460 461def remove_whitespace(string): 462 return string.replace(' ', '').replace('\t', '').replace('\n', '') 463 464def get_expected_file(function_ptr): 465 module_id = function_ptr.__class__.__module__ 466 test_file = sys.modules[module_id].__file__ 467 # Look for the .py file (since __file__ could be pyc). 468 test_file = ".".join(test_file.split('.')[:-1]) + '.py' 469 470 # Use realpath to follow symlinks appropriately. 471 test_dir = os.path.dirname(os.path.realpath(test_file)) 472 functionName = function_ptr.id().split('.')[-1] 473 return os.path.join(test_dir, 474 "expect", 475 'TestTensorBoard.' + functionName + ".expect") 476 477def read_expected_content(function_ptr): 478 expected_file = get_expected_file(function_ptr) 479 assert os.path.exists(expected_file), expected_file 480 with open(expected_file) as f: 481 return f.read() 482 483def compare_image_proto(actual_proto, function_ptr): 484 if expecttest.ACCEPT: 485 expected_file = get_expected_file(function_ptr) 486 with open(expected_file, 'w') as f: 487 f.write(text_format.MessageToString(actual_proto)) 488 return True 489 expected_str = read_expected_content(function_ptr) 490 expected_proto = Summary() 491 text_format.Parse(expected_str, expected_proto) 492 493 [actual, expected] = [actual_proto.value[0], expected_proto.value[0]] 494 actual_img = Image.open(io.BytesIO(actual.image.encoded_image_string)) 495 expected_img = Image.open(io.BytesIO(expected.image.encoded_image_string)) 496 497 return ( 498 actual.tag == expected.tag and 499 actual.image.height == expected.image.height and 500 actual.image.width == expected.image.width and 501 actual.image.colorspace == expected.image.colorspace and 502 actual_img == expected_img 503 ) 504 505def compare_proto(str_to_compare, function_ptr): 506 if expecttest.ACCEPT: 507 write_proto(str_to_compare, function_ptr) 508 return True 509 expected = read_expected_content(function_ptr) 510 str_to_compare = str(str_to_compare) 511 return remove_whitespace(str_to_compare) == remove_whitespace(expected) 512 513def write_proto(str_to_compare, function_ptr): 514 expected_file = get_expected_file(function_ptr) 515 with open(expected_file, 'w') as f: 516 f.write(str(str_to_compare)) 517 518class TestTensorBoardPytorchGraph(BaseTestCase): 519 def test_pytorch_graph(self): 520 dummy_input = (torch.zeros(1, 3),) 521 522 class myLinear(torch.nn.Module): 523 def __init__(self) -> None: 524 super().__init__() 525 self.l = torch.nn.Linear(3, 5) 526 527 def forward(self, x): 528 return self.l(x) 529 530 with self.createSummaryWriter() as w: 531 w.add_graph(myLinear(), dummy_input) 532 533 actual_proto, _ = graph(myLinear(), dummy_input) 534 535 expected_str = read_expected_content(self) 536 expected_proto = GraphDef() 537 text_format.Parse(expected_str, expected_proto) 538 539 self.assertEqual(len(expected_proto.node), len(actual_proto.node)) 540 for i in range(len(expected_proto.node)): 541 expected_node = expected_proto.node[i] 542 actual_node = actual_proto.node[i] 543 self.assertEqual(expected_node.name, actual_node.name) 544 self.assertEqual(expected_node.op, actual_node.op) 545 self.assertEqual(expected_node.input, actual_node.input) 546 self.assertEqual(expected_node.device, actual_node.device) 547 self.assertEqual( 548 sorted(expected_node.attr.keys()), sorted(actual_node.attr.keys())) 549 550 def test_nested_nn_squential(self): 551 552 dummy_input = torch.randn(2, 3) 553 554 class InnerNNSquential(torch.nn.Module): 555 def __init__(self, dim1, dim2): 556 super().__init__() 557 self.inner_nn_squential = torch.nn.Sequential( 558 torch.nn.Linear(dim1, dim2), 559 torch.nn.Linear(dim2, dim1), 560 ) 561 562 def forward(self, x): 563 x = self.inner_nn_squential(x) 564 return x 565 566 class OuterNNSquential(torch.nn.Module): 567 def __init__(self, dim1=3, dim2=4, depth=2): 568 super().__init__() 569 layers = [] 570 for _ in range(depth): 571 layers.append(InnerNNSquential(dim1, dim2)) 572 self.outer_nn_squential = torch.nn.Sequential(*layers) 573 574 def forward(self, x): 575 x = self.outer_nn_squential(x) 576 return x 577 578 with self.createSummaryWriter() as w: 579 w.add_graph(OuterNNSquential(), dummy_input) 580 581 actual_proto, _ = graph(OuterNNSquential(), dummy_input) 582 583 expected_str = read_expected_content(self) 584 expected_proto = GraphDef() 585 text_format.Parse(expected_str, expected_proto) 586 587 self.assertEqual(len(expected_proto.node), len(actual_proto.node)) 588 for i in range(len(expected_proto.node)): 589 expected_node = expected_proto.node[i] 590 actual_node = actual_proto.node[i] 591 self.assertEqual(expected_node.name, actual_node.name) 592 self.assertEqual(expected_node.op, actual_node.op) 593 self.assertEqual(expected_node.input, actual_node.input) 594 self.assertEqual(expected_node.device, actual_node.device) 595 self.assertEqual( 596 sorted(expected_node.attr.keys()), sorted(actual_node.attr.keys())) 597 598 def test_pytorch_graph_dict_input(self): 599 class Model(torch.nn.Module): 600 def __init__(self) -> None: 601 super().__init__() 602 self.l = torch.nn.Linear(3, 5) 603 604 def forward(self, x): 605 return self.l(x) 606 607 class ModelDict(torch.nn.Module): 608 def __init__(self) -> None: 609 super().__init__() 610 self.l = torch.nn.Linear(3, 5) 611 612 def forward(self, x): 613 return {"out": self.l(x)} 614 615 616 dummy_input = torch.zeros(1, 3) 617 618 with self.createSummaryWriter() as w: 619 w.add_graph(Model(), dummy_input) 620 621 with self.createSummaryWriter() as w: 622 w.add_graph(Model(), dummy_input, use_strict_trace=True) 623 624 # expect error: Encountering a dict at the output of the tracer... 625 with self.assertRaises(RuntimeError): 626 with self.createSummaryWriter() as w: 627 w.add_graph(ModelDict(), dummy_input, use_strict_trace=True) 628 629 with self.createSummaryWriter() as w: 630 w.add_graph(ModelDict(), dummy_input, use_strict_trace=False) 631 632 633 def test_mlp_graph(self): 634 dummy_input = (torch.zeros(2, 1, 28, 28),) 635 636 # This MLP class with the above input is expected 637 # to fail JIT optimizations as seen at 638 # https://github.com/pytorch/pytorch/issues/18903 639 # 640 # However, it should not raise an error during 641 # the add_graph call and still continue. 642 class myMLP(torch.nn.Module): 643 def __init__(self) -> None: 644 super().__init__() 645 self.input_len = 1 * 28 * 28 646 self.fc1 = torch.nn.Linear(self.input_len, 1200) 647 self.fc2 = torch.nn.Linear(1200, 1200) 648 self.fc3 = torch.nn.Linear(1200, 10) 649 650 def forward(self, x, update_batch_stats=True): 651 h = torch.nn.functional.relu( 652 self.fc1(x.view(-1, self.input_len))) 653 h = self.fc2(h) 654 h = torch.nn.functional.relu(h) 655 h = self.fc3(h) 656 return h 657 658 with self.createSummaryWriter() as w: 659 w.add_graph(myMLP(), dummy_input) 660 661 def test_wrong_input_size(self): 662 with self.assertRaises(RuntimeError) as e_info: 663 dummy_input = torch.rand(1, 9) 664 model = torch.nn.Linear(3, 5) 665 with self.createSummaryWriter() as w: 666 w.add_graph(model, dummy_input) # error 667 668 @skipIfNoTorchVision 669 def test_torchvision_smoke(self): 670 model_input_shapes = { 671 'alexnet': (2, 3, 224, 224), 672 'resnet34': (2, 3, 224, 224), 673 'resnet152': (2, 3, 224, 224), 674 'densenet121': (2, 3, 224, 224), 675 'vgg16': (2, 3, 224, 224), 676 'vgg19': (2, 3, 224, 224), 677 'vgg16_bn': (2, 3, 224, 224), 678 'vgg19_bn': (2, 3, 224, 224), 679 'mobilenet_v2': (2, 3, 224, 224), 680 } 681 for model_name, input_shape in model_input_shapes.items(): 682 with self.createSummaryWriter() as w: 683 model = getattr(torchvision.models, model_name)() 684 w.add_graph(model, torch.zeros(input_shape)) 685 686class TestTensorBoardFigure(BaseTestCase): 687 @skipIfNoMatplotlib 688 def test_figure(self): 689 writer = self.createSummaryWriter() 690 691 figure, axes = plt.figure(), plt.gca() 692 circle1 = plt.Circle((0.2, 0.5), 0.2, color='r') 693 circle2 = plt.Circle((0.8, 0.5), 0.2, color='g') 694 axes.add_patch(circle1) 695 axes.add_patch(circle2) 696 plt.axis('scaled') 697 plt.tight_layout() 698 699 writer.add_figure("add_figure/figure", figure, 0, close=False) 700 self.assertTrue(plt.fignum_exists(figure.number)) 701 702 writer.add_figure("add_figure/figure", figure, 1) 703 if matplotlib.__version__ != '3.3.0': 704 self.assertFalse(plt.fignum_exists(figure.number)) 705 else: 706 print("Skipping fignum_exists, see https://github.com/matplotlib/matplotlib/issues/18163") 707 708 writer.close() 709 710 @skipIfNoMatplotlib 711 def test_figure_list(self): 712 writer = self.createSummaryWriter() 713 714 figures = [] 715 for i in range(5): 716 figure = plt.figure() 717 plt.plot([i * 1, i * 2, i * 3], label="Plot " + str(i)) 718 plt.xlabel("X") 719 plt.xlabel("Y") 720 plt.legend() 721 plt.tight_layout() 722 figures.append(figure) 723 724 writer.add_figure("add_figure/figure_list", figures, 0, close=False) 725 self.assertTrue(all(plt.fignum_exists(figure.number) is True for figure in figures)) # noqa: F812 726 727 writer.add_figure("add_figure/figure_list", figures, 1) 728 if matplotlib.__version__ != '3.3.0': 729 self.assertTrue(all(plt.fignum_exists(figure.number) is False for figure in figures)) # noqa: F812 730 else: 731 print("Skipping fignum_exists, see https://github.com/matplotlib/matplotlib/issues/18163") 732 733 writer.close() 734 735class TestTensorBoardNumpy(BaseTestCase): 736 @unittest.skipIf(IS_WINDOWS, "Skipping on windows, see https://github.com/pytorch/pytorch/pull/109349 ") 737 @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ") 738 def test_scalar(self): 739 res = make_np(1.1) 740 self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,)) 741 res = make_np(1 << 64 - 1) # uint64_max 742 self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,)) 743 res = make_np(np.float16(1.00000087)) 744 self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,)) 745 res = make_np(np.float128(1.00008 + 9)) 746 self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,)) 747 res = make_np(np.int64(100000000000)) 748 self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,)) 749 750 def test_pytorch_np_expect_fail(self): 751 with self.assertRaises(NotImplementedError): 752 res = make_np({'pytorch': 1.0}) 753 754 755 756class TestTensorProtoSummary(BaseTestCase): 757 @parametrize( 758 "tensor_type,proto_type", 759 [ 760 (torch.float16, DataType.DT_HALF), 761 (torch.bfloat16, DataType.DT_BFLOAT16), 762 ], 763 ) 764 @skipIfTorchDynamo("Unsuitable test for Dynamo, behavior changes with version") 765 def test_half_tensor_proto(self, tensor_type, proto_type): 766 float_values = [1.0, 2.0, 3.0] 767 actual_proto = tensor_proto( 768 "dummy", 769 torch.tensor(float_values, dtype=tensor_type), 770 ).value[0].tensor 771 self.assertSequenceEqual( 772 [int_to_half(x) for x in actual_proto.half_val], 773 float_values, 774 ) 775 self.assertTrue(actual_proto.dtype == proto_type) 776 777 def test_float_tensor_proto(self): 778 float_values = [1.0, 2.0, 3.0] 779 actual_proto = ( 780 tensor_proto("dummy", torch.tensor(float_values)).value[0].tensor 781 ) 782 self.assertEqual(actual_proto.float_val, float_values) 783 self.assertTrue(actual_proto.dtype == DataType.DT_FLOAT) 784 785 def test_int_tensor_proto(self): 786 int_values = [1, 2, 3] 787 actual_proto = ( 788 tensor_proto("dummy", torch.tensor(int_values, dtype=torch.int32)) 789 .value[0] 790 .tensor 791 ) 792 self.assertEqual(actual_proto.int_val, int_values) 793 self.assertTrue(actual_proto.dtype == DataType.DT_INT32) 794 795 def test_scalar_tensor_proto(self): 796 scalar_value = 0.1 797 actual_proto = ( 798 tensor_proto("dummy", torch.tensor(scalar_value)).value[0].tensor 799 ) 800 self.assertAlmostEqual(actual_proto.float_val[0], scalar_value) 801 802 def test_complex_tensor_proto(self): 803 real = torch.tensor([1.0, 2.0]) 804 imag = torch.tensor([3.0, 4.0]) 805 actual_proto = ( 806 tensor_proto("dummy", torch.complex(real, imag)).value[0].tensor 807 ) 808 self.assertEqual(actual_proto.scomplex_val, [1.0, 3.0, 2.0, 4.0]) 809 810 def test_empty_tensor_proto(self): 811 actual_proto = tensor_proto("dummy", torch.empty(0)).value[0].tensor 812 self.assertEqual(actual_proto.float_val, []) 813 814instantiate_parametrized_tests(TestTensorProtoSummary) 815 816if __name__ == '__main__': 817 run_tests() 818