xref: /aosp_15_r20/external/tensorflow/tensorflow/python/summary/summary_v2_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the API surface of the V1 tf.summary ops when TF2 is enabled.
16
17V1 summary ops will invoke V2 TensorBoard summary ops in eager mode.
18"""
19
20from tensorboard.summary import v2 as summary_v2
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import test_util
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import summary_ops_v2
27from tensorflow.python.platform import test
28from tensorflow.python.summary import summary as summary_lib
29from tensorflow.python.training import training_util
30
31
32class SummaryV2Test(test.TestCase):
33
34  @test_util.run_v2_only
35  def test_scalar_summary_v2__w_writer(self):
36    """Tests scalar v2 invocation with a v2 writer."""
37    with test.mock.patch.object(
38        summary_v2, 'scalar', autospec=True) as mock_scalar_v2:
39      with summary_ops_v2.create_summary_file_writer(
40          self.get_temp_dir()).as_default(step=1):
41        i = constant_op.constant(2.5)
42        tensor = summary_lib.scalar('float', i)
43    # Returns empty string.
44    self.assertEqual(tensor.numpy(), b'')
45    self.assertEqual(tensor.dtype, dtypes.string)
46    mock_scalar_v2.assert_called_once_with('float', data=i, step=1)
47
48  @test_util.run_v2_only
49  def test_scalar_summary_v2__wo_writer(self):
50    """Tests scalar v2 invocation with no writer."""
51    with self.assertWarnsRegex(
52        UserWarning, 'default summary writer not found'):
53      with test.mock.patch.object(
54          summary_v2, 'scalar', autospec=True) as mock_scalar_v2:
55        summary_lib.scalar('float', constant_op.constant(2.5))
56    mock_scalar_v2.assert_not_called()
57
58  @test_util.run_v2_only
59  def test_scalar_summary_v2__global_step_not_set(self):
60    """Tests scalar v2 invocation when global step is not set."""
61    with self.assertWarnsRegex(UserWarning, 'global step not set'):
62      with test.mock.patch.object(
63          summary_v2, 'scalar', autospec=True) as mock_scalar_v2:
64        with summary_ops_v2.create_summary_file_writer(
65            self.get_temp_dir()).as_default():
66          summary_lib.scalar('float', constant_op.constant(2.5))
67    mock_scalar_v2.assert_not_called()
68
69  @test_util.run_v2_only
70  def test_scalar_summary_v2__family(self):
71    """Tests `family` arg handling when scalar v2 is invoked."""
72    with test.mock.patch.object(
73        summary_v2, 'scalar', autospec=True) as mock_scalar_v2:
74      with summary_ops_v2.create_summary_file_writer(
75          self.get_temp_dir()).as_default(step=1):
76        tensor = summary_lib.scalar(
77            'float', constant_op.constant(2.5), family='otter')
78    # Returns empty string.
79    self.assertEqual(tensor.numpy(), b'')
80    self.assertEqual(tensor.dtype, dtypes.string)
81    mock_scalar_v2.assert_called_once_with(
82        'otter/otter/float', data=constant_op.constant(2.5), step=1)
83
84  @test_util.run_v2_only
85  def test_scalar_summary_v2__family_w_outer_scope(self):
86    """Tests `family` arg handling when there is an outer scope."""
87    with test.mock.patch.object(
88        summary_v2, 'scalar', autospec=True) as mock_scalar_v2:
89      with summary_ops_v2.create_summary_file_writer(
90          self.get_temp_dir()).as_default(step=1):
91        with ops.name_scope_v2('sea'):
92          tensor = summary_lib.scalar(
93              'float', constant_op.constant(3.5), family='crabnet')
94    # Returns empty string.
95    self.assertEqual(tensor.numpy(), b'')
96    self.assertEqual(tensor.dtype, dtypes.string)
97    mock_scalar_v2.assert_called_once_with(
98        'crabnet/sea/crabnet/float', data=constant_op.constant(3.5), step=1)
99
100  @test_util.run_v2_only
101  def test_scalar_summary_v2__v1_set_step(self):
102    """Tests scalar v2 invocation when v1 step is set."""
103    global_step = training_util.create_global_step()
104    global_step.assign(1024)
105    with test.mock.patch.object(
106        summary_v2, 'scalar', autospec=True) as mock_scalar_v2:
107      with summary_ops_v2.create_summary_file_writer(
108          self.get_temp_dir()).as_default():
109        i = constant_op.constant(2.5)
110        tensor = summary_lib.scalar('float', i)
111    # Returns empty string.
112    self.assertEqual(tensor.numpy(), b'')
113    self.assertEqual(tensor.dtype, dtypes.string)
114    mock_scalar_v2.assert_called_once_with('float', data=i, step=1024)
115
116  @test_util.run_v2_only
117  def test_image_summary_v2(self):
118    """Tests image v2 invocation."""
119    with test.mock.patch.object(
120        summary_v2, 'image', autospec=True) as mock_image_v2:
121      with summary_ops_v2.create_summary_file_writer(
122          self.get_temp_dir()).as_default(step=2):
123        i = array_ops.ones((5, 4, 4, 3))
124        with ops.name_scope_v2('outer'):
125          tensor = summary_lib.image('image', i, max_outputs=3, family='family')
126    # Returns empty string.
127    self.assertEqual(tensor.numpy(), b'')
128    self.assertEqual(tensor.dtype, dtypes.string)
129    mock_image_v2.assert_called_once_with(
130        'family/outer/family/image', data=i, step=2, max_outputs=3)
131
132  @test_util.run_v2_only
133  def test_histogram_summary_v2(self):
134    """Tests histogram v2 invocation."""
135    with test.mock.patch.object(
136        summary_v2, 'histogram', autospec=True) as mock_histogram_v2:
137      with summary_ops_v2.create_summary_file_writer(
138          self.get_temp_dir()).as_default(step=3):
139        i = array_ops.ones((1024,))
140        tensor = summary_lib.histogram('histogram', i, family='family')
141    # Returns empty string.
142    self.assertEqual(tensor.numpy(), b'')
143    self.assertEqual(tensor.dtype, dtypes.string)
144    mock_histogram_v2.assert_called_once_with(
145        'family/family/histogram', data=i, step=3)
146
147  @test_util.run_v2_only
148  def test_audio_summary_v2(self):
149    """Tests audio v2 invocation."""
150    with test.mock.patch.object(
151        summary_v2, 'audio', autospec=True) as mock_audio_v2:
152      with summary_ops_v2.create_summary_file_writer(
153          self.get_temp_dir()).as_default(step=10):
154        i = array_ops.ones((5, 3, 4))
155        with ops.name_scope_v2('dolphin'):
156          tensor = summary_lib.audio('wave', i, 0.2, max_outputs=3)
157    # Returns empty string.
158    self.assertEqual(tensor.numpy(), b'')
159    self.assertEqual(tensor.dtype, dtypes.string)
160    mock_audio_v2.assert_called_once_with(
161        'dolphin/wave', data=i, sample_rate=0.2, step=10, max_outputs=3)
162
163  @test_util.run_v2_only
164  def test_audio_summary_v2__2d_tensor(self):
165    """Tests audio v2 invocation with 2-D tensor input."""
166    with test.mock.patch.object(
167        summary_v2, 'audio', autospec=True) as mock_audio_v2:
168      with summary_ops_v2.create_summary_file_writer(
169          self.get_temp_dir()).as_default(step=11):
170        input_2d = array_ops.ones((5, 3))
171        tensor = summary_lib.audio('wave', input_2d, 0.2, max_outputs=3)
172
173    # Returns empty string.
174    self.assertEqual(tensor.numpy(), b'')
175    self.assertEqual(tensor.dtype, dtypes.string)
176
177    mock_audio_v2.assert_called_once_with(
178        'wave', data=test.mock.ANY, sample_rate=0.2, step=11, max_outputs=3)
179    input_3d = array_ops.ones((5, 3, 1))  # 3-D input tensor
180    self.assertAllEqual(mock_audio_v2.call_args[1]['data'], input_3d)
181
182  @test_util.run_v2_only
183  def test_text_summary_v2(self):
184    """Tests text v2 invocation."""
185    with test.mock.patch.object(
186        summary_v2, 'text', autospec=True) as mock_text_v2:
187      with summary_ops_v2.create_summary_file_writer(
188          self.get_temp_dir()).as_default(step=22):
189        i = constant_op.constant('lorem ipsum', dtype=dtypes.string)
190        tensor = summary_lib.text('text', i)
191    # Returns empty string.
192    self.assertEqual(tensor.numpy(), b'')
193    self.assertEqual(tensor.dtype, dtypes.string)
194    mock_text_v2.assert_called_once_with('text', data=i, step=22)
195
196
197if __name__ == '__main__':
198  test.main()
199