1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: unknown"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport io 4*da0073e9SAndroid Build Coastguard Workerimport os 5*da0073e9SAndroid Build Coastguard Workerimport shutil 6*da0073e9SAndroid Build Coastguard Workerimport sys 7*da0073e9SAndroid Build Coastguard Workerimport tempfile 8*da0073e9SAndroid Build Coastguard Workerimport unittest 9*da0073e9SAndroid Build Coastguard Workerfrom pathlib import Path 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Workerimport expecttest 12*da0073e9SAndroid Build Coastguard Workerimport numpy as np 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard WorkerTEST_TENSORBOARD = True 16*da0073e9SAndroid Build Coastguard Workertry: 17*da0073e9SAndroid Build Coastguard Worker import tensorboard.summary.writer.event_file_writer # noqa: F401 18*da0073e9SAndroid Build Coastguard Worker from tensorboard.compat.proto.summary_pb2 import Summary 19*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 20*da0073e9SAndroid Build Coastguard Worker TEST_TENSORBOARD = False 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard WorkerHAS_TORCHVISION = True 23*da0073e9SAndroid Build Coastguard Workertry: 24*da0073e9SAndroid Build Coastguard Worker import torchvision 25*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 26*da0073e9SAndroid Build Coastguard Worker HAS_TORCHVISION = False 27*da0073e9SAndroid Build Coastguard WorkerskipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard WorkerTEST_MATPLOTLIB = True 30*da0073e9SAndroid Build Coastguard Workertry: 31*da0073e9SAndroid Build Coastguard Worker import matplotlib 32*da0073e9SAndroid Build Coastguard Worker if os.environ.get('DISPLAY', '') == '': 33*da0073e9SAndroid Build Coastguard Worker matplotlib.use('Agg') 34*da0073e9SAndroid Build Coastguard Worker import matplotlib.pyplot as plt 35*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 36*da0073e9SAndroid Build Coastguard Worker TEST_MATPLOTLIB = False 37*da0073e9SAndroid Build Coastguard WorkerskipIfNoMatplotlib = unittest.skipIf(not TEST_MATPLOTLIB, "no matplotlib") 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Workerimport torch 40*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 41*da0073e9SAndroid Build Coastguard Worker instantiate_parametrized_tests, 42*da0073e9SAndroid Build Coastguard Worker IS_MACOS, 43*da0073e9SAndroid Build Coastguard Worker IS_WINDOWS, 44*da0073e9SAndroid Build Coastguard Worker parametrize, 45*da0073e9SAndroid Build Coastguard Worker run_tests, 46*da0073e9SAndroid Build Coastguard Worker TEST_WITH_CROSSREF, 47*da0073e9SAndroid Build Coastguard Worker TestCase, 48*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 49*da0073e9SAndroid Build Coastguard Worker) 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Workerdef tensor_N(shape, dtype=float): 53*da0073e9SAndroid Build Coastguard Worker numel = np.prod(shape) 54*da0073e9SAndroid Build Coastguard Worker x = (np.arange(numel, dtype=dtype)).reshape(shape) 55*da0073e9SAndroid Build Coastguard Worker return x 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Workerclass BaseTestCase(TestCase): 58*da0073e9SAndroid Build Coastguard Worker """ Base class used for all TensorBoard tests """ 59*da0073e9SAndroid Build Coastguard Worker def setUp(self): 60*da0073e9SAndroid Build Coastguard Worker super().setUp() 61*da0073e9SAndroid Build Coastguard Worker if not TEST_TENSORBOARD: 62*da0073e9SAndroid Build Coastguard Worker return self.skipTest("Skip the test since TensorBoard is not installed") 63*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_CROSSREF: 64*da0073e9SAndroid Build Coastguard Worker return self.skipTest("Don't run TensorBoard tests with crossref") 65*da0073e9SAndroid Build Coastguard Worker self.temp_dirs = [] 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker def createSummaryWriter(self): 68*da0073e9SAndroid Build Coastguard Worker # Just to get the name of the directory in a writable place. tearDown() 69*da0073e9SAndroid Build Coastguard Worker # is responsible for clean-ups. 70*da0073e9SAndroid Build Coastguard Worker temp_dir = tempfile.TemporaryDirectory(prefix="test_tensorboard").name 71*da0073e9SAndroid Build Coastguard Worker self.temp_dirs.append(temp_dir) 72*da0073e9SAndroid Build Coastguard Worker return SummaryWriter(temp_dir) 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 75*da0073e9SAndroid Build Coastguard Worker super().tearDown() 76*da0073e9SAndroid Build Coastguard Worker # Remove directories created by SummaryWriter 77*da0073e9SAndroid Build Coastguard Worker for temp_dir in self.temp_dirs: 78*da0073e9SAndroid Build Coastguard Worker if os.path.exists(temp_dir): 79*da0073e9SAndroid Build Coastguard Worker shutil.rmtree(temp_dir) 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Workerif TEST_TENSORBOARD: 83*da0073e9SAndroid Build Coastguard Worker from google.protobuf import text_format 84*da0073e9SAndroid Build Coastguard Worker from PIL import Image 85*da0073e9SAndroid Build Coastguard Worker from tensorboard.compat.proto.graph_pb2 import GraphDef 86*da0073e9SAndroid Build Coastguard Worker from tensorboard.compat.proto.types_pb2 import DataType 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker from torch.utils.tensorboard import summary, SummaryWriter 89*da0073e9SAndroid Build Coastguard Worker from torch.utils.tensorboard._convert_np import make_np 90*da0073e9SAndroid Build Coastguard Worker from torch.utils.tensorboard._pytorch_graph import graph 91*da0073e9SAndroid Build Coastguard Worker from torch.utils.tensorboard._utils import _prepare_video, convert_to_HWC 92*da0073e9SAndroid Build Coastguard Worker from torch.utils.tensorboard.summary import int_to_half, tensor_proto 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Workerclass TestTensorBoardPyTorchNumpy(BaseTestCase): 95*da0073e9SAndroid Build Coastguard Worker def test_pytorch_np(self): 96*da0073e9SAndroid Build Coastguard Worker tensors = [torch.rand(3, 10, 10), torch.rand(1), torch.rand(1, 2, 3, 4, 5)] 97*da0073e9SAndroid Build Coastguard Worker for tensor in tensors: 98*da0073e9SAndroid Build Coastguard Worker # regular tensor 99*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(make_np(tensor), np.ndarray) 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker # CUDA tensor 102*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 103*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(make_np(tensor.cuda()), np.ndarray) 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker # regular variable 106*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(make_np(torch.autograd.Variable(tensor)), np.ndarray) 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker # CUDA variable 109*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 110*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(make_np(torch.autograd.Variable(tensor).cuda()), np.ndarray) 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker # python primitive type 113*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(make_np(0), np.ndarray) 114*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(make_np(0.1), np.ndarray) 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker def test_pytorch_autograd_np(self): 117*da0073e9SAndroid Build Coastguard Worker x = torch.autograd.Variable(torch.empty(1)) 118*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(make_np(x), np.ndarray) 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker def test_pytorch_write(self): 121*da0073e9SAndroid Build Coastguard Worker with self.createSummaryWriter() as w: 122*da0073e9SAndroid Build Coastguard Worker w.add_scalar('scalar', torch.autograd.Variable(torch.rand(1)), 0) 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker def test_pytorch_histogram(self): 125*da0073e9SAndroid Build Coastguard Worker with self.createSummaryWriter() as w: 126*da0073e9SAndroid Build Coastguard Worker w.add_histogram('float histogram', torch.rand((50,))) 127*da0073e9SAndroid Build Coastguard Worker w.add_histogram('int histogram', torch.randint(0, 100, (50,))) 128*da0073e9SAndroid Build Coastguard Worker w.add_histogram('bfloat16 histogram', torch.rand(50, dtype=torch.bfloat16)) 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker def test_pytorch_histogram_raw(self): 131*da0073e9SAndroid Build Coastguard Worker with self.createSummaryWriter() as w: 132*da0073e9SAndroid Build Coastguard Worker num = 50 133*da0073e9SAndroid Build Coastguard Worker floats = make_np(torch.rand((num,))) 134*da0073e9SAndroid Build Coastguard Worker bins = [0.0, 0.25, 0.5, 0.75, 1.0] 135*da0073e9SAndroid Build Coastguard Worker counts, limits = np.histogram(floats, bins) 136*da0073e9SAndroid Build Coastguard Worker sum_sq = floats.dot(floats).item() 137*da0073e9SAndroid Build Coastguard Worker w.add_histogram_raw('float histogram raw', 138*da0073e9SAndroid Build Coastguard Worker min=floats.min().item(), 139*da0073e9SAndroid Build Coastguard Worker max=floats.max().item(), 140*da0073e9SAndroid Build Coastguard Worker num=num, 141*da0073e9SAndroid Build Coastguard Worker sum=floats.sum().item(), 142*da0073e9SAndroid Build Coastguard Worker sum_squares=sum_sq, 143*da0073e9SAndroid Build Coastguard Worker bucket_limits=limits[1:].tolist(), 144*da0073e9SAndroid Build Coastguard Worker bucket_counts=counts.tolist()) 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker ints = make_np(torch.randint(0, 100, (num,))) 147*da0073e9SAndroid Build Coastguard Worker bins = [0, 25, 50, 75, 100] 148*da0073e9SAndroid Build Coastguard Worker counts, limits = np.histogram(ints, bins) 149*da0073e9SAndroid Build Coastguard Worker sum_sq = ints.dot(ints).item() 150*da0073e9SAndroid Build Coastguard Worker w.add_histogram_raw('int histogram raw', 151*da0073e9SAndroid Build Coastguard Worker min=ints.min().item(), 152*da0073e9SAndroid Build Coastguard Worker max=ints.max().item(), 153*da0073e9SAndroid Build Coastguard Worker num=num, 154*da0073e9SAndroid Build Coastguard Worker sum=ints.sum().item(), 155*da0073e9SAndroid Build Coastguard Worker sum_squares=sum_sq, 156*da0073e9SAndroid Build Coastguard Worker bucket_limits=limits[1:].tolist(), 157*da0073e9SAndroid Build Coastguard Worker bucket_counts=counts.tolist()) 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker ints = torch.tensor(range(0, 100)).float() 160*da0073e9SAndroid Build Coastguard Worker nbins = 100 161*da0073e9SAndroid Build Coastguard Worker counts = torch.histc(ints, bins=nbins, min=0, max=99) 162*da0073e9SAndroid Build Coastguard Worker limits = torch.tensor(range(nbins)) 163*da0073e9SAndroid Build Coastguard Worker sum_sq = ints.dot(ints).item() 164*da0073e9SAndroid Build Coastguard Worker w.add_histogram_raw('int histogram raw', 165*da0073e9SAndroid Build Coastguard Worker min=ints.min().item(), 166*da0073e9SAndroid Build Coastguard Worker max=ints.max().item(), 167*da0073e9SAndroid Build Coastguard Worker num=num, 168*da0073e9SAndroid Build Coastguard Worker sum=ints.sum().item(), 169*da0073e9SAndroid Build Coastguard Worker sum_squares=sum_sq, 170*da0073e9SAndroid Build Coastguard Worker bucket_limits=limits.tolist(), 171*da0073e9SAndroid Build Coastguard Worker bucket_counts=counts.tolist()) 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Workerclass TestTensorBoardUtils(BaseTestCase): 174*da0073e9SAndroid Build Coastguard Worker def test_to_HWC(self): 175*da0073e9SAndroid Build Coastguard Worker test_image = np.random.randint(0, 256, size=(3, 32, 32), dtype=np.uint8) 176*da0073e9SAndroid Build Coastguard Worker converted = convert_to_HWC(test_image, 'chw') 177*da0073e9SAndroid Build Coastguard Worker self.assertEqual(converted.shape, (32, 32, 3)) 178*da0073e9SAndroid Build Coastguard Worker test_image = np.random.randint(0, 256, size=(16, 3, 32, 32), dtype=np.uint8) 179*da0073e9SAndroid Build Coastguard Worker converted = convert_to_HWC(test_image, 'nchw') 180*da0073e9SAndroid Build Coastguard Worker self.assertEqual(converted.shape, (64, 256, 3)) 181*da0073e9SAndroid Build Coastguard Worker test_image = np.random.randint(0, 256, size=(32, 32), dtype=np.uint8) 182*da0073e9SAndroid Build Coastguard Worker converted = convert_to_HWC(test_image, 'hw') 183*da0073e9SAndroid Build Coastguard Worker self.assertEqual(converted.shape, (32, 32, 3)) 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Worker def test_convert_to_HWC_dtype_remains_same(self): 186*da0073e9SAndroid Build Coastguard Worker # test to ensure convert_to_HWC restores the dtype of input np array and 187*da0073e9SAndroid Build Coastguard Worker # thus the scale_factor calculated for the image is 1 188*da0073e9SAndroid Build Coastguard Worker test_image = torch.tensor([[[[1, 2, 3], [4, 5, 6]]]], dtype=torch.uint8) 189*da0073e9SAndroid Build Coastguard Worker tensor = make_np(test_image) 190*da0073e9SAndroid Build Coastguard Worker tensor = convert_to_HWC(tensor, 'NCHW') 191*da0073e9SAndroid Build Coastguard Worker scale_factor = summary._calc_scale_factor(tensor) 192*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scale_factor, 1, msg='Values are already in [0, 255], scale factor should be 1') 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker def test_prepare_video(self): 196*da0073e9SAndroid Build Coastguard Worker # At each timeframe, the sum over all other 197*da0073e9SAndroid Build Coastguard Worker # dimensions of the video should be the same. 198*da0073e9SAndroid Build Coastguard Worker shapes = [ 199*da0073e9SAndroid Build Coastguard Worker (16, 30, 3, 28, 28), 200*da0073e9SAndroid Build Coastguard Worker (36, 30, 3, 28, 28), 201*da0073e9SAndroid Build Coastguard Worker (19, 29, 3, 23, 19), 202*da0073e9SAndroid Build Coastguard Worker (3, 3, 3, 3, 3) 203*da0073e9SAndroid Build Coastguard Worker ] 204*da0073e9SAndroid Build Coastguard Worker for s in shapes: 205*da0073e9SAndroid Build Coastguard Worker V_input = np.random.random(s) 206*da0073e9SAndroid Build Coastguard Worker V_after = _prepare_video(np.copy(V_input)) 207*da0073e9SAndroid Build Coastguard Worker total_frame = s[1] 208*da0073e9SAndroid Build Coastguard Worker V_input = np.swapaxes(V_input, 0, 1) 209*da0073e9SAndroid Build Coastguard Worker for f in range(total_frame): 210*da0073e9SAndroid Build Coastguard Worker x = np.reshape(V_input[f], newshape=(-1)) 211*da0073e9SAndroid Build Coastguard Worker y = np.reshape(V_after[f], newshape=(-1)) 212*da0073e9SAndroid Build Coastguard Worker np.testing.assert_array_almost_equal(np.sum(x), np.sum(y)) 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker def test_numpy_vid_uint8(self): 215*da0073e9SAndroid Build Coastguard Worker V_input = np.random.randint(0, 256, (16, 30, 3, 28, 28)).astype(np.uint8) 216*da0073e9SAndroid Build Coastguard Worker V_after = _prepare_video(np.copy(V_input)) * 255 217*da0073e9SAndroid Build Coastguard Worker total_frame = V_input.shape[1] 218*da0073e9SAndroid Build Coastguard Worker V_input = np.swapaxes(V_input, 0, 1) 219*da0073e9SAndroid Build Coastguard Worker for f in range(total_frame): 220*da0073e9SAndroid Build Coastguard Worker x = np.reshape(V_input[f], newshape=(-1)) 221*da0073e9SAndroid Build Coastguard Worker y = np.reshape(V_after[f], newshape=(-1)) 222*da0073e9SAndroid Build Coastguard Worker np.testing.assert_array_almost_equal(np.sum(x), np.sum(y)) 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Workerfreqs = [262, 294, 330, 349, 392, 440, 440, 440, 440, 440, 440] 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Workertrue_positive_counts = [75, 64, 21, 5, 0] 227*da0073e9SAndroid Build Coastguard Workerfalse_positive_counts = [150, 105, 18, 0, 0] 228*da0073e9SAndroid Build Coastguard Workertrue_negative_counts = [0, 45, 132, 150, 150] 229*da0073e9SAndroid Build Coastguard Workerfalse_negative_counts = [0, 11, 54, 70, 75] 230*da0073e9SAndroid Build Coastguard Workerprecision = [0.3333333, 0.3786982, 0.5384616, 1.0, 0.0] 231*da0073e9SAndroid Build Coastguard Workerrecall = [1.0, 0.8533334, 0.28, 0.0666667, 0.0] 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Workerclass TestTensorBoardWriter(BaseTestCase): 234*da0073e9SAndroid Build Coastguard Worker def test_writer(self): 235*da0073e9SAndroid Build Coastguard Worker with self.createSummaryWriter() as writer: 236*da0073e9SAndroid Build Coastguard Worker sample_rate = 44100 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Worker n_iter = 0 239*da0073e9SAndroid Build Coastguard Worker writer.add_hparams( 240*da0073e9SAndroid Build Coastguard Worker {'lr': 0.1, 'bsize': 1}, 241*da0073e9SAndroid Build Coastguard Worker {'hparam/accuracy': 10, 'hparam/loss': 10} 242*da0073e9SAndroid Build Coastguard Worker ) 243*da0073e9SAndroid Build Coastguard Worker writer.add_scalar('data/scalar_systemtime', 0.1, n_iter) 244*da0073e9SAndroid Build Coastguard Worker writer.add_scalar('data/scalar_customtime', 0.2, n_iter, walltime=n_iter) 245*da0073e9SAndroid Build Coastguard Worker writer.add_scalar('data/new_style', 0.2, n_iter, new_style=True) 246*da0073e9SAndroid Build Coastguard Worker writer.add_scalars('data/scalar_group', { 247*da0073e9SAndroid Build Coastguard Worker "xsinx": n_iter * np.sin(n_iter), 248*da0073e9SAndroid Build Coastguard Worker "xcosx": n_iter * np.cos(n_iter), 249*da0073e9SAndroid Build Coastguard Worker "arctanx": np.arctan(n_iter) 250*da0073e9SAndroid Build Coastguard Worker }, n_iter) 251*da0073e9SAndroid Build Coastguard Worker x = np.zeros((32, 3, 64, 64)) # output from network 252*da0073e9SAndroid Build Coastguard Worker writer.add_images('Image', x, n_iter) # Tensor 253*da0073e9SAndroid Build Coastguard Worker writer.add_image_with_boxes('imagebox', 254*da0073e9SAndroid Build Coastguard Worker np.zeros((3, 64, 64)), 255*da0073e9SAndroid Build Coastguard Worker np.array([[10, 10, 40, 40], [40, 40, 60, 60]]), 256*da0073e9SAndroid Build Coastguard Worker n_iter) 257*da0073e9SAndroid Build Coastguard Worker x = np.zeros(sample_rate * 2) 258*da0073e9SAndroid Build Coastguard Worker 259*da0073e9SAndroid Build Coastguard Worker writer.add_audio('myAudio', x, n_iter) 260*da0073e9SAndroid Build Coastguard Worker writer.add_video('myVideo', np.random.rand(16, 48, 1, 28, 28).astype(np.float32), n_iter) 261*da0073e9SAndroid Build Coastguard Worker writer.add_text('Text', 'text logged at step:' + str(n_iter), n_iter) 262*da0073e9SAndroid Build Coastguard Worker writer.add_text('markdown Text', '''a|b\n-|-\nc|d''', n_iter) 263*da0073e9SAndroid Build Coastguard Worker writer.add_histogram('hist', np.random.rand(100, 100), n_iter) 264*da0073e9SAndroid Build Coastguard Worker writer.add_pr_curve('xoxo', np.random.randint(2, size=100), np.random.rand( 265*da0073e9SAndroid Build Coastguard Worker 100), n_iter) # needs tensorboard 0.4RC or later 266*da0073e9SAndroid Build Coastguard Worker writer.add_pr_curve_raw('prcurve with raw data', true_positive_counts, 267*da0073e9SAndroid Build Coastguard Worker false_positive_counts, 268*da0073e9SAndroid Build Coastguard Worker true_negative_counts, 269*da0073e9SAndroid Build Coastguard Worker false_negative_counts, 270*da0073e9SAndroid Build Coastguard Worker precision, 271*da0073e9SAndroid Build Coastguard Worker recall, n_iter) 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker v = np.array([[[1, 1, 1], [-1, -1, 1], [1, -1, -1], [-1, 1, -1]]], dtype=float) 274*da0073e9SAndroid Build Coastguard Worker c = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 0, 255]]], dtype=int) 275*da0073e9SAndroid Build Coastguard Worker f = np.array([[[0, 2, 3], [0, 3, 1], [0, 1, 2], [1, 3, 2]]], dtype=int) 276*da0073e9SAndroid Build Coastguard Worker writer.add_mesh('my_mesh', vertices=v, colors=c, faces=f) 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Workerclass TestTensorBoardSummaryWriter(BaseTestCase): 279*da0073e9SAndroid Build Coastguard Worker def test_summary_writer_ctx(self): 280*da0073e9SAndroid Build Coastguard Worker # after using a SummaryWriter as a ctx it should be closed 281*da0073e9SAndroid Build Coastguard Worker with self.createSummaryWriter() as writer: 282*da0073e9SAndroid Build Coastguard Worker writer.add_scalar('test', 1) 283*da0073e9SAndroid Build Coastguard Worker self.assertIs(writer.file_writer, None) 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker def test_summary_writer_close(self): 286*da0073e9SAndroid Build Coastguard Worker # Opening and closing SummaryWriter a lot should not run into 287*da0073e9SAndroid Build Coastguard Worker # OSError: [Errno 24] Too many open files 288*da0073e9SAndroid Build Coastguard Worker passed = True 289*da0073e9SAndroid Build Coastguard Worker try: 290*da0073e9SAndroid Build Coastguard Worker writer = self.createSummaryWriter() 291*da0073e9SAndroid Build Coastguard Worker writer.close() 292*da0073e9SAndroid Build Coastguard Worker except OSError: 293*da0073e9SAndroid Build Coastguard Worker passed = False 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Worker self.assertTrue(passed) 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker def test_pathlib(self): 298*da0073e9SAndroid Build Coastguard Worker with tempfile.TemporaryDirectory(prefix="test_tensorboard_pathlib") as d: 299*da0073e9SAndroid Build Coastguard Worker p = Path(d) 300*da0073e9SAndroid Build Coastguard Worker with SummaryWriter(p) as writer: 301*da0073e9SAndroid Build Coastguard Worker writer.add_scalar('test', 1) 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Workerclass TestTensorBoardEmbedding(BaseTestCase): 304*da0073e9SAndroid Build Coastguard Worker def test_embedding(self): 305*da0073e9SAndroid Build Coastguard Worker w = self.createSummaryWriter() 306*da0073e9SAndroid Build Coastguard Worker all_features = torch.tensor([[1., 2., 3.], [5., 4., 1.], [3., 7., 7.]]) 307*da0073e9SAndroid Build Coastguard Worker all_labels = torch.tensor([33., 44., 55.]) 308*da0073e9SAndroid Build Coastguard Worker all_images = torch.zeros(3, 3, 5, 5) 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker w.add_embedding(all_features, 311*da0073e9SAndroid Build Coastguard Worker metadata=all_labels, 312*da0073e9SAndroid Build Coastguard Worker label_img=all_images, 313*da0073e9SAndroid Build Coastguard Worker global_step=2) 314*da0073e9SAndroid Build Coastguard Worker 315*da0073e9SAndroid Build Coastguard Worker dataset_label = ['test'] * 2 + ['train'] * 2 316*da0073e9SAndroid Build Coastguard Worker all_labels = list(zip(all_labels, dataset_label)) 317*da0073e9SAndroid Build Coastguard Worker w.add_embedding(all_features, 318*da0073e9SAndroid Build Coastguard Worker metadata=all_labels, 319*da0073e9SAndroid Build Coastguard Worker label_img=all_images, 320*da0073e9SAndroid Build Coastguard Worker metadata_header=['digit', 'dataset'], 321*da0073e9SAndroid Build Coastguard Worker global_step=2) 322*da0073e9SAndroid Build Coastguard Worker # assert... 323*da0073e9SAndroid Build Coastguard Worker 324*da0073e9SAndroid Build Coastguard Worker def test_embedding_64(self): 325*da0073e9SAndroid Build Coastguard Worker w = self.createSummaryWriter() 326*da0073e9SAndroid Build Coastguard Worker all_features = torch.tensor([[1., 2., 3.], [5., 4., 1.], [3., 7., 7.]]) 327*da0073e9SAndroid Build Coastguard Worker all_labels = torch.tensor([33., 44., 55.]) 328*da0073e9SAndroid Build Coastguard Worker all_images = torch.zeros((3, 3, 5, 5), dtype=torch.float64) 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard Worker w.add_embedding(all_features, 331*da0073e9SAndroid Build Coastguard Worker metadata=all_labels, 332*da0073e9SAndroid Build Coastguard Worker label_img=all_images, 333*da0073e9SAndroid Build Coastguard Worker global_step=2) 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker dataset_label = ['test'] * 2 + ['train'] * 2 336*da0073e9SAndroid Build Coastguard Worker all_labels = list(zip(all_labels, dataset_label)) 337*da0073e9SAndroid Build Coastguard Worker w.add_embedding(all_features, 338*da0073e9SAndroid Build Coastguard Worker metadata=all_labels, 339*da0073e9SAndroid Build Coastguard Worker label_img=all_images, 340*da0073e9SAndroid Build Coastguard Worker metadata_header=['digit', 'dataset'], 341*da0073e9SAndroid Build Coastguard Worker global_step=2) 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard Workerclass TestTensorBoardSummary(BaseTestCase): 344*da0073e9SAndroid Build Coastguard Worker def test_uint8_image(self): 345*da0073e9SAndroid Build Coastguard Worker ''' 346*da0073e9SAndroid Build Coastguard Worker Tests that uint8 image (pixel values in [0, 255]) is not changed 347*da0073e9SAndroid Build Coastguard Worker ''' 348*da0073e9SAndroid Build Coastguard Worker test_image = np.random.randint(0, 256, size=(3, 32, 32), dtype=np.uint8) 349*da0073e9SAndroid Build Coastguard Worker scale_factor = summary._calc_scale_factor(test_image) 350*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scale_factor, 1, msg='Values are already in [0, 255], scale factor should be 1') 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker def test_float32_image(self): 353*da0073e9SAndroid Build Coastguard Worker ''' 354*da0073e9SAndroid Build Coastguard Worker Tests that float32 image (pixel values in [0, 1]) are scaled correctly 355*da0073e9SAndroid Build Coastguard Worker to [0, 255] 356*da0073e9SAndroid Build Coastguard Worker ''' 357*da0073e9SAndroid Build Coastguard Worker test_image = np.random.rand(3, 32, 32).astype(np.float32) 358*da0073e9SAndroid Build Coastguard Worker scale_factor = summary._calc_scale_factor(test_image) 359*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scale_factor, 255, msg='Values are in [0, 1], scale factor should be 255') 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker def test_list_input(self): 362*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(Exception) as e_info: 363*da0073e9SAndroid Build Coastguard Worker summary.histogram('dummy', [1, 3, 4, 5, 6], 'tensorflow') 364*da0073e9SAndroid Build Coastguard Worker 365*da0073e9SAndroid Build Coastguard Worker def test_empty_input(self): 366*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(Exception) as e_info: 367*da0073e9SAndroid Build Coastguard Worker summary.histogram('dummy', np.ndarray(0), 'tensorflow') 368*da0073e9SAndroid Build Coastguard Worker 369*da0073e9SAndroid Build Coastguard Worker def test_image_with_boxes(self): 370*da0073e9SAndroid Build Coastguard Worker self.assertTrue(compare_image_proto(summary.image_boxes('dummy', 371*da0073e9SAndroid Build Coastguard Worker tensor_N(shape=(3, 32, 32)), 372*da0073e9SAndroid Build Coastguard Worker np.array([[10, 10, 40, 40]])), 373*da0073e9SAndroid Build Coastguard Worker self)) 374*da0073e9SAndroid Build Coastguard Worker 375*da0073e9SAndroid Build Coastguard Worker def test_image_with_one_channel(self): 376*da0073e9SAndroid Build Coastguard Worker self.assertTrue(compare_image_proto( 377*da0073e9SAndroid Build Coastguard Worker summary.image('dummy', 378*da0073e9SAndroid Build Coastguard Worker tensor_N(shape=(1, 8, 8)), 379*da0073e9SAndroid Build Coastguard Worker dataformats='CHW'), 380*da0073e9SAndroid Build Coastguard Worker self)) # noqa: E131 381*da0073e9SAndroid Build Coastguard Worker 382*da0073e9SAndroid Build Coastguard Worker def test_image_with_one_channel_batched(self): 383*da0073e9SAndroid Build Coastguard Worker self.assertTrue(compare_image_proto( 384*da0073e9SAndroid Build Coastguard Worker summary.image('dummy', 385*da0073e9SAndroid Build Coastguard Worker tensor_N(shape=(2, 1, 8, 8)), 386*da0073e9SAndroid Build Coastguard Worker dataformats='NCHW'), 387*da0073e9SAndroid Build Coastguard Worker self)) # noqa: E131 388*da0073e9SAndroid Build Coastguard Worker 389*da0073e9SAndroid Build Coastguard Worker def test_image_with_3_channel_batched(self): 390*da0073e9SAndroid Build Coastguard Worker self.assertTrue(compare_image_proto( 391*da0073e9SAndroid Build Coastguard Worker summary.image('dummy', 392*da0073e9SAndroid Build Coastguard Worker tensor_N(shape=(2, 3, 8, 8)), 393*da0073e9SAndroid Build Coastguard Worker dataformats='NCHW'), 394*da0073e9SAndroid Build Coastguard Worker self)) # noqa: E131 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker def test_image_without_channel(self): 397*da0073e9SAndroid Build Coastguard Worker self.assertTrue(compare_image_proto( 398*da0073e9SAndroid Build Coastguard Worker summary.image('dummy', 399*da0073e9SAndroid Build Coastguard Worker tensor_N(shape=(8, 8)), 400*da0073e9SAndroid Build Coastguard Worker dataformats='HW'), 401*da0073e9SAndroid Build Coastguard Worker self)) # noqa: E131 402*da0073e9SAndroid Build Coastguard Worker 403*da0073e9SAndroid Build Coastguard Worker def test_video(self): 404*da0073e9SAndroid Build Coastguard Worker try: 405*da0073e9SAndroid Build Coastguard Worker import moviepy # noqa: F401 406*da0073e9SAndroid Build Coastguard Worker except ImportError: 407*da0073e9SAndroid Build Coastguard Worker return 408*da0073e9SAndroid Build Coastguard Worker self.assertTrue(compare_proto(summary.video('dummy', tensor_N(shape=(4, 3, 1, 8, 8))), self)) 409*da0073e9SAndroid Build Coastguard Worker summary.video('dummy', np.random.rand(16, 48, 1, 28, 28)) 410*da0073e9SAndroid Build Coastguard Worker summary.video('dummy', np.random.rand(20, 7, 1, 8, 8)) 411*da0073e9SAndroid Build Coastguard Worker 412*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ") 413*da0073e9SAndroid Build Coastguard Worker def test_audio(self): 414*da0073e9SAndroid Build Coastguard Worker self.assertTrue(compare_proto(summary.audio('dummy', tensor_N(shape=(42,))), self)) 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ") 417*da0073e9SAndroid Build Coastguard Worker def test_text(self): 418*da0073e9SAndroid Build Coastguard Worker self.assertTrue(compare_proto(summary.text('dummy', 'text 123'), self)) 419*da0073e9SAndroid Build Coastguard Worker 420*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ") 421*da0073e9SAndroid Build Coastguard Worker def test_histogram_auto(self): 422*da0073e9SAndroid Build Coastguard Worker self.assertTrue(compare_proto(summary.histogram('dummy', tensor_N(shape=(1024,)), bins='auto', max_bins=5), self)) 423*da0073e9SAndroid Build Coastguard Worker 424*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ") 425*da0073e9SAndroid Build Coastguard Worker def test_histogram_fd(self): 426*da0073e9SAndroid Build Coastguard Worker self.assertTrue(compare_proto(summary.histogram('dummy', tensor_N(shape=(1024,)), bins='fd', max_bins=5), self)) 427*da0073e9SAndroid Build Coastguard Worker 428*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ") 429*da0073e9SAndroid Build Coastguard Worker def test_histogram_doane(self): 430*da0073e9SAndroid Build Coastguard Worker self.assertTrue(compare_proto(summary.histogram('dummy', tensor_N(shape=(1024,)), bins='doane', max_bins=5), self)) 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker def test_custom_scalars(self): 433*da0073e9SAndroid Build Coastguard Worker layout = { 434*da0073e9SAndroid Build Coastguard Worker 'Taiwan': { 435*da0073e9SAndroid Build Coastguard Worker 'twse': ['Multiline', ['twse/0050', 'twse/2330']] 436*da0073e9SAndroid Build Coastguard Worker }, 437*da0073e9SAndroid Build Coastguard Worker 'USA': { 438*da0073e9SAndroid Build Coastguard Worker 'dow': ['Margin', ['dow/aaa', 'dow/bbb', 'dow/ccc']], 439*da0073e9SAndroid Build Coastguard Worker 'nasdaq': ['Margin', ['nasdaq/aaa', 'nasdaq/bbb', 'nasdaq/ccc']] 440*da0073e9SAndroid Build Coastguard Worker } 441*da0073e9SAndroid Build Coastguard Worker } 442*da0073e9SAndroid Build Coastguard Worker summary.custom_scalars(layout) # only smoke test. Because protobuf in python2/3 serialize dictionary differently. 443*da0073e9SAndroid Build Coastguard Worker 444*da0073e9SAndroid Build Coastguard Worker 445*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ") 446*da0073e9SAndroid Build Coastguard Worker def test_mesh(self): 447*da0073e9SAndroid Build Coastguard Worker v = np.array([[[1, 1, 1], [-1, -1, 1], [1, -1, -1], [-1, 1, -1]]], dtype=float) 448*da0073e9SAndroid Build Coastguard Worker c = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 0, 255]]], dtype=int) 449*da0073e9SAndroid Build Coastguard Worker f = np.array([[[0, 2, 3], [0, 3, 1], [0, 1, 2], [1, 3, 2]]], dtype=int) 450*da0073e9SAndroid Build Coastguard Worker mesh = summary.mesh('my_mesh', vertices=v, colors=c, faces=f, config_dict=None) 451*da0073e9SAndroid Build Coastguard Worker self.assertTrue(compare_proto(mesh, self)) 452*da0073e9SAndroid Build Coastguard Worker 453*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ") 454*da0073e9SAndroid Build Coastguard Worker def test_scalar_new_style(self): 455*da0073e9SAndroid Build Coastguard Worker scalar = summary.scalar('test_scalar', 1.0, new_style=True) 456*da0073e9SAndroid Build Coastguard Worker self.assertTrue(compare_proto(scalar, self)) 457*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 458*da0073e9SAndroid Build Coastguard Worker summary.scalar('test_scalar2', torch.Tensor([1, 2, 3]), new_style=True) 459*da0073e9SAndroid Build Coastguard Worker 460*da0073e9SAndroid Build Coastguard Worker 461*da0073e9SAndroid Build Coastguard Workerdef remove_whitespace(string): 462*da0073e9SAndroid Build Coastguard Worker return string.replace(' ', '').replace('\t', '').replace('\n', '') 463*da0073e9SAndroid Build Coastguard Worker 464*da0073e9SAndroid Build Coastguard Workerdef get_expected_file(function_ptr): 465*da0073e9SAndroid Build Coastguard Worker module_id = function_ptr.__class__.__module__ 466*da0073e9SAndroid Build Coastguard Worker test_file = sys.modules[module_id].__file__ 467*da0073e9SAndroid Build Coastguard Worker # Look for the .py file (since __file__ could be pyc). 468*da0073e9SAndroid Build Coastguard Worker test_file = ".".join(test_file.split('.')[:-1]) + '.py' 469*da0073e9SAndroid Build Coastguard Worker 470*da0073e9SAndroid Build Coastguard Worker # Use realpath to follow symlinks appropriately. 471*da0073e9SAndroid Build Coastguard Worker test_dir = os.path.dirname(os.path.realpath(test_file)) 472*da0073e9SAndroid Build Coastguard Worker functionName = function_ptr.id().split('.')[-1] 473*da0073e9SAndroid Build Coastguard Worker return os.path.join(test_dir, 474*da0073e9SAndroid Build Coastguard Worker "expect", 475*da0073e9SAndroid Build Coastguard Worker 'TestTensorBoard.' + functionName + ".expect") 476*da0073e9SAndroid Build Coastguard Worker 477*da0073e9SAndroid Build Coastguard Workerdef read_expected_content(function_ptr): 478*da0073e9SAndroid Build Coastguard Worker expected_file = get_expected_file(function_ptr) 479*da0073e9SAndroid Build Coastguard Worker assert os.path.exists(expected_file), expected_file 480*da0073e9SAndroid Build Coastguard Worker with open(expected_file) as f: 481*da0073e9SAndroid Build Coastguard Worker return f.read() 482*da0073e9SAndroid Build Coastguard Worker 483*da0073e9SAndroid Build Coastguard Workerdef compare_image_proto(actual_proto, function_ptr): 484*da0073e9SAndroid Build Coastguard Worker if expecttest.ACCEPT: 485*da0073e9SAndroid Build Coastguard Worker expected_file = get_expected_file(function_ptr) 486*da0073e9SAndroid Build Coastguard Worker with open(expected_file, 'w') as f: 487*da0073e9SAndroid Build Coastguard Worker f.write(text_format.MessageToString(actual_proto)) 488*da0073e9SAndroid Build Coastguard Worker return True 489*da0073e9SAndroid Build Coastguard Worker expected_str = read_expected_content(function_ptr) 490*da0073e9SAndroid Build Coastguard Worker expected_proto = Summary() 491*da0073e9SAndroid Build Coastguard Worker text_format.Parse(expected_str, expected_proto) 492*da0073e9SAndroid Build Coastguard Worker 493*da0073e9SAndroid Build Coastguard Worker [actual, expected] = [actual_proto.value[0], expected_proto.value[0]] 494*da0073e9SAndroid Build Coastguard Worker actual_img = Image.open(io.BytesIO(actual.image.encoded_image_string)) 495*da0073e9SAndroid Build Coastguard Worker expected_img = Image.open(io.BytesIO(expected.image.encoded_image_string)) 496*da0073e9SAndroid Build Coastguard Worker 497*da0073e9SAndroid Build Coastguard Worker return ( 498*da0073e9SAndroid Build Coastguard Worker actual.tag == expected.tag and 499*da0073e9SAndroid Build Coastguard Worker actual.image.height == expected.image.height and 500*da0073e9SAndroid Build Coastguard Worker actual.image.width == expected.image.width and 501*da0073e9SAndroid Build Coastguard Worker actual.image.colorspace == expected.image.colorspace and 502*da0073e9SAndroid Build Coastguard Worker actual_img == expected_img 503*da0073e9SAndroid Build Coastguard Worker ) 504*da0073e9SAndroid Build Coastguard Worker 505*da0073e9SAndroid Build Coastguard Workerdef compare_proto(str_to_compare, function_ptr): 506*da0073e9SAndroid Build Coastguard Worker if expecttest.ACCEPT: 507*da0073e9SAndroid Build Coastguard Worker write_proto(str_to_compare, function_ptr) 508*da0073e9SAndroid Build Coastguard Worker return True 509*da0073e9SAndroid Build Coastguard Worker expected = read_expected_content(function_ptr) 510*da0073e9SAndroid Build Coastguard Worker str_to_compare = str(str_to_compare) 511*da0073e9SAndroid Build Coastguard Worker return remove_whitespace(str_to_compare) == remove_whitespace(expected) 512*da0073e9SAndroid Build Coastguard Worker 513*da0073e9SAndroid Build Coastguard Workerdef write_proto(str_to_compare, function_ptr): 514*da0073e9SAndroid Build Coastguard Worker expected_file = get_expected_file(function_ptr) 515*da0073e9SAndroid Build Coastguard Worker with open(expected_file, 'w') as f: 516*da0073e9SAndroid Build Coastguard Worker f.write(str(str_to_compare)) 517*da0073e9SAndroid Build Coastguard Worker 518*da0073e9SAndroid Build Coastguard Workerclass TestTensorBoardPytorchGraph(BaseTestCase): 519*da0073e9SAndroid Build Coastguard Worker def test_pytorch_graph(self): 520*da0073e9SAndroid Build Coastguard Worker dummy_input = (torch.zeros(1, 3),) 521*da0073e9SAndroid Build Coastguard Worker 522*da0073e9SAndroid Build Coastguard Worker class myLinear(torch.nn.Module): 523*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 524*da0073e9SAndroid Build Coastguard Worker super().__init__() 525*da0073e9SAndroid Build Coastguard Worker self.l = torch.nn.Linear(3, 5) 526*da0073e9SAndroid Build Coastguard Worker 527*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 528*da0073e9SAndroid Build Coastguard Worker return self.l(x) 529*da0073e9SAndroid Build Coastguard Worker 530*da0073e9SAndroid Build Coastguard Worker with self.createSummaryWriter() as w: 531*da0073e9SAndroid Build Coastguard Worker w.add_graph(myLinear(), dummy_input) 532*da0073e9SAndroid Build Coastguard Worker 533*da0073e9SAndroid Build Coastguard Worker actual_proto, _ = graph(myLinear(), dummy_input) 534*da0073e9SAndroid Build Coastguard Worker 535*da0073e9SAndroid Build Coastguard Worker expected_str = read_expected_content(self) 536*da0073e9SAndroid Build Coastguard Worker expected_proto = GraphDef() 537*da0073e9SAndroid Build Coastguard Worker text_format.Parse(expected_str, expected_proto) 538*da0073e9SAndroid Build Coastguard Worker 539*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(expected_proto.node), len(actual_proto.node)) 540*da0073e9SAndroid Build Coastguard Worker for i in range(len(expected_proto.node)): 541*da0073e9SAndroid Build Coastguard Worker expected_node = expected_proto.node[i] 542*da0073e9SAndroid Build Coastguard Worker actual_node = actual_proto.node[i] 543*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_node.name, actual_node.name) 544*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_node.op, actual_node.op) 545*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_node.input, actual_node.input) 546*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_node.device, actual_node.device) 547*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 548*da0073e9SAndroid Build Coastguard Worker sorted(expected_node.attr.keys()), sorted(actual_node.attr.keys())) 549*da0073e9SAndroid Build Coastguard Worker 550*da0073e9SAndroid Build Coastguard Worker def test_nested_nn_squential(self): 551*da0073e9SAndroid Build Coastguard Worker 552*da0073e9SAndroid Build Coastguard Worker dummy_input = torch.randn(2, 3) 553*da0073e9SAndroid Build Coastguard Worker 554*da0073e9SAndroid Build Coastguard Worker class InnerNNSquential(torch.nn.Module): 555*da0073e9SAndroid Build Coastguard Worker def __init__(self, dim1, dim2): 556*da0073e9SAndroid Build Coastguard Worker super().__init__() 557*da0073e9SAndroid Build Coastguard Worker self.inner_nn_squential = torch.nn.Sequential( 558*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(dim1, dim2), 559*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(dim2, dim1), 560*da0073e9SAndroid Build Coastguard Worker ) 561*da0073e9SAndroid Build Coastguard Worker 562*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 563*da0073e9SAndroid Build Coastguard Worker x = self.inner_nn_squential(x) 564*da0073e9SAndroid Build Coastguard Worker return x 565*da0073e9SAndroid Build Coastguard Worker 566*da0073e9SAndroid Build Coastguard Worker class OuterNNSquential(torch.nn.Module): 567*da0073e9SAndroid Build Coastguard Worker def __init__(self, dim1=3, dim2=4, depth=2): 568*da0073e9SAndroid Build Coastguard Worker super().__init__() 569*da0073e9SAndroid Build Coastguard Worker layers = [] 570*da0073e9SAndroid Build Coastguard Worker for _ in range(depth): 571*da0073e9SAndroid Build Coastguard Worker layers.append(InnerNNSquential(dim1, dim2)) 572*da0073e9SAndroid Build Coastguard Worker self.outer_nn_squential = torch.nn.Sequential(*layers) 573*da0073e9SAndroid Build Coastguard Worker 574*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 575*da0073e9SAndroid Build Coastguard Worker x = self.outer_nn_squential(x) 576*da0073e9SAndroid Build Coastguard Worker return x 577*da0073e9SAndroid Build Coastguard Worker 578*da0073e9SAndroid Build Coastguard Worker with self.createSummaryWriter() as w: 579*da0073e9SAndroid Build Coastguard Worker w.add_graph(OuterNNSquential(), dummy_input) 580*da0073e9SAndroid Build Coastguard Worker 581*da0073e9SAndroid Build Coastguard Worker actual_proto, _ = graph(OuterNNSquential(), dummy_input) 582*da0073e9SAndroid Build Coastguard Worker 583*da0073e9SAndroid Build Coastguard Worker expected_str = read_expected_content(self) 584*da0073e9SAndroid Build Coastguard Worker expected_proto = GraphDef() 585*da0073e9SAndroid Build Coastguard Worker text_format.Parse(expected_str, expected_proto) 586*da0073e9SAndroid Build Coastguard Worker 587*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(expected_proto.node), len(actual_proto.node)) 588*da0073e9SAndroid Build Coastguard Worker for i in range(len(expected_proto.node)): 589*da0073e9SAndroid Build Coastguard Worker expected_node = expected_proto.node[i] 590*da0073e9SAndroid Build Coastguard Worker actual_node = actual_proto.node[i] 591*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_node.name, actual_node.name) 592*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_node.op, actual_node.op) 593*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_node.input, actual_node.input) 594*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_node.device, actual_node.device) 595*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 596*da0073e9SAndroid Build Coastguard Worker sorted(expected_node.attr.keys()), sorted(actual_node.attr.keys())) 597*da0073e9SAndroid Build Coastguard Worker 598*da0073e9SAndroid Build Coastguard Worker def test_pytorch_graph_dict_input(self): 599*da0073e9SAndroid Build Coastguard Worker class Model(torch.nn.Module): 600*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 601*da0073e9SAndroid Build Coastguard Worker super().__init__() 602*da0073e9SAndroid Build Coastguard Worker self.l = torch.nn.Linear(3, 5) 603*da0073e9SAndroid Build Coastguard Worker 604*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 605*da0073e9SAndroid Build Coastguard Worker return self.l(x) 606*da0073e9SAndroid Build Coastguard Worker 607*da0073e9SAndroid Build Coastguard Worker class ModelDict(torch.nn.Module): 608*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 609*da0073e9SAndroid Build Coastguard Worker super().__init__() 610*da0073e9SAndroid Build Coastguard Worker self.l = torch.nn.Linear(3, 5) 611*da0073e9SAndroid Build Coastguard Worker 612*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 613*da0073e9SAndroid Build Coastguard Worker return {"out": self.l(x)} 614*da0073e9SAndroid Build Coastguard Worker 615*da0073e9SAndroid Build Coastguard Worker 616*da0073e9SAndroid Build Coastguard Worker dummy_input = torch.zeros(1, 3) 617*da0073e9SAndroid Build Coastguard Worker 618*da0073e9SAndroid Build Coastguard Worker with self.createSummaryWriter() as w: 619*da0073e9SAndroid Build Coastguard Worker w.add_graph(Model(), dummy_input) 620*da0073e9SAndroid Build Coastguard Worker 621*da0073e9SAndroid Build Coastguard Worker with self.createSummaryWriter() as w: 622*da0073e9SAndroid Build Coastguard Worker w.add_graph(Model(), dummy_input, use_strict_trace=True) 623*da0073e9SAndroid Build Coastguard Worker 624*da0073e9SAndroid Build Coastguard Worker # expect error: Encountering a dict at the output of the tracer... 625*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 626*da0073e9SAndroid Build Coastguard Worker with self.createSummaryWriter() as w: 627*da0073e9SAndroid Build Coastguard Worker w.add_graph(ModelDict(), dummy_input, use_strict_trace=True) 628*da0073e9SAndroid Build Coastguard Worker 629*da0073e9SAndroid Build Coastguard Worker with self.createSummaryWriter() as w: 630*da0073e9SAndroid Build Coastguard Worker w.add_graph(ModelDict(), dummy_input, use_strict_trace=False) 631*da0073e9SAndroid Build Coastguard Worker 632*da0073e9SAndroid Build Coastguard Worker 633*da0073e9SAndroid Build Coastguard Worker def test_mlp_graph(self): 634*da0073e9SAndroid Build Coastguard Worker dummy_input = (torch.zeros(2, 1, 28, 28),) 635*da0073e9SAndroid Build Coastguard Worker 636*da0073e9SAndroid Build Coastguard Worker # This MLP class with the above input is expected 637*da0073e9SAndroid Build Coastguard Worker # to fail JIT optimizations as seen at 638*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/18903 639*da0073e9SAndroid Build Coastguard Worker # 640*da0073e9SAndroid Build Coastguard Worker # However, it should not raise an error during 641*da0073e9SAndroid Build Coastguard Worker # the add_graph call and still continue. 642*da0073e9SAndroid Build Coastguard Worker class myMLP(torch.nn.Module): 643*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 644*da0073e9SAndroid Build Coastguard Worker super().__init__() 645*da0073e9SAndroid Build Coastguard Worker self.input_len = 1 * 28 * 28 646*da0073e9SAndroid Build Coastguard Worker self.fc1 = torch.nn.Linear(self.input_len, 1200) 647*da0073e9SAndroid Build Coastguard Worker self.fc2 = torch.nn.Linear(1200, 1200) 648*da0073e9SAndroid Build Coastguard Worker self.fc3 = torch.nn.Linear(1200, 10) 649*da0073e9SAndroid Build Coastguard Worker 650*da0073e9SAndroid Build Coastguard Worker def forward(self, x, update_batch_stats=True): 651*da0073e9SAndroid Build Coastguard Worker h = torch.nn.functional.relu( 652*da0073e9SAndroid Build Coastguard Worker self.fc1(x.view(-1, self.input_len))) 653*da0073e9SAndroid Build Coastguard Worker h = self.fc2(h) 654*da0073e9SAndroid Build Coastguard Worker h = torch.nn.functional.relu(h) 655*da0073e9SAndroid Build Coastguard Worker h = self.fc3(h) 656*da0073e9SAndroid Build Coastguard Worker return h 657*da0073e9SAndroid Build Coastguard Worker 658*da0073e9SAndroid Build Coastguard Worker with self.createSummaryWriter() as w: 659*da0073e9SAndroid Build Coastguard Worker w.add_graph(myMLP(), dummy_input) 660*da0073e9SAndroid Build Coastguard Worker 661*da0073e9SAndroid Build Coastguard Worker def test_wrong_input_size(self): 662*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError) as e_info: 663*da0073e9SAndroid Build Coastguard Worker dummy_input = torch.rand(1, 9) 664*da0073e9SAndroid Build Coastguard Worker model = torch.nn.Linear(3, 5) 665*da0073e9SAndroid Build Coastguard Worker with self.createSummaryWriter() as w: 666*da0073e9SAndroid Build Coastguard Worker w.add_graph(model, dummy_input) # error 667*da0073e9SAndroid Build Coastguard Worker 668*da0073e9SAndroid Build Coastguard Worker @skipIfNoTorchVision 669*da0073e9SAndroid Build Coastguard Worker def test_torchvision_smoke(self): 670*da0073e9SAndroid Build Coastguard Worker model_input_shapes = { 671*da0073e9SAndroid Build Coastguard Worker 'alexnet': (2, 3, 224, 224), 672*da0073e9SAndroid Build Coastguard Worker 'resnet34': (2, 3, 224, 224), 673*da0073e9SAndroid Build Coastguard Worker 'resnet152': (2, 3, 224, 224), 674*da0073e9SAndroid Build Coastguard Worker 'densenet121': (2, 3, 224, 224), 675*da0073e9SAndroid Build Coastguard Worker 'vgg16': (2, 3, 224, 224), 676*da0073e9SAndroid Build Coastguard Worker 'vgg19': (2, 3, 224, 224), 677*da0073e9SAndroid Build Coastguard Worker 'vgg16_bn': (2, 3, 224, 224), 678*da0073e9SAndroid Build Coastguard Worker 'vgg19_bn': (2, 3, 224, 224), 679*da0073e9SAndroid Build Coastguard Worker 'mobilenet_v2': (2, 3, 224, 224), 680*da0073e9SAndroid Build Coastguard Worker } 681*da0073e9SAndroid Build Coastguard Worker for model_name, input_shape in model_input_shapes.items(): 682*da0073e9SAndroid Build Coastguard Worker with self.createSummaryWriter() as w: 683*da0073e9SAndroid Build Coastguard Worker model = getattr(torchvision.models, model_name)() 684*da0073e9SAndroid Build Coastguard Worker w.add_graph(model, torch.zeros(input_shape)) 685*da0073e9SAndroid Build Coastguard Worker 686*da0073e9SAndroid Build Coastguard Workerclass TestTensorBoardFigure(BaseTestCase): 687*da0073e9SAndroid Build Coastguard Worker @skipIfNoMatplotlib 688*da0073e9SAndroid Build Coastguard Worker def test_figure(self): 689*da0073e9SAndroid Build Coastguard Worker writer = self.createSummaryWriter() 690*da0073e9SAndroid Build Coastguard Worker 691*da0073e9SAndroid Build Coastguard Worker figure, axes = plt.figure(), plt.gca() 692*da0073e9SAndroid Build Coastguard Worker circle1 = plt.Circle((0.2, 0.5), 0.2, color='r') 693*da0073e9SAndroid Build Coastguard Worker circle2 = plt.Circle((0.8, 0.5), 0.2, color='g') 694*da0073e9SAndroid Build Coastguard Worker axes.add_patch(circle1) 695*da0073e9SAndroid Build Coastguard Worker axes.add_patch(circle2) 696*da0073e9SAndroid Build Coastguard Worker plt.axis('scaled') 697*da0073e9SAndroid Build Coastguard Worker plt.tight_layout() 698*da0073e9SAndroid Build Coastguard Worker 699*da0073e9SAndroid Build Coastguard Worker writer.add_figure("add_figure/figure", figure, 0, close=False) 700*da0073e9SAndroid Build Coastguard Worker self.assertTrue(plt.fignum_exists(figure.number)) 701*da0073e9SAndroid Build Coastguard Worker 702*da0073e9SAndroid Build Coastguard Worker writer.add_figure("add_figure/figure", figure, 1) 703*da0073e9SAndroid Build Coastguard Worker if matplotlib.__version__ != '3.3.0': 704*da0073e9SAndroid Build Coastguard Worker self.assertFalse(plt.fignum_exists(figure.number)) 705*da0073e9SAndroid Build Coastguard Worker else: 706*da0073e9SAndroid Build Coastguard Worker print("Skipping fignum_exists, see https://github.com/matplotlib/matplotlib/issues/18163") 707*da0073e9SAndroid Build Coastguard Worker 708*da0073e9SAndroid Build Coastguard Worker writer.close() 709*da0073e9SAndroid Build Coastguard Worker 710*da0073e9SAndroid Build Coastguard Worker @skipIfNoMatplotlib 711*da0073e9SAndroid Build Coastguard Worker def test_figure_list(self): 712*da0073e9SAndroid Build Coastguard Worker writer = self.createSummaryWriter() 713*da0073e9SAndroid Build Coastguard Worker 714*da0073e9SAndroid Build Coastguard Worker figures = [] 715*da0073e9SAndroid Build Coastguard Worker for i in range(5): 716*da0073e9SAndroid Build Coastguard Worker figure = plt.figure() 717*da0073e9SAndroid Build Coastguard Worker plt.plot([i * 1, i * 2, i * 3], label="Plot " + str(i)) 718*da0073e9SAndroid Build Coastguard Worker plt.xlabel("X") 719*da0073e9SAndroid Build Coastguard Worker plt.xlabel("Y") 720*da0073e9SAndroid Build Coastguard Worker plt.legend() 721*da0073e9SAndroid Build Coastguard Worker plt.tight_layout() 722*da0073e9SAndroid Build Coastguard Worker figures.append(figure) 723*da0073e9SAndroid Build Coastguard Worker 724*da0073e9SAndroid Build Coastguard Worker writer.add_figure("add_figure/figure_list", figures, 0, close=False) 725*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all(plt.fignum_exists(figure.number) is True for figure in figures)) # noqa: F812 726*da0073e9SAndroid Build Coastguard Worker 727*da0073e9SAndroid Build Coastguard Worker writer.add_figure("add_figure/figure_list", figures, 1) 728*da0073e9SAndroid Build Coastguard Worker if matplotlib.__version__ != '3.3.0': 729*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all(plt.fignum_exists(figure.number) is False for figure in figures)) # noqa: F812 730*da0073e9SAndroid Build Coastguard Worker else: 731*da0073e9SAndroid Build Coastguard Worker print("Skipping fignum_exists, see https://github.com/matplotlib/matplotlib/issues/18163") 732*da0073e9SAndroid Build Coastguard Worker 733*da0073e9SAndroid Build Coastguard Worker writer.close() 734*da0073e9SAndroid Build Coastguard Worker 735*da0073e9SAndroid Build Coastguard Workerclass TestTensorBoardNumpy(BaseTestCase): 736*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "Skipping on windows, see https://github.com/pytorch/pytorch/pull/109349 ") 737*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ") 738*da0073e9SAndroid Build Coastguard Worker def test_scalar(self): 739*da0073e9SAndroid Build Coastguard Worker res = make_np(1.1) 740*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,)) 741*da0073e9SAndroid Build Coastguard Worker res = make_np(1 << 64 - 1) # uint64_max 742*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,)) 743*da0073e9SAndroid Build Coastguard Worker res = make_np(np.float16(1.00000087)) 744*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,)) 745*da0073e9SAndroid Build Coastguard Worker res = make_np(np.float128(1.00008 + 9)) 746*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,)) 747*da0073e9SAndroid Build Coastguard Worker res = make_np(np.int64(100000000000)) 748*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,)) 749*da0073e9SAndroid Build Coastguard Worker 750*da0073e9SAndroid Build Coastguard Worker def test_pytorch_np_expect_fail(self): 751*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(NotImplementedError): 752*da0073e9SAndroid Build Coastguard Worker res = make_np({'pytorch': 1.0}) 753*da0073e9SAndroid Build Coastguard Worker 754*da0073e9SAndroid Build Coastguard Worker 755*da0073e9SAndroid Build Coastguard Worker 756*da0073e9SAndroid Build Coastguard Workerclass TestTensorProtoSummary(BaseTestCase): 757*da0073e9SAndroid Build Coastguard Worker @parametrize( 758*da0073e9SAndroid Build Coastguard Worker "tensor_type,proto_type", 759*da0073e9SAndroid Build Coastguard Worker [ 760*da0073e9SAndroid Build Coastguard Worker (torch.float16, DataType.DT_HALF), 761*da0073e9SAndroid Build Coastguard Worker (torch.bfloat16, DataType.DT_BFLOAT16), 762*da0073e9SAndroid Build Coastguard Worker ], 763*da0073e9SAndroid Build Coastguard Worker ) 764*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Unsuitable test for Dynamo, behavior changes with version") 765*da0073e9SAndroid Build Coastguard Worker def test_half_tensor_proto(self, tensor_type, proto_type): 766*da0073e9SAndroid Build Coastguard Worker float_values = [1.0, 2.0, 3.0] 767*da0073e9SAndroid Build Coastguard Worker actual_proto = tensor_proto( 768*da0073e9SAndroid Build Coastguard Worker "dummy", 769*da0073e9SAndroid Build Coastguard Worker torch.tensor(float_values, dtype=tensor_type), 770*da0073e9SAndroid Build Coastguard Worker ).value[0].tensor 771*da0073e9SAndroid Build Coastguard Worker self.assertSequenceEqual( 772*da0073e9SAndroid Build Coastguard Worker [int_to_half(x) for x in actual_proto.half_val], 773*da0073e9SAndroid Build Coastguard Worker float_values, 774*da0073e9SAndroid Build Coastguard Worker ) 775*da0073e9SAndroid Build Coastguard Worker self.assertTrue(actual_proto.dtype == proto_type) 776*da0073e9SAndroid Build Coastguard Worker 777*da0073e9SAndroid Build Coastguard Worker def test_float_tensor_proto(self): 778*da0073e9SAndroid Build Coastguard Worker float_values = [1.0, 2.0, 3.0] 779*da0073e9SAndroid Build Coastguard Worker actual_proto = ( 780*da0073e9SAndroid Build Coastguard Worker tensor_proto("dummy", torch.tensor(float_values)).value[0].tensor 781*da0073e9SAndroid Build Coastguard Worker ) 782*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_proto.float_val, float_values) 783*da0073e9SAndroid Build Coastguard Worker self.assertTrue(actual_proto.dtype == DataType.DT_FLOAT) 784*da0073e9SAndroid Build Coastguard Worker 785*da0073e9SAndroid Build Coastguard Worker def test_int_tensor_proto(self): 786*da0073e9SAndroid Build Coastguard Worker int_values = [1, 2, 3] 787*da0073e9SAndroid Build Coastguard Worker actual_proto = ( 788*da0073e9SAndroid Build Coastguard Worker tensor_proto("dummy", torch.tensor(int_values, dtype=torch.int32)) 789*da0073e9SAndroid Build Coastguard Worker .value[0] 790*da0073e9SAndroid Build Coastguard Worker .tensor 791*da0073e9SAndroid Build Coastguard Worker ) 792*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_proto.int_val, int_values) 793*da0073e9SAndroid Build Coastguard Worker self.assertTrue(actual_proto.dtype == DataType.DT_INT32) 794*da0073e9SAndroid Build Coastguard Worker 795*da0073e9SAndroid Build Coastguard Worker def test_scalar_tensor_proto(self): 796*da0073e9SAndroid Build Coastguard Worker scalar_value = 0.1 797*da0073e9SAndroid Build Coastguard Worker actual_proto = ( 798*da0073e9SAndroid Build Coastguard Worker tensor_proto("dummy", torch.tensor(scalar_value)).value[0].tensor 799*da0073e9SAndroid Build Coastguard Worker ) 800*da0073e9SAndroid Build Coastguard Worker self.assertAlmostEqual(actual_proto.float_val[0], scalar_value) 801*da0073e9SAndroid Build Coastguard Worker 802*da0073e9SAndroid Build Coastguard Worker def test_complex_tensor_proto(self): 803*da0073e9SAndroid Build Coastguard Worker real = torch.tensor([1.0, 2.0]) 804*da0073e9SAndroid Build Coastguard Worker imag = torch.tensor([3.0, 4.0]) 805*da0073e9SAndroid Build Coastguard Worker actual_proto = ( 806*da0073e9SAndroid Build Coastguard Worker tensor_proto("dummy", torch.complex(real, imag)).value[0].tensor 807*da0073e9SAndroid Build Coastguard Worker ) 808*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_proto.scomplex_val, [1.0, 3.0, 2.0, 4.0]) 809*da0073e9SAndroid Build Coastguard Worker 810*da0073e9SAndroid Build Coastguard Worker def test_empty_tensor_proto(self): 811*da0073e9SAndroid Build Coastguard Worker actual_proto = tensor_proto("dummy", torch.empty(0)).value[0].tensor 812*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_proto.float_val, []) 813*da0073e9SAndroid Build Coastguard Worker 814*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestTensorProtoSummary) 815*da0073e9SAndroid Build Coastguard Worker 816*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 817*da0073e9SAndroid Build Coastguard Worker run_tests() 818