xref: /aosp_15_r20/external/pytorch/test/test_tensorboard.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: unknown"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport io
4*da0073e9SAndroid Build Coastguard Workerimport os
5*da0073e9SAndroid Build Coastguard Workerimport shutil
6*da0073e9SAndroid Build Coastguard Workerimport sys
7*da0073e9SAndroid Build Coastguard Workerimport tempfile
8*da0073e9SAndroid Build Coastguard Workerimport unittest
9*da0073e9SAndroid Build Coastguard Workerfrom pathlib import Path
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Workerimport expecttest
12*da0073e9SAndroid Build Coastguard Workerimport numpy as np
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard WorkerTEST_TENSORBOARD = True
16*da0073e9SAndroid Build Coastguard Workertry:
17*da0073e9SAndroid Build Coastguard Worker    import tensorboard.summary.writer.event_file_writer  # noqa: F401
18*da0073e9SAndroid Build Coastguard Worker    from tensorboard.compat.proto.summary_pb2 import Summary
19*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
20*da0073e9SAndroid Build Coastguard Worker    TEST_TENSORBOARD = False
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard WorkerHAS_TORCHVISION = True
23*da0073e9SAndroid Build Coastguard Workertry:
24*da0073e9SAndroid Build Coastguard Worker    import torchvision
25*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
26*da0073e9SAndroid Build Coastguard Worker    HAS_TORCHVISION = False
27*da0073e9SAndroid Build Coastguard WorkerskipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard WorkerTEST_MATPLOTLIB = True
30*da0073e9SAndroid Build Coastguard Workertry:
31*da0073e9SAndroid Build Coastguard Worker    import matplotlib
32*da0073e9SAndroid Build Coastguard Worker    if os.environ.get('DISPLAY', '') == '':
33*da0073e9SAndroid Build Coastguard Worker        matplotlib.use('Agg')
34*da0073e9SAndroid Build Coastguard Worker    import matplotlib.pyplot as plt
35*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
36*da0073e9SAndroid Build Coastguard Worker    TEST_MATPLOTLIB = False
37*da0073e9SAndroid Build Coastguard WorkerskipIfNoMatplotlib = unittest.skipIf(not TEST_MATPLOTLIB, "no matplotlib")
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Workerimport torch
40*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
41*da0073e9SAndroid Build Coastguard Worker    instantiate_parametrized_tests,
42*da0073e9SAndroid Build Coastguard Worker    IS_MACOS,
43*da0073e9SAndroid Build Coastguard Worker    IS_WINDOWS,
44*da0073e9SAndroid Build Coastguard Worker    parametrize,
45*da0073e9SAndroid Build Coastguard Worker    run_tests,
46*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_CROSSREF,
47*da0073e9SAndroid Build Coastguard Worker    TestCase,
48*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
49*da0073e9SAndroid Build Coastguard Worker)
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Workerdef tensor_N(shape, dtype=float):
53*da0073e9SAndroid Build Coastguard Worker    numel = np.prod(shape)
54*da0073e9SAndroid Build Coastguard Worker    x = (np.arange(numel, dtype=dtype)).reshape(shape)
55*da0073e9SAndroid Build Coastguard Worker    return x
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Workerclass BaseTestCase(TestCase):
58*da0073e9SAndroid Build Coastguard Worker    """ Base class used for all TensorBoard tests """
59*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
60*da0073e9SAndroid Build Coastguard Worker        super().setUp()
61*da0073e9SAndroid Build Coastguard Worker        if not TEST_TENSORBOARD:
62*da0073e9SAndroid Build Coastguard Worker            return self.skipTest("Skip the test since TensorBoard is not installed")
63*da0073e9SAndroid Build Coastguard Worker        if TEST_WITH_CROSSREF:
64*da0073e9SAndroid Build Coastguard Worker            return self.skipTest("Don't run TensorBoard tests with crossref")
65*da0073e9SAndroid Build Coastguard Worker        self.temp_dirs = []
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker    def createSummaryWriter(self):
68*da0073e9SAndroid Build Coastguard Worker        # Just to get the name of the directory in a writable place. tearDown()
69*da0073e9SAndroid Build Coastguard Worker        # is responsible for clean-ups.
70*da0073e9SAndroid Build Coastguard Worker        temp_dir = tempfile.TemporaryDirectory(prefix="test_tensorboard").name
71*da0073e9SAndroid Build Coastguard Worker        self.temp_dirs.append(temp_dir)
72*da0073e9SAndroid Build Coastguard Worker        return SummaryWriter(temp_dir)
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
75*da0073e9SAndroid Build Coastguard Worker        super().tearDown()
76*da0073e9SAndroid Build Coastguard Worker        # Remove directories created by SummaryWriter
77*da0073e9SAndroid Build Coastguard Worker        for temp_dir in self.temp_dirs:
78*da0073e9SAndroid Build Coastguard Worker            if os.path.exists(temp_dir):
79*da0073e9SAndroid Build Coastguard Worker                shutil.rmtree(temp_dir)
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Workerif TEST_TENSORBOARD:
83*da0073e9SAndroid Build Coastguard Worker    from google.protobuf import text_format
84*da0073e9SAndroid Build Coastguard Worker    from PIL import Image
85*da0073e9SAndroid Build Coastguard Worker    from tensorboard.compat.proto.graph_pb2 import GraphDef
86*da0073e9SAndroid Build Coastguard Worker    from tensorboard.compat.proto.types_pb2 import DataType
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker    from torch.utils.tensorboard import summary, SummaryWriter
89*da0073e9SAndroid Build Coastguard Worker    from torch.utils.tensorboard._convert_np import make_np
90*da0073e9SAndroid Build Coastguard Worker    from torch.utils.tensorboard._pytorch_graph import graph
91*da0073e9SAndroid Build Coastguard Worker    from torch.utils.tensorboard._utils import _prepare_video, convert_to_HWC
92*da0073e9SAndroid Build Coastguard Worker    from torch.utils.tensorboard.summary import int_to_half, tensor_proto
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Workerclass TestTensorBoardPyTorchNumpy(BaseTestCase):
95*da0073e9SAndroid Build Coastguard Worker    def test_pytorch_np(self):
96*da0073e9SAndroid Build Coastguard Worker        tensors = [torch.rand(3, 10, 10), torch.rand(1), torch.rand(1, 2, 3, 4, 5)]
97*da0073e9SAndroid Build Coastguard Worker        for tensor in tensors:
98*da0073e9SAndroid Build Coastguard Worker            # regular tensor
99*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(make_np(tensor), np.ndarray)
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker            # CUDA tensor
102*da0073e9SAndroid Build Coastguard Worker            if torch.cuda.is_available():
103*da0073e9SAndroid Build Coastguard Worker                self.assertIsInstance(make_np(tensor.cuda()), np.ndarray)
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker            # regular variable
106*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(make_np(torch.autograd.Variable(tensor)), np.ndarray)
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker            # CUDA variable
109*da0073e9SAndroid Build Coastguard Worker            if torch.cuda.is_available():
110*da0073e9SAndroid Build Coastguard Worker                self.assertIsInstance(make_np(torch.autograd.Variable(tensor).cuda()), np.ndarray)
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker        # python primitive type
113*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(make_np(0), np.ndarray)
114*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(make_np(0.1), np.ndarray)
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker    def test_pytorch_autograd_np(self):
117*da0073e9SAndroid Build Coastguard Worker        x = torch.autograd.Variable(torch.empty(1))
118*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(make_np(x), np.ndarray)
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker    def test_pytorch_write(self):
121*da0073e9SAndroid Build Coastguard Worker        with self.createSummaryWriter() as w:
122*da0073e9SAndroid Build Coastguard Worker            w.add_scalar('scalar', torch.autograd.Variable(torch.rand(1)), 0)
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker    def test_pytorch_histogram(self):
125*da0073e9SAndroid Build Coastguard Worker        with self.createSummaryWriter() as w:
126*da0073e9SAndroid Build Coastguard Worker            w.add_histogram('float histogram', torch.rand((50,)))
127*da0073e9SAndroid Build Coastguard Worker            w.add_histogram('int histogram', torch.randint(0, 100, (50,)))
128*da0073e9SAndroid Build Coastguard Worker            w.add_histogram('bfloat16 histogram', torch.rand(50, dtype=torch.bfloat16))
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker    def test_pytorch_histogram_raw(self):
131*da0073e9SAndroid Build Coastguard Worker        with self.createSummaryWriter() as w:
132*da0073e9SAndroid Build Coastguard Worker            num = 50
133*da0073e9SAndroid Build Coastguard Worker            floats = make_np(torch.rand((num,)))
134*da0073e9SAndroid Build Coastguard Worker            bins = [0.0, 0.25, 0.5, 0.75, 1.0]
135*da0073e9SAndroid Build Coastguard Worker            counts, limits = np.histogram(floats, bins)
136*da0073e9SAndroid Build Coastguard Worker            sum_sq = floats.dot(floats).item()
137*da0073e9SAndroid Build Coastguard Worker            w.add_histogram_raw('float histogram raw',
138*da0073e9SAndroid Build Coastguard Worker                                min=floats.min().item(),
139*da0073e9SAndroid Build Coastguard Worker                                max=floats.max().item(),
140*da0073e9SAndroid Build Coastguard Worker                                num=num,
141*da0073e9SAndroid Build Coastguard Worker                                sum=floats.sum().item(),
142*da0073e9SAndroid Build Coastguard Worker                                sum_squares=sum_sq,
143*da0073e9SAndroid Build Coastguard Worker                                bucket_limits=limits[1:].tolist(),
144*da0073e9SAndroid Build Coastguard Worker                                bucket_counts=counts.tolist())
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker            ints = make_np(torch.randint(0, 100, (num,)))
147*da0073e9SAndroid Build Coastguard Worker            bins = [0, 25, 50, 75, 100]
148*da0073e9SAndroid Build Coastguard Worker            counts, limits = np.histogram(ints, bins)
149*da0073e9SAndroid Build Coastguard Worker            sum_sq = ints.dot(ints).item()
150*da0073e9SAndroid Build Coastguard Worker            w.add_histogram_raw('int histogram raw',
151*da0073e9SAndroid Build Coastguard Worker                                min=ints.min().item(),
152*da0073e9SAndroid Build Coastguard Worker                                max=ints.max().item(),
153*da0073e9SAndroid Build Coastguard Worker                                num=num,
154*da0073e9SAndroid Build Coastguard Worker                                sum=ints.sum().item(),
155*da0073e9SAndroid Build Coastguard Worker                                sum_squares=sum_sq,
156*da0073e9SAndroid Build Coastguard Worker                                bucket_limits=limits[1:].tolist(),
157*da0073e9SAndroid Build Coastguard Worker                                bucket_counts=counts.tolist())
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker            ints = torch.tensor(range(0, 100)).float()
160*da0073e9SAndroid Build Coastguard Worker            nbins = 100
161*da0073e9SAndroid Build Coastguard Worker            counts = torch.histc(ints, bins=nbins, min=0, max=99)
162*da0073e9SAndroid Build Coastguard Worker            limits = torch.tensor(range(nbins))
163*da0073e9SAndroid Build Coastguard Worker            sum_sq = ints.dot(ints).item()
164*da0073e9SAndroid Build Coastguard Worker            w.add_histogram_raw('int histogram raw',
165*da0073e9SAndroid Build Coastguard Worker                                min=ints.min().item(),
166*da0073e9SAndroid Build Coastguard Worker                                max=ints.max().item(),
167*da0073e9SAndroid Build Coastguard Worker                                num=num,
168*da0073e9SAndroid Build Coastguard Worker                                sum=ints.sum().item(),
169*da0073e9SAndroid Build Coastguard Worker                                sum_squares=sum_sq,
170*da0073e9SAndroid Build Coastguard Worker                                bucket_limits=limits.tolist(),
171*da0073e9SAndroid Build Coastguard Worker                                bucket_counts=counts.tolist())
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Workerclass TestTensorBoardUtils(BaseTestCase):
174*da0073e9SAndroid Build Coastguard Worker    def test_to_HWC(self):
175*da0073e9SAndroid Build Coastguard Worker        test_image = np.random.randint(0, 256, size=(3, 32, 32), dtype=np.uint8)
176*da0073e9SAndroid Build Coastguard Worker        converted = convert_to_HWC(test_image, 'chw')
177*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(converted.shape, (32, 32, 3))
178*da0073e9SAndroid Build Coastguard Worker        test_image = np.random.randint(0, 256, size=(16, 3, 32, 32), dtype=np.uint8)
179*da0073e9SAndroid Build Coastguard Worker        converted = convert_to_HWC(test_image, 'nchw')
180*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(converted.shape, (64, 256, 3))
181*da0073e9SAndroid Build Coastguard Worker        test_image = np.random.randint(0, 256, size=(32, 32), dtype=np.uint8)
182*da0073e9SAndroid Build Coastguard Worker        converted = convert_to_HWC(test_image, 'hw')
183*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(converted.shape, (32, 32, 3))
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker    def test_convert_to_HWC_dtype_remains_same(self):
186*da0073e9SAndroid Build Coastguard Worker        # test to ensure convert_to_HWC restores the dtype of input np array and
187*da0073e9SAndroid Build Coastguard Worker        # thus the scale_factor calculated for the image is 1
188*da0073e9SAndroid Build Coastguard Worker        test_image = torch.tensor([[[[1, 2, 3], [4, 5, 6]]]], dtype=torch.uint8)
189*da0073e9SAndroid Build Coastguard Worker        tensor = make_np(test_image)
190*da0073e9SAndroid Build Coastguard Worker        tensor = convert_to_HWC(tensor, 'NCHW')
191*da0073e9SAndroid Build Coastguard Worker        scale_factor = summary._calc_scale_factor(tensor)
192*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scale_factor, 1, msg='Values are already in [0, 255], scale factor should be 1')
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker    def test_prepare_video(self):
196*da0073e9SAndroid Build Coastguard Worker        # At each timeframe, the sum over all other
197*da0073e9SAndroid Build Coastguard Worker        # dimensions of the video should be the same.
198*da0073e9SAndroid Build Coastguard Worker        shapes = [
199*da0073e9SAndroid Build Coastguard Worker            (16, 30, 3, 28, 28),
200*da0073e9SAndroid Build Coastguard Worker            (36, 30, 3, 28, 28),
201*da0073e9SAndroid Build Coastguard Worker            (19, 29, 3, 23, 19),
202*da0073e9SAndroid Build Coastguard Worker            (3, 3, 3, 3, 3)
203*da0073e9SAndroid Build Coastguard Worker        ]
204*da0073e9SAndroid Build Coastguard Worker        for s in shapes:
205*da0073e9SAndroid Build Coastguard Worker            V_input = np.random.random(s)
206*da0073e9SAndroid Build Coastguard Worker            V_after = _prepare_video(np.copy(V_input))
207*da0073e9SAndroid Build Coastguard Worker            total_frame = s[1]
208*da0073e9SAndroid Build Coastguard Worker            V_input = np.swapaxes(V_input, 0, 1)
209*da0073e9SAndroid Build Coastguard Worker            for f in range(total_frame):
210*da0073e9SAndroid Build Coastguard Worker                x = np.reshape(V_input[f], newshape=(-1))
211*da0073e9SAndroid Build Coastguard Worker                y = np.reshape(V_after[f], newshape=(-1))
212*da0073e9SAndroid Build Coastguard Worker                np.testing.assert_array_almost_equal(np.sum(x), np.sum(y))
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker    def test_numpy_vid_uint8(self):
215*da0073e9SAndroid Build Coastguard Worker        V_input = np.random.randint(0, 256, (16, 30, 3, 28, 28)).astype(np.uint8)
216*da0073e9SAndroid Build Coastguard Worker        V_after = _prepare_video(np.copy(V_input)) * 255
217*da0073e9SAndroid Build Coastguard Worker        total_frame = V_input.shape[1]
218*da0073e9SAndroid Build Coastguard Worker        V_input = np.swapaxes(V_input, 0, 1)
219*da0073e9SAndroid Build Coastguard Worker        for f in range(total_frame):
220*da0073e9SAndroid Build Coastguard Worker            x = np.reshape(V_input[f], newshape=(-1))
221*da0073e9SAndroid Build Coastguard Worker            y = np.reshape(V_after[f], newshape=(-1))
222*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_array_almost_equal(np.sum(x), np.sum(y))
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Workerfreqs = [262, 294, 330, 349, 392, 440, 440, 440, 440, 440, 440]
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Workertrue_positive_counts = [75, 64, 21, 5, 0]
227*da0073e9SAndroid Build Coastguard Workerfalse_positive_counts = [150, 105, 18, 0, 0]
228*da0073e9SAndroid Build Coastguard Workertrue_negative_counts = [0, 45, 132, 150, 150]
229*da0073e9SAndroid Build Coastguard Workerfalse_negative_counts = [0, 11, 54, 70, 75]
230*da0073e9SAndroid Build Coastguard Workerprecision = [0.3333333, 0.3786982, 0.5384616, 1.0, 0.0]
231*da0073e9SAndroid Build Coastguard Workerrecall = [1.0, 0.8533334, 0.28, 0.0666667, 0.0]
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Workerclass TestTensorBoardWriter(BaseTestCase):
234*da0073e9SAndroid Build Coastguard Worker    def test_writer(self):
235*da0073e9SAndroid Build Coastguard Worker        with self.createSummaryWriter() as writer:
236*da0073e9SAndroid Build Coastguard Worker            sample_rate = 44100
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker            n_iter = 0
239*da0073e9SAndroid Build Coastguard Worker            writer.add_hparams(
240*da0073e9SAndroid Build Coastguard Worker                {'lr': 0.1, 'bsize': 1},
241*da0073e9SAndroid Build Coastguard Worker                {'hparam/accuracy': 10, 'hparam/loss': 10}
242*da0073e9SAndroid Build Coastguard Worker            )
243*da0073e9SAndroid Build Coastguard Worker            writer.add_scalar('data/scalar_systemtime', 0.1, n_iter)
244*da0073e9SAndroid Build Coastguard Worker            writer.add_scalar('data/scalar_customtime', 0.2, n_iter, walltime=n_iter)
245*da0073e9SAndroid Build Coastguard Worker            writer.add_scalar('data/new_style', 0.2, n_iter, new_style=True)
246*da0073e9SAndroid Build Coastguard Worker            writer.add_scalars('data/scalar_group', {
247*da0073e9SAndroid Build Coastguard Worker                "xsinx": n_iter * np.sin(n_iter),
248*da0073e9SAndroid Build Coastguard Worker                "xcosx": n_iter * np.cos(n_iter),
249*da0073e9SAndroid Build Coastguard Worker                "arctanx": np.arctan(n_iter)
250*da0073e9SAndroid Build Coastguard Worker            }, n_iter)
251*da0073e9SAndroid Build Coastguard Worker            x = np.zeros((32, 3, 64, 64))  # output from network
252*da0073e9SAndroid Build Coastguard Worker            writer.add_images('Image', x, n_iter)  # Tensor
253*da0073e9SAndroid Build Coastguard Worker            writer.add_image_with_boxes('imagebox',
254*da0073e9SAndroid Build Coastguard Worker                                        np.zeros((3, 64, 64)),
255*da0073e9SAndroid Build Coastguard Worker                                        np.array([[10, 10, 40, 40], [40, 40, 60, 60]]),
256*da0073e9SAndroid Build Coastguard Worker                                        n_iter)
257*da0073e9SAndroid Build Coastguard Worker            x = np.zeros(sample_rate * 2)
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Worker            writer.add_audio('myAudio', x, n_iter)
260*da0073e9SAndroid Build Coastguard Worker            writer.add_video('myVideo', np.random.rand(16, 48, 1, 28, 28).astype(np.float32), n_iter)
261*da0073e9SAndroid Build Coastguard Worker            writer.add_text('Text', 'text logged at step:' + str(n_iter), n_iter)
262*da0073e9SAndroid Build Coastguard Worker            writer.add_text('markdown Text', '''a|b\n-|-\nc|d''', n_iter)
263*da0073e9SAndroid Build Coastguard Worker            writer.add_histogram('hist', np.random.rand(100, 100), n_iter)
264*da0073e9SAndroid Build Coastguard Worker            writer.add_pr_curve('xoxo', np.random.randint(2, size=100), np.random.rand(
265*da0073e9SAndroid Build Coastguard Worker                100), n_iter)  # needs tensorboard 0.4RC or later
266*da0073e9SAndroid Build Coastguard Worker            writer.add_pr_curve_raw('prcurve with raw data', true_positive_counts,
267*da0073e9SAndroid Build Coastguard Worker                                    false_positive_counts,
268*da0073e9SAndroid Build Coastguard Worker                                    true_negative_counts,
269*da0073e9SAndroid Build Coastguard Worker                                    false_negative_counts,
270*da0073e9SAndroid Build Coastguard Worker                                    precision,
271*da0073e9SAndroid Build Coastguard Worker                                    recall, n_iter)
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker            v = np.array([[[1, 1, 1], [-1, -1, 1], [1, -1, -1], [-1, 1, -1]]], dtype=float)
274*da0073e9SAndroid Build Coastguard Worker            c = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 0, 255]]], dtype=int)
275*da0073e9SAndroid Build Coastguard Worker            f = np.array([[[0, 2, 3], [0, 3, 1], [0, 1, 2], [1, 3, 2]]], dtype=int)
276*da0073e9SAndroid Build Coastguard Worker            writer.add_mesh('my_mesh', vertices=v, colors=c, faces=f)
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Workerclass TestTensorBoardSummaryWriter(BaseTestCase):
279*da0073e9SAndroid Build Coastguard Worker    def test_summary_writer_ctx(self):
280*da0073e9SAndroid Build Coastguard Worker        # after using a SummaryWriter as a ctx it should be closed
281*da0073e9SAndroid Build Coastguard Worker        with self.createSummaryWriter() as writer:
282*da0073e9SAndroid Build Coastguard Worker            writer.add_scalar('test', 1)
283*da0073e9SAndroid Build Coastguard Worker        self.assertIs(writer.file_writer, None)
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker    def test_summary_writer_close(self):
286*da0073e9SAndroid Build Coastguard Worker        # Opening and closing SummaryWriter a lot should not run into
287*da0073e9SAndroid Build Coastguard Worker        # OSError: [Errno 24] Too many open files
288*da0073e9SAndroid Build Coastguard Worker        passed = True
289*da0073e9SAndroid Build Coastguard Worker        try:
290*da0073e9SAndroid Build Coastguard Worker            writer = self.createSummaryWriter()
291*da0073e9SAndroid Build Coastguard Worker            writer.close()
292*da0073e9SAndroid Build Coastguard Worker        except OSError:
293*da0073e9SAndroid Build Coastguard Worker            passed = False
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(passed)
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker    def test_pathlib(self):
298*da0073e9SAndroid Build Coastguard Worker        with tempfile.TemporaryDirectory(prefix="test_tensorboard_pathlib") as d:
299*da0073e9SAndroid Build Coastguard Worker            p = Path(d)
300*da0073e9SAndroid Build Coastguard Worker            with SummaryWriter(p) as writer:
301*da0073e9SAndroid Build Coastguard Worker                writer.add_scalar('test', 1)
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Workerclass TestTensorBoardEmbedding(BaseTestCase):
304*da0073e9SAndroid Build Coastguard Worker    def test_embedding(self):
305*da0073e9SAndroid Build Coastguard Worker        w = self.createSummaryWriter()
306*da0073e9SAndroid Build Coastguard Worker        all_features = torch.tensor([[1., 2., 3.], [5., 4., 1.], [3., 7., 7.]])
307*da0073e9SAndroid Build Coastguard Worker        all_labels = torch.tensor([33., 44., 55.])
308*da0073e9SAndroid Build Coastguard Worker        all_images = torch.zeros(3, 3, 5, 5)
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker        w.add_embedding(all_features,
311*da0073e9SAndroid Build Coastguard Worker                        metadata=all_labels,
312*da0073e9SAndroid Build Coastguard Worker                        label_img=all_images,
313*da0073e9SAndroid Build Coastguard Worker                        global_step=2)
314*da0073e9SAndroid Build Coastguard Worker
315*da0073e9SAndroid Build Coastguard Worker        dataset_label = ['test'] * 2 + ['train'] * 2
316*da0073e9SAndroid Build Coastguard Worker        all_labels = list(zip(all_labels, dataset_label))
317*da0073e9SAndroid Build Coastguard Worker        w.add_embedding(all_features,
318*da0073e9SAndroid Build Coastguard Worker                        metadata=all_labels,
319*da0073e9SAndroid Build Coastguard Worker                        label_img=all_images,
320*da0073e9SAndroid Build Coastguard Worker                        metadata_header=['digit', 'dataset'],
321*da0073e9SAndroid Build Coastguard Worker                        global_step=2)
322*da0073e9SAndroid Build Coastguard Worker        # assert...
323*da0073e9SAndroid Build Coastguard Worker
324*da0073e9SAndroid Build Coastguard Worker    def test_embedding_64(self):
325*da0073e9SAndroid Build Coastguard Worker        w = self.createSummaryWriter()
326*da0073e9SAndroid Build Coastguard Worker        all_features = torch.tensor([[1., 2., 3.], [5., 4., 1.], [3., 7., 7.]])
327*da0073e9SAndroid Build Coastguard Worker        all_labels = torch.tensor([33., 44., 55.])
328*da0073e9SAndroid Build Coastguard Worker        all_images = torch.zeros((3, 3, 5, 5), dtype=torch.float64)
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Worker        w.add_embedding(all_features,
331*da0073e9SAndroid Build Coastguard Worker                        metadata=all_labels,
332*da0073e9SAndroid Build Coastguard Worker                        label_img=all_images,
333*da0073e9SAndroid Build Coastguard Worker                        global_step=2)
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker        dataset_label = ['test'] * 2 + ['train'] * 2
336*da0073e9SAndroid Build Coastguard Worker        all_labels = list(zip(all_labels, dataset_label))
337*da0073e9SAndroid Build Coastguard Worker        w.add_embedding(all_features,
338*da0073e9SAndroid Build Coastguard Worker                        metadata=all_labels,
339*da0073e9SAndroid Build Coastguard Worker                        label_img=all_images,
340*da0073e9SAndroid Build Coastguard Worker                        metadata_header=['digit', 'dataset'],
341*da0073e9SAndroid Build Coastguard Worker                        global_step=2)
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Workerclass TestTensorBoardSummary(BaseTestCase):
344*da0073e9SAndroid Build Coastguard Worker    def test_uint8_image(self):
345*da0073e9SAndroid Build Coastguard Worker        '''
346*da0073e9SAndroid Build Coastguard Worker        Tests that uint8 image (pixel values in [0, 255]) is not changed
347*da0073e9SAndroid Build Coastguard Worker        '''
348*da0073e9SAndroid Build Coastguard Worker        test_image = np.random.randint(0, 256, size=(3, 32, 32), dtype=np.uint8)
349*da0073e9SAndroid Build Coastguard Worker        scale_factor = summary._calc_scale_factor(test_image)
350*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scale_factor, 1, msg='Values are already in [0, 255], scale factor should be 1')
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker    def test_float32_image(self):
353*da0073e9SAndroid Build Coastguard Worker        '''
354*da0073e9SAndroid Build Coastguard Worker        Tests that float32 image (pixel values in [0, 1]) are scaled correctly
355*da0073e9SAndroid Build Coastguard Worker        to [0, 255]
356*da0073e9SAndroid Build Coastguard Worker        '''
357*da0073e9SAndroid Build Coastguard Worker        test_image = np.random.rand(3, 32, 32).astype(np.float32)
358*da0073e9SAndroid Build Coastguard Worker        scale_factor = summary._calc_scale_factor(test_image)
359*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scale_factor, 255, msg='Values are in [0, 1], scale factor should be 255')
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker    def test_list_input(self):
362*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(Exception) as e_info:
363*da0073e9SAndroid Build Coastguard Worker            summary.histogram('dummy', [1, 3, 4, 5, 6], 'tensorflow')
364*da0073e9SAndroid Build Coastguard Worker
365*da0073e9SAndroid Build Coastguard Worker    def test_empty_input(self):
366*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(Exception) as e_info:
367*da0073e9SAndroid Build Coastguard Worker            summary.histogram('dummy', np.ndarray(0), 'tensorflow')
368*da0073e9SAndroid Build Coastguard Worker
369*da0073e9SAndroid Build Coastguard Worker    def test_image_with_boxes(self):
370*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(compare_image_proto(summary.image_boxes('dummy',
371*da0073e9SAndroid Build Coastguard Worker                                            tensor_N(shape=(3, 32, 32)),
372*da0073e9SAndroid Build Coastguard Worker                                            np.array([[10, 10, 40, 40]])),
373*da0073e9SAndroid Build Coastguard Worker                                            self))
374*da0073e9SAndroid Build Coastguard Worker
375*da0073e9SAndroid Build Coastguard Worker    def test_image_with_one_channel(self):
376*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(compare_image_proto(
377*da0073e9SAndroid Build Coastguard Worker            summary.image('dummy',
378*da0073e9SAndroid Build Coastguard Worker                          tensor_N(shape=(1, 8, 8)),
379*da0073e9SAndroid Build Coastguard Worker                          dataformats='CHW'),
380*da0073e9SAndroid Build Coastguard Worker                          self))  # noqa: E131
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker    def test_image_with_one_channel_batched(self):
383*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(compare_image_proto(
384*da0073e9SAndroid Build Coastguard Worker            summary.image('dummy',
385*da0073e9SAndroid Build Coastguard Worker                          tensor_N(shape=(2, 1, 8, 8)),
386*da0073e9SAndroid Build Coastguard Worker                          dataformats='NCHW'),
387*da0073e9SAndroid Build Coastguard Worker                          self))  # noqa: E131
388*da0073e9SAndroid Build Coastguard Worker
389*da0073e9SAndroid Build Coastguard Worker    def test_image_with_3_channel_batched(self):
390*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(compare_image_proto(
391*da0073e9SAndroid Build Coastguard Worker            summary.image('dummy',
392*da0073e9SAndroid Build Coastguard Worker                          tensor_N(shape=(2, 3, 8, 8)),
393*da0073e9SAndroid Build Coastguard Worker                          dataformats='NCHW'),
394*da0073e9SAndroid Build Coastguard Worker                          self))  # noqa: E131
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker    def test_image_without_channel(self):
397*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(compare_image_proto(
398*da0073e9SAndroid Build Coastguard Worker            summary.image('dummy',
399*da0073e9SAndroid Build Coastguard Worker                          tensor_N(shape=(8, 8)),
400*da0073e9SAndroid Build Coastguard Worker                          dataformats='HW'),
401*da0073e9SAndroid Build Coastguard Worker                          self))  # noqa: E131
402*da0073e9SAndroid Build Coastguard Worker
403*da0073e9SAndroid Build Coastguard Worker    def test_video(self):
404*da0073e9SAndroid Build Coastguard Worker        try:
405*da0073e9SAndroid Build Coastguard Worker            import moviepy  # noqa: F401
406*da0073e9SAndroid Build Coastguard Worker        except ImportError:
407*da0073e9SAndroid Build Coastguard Worker            return
408*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(compare_proto(summary.video('dummy', tensor_N(shape=(4, 3, 1, 8, 8))), self))
409*da0073e9SAndroid Build Coastguard Worker        summary.video('dummy', np.random.rand(16, 48, 1, 28, 28))
410*da0073e9SAndroid Build Coastguard Worker        summary.video('dummy', np.random.rand(20, 7, 1, 8, 8))
411*da0073e9SAndroid Build Coastguard Worker
412*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ")
413*da0073e9SAndroid Build Coastguard Worker    def test_audio(self):
414*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(compare_proto(summary.audio('dummy', tensor_N(shape=(42,))), self))
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ")
417*da0073e9SAndroid Build Coastguard Worker    def test_text(self):
418*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(compare_proto(summary.text('dummy', 'text 123'), self))
419*da0073e9SAndroid Build Coastguard Worker
420*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ")
421*da0073e9SAndroid Build Coastguard Worker    def test_histogram_auto(self):
422*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(compare_proto(summary.histogram('dummy', tensor_N(shape=(1024,)), bins='auto', max_bins=5), self))
423*da0073e9SAndroid Build Coastguard Worker
424*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ")
425*da0073e9SAndroid Build Coastguard Worker    def test_histogram_fd(self):
426*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(compare_proto(summary.histogram('dummy', tensor_N(shape=(1024,)), bins='fd', max_bins=5), self))
427*da0073e9SAndroid Build Coastguard Worker
428*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ")
429*da0073e9SAndroid Build Coastguard Worker    def test_histogram_doane(self):
430*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(compare_proto(summary.histogram('dummy', tensor_N(shape=(1024,)), bins='doane', max_bins=5), self))
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker    def test_custom_scalars(self):
433*da0073e9SAndroid Build Coastguard Worker        layout = {
434*da0073e9SAndroid Build Coastguard Worker            'Taiwan': {
435*da0073e9SAndroid Build Coastguard Worker                'twse': ['Multiline', ['twse/0050', 'twse/2330']]
436*da0073e9SAndroid Build Coastguard Worker            },
437*da0073e9SAndroid Build Coastguard Worker            'USA': {
438*da0073e9SAndroid Build Coastguard Worker                'dow': ['Margin', ['dow/aaa', 'dow/bbb', 'dow/ccc']],
439*da0073e9SAndroid Build Coastguard Worker                'nasdaq': ['Margin', ['nasdaq/aaa', 'nasdaq/bbb', 'nasdaq/ccc']]
440*da0073e9SAndroid Build Coastguard Worker            }
441*da0073e9SAndroid Build Coastguard Worker        }
442*da0073e9SAndroid Build Coastguard Worker        summary.custom_scalars(layout)  # only smoke test. Because protobuf in python2/3 serialize dictionary differently.
443*da0073e9SAndroid Build Coastguard Worker
444*da0073e9SAndroid Build Coastguard Worker
445*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ")
446*da0073e9SAndroid Build Coastguard Worker    def test_mesh(self):
447*da0073e9SAndroid Build Coastguard Worker        v = np.array([[[1, 1, 1], [-1, -1, 1], [1, -1, -1], [-1, 1, -1]]], dtype=float)
448*da0073e9SAndroid Build Coastguard Worker        c = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 0, 255]]], dtype=int)
449*da0073e9SAndroid Build Coastguard Worker        f = np.array([[[0, 2, 3], [0, 3, 1], [0, 1, 2], [1, 3, 2]]], dtype=int)
450*da0073e9SAndroid Build Coastguard Worker        mesh = summary.mesh('my_mesh', vertices=v, colors=c, faces=f, config_dict=None)
451*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(compare_proto(mesh, self))
452*da0073e9SAndroid Build Coastguard Worker
453*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ")
454*da0073e9SAndroid Build Coastguard Worker    def test_scalar_new_style(self):
455*da0073e9SAndroid Build Coastguard Worker        scalar = summary.scalar('test_scalar', 1.0, new_style=True)
456*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(compare_proto(scalar, self))
457*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AssertionError):
458*da0073e9SAndroid Build Coastguard Worker            summary.scalar('test_scalar2', torch.Tensor([1, 2, 3]), new_style=True)
459*da0073e9SAndroid Build Coastguard Worker
460*da0073e9SAndroid Build Coastguard Worker
461*da0073e9SAndroid Build Coastguard Workerdef remove_whitespace(string):
462*da0073e9SAndroid Build Coastguard Worker    return string.replace(' ', '').replace('\t', '').replace('\n', '')
463*da0073e9SAndroid Build Coastguard Worker
464*da0073e9SAndroid Build Coastguard Workerdef get_expected_file(function_ptr):
465*da0073e9SAndroid Build Coastguard Worker    module_id = function_ptr.__class__.__module__
466*da0073e9SAndroid Build Coastguard Worker    test_file = sys.modules[module_id].__file__
467*da0073e9SAndroid Build Coastguard Worker    # Look for the .py file (since __file__ could be pyc).
468*da0073e9SAndroid Build Coastguard Worker    test_file = ".".join(test_file.split('.')[:-1]) + '.py'
469*da0073e9SAndroid Build Coastguard Worker
470*da0073e9SAndroid Build Coastguard Worker    # Use realpath to follow symlinks appropriately.
471*da0073e9SAndroid Build Coastguard Worker    test_dir = os.path.dirname(os.path.realpath(test_file))
472*da0073e9SAndroid Build Coastguard Worker    functionName = function_ptr.id().split('.')[-1]
473*da0073e9SAndroid Build Coastguard Worker    return os.path.join(test_dir,
474*da0073e9SAndroid Build Coastguard Worker                        "expect",
475*da0073e9SAndroid Build Coastguard Worker                        'TestTensorBoard.' + functionName + ".expect")
476*da0073e9SAndroid Build Coastguard Worker
477*da0073e9SAndroid Build Coastguard Workerdef read_expected_content(function_ptr):
478*da0073e9SAndroid Build Coastguard Worker    expected_file = get_expected_file(function_ptr)
479*da0073e9SAndroid Build Coastguard Worker    assert os.path.exists(expected_file), expected_file
480*da0073e9SAndroid Build Coastguard Worker    with open(expected_file) as f:
481*da0073e9SAndroid Build Coastguard Worker        return f.read()
482*da0073e9SAndroid Build Coastguard Worker
483*da0073e9SAndroid Build Coastguard Workerdef compare_image_proto(actual_proto, function_ptr):
484*da0073e9SAndroid Build Coastguard Worker    if expecttest.ACCEPT:
485*da0073e9SAndroid Build Coastguard Worker        expected_file = get_expected_file(function_ptr)
486*da0073e9SAndroid Build Coastguard Worker        with open(expected_file, 'w') as f:
487*da0073e9SAndroid Build Coastguard Worker            f.write(text_format.MessageToString(actual_proto))
488*da0073e9SAndroid Build Coastguard Worker        return True
489*da0073e9SAndroid Build Coastguard Worker    expected_str = read_expected_content(function_ptr)
490*da0073e9SAndroid Build Coastguard Worker    expected_proto = Summary()
491*da0073e9SAndroid Build Coastguard Worker    text_format.Parse(expected_str, expected_proto)
492*da0073e9SAndroid Build Coastguard Worker
493*da0073e9SAndroid Build Coastguard Worker    [actual, expected] = [actual_proto.value[0], expected_proto.value[0]]
494*da0073e9SAndroid Build Coastguard Worker    actual_img = Image.open(io.BytesIO(actual.image.encoded_image_string))
495*da0073e9SAndroid Build Coastguard Worker    expected_img = Image.open(io.BytesIO(expected.image.encoded_image_string))
496*da0073e9SAndroid Build Coastguard Worker
497*da0073e9SAndroid Build Coastguard Worker    return (
498*da0073e9SAndroid Build Coastguard Worker        actual.tag == expected.tag and
499*da0073e9SAndroid Build Coastguard Worker        actual.image.height == expected.image.height and
500*da0073e9SAndroid Build Coastguard Worker        actual.image.width == expected.image.width and
501*da0073e9SAndroid Build Coastguard Worker        actual.image.colorspace == expected.image.colorspace and
502*da0073e9SAndroid Build Coastguard Worker        actual_img == expected_img
503*da0073e9SAndroid Build Coastguard Worker    )
504*da0073e9SAndroid Build Coastguard Worker
505*da0073e9SAndroid Build Coastguard Workerdef compare_proto(str_to_compare, function_ptr):
506*da0073e9SAndroid Build Coastguard Worker    if expecttest.ACCEPT:
507*da0073e9SAndroid Build Coastguard Worker        write_proto(str_to_compare, function_ptr)
508*da0073e9SAndroid Build Coastguard Worker        return True
509*da0073e9SAndroid Build Coastguard Worker    expected = read_expected_content(function_ptr)
510*da0073e9SAndroid Build Coastguard Worker    str_to_compare = str(str_to_compare)
511*da0073e9SAndroid Build Coastguard Worker    return remove_whitespace(str_to_compare) == remove_whitespace(expected)
512*da0073e9SAndroid Build Coastguard Worker
513*da0073e9SAndroid Build Coastguard Workerdef write_proto(str_to_compare, function_ptr):
514*da0073e9SAndroid Build Coastguard Worker    expected_file = get_expected_file(function_ptr)
515*da0073e9SAndroid Build Coastguard Worker    with open(expected_file, 'w') as f:
516*da0073e9SAndroid Build Coastguard Worker        f.write(str(str_to_compare))
517*da0073e9SAndroid Build Coastguard Worker
518*da0073e9SAndroid Build Coastguard Workerclass TestTensorBoardPytorchGraph(BaseTestCase):
519*da0073e9SAndroid Build Coastguard Worker    def test_pytorch_graph(self):
520*da0073e9SAndroid Build Coastguard Worker        dummy_input = (torch.zeros(1, 3),)
521*da0073e9SAndroid Build Coastguard Worker
522*da0073e9SAndroid Build Coastguard Worker        class myLinear(torch.nn.Module):
523*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
524*da0073e9SAndroid Build Coastguard Worker                super().__init__()
525*da0073e9SAndroid Build Coastguard Worker                self.l = torch.nn.Linear(3, 5)
526*da0073e9SAndroid Build Coastguard Worker
527*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
528*da0073e9SAndroid Build Coastguard Worker                return self.l(x)
529*da0073e9SAndroid Build Coastguard Worker
530*da0073e9SAndroid Build Coastguard Worker        with self.createSummaryWriter() as w:
531*da0073e9SAndroid Build Coastguard Worker            w.add_graph(myLinear(), dummy_input)
532*da0073e9SAndroid Build Coastguard Worker
533*da0073e9SAndroid Build Coastguard Worker        actual_proto, _ = graph(myLinear(), dummy_input)
534*da0073e9SAndroid Build Coastguard Worker
535*da0073e9SAndroid Build Coastguard Worker        expected_str = read_expected_content(self)
536*da0073e9SAndroid Build Coastguard Worker        expected_proto = GraphDef()
537*da0073e9SAndroid Build Coastguard Worker        text_format.Parse(expected_str, expected_proto)
538*da0073e9SAndroid Build Coastguard Worker
539*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(expected_proto.node), len(actual_proto.node))
540*da0073e9SAndroid Build Coastguard Worker        for i in range(len(expected_proto.node)):
541*da0073e9SAndroid Build Coastguard Worker            expected_node = expected_proto.node[i]
542*da0073e9SAndroid Build Coastguard Worker            actual_node = actual_proto.node[i]
543*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected_node.name, actual_node.name)
544*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected_node.op, actual_node.op)
545*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected_node.input, actual_node.input)
546*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected_node.device, actual_node.device)
547*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
548*da0073e9SAndroid Build Coastguard Worker                sorted(expected_node.attr.keys()), sorted(actual_node.attr.keys()))
549*da0073e9SAndroid Build Coastguard Worker
550*da0073e9SAndroid Build Coastguard Worker    def test_nested_nn_squential(self):
551*da0073e9SAndroid Build Coastguard Worker
552*da0073e9SAndroid Build Coastguard Worker        dummy_input = torch.randn(2, 3)
553*da0073e9SAndroid Build Coastguard Worker
554*da0073e9SAndroid Build Coastguard Worker        class InnerNNSquential(torch.nn.Module):
555*da0073e9SAndroid Build Coastguard Worker            def __init__(self, dim1, dim2):
556*da0073e9SAndroid Build Coastguard Worker                super().__init__()
557*da0073e9SAndroid Build Coastguard Worker                self.inner_nn_squential = torch.nn.Sequential(
558*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Linear(dim1, dim2),
559*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Linear(dim2, dim1),
560*da0073e9SAndroid Build Coastguard Worker                )
561*da0073e9SAndroid Build Coastguard Worker
562*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
563*da0073e9SAndroid Build Coastguard Worker                x = self.inner_nn_squential(x)
564*da0073e9SAndroid Build Coastguard Worker                return x
565*da0073e9SAndroid Build Coastguard Worker
566*da0073e9SAndroid Build Coastguard Worker        class OuterNNSquential(torch.nn.Module):
567*da0073e9SAndroid Build Coastguard Worker            def __init__(self, dim1=3, dim2=4, depth=2):
568*da0073e9SAndroid Build Coastguard Worker                super().__init__()
569*da0073e9SAndroid Build Coastguard Worker                layers = []
570*da0073e9SAndroid Build Coastguard Worker                for _ in range(depth):
571*da0073e9SAndroid Build Coastguard Worker                    layers.append(InnerNNSquential(dim1, dim2))
572*da0073e9SAndroid Build Coastguard Worker                self.outer_nn_squential = torch.nn.Sequential(*layers)
573*da0073e9SAndroid Build Coastguard Worker
574*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
575*da0073e9SAndroid Build Coastguard Worker                x = self.outer_nn_squential(x)
576*da0073e9SAndroid Build Coastguard Worker                return x
577*da0073e9SAndroid Build Coastguard Worker
578*da0073e9SAndroid Build Coastguard Worker        with self.createSummaryWriter() as w:
579*da0073e9SAndroid Build Coastguard Worker            w.add_graph(OuterNNSquential(), dummy_input)
580*da0073e9SAndroid Build Coastguard Worker
581*da0073e9SAndroid Build Coastguard Worker        actual_proto, _ = graph(OuterNNSquential(), dummy_input)
582*da0073e9SAndroid Build Coastguard Worker
583*da0073e9SAndroid Build Coastguard Worker        expected_str = read_expected_content(self)
584*da0073e9SAndroid Build Coastguard Worker        expected_proto = GraphDef()
585*da0073e9SAndroid Build Coastguard Worker        text_format.Parse(expected_str, expected_proto)
586*da0073e9SAndroid Build Coastguard Worker
587*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(expected_proto.node), len(actual_proto.node))
588*da0073e9SAndroid Build Coastguard Worker        for i in range(len(expected_proto.node)):
589*da0073e9SAndroid Build Coastguard Worker            expected_node = expected_proto.node[i]
590*da0073e9SAndroid Build Coastguard Worker            actual_node = actual_proto.node[i]
591*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected_node.name, actual_node.name)
592*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected_node.op, actual_node.op)
593*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected_node.input, actual_node.input)
594*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected_node.device, actual_node.device)
595*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
596*da0073e9SAndroid Build Coastguard Worker                sorted(expected_node.attr.keys()), sorted(actual_node.attr.keys()))
597*da0073e9SAndroid Build Coastguard Worker
598*da0073e9SAndroid Build Coastguard Worker    def test_pytorch_graph_dict_input(self):
599*da0073e9SAndroid Build Coastguard Worker        class Model(torch.nn.Module):
600*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
601*da0073e9SAndroid Build Coastguard Worker                super().__init__()
602*da0073e9SAndroid Build Coastguard Worker                self.l = torch.nn.Linear(3, 5)
603*da0073e9SAndroid Build Coastguard Worker
604*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
605*da0073e9SAndroid Build Coastguard Worker                return self.l(x)
606*da0073e9SAndroid Build Coastguard Worker
607*da0073e9SAndroid Build Coastguard Worker        class ModelDict(torch.nn.Module):
608*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
609*da0073e9SAndroid Build Coastguard Worker                super().__init__()
610*da0073e9SAndroid Build Coastguard Worker                self.l = torch.nn.Linear(3, 5)
611*da0073e9SAndroid Build Coastguard Worker
612*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
613*da0073e9SAndroid Build Coastguard Worker                return {"out": self.l(x)}
614*da0073e9SAndroid Build Coastguard Worker
615*da0073e9SAndroid Build Coastguard Worker
616*da0073e9SAndroid Build Coastguard Worker        dummy_input = torch.zeros(1, 3)
617*da0073e9SAndroid Build Coastguard Worker
618*da0073e9SAndroid Build Coastguard Worker        with self.createSummaryWriter() as w:
619*da0073e9SAndroid Build Coastguard Worker            w.add_graph(Model(), dummy_input)
620*da0073e9SAndroid Build Coastguard Worker
621*da0073e9SAndroid Build Coastguard Worker        with self.createSummaryWriter() as w:
622*da0073e9SAndroid Build Coastguard Worker            w.add_graph(Model(), dummy_input, use_strict_trace=True)
623*da0073e9SAndroid Build Coastguard Worker
624*da0073e9SAndroid Build Coastguard Worker        # expect error: Encountering a dict at the output of the tracer...
625*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
626*da0073e9SAndroid Build Coastguard Worker            with self.createSummaryWriter() as w:
627*da0073e9SAndroid Build Coastguard Worker                w.add_graph(ModelDict(), dummy_input, use_strict_trace=True)
628*da0073e9SAndroid Build Coastguard Worker
629*da0073e9SAndroid Build Coastguard Worker        with self.createSummaryWriter() as w:
630*da0073e9SAndroid Build Coastguard Worker            w.add_graph(ModelDict(), dummy_input, use_strict_trace=False)
631*da0073e9SAndroid Build Coastguard Worker
632*da0073e9SAndroid Build Coastguard Worker
633*da0073e9SAndroid Build Coastguard Worker    def test_mlp_graph(self):
634*da0073e9SAndroid Build Coastguard Worker        dummy_input = (torch.zeros(2, 1, 28, 28),)
635*da0073e9SAndroid Build Coastguard Worker
636*da0073e9SAndroid Build Coastguard Worker        # This MLP class with the above input is expected
637*da0073e9SAndroid Build Coastguard Worker        # to fail JIT optimizations as seen at
638*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/18903
639*da0073e9SAndroid Build Coastguard Worker        #
640*da0073e9SAndroid Build Coastguard Worker        # However, it should not raise an error during
641*da0073e9SAndroid Build Coastguard Worker        # the add_graph call and still continue.
642*da0073e9SAndroid Build Coastguard Worker        class myMLP(torch.nn.Module):
643*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
644*da0073e9SAndroid Build Coastguard Worker                super().__init__()
645*da0073e9SAndroid Build Coastguard Worker                self.input_len = 1 * 28 * 28
646*da0073e9SAndroid Build Coastguard Worker                self.fc1 = torch.nn.Linear(self.input_len, 1200)
647*da0073e9SAndroid Build Coastguard Worker                self.fc2 = torch.nn.Linear(1200, 1200)
648*da0073e9SAndroid Build Coastguard Worker                self.fc3 = torch.nn.Linear(1200, 10)
649*da0073e9SAndroid Build Coastguard Worker
650*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, update_batch_stats=True):
651*da0073e9SAndroid Build Coastguard Worker                h = torch.nn.functional.relu(
652*da0073e9SAndroid Build Coastguard Worker                    self.fc1(x.view(-1, self.input_len)))
653*da0073e9SAndroid Build Coastguard Worker                h = self.fc2(h)
654*da0073e9SAndroid Build Coastguard Worker                h = torch.nn.functional.relu(h)
655*da0073e9SAndroid Build Coastguard Worker                h = self.fc3(h)
656*da0073e9SAndroid Build Coastguard Worker                return h
657*da0073e9SAndroid Build Coastguard Worker
658*da0073e9SAndroid Build Coastguard Worker        with self.createSummaryWriter() as w:
659*da0073e9SAndroid Build Coastguard Worker            w.add_graph(myMLP(), dummy_input)
660*da0073e9SAndroid Build Coastguard Worker
661*da0073e9SAndroid Build Coastguard Worker    def test_wrong_input_size(self):
662*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError) as e_info:
663*da0073e9SAndroid Build Coastguard Worker            dummy_input = torch.rand(1, 9)
664*da0073e9SAndroid Build Coastguard Worker            model = torch.nn.Linear(3, 5)
665*da0073e9SAndroid Build Coastguard Worker            with self.createSummaryWriter() as w:
666*da0073e9SAndroid Build Coastguard Worker                w.add_graph(model, dummy_input)  # error
667*da0073e9SAndroid Build Coastguard Worker
668*da0073e9SAndroid Build Coastguard Worker    @skipIfNoTorchVision
669*da0073e9SAndroid Build Coastguard Worker    def test_torchvision_smoke(self):
670*da0073e9SAndroid Build Coastguard Worker        model_input_shapes = {
671*da0073e9SAndroid Build Coastguard Worker            'alexnet': (2, 3, 224, 224),
672*da0073e9SAndroid Build Coastguard Worker            'resnet34': (2, 3, 224, 224),
673*da0073e9SAndroid Build Coastguard Worker            'resnet152': (2, 3, 224, 224),
674*da0073e9SAndroid Build Coastguard Worker            'densenet121': (2, 3, 224, 224),
675*da0073e9SAndroid Build Coastguard Worker            'vgg16': (2, 3, 224, 224),
676*da0073e9SAndroid Build Coastguard Worker            'vgg19': (2, 3, 224, 224),
677*da0073e9SAndroid Build Coastguard Worker            'vgg16_bn': (2, 3, 224, 224),
678*da0073e9SAndroid Build Coastguard Worker            'vgg19_bn': (2, 3, 224, 224),
679*da0073e9SAndroid Build Coastguard Worker            'mobilenet_v2': (2, 3, 224, 224),
680*da0073e9SAndroid Build Coastguard Worker        }
681*da0073e9SAndroid Build Coastguard Worker        for model_name, input_shape in model_input_shapes.items():
682*da0073e9SAndroid Build Coastguard Worker            with self.createSummaryWriter() as w:
683*da0073e9SAndroid Build Coastguard Worker                model = getattr(torchvision.models, model_name)()
684*da0073e9SAndroid Build Coastguard Worker                w.add_graph(model, torch.zeros(input_shape))
685*da0073e9SAndroid Build Coastguard Worker
686*da0073e9SAndroid Build Coastguard Workerclass TestTensorBoardFigure(BaseTestCase):
687*da0073e9SAndroid Build Coastguard Worker    @skipIfNoMatplotlib
688*da0073e9SAndroid Build Coastguard Worker    def test_figure(self):
689*da0073e9SAndroid Build Coastguard Worker        writer = self.createSummaryWriter()
690*da0073e9SAndroid Build Coastguard Worker
691*da0073e9SAndroid Build Coastguard Worker        figure, axes = plt.figure(), plt.gca()
692*da0073e9SAndroid Build Coastguard Worker        circle1 = plt.Circle((0.2, 0.5), 0.2, color='r')
693*da0073e9SAndroid Build Coastguard Worker        circle2 = plt.Circle((0.8, 0.5), 0.2, color='g')
694*da0073e9SAndroid Build Coastguard Worker        axes.add_patch(circle1)
695*da0073e9SAndroid Build Coastguard Worker        axes.add_patch(circle2)
696*da0073e9SAndroid Build Coastguard Worker        plt.axis('scaled')
697*da0073e9SAndroid Build Coastguard Worker        plt.tight_layout()
698*da0073e9SAndroid Build Coastguard Worker
699*da0073e9SAndroid Build Coastguard Worker        writer.add_figure("add_figure/figure", figure, 0, close=False)
700*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(plt.fignum_exists(figure.number))
701*da0073e9SAndroid Build Coastguard Worker
702*da0073e9SAndroid Build Coastguard Worker        writer.add_figure("add_figure/figure", figure, 1)
703*da0073e9SAndroid Build Coastguard Worker        if matplotlib.__version__ != '3.3.0':
704*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(plt.fignum_exists(figure.number))
705*da0073e9SAndroid Build Coastguard Worker        else:
706*da0073e9SAndroid Build Coastguard Worker            print("Skipping fignum_exists, see https://github.com/matplotlib/matplotlib/issues/18163")
707*da0073e9SAndroid Build Coastguard Worker
708*da0073e9SAndroid Build Coastguard Worker        writer.close()
709*da0073e9SAndroid Build Coastguard Worker
710*da0073e9SAndroid Build Coastguard Worker    @skipIfNoMatplotlib
711*da0073e9SAndroid Build Coastguard Worker    def test_figure_list(self):
712*da0073e9SAndroid Build Coastguard Worker        writer = self.createSummaryWriter()
713*da0073e9SAndroid Build Coastguard Worker
714*da0073e9SAndroid Build Coastguard Worker        figures = []
715*da0073e9SAndroid Build Coastguard Worker        for i in range(5):
716*da0073e9SAndroid Build Coastguard Worker            figure = plt.figure()
717*da0073e9SAndroid Build Coastguard Worker            plt.plot([i * 1, i * 2, i * 3], label="Plot " + str(i))
718*da0073e9SAndroid Build Coastguard Worker            plt.xlabel("X")
719*da0073e9SAndroid Build Coastguard Worker            plt.xlabel("Y")
720*da0073e9SAndroid Build Coastguard Worker            plt.legend()
721*da0073e9SAndroid Build Coastguard Worker            plt.tight_layout()
722*da0073e9SAndroid Build Coastguard Worker            figures.append(figure)
723*da0073e9SAndroid Build Coastguard Worker
724*da0073e9SAndroid Build Coastguard Worker        writer.add_figure("add_figure/figure_list", figures, 0, close=False)
725*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(all(plt.fignum_exists(figure.number) is True for figure in figures))  # noqa: F812
726*da0073e9SAndroid Build Coastguard Worker
727*da0073e9SAndroid Build Coastguard Worker        writer.add_figure("add_figure/figure_list", figures, 1)
728*da0073e9SAndroid Build Coastguard Worker        if matplotlib.__version__ != '3.3.0':
729*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(all(plt.fignum_exists(figure.number) is False for figure in figures))  # noqa: F812
730*da0073e9SAndroid Build Coastguard Worker        else:
731*da0073e9SAndroid Build Coastguard Worker            print("Skipping fignum_exists, see https://github.com/matplotlib/matplotlib/issues/18163")
732*da0073e9SAndroid Build Coastguard Worker
733*da0073e9SAndroid Build Coastguard Worker        writer.close()
734*da0073e9SAndroid Build Coastguard Worker
735*da0073e9SAndroid Build Coastguard Workerclass TestTensorBoardNumpy(BaseTestCase):
736*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "Skipping on windows, see https://github.com/pytorch/pytorch/pull/109349 ")
737*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ")
738*da0073e9SAndroid Build Coastguard Worker    def test_scalar(self):
739*da0073e9SAndroid Build Coastguard Worker        res = make_np(1.1)
740*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,))
741*da0073e9SAndroid Build Coastguard Worker        res = make_np(1 << 64 - 1)  # uint64_max
742*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,))
743*da0073e9SAndroid Build Coastguard Worker        res = make_np(np.float16(1.00000087))
744*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,))
745*da0073e9SAndroid Build Coastguard Worker        res = make_np(np.float128(1.00008 + 9))
746*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,))
747*da0073e9SAndroid Build Coastguard Worker        res = make_np(np.int64(100000000000))
748*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,))
749*da0073e9SAndroid Build Coastguard Worker
750*da0073e9SAndroid Build Coastguard Worker    def test_pytorch_np_expect_fail(self):
751*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(NotImplementedError):
752*da0073e9SAndroid Build Coastguard Worker            res = make_np({'pytorch': 1.0})
753*da0073e9SAndroid Build Coastguard Worker
754*da0073e9SAndroid Build Coastguard Worker
755*da0073e9SAndroid Build Coastguard Worker
756*da0073e9SAndroid Build Coastguard Workerclass TestTensorProtoSummary(BaseTestCase):
757*da0073e9SAndroid Build Coastguard Worker    @parametrize(
758*da0073e9SAndroid Build Coastguard Worker        "tensor_type,proto_type",
759*da0073e9SAndroid Build Coastguard Worker        [
760*da0073e9SAndroid Build Coastguard Worker            (torch.float16, DataType.DT_HALF),
761*da0073e9SAndroid Build Coastguard Worker            (torch.bfloat16, DataType.DT_BFLOAT16),
762*da0073e9SAndroid Build Coastguard Worker        ],
763*da0073e9SAndroid Build Coastguard Worker    )
764*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Unsuitable test for Dynamo, behavior changes with version")
765*da0073e9SAndroid Build Coastguard Worker    def test_half_tensor_proto(self, tensor_type, proto_type):
766*da0073e9SAndroid Build Coastguard Worker        float_values = [1.0, 2.0, 3.0]
767*da0073e9SAndroid Build Coastguard Worker        actual_proto = tensor_proto(
768*da0073e9SAndroid Build Coastguard Worker            "dummy",
769*da0073e9SAndroid Build Coastguard Worker            torch.tensor(float_values, dtype=tensor_type),
770*da0073e9SAndroid Build Coastguard Worker        ).value[0].tensor
771*da0073e9SAndroid Build Coastguard Worker        self.assertSequenceEqual(
772*da0073e9SAndroid Build Coastguard Worker            [int_to_half(x) for x in actual_proto.half_val],
773*da0073e9SAndroid Build Coastguard Worker            float_values,
774*da0073e9SAndroid Build Coastguard Worker        )
775*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(actual_proto.dtype == proto_type)
776*da0073e9SAndroid Build Coastguard Worker
777*da0073e9SAndroid Build Coastguard Worker    def test_float_tensor_proto(self):
778*da0073e9SAndroid Build Coastguard Worker        float_values = [1.0, 2.0, 3.0]
779*da0073e9SAndroid Build Coastguard Worker        actual_proto = (
780*da0073e9SAndroid Build Coastguard Worker            tensor_proto("dummy", torch.tensor(float_values)).value[0].tensor
781*da0073e9SAndroid Build Coastguard Worker        )
782*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual_proto.float_val, float_values)
783*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(actual_proto.dtype == DataType.DT_FLOAT)
784*da0073e9SAndroid Build Coastguard Worker
785*da0073e9SAndroid Build Coastguard Worker    def test_int_tensor_proto(self):
786*da0073e9SAndroid Build Coastguard Worker        int_values = [1, 2, 3]
787*da0073e9SAndroid Build Coastguard Worker        actual_proto = (
788*da0073e9SAndroid Build Coastguard Worker            tensor_proto("dummy", torch.tensor(int_values, dtype=torch.int32))
789*da0073e9SAndroid Build Coastguard Worker            .value[0]
790*da0073e9SAndroid Build Coastguard Worker            .tensor
791*da0073e9SAndroid Build Coastguard Worker        )
792*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual_proto.int_val, int_values)
793*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(actual_proto.dtype == DataType.DT_INT32)
794*da0073e9SAndroid Build Coastguard Worker
795*da0073e9SAndroid Build Coastguard Worker    def test_scalar_tensor_proto(self):
796*da0073e9SAndroid Build Coastguard Worker        scalar_value = 0.1
797*da0073e9SAndroid Build Coastguard Worker        actual_proto = (
798*da0073e9SAndroid Build Coastguard Worker            tensor_proto("dummy", torch.tensor(scalar_value)).value[0].tensor
799*da0073e9SAndroid Build Coastguard Worker        )
800*da0073e9SAndroid Build Coastguard Worker        self.assertAlmostEqual(actual_proto.float_val[0], scalar_value)
801*da0073e9SAndroid Build Coastguard Worker
802*da0073e9SAndroid Build Coastguard Worker    def test_complex_tensor_proto(self):
803*da0073e9SAndroid Build Coastguard Worker        real = torch.tensor([1.0, 2.0])
804*da0073e9SAndroid Build Coastguard Worker        imag = torch.tensor([3.0, 4.0])
805*da0073e9SAndroid Build Coastguard Worker        actual_proto = (
806*da0073e9SAndroid Build Coastguard Worker            tensor_proto("dummy", torch.complex(real, imag)).value[0].tensor
807*da0073e9SAndroid Build Coastguard Worker        )
808*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual_proto.scomplex_val, [1.0, 3.0, 2.0, 4.0])
809*da0073e9SAndroid Build Coastguard Worker
810*da0073e9SAndroid Build Coastguard Worker    def test_empty_tensor_proto(self):
811*da0073e9SAndroid Build Coastguard Worker        actual_proto = tensor_proto("dummy", torch.empty(0)).value[0].tensor
812*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual_proto.float_val, [])
813*da0073e9SAndroid Build Coastguard Worker
814*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestTensorProtoSummary)
815*da0073e9SAndroid Build Coastguard Worker
816*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
817*da0073e9SAndroid Build Coastguard Worker    run_tests()
818