1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the API surface of the V1 tf.summary ops when TF2 is enabled. 16 17V1 summary ops will invoke V2 TensorBoard summary ops in eager mode. 18""" 19 20from tensorboard.summary import v2 as summary_v2 21from tensorflow.python.framework import constant_op 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import test_util 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import summary_ops_v2 27from tensorflow.python.platform import test 28from tensorflow.python.summary import summary as summary_lib 29from tensorflow.python.training import training_util 30 31 32class SummaryV2Test(test.TestCase): 33 34 @test_util.run_v2_only 35 def test_scalar_summary_v2__w_writer(self): 36 """Tests scalar v2 invocation with a v2 writer.""" 37 with test.mock.patch.object( 38 summary_v2, 'scalar', autospec=True) as mock_scalar_v2: 39 with summary_ops_v2.create_summary_file_writer( 40 self.get_temp_dir()).as_default(step=1): 41 i = constant_op.constant(2.5) 42 tensor = summary_lib.scalar('float', i) 43 # Returns empty string. 44 self.assertEqual(tensor.numpy(), b'') 45 self.assertEqual(tensor.dtype, dtypes.string) 46 mock_scalar_v2.assert_called_once_with('float', data=i, step=1) 47 48 @test_util.run_v2_only 49 def test_scalar_summary_v2__wo_writer(self): 50 """Tests scalar v2 invocation with no writer.""" 51 with self.assertWarnsRegex( 52 UserWarning, 'default summary writer not found'): 53 with test.mock.patch.object( 54 summary_v2, 'scalar', autospec=True) as mock_scalar_v2: 55 summary_lib.scalar('float', constant_op.constant(2.5)) 56 mock_scalar_v2.assert_not_called() 57 58 @test_util.run_v2_only 59 def test_scalar_summary_v2__global_step_not_set(self): 60 """Tests scalar v2 invocation when global step is not set.""" 61 with self.assertWarnsRegex(UserWarning, 'global step not set'): 62 with test.mock.patch.object( 63 summary_v2, 'scalar', autospec=True) as mock_scalar_v2: 64 with summary_ops_v2.create_summary_file_writer( 65 self.get_temp_dir()).as_default(): 66 summary_lib.scalar('float', constant_op.constant(2.5)) 67 mock_scalar_v2.assert_not_called() 68 69 @test_util.run_v2_only 70 def test_scalar_summary_v2__family(self): 71 """Tests `family` arg handling when scalar v2 is invoked.""" 72 with test.mock.patch.object( 73 summary_v2, 'scalar', autospec=True) as mock_scalar_v2: 74 with summary_ops_v2.create_summary_file_writer( 75 self.get_temp_dir()).as_default(step=1): 76 tensor = summary_lib.scalar( 77 'float', constant_op.constant(2.5), family='otter') 78 # Returns empty string. 79 self.assertEqual(tensor.numpy(), b'') 80 self.assertEqual(tensor.dtype, dtypes.string) 81 mock_scalar_v2.assert_called_once_with( 82 'otter/otter/float', data=constant_op.constant(2.5), step=1) 83 84 @test_util.run_v2_only 85 def test_scalar_summary_v2__family_w_outer_scope(self): 86 """Tests `family` arg handling when there is an outer scope.""" 87 with test.mock.patch.object( 88 summary_v2, 'scalar', autospec=True) as mock_scalar_v2: 89 with summary_ops_v2.create_summary_file_writer( 90 self.get_temp_dir()).as_default(step=1): 91 with ops.name_scope_v2('sea'): 92 tensor = summary_lib.scalar( 93 'float', constant_op.constant(3.5), family='crabnet') 94 # Returns empty string. 95 self.assertEqual(tensor.numpy(), b'') 96 self.assertEqual(tensor.dtype, dtypes.string) 97 mock_scalar_v2.assert_called_once_with( 98 'crabnet/sea/crabnet/float', data=constant_op.constant(3.5), step=1) 99 100 @test_util.run_v2_only 101 def test_scalar_summary_v2__v1_set_step(self): 102 """Tests scalar v2 invocation when v1 step is set.""" 103 global_step = training_util.create_global_step() 104 global_step.assign(1024) 105 with test.mock.patch.object( 106 summary_v2, 'scalar', autospec=True) as mock_scalar_v2: 107 with summary_ops_v2.create_summary_file_writer( 108 self.get_temp_dir()).as_default(): 109 i = constant_op.constant(2.5) 110 tensor = summary_lib.scalar('float', i) 111 # Returns empty string. 112 self.assertEqual(tensor.numpy(), b'') 113 self.assertEqual(tensor.dtype, dtypes.string) 114 mock_scalar_v2.assert_called_once_with('float', data=i, step=1024) 115 116 @test_util.run_v2_only 117 def test_image_summary_v2(self): 118 """Tests image v2 invocation.""" 119 with test.mock.patch.object( 120 summary_v2, 'image', autospec=True) as mock_image_v2: 121 with summary_ops_v2.create_summary_file_writer( 122 self.get_temp_dir()).as_default(step=2): 123 i = array_ops.ones((5, 4, 4, 3)) 124 with ops.name_scope_v2('outer'): 125 tensor = summary_lib.image('image', i, max_outputs=3, family='family') 126 # Returns empty string. 127 self.assertEqual(tensor.numpy(), b'') 128 self.assertEqual(tensor.dtype, dtypes.string) 129 mock_image_v2.assert_called_once_with( 130 'family/outer/family/image', data=i, step=2, max_outputs=3) 131 132 @test_util.run_v2_only 133 def test_histogram_summary_v2(self): 134 """Tests histogram v2 invocation.""" 135 with test.mock.patch.object( 136 summary_v2, 'histogram', autospec=True) as mock_histogram_v2: 137 with summary_ops_v2.create_summary_file_writer( 138 self.get_temp_dir()).as_default(step=3): 139 i = array_ops.ones((1024,)) 140 tensor = summary_lib.histogram('histogram', i, family='family') 141 # Returns empty string. 142 self.assertEqual(tensor.numpy(), b'') 143 self.assertEqual(tensor.dtype, dtypes.string) 144 mock_histogram_v2.assert_called_once_with( 145 'family/family/histogram', data=i, step=3) 146 147 @test_util.run_v2_only 148 def test_audio_summary_v2(self): 149 """Tests audio v2 invocation.""" 150 with test.mock.patch.object( 151 summary_v2, 'audio', autospec=True) as mock_audio_v2: 152 with summary_ops_v2.create_summary_file_writer( 153 self.get_temp_dir()).as_default(step=10): 154 i = array_ops.ones((5, 3, 4)) 155 with ops.name_scope_v2('dolphin'): 156 tensor = summary_lib.audio('wave', i, 0.2, max_outputs=3) 157 # Returns empty string. 158 self.assertEqual(tensor.numpy(), b'') 159 self.assertEqual(tensor.dtype, dtypes.string) 160 mock_audio_v2.assert_called_once_with( 161 'dolphin/wave', data=i, sample_rate=0.2, step=10, max_outputs=3) 162 163 @test_util.run_v2_only 164 def test_audio_summary_v2__2d_tensor(self): 165 """Tests audio v2 invocation with 2-D tensor input.""" 166 with test.mock.patch.object( 167 summary_v2, 'audio', autospec=True) as mock_audio_v2: 168 with summary_ops_v2.create_summary_file_writer( 169 self.get_temp_dir()).as_default(step=11): 170 input_2d = array_ops.ones((5, 3)) 171 tensor = summary_lib.audio('wave', input_2d, 0.2, max_outputs=3) 172 173 # Returns empty string. 174 self.assertEqual(tensor.numpy(), b'') 175 self.assertEqual(tensor.dtype, dtypes.string) 176 177 mock_audio_v2.assert_called_once_with( 178 'wave', data=test.mock.ANY, sample_rate=0.2, step=11, max_outputs=3) 179 input_3d = array_ops.ones((5, 3, 1)) # 3-D input tensor 180 self.assertAllEqual(mock_audio_v2.call_args[1]['data'], input_3d) 181 182 @test_util.run_v2_only 183 def test_text_summary_v2(self): 184 """Tests text v2 invocation.""" 185 with test.mock.patch.object( 186 summary_v2, 'text', autospec=True) as mock_text_v2: 187 with summary_ops_v2.create_summary_file_writer( 188 self.get_temp_dir()).as_default(step=22): 189 i = constant_op.constant('lorem ipsum', dtype=dtypes.string) 190 tensor = summary_lib.text('text', i) 191 # Returns empty string. 192 self.assertEqual(tensor.numpy(), b'') 193 self.assertEqual(tensor.dtype, dtypes.string) 194 mock_text_v2.assert_called_once_with('text', data=i, step=22) 195 196 197if __name__ == '__main__': 198 test.main() 199