1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for test_utils.""" 15 16from absl.testing import absltest 17 18import tensorflow as tf 19 20from fcp.demo import test_utils 21 22 23class TestUtilsTest(absltest.TestCase): 24 25 def test_create_checkpoint(self): 26 checkpoint = test_utils.create_checkpoint({ 27 'int': 3, 28 'str': 'test', 29 'list': [1, 2, 3], 30 }) 31 self.assertEqual( 32 test_utils.read_tensor_from_checkpoint(checkpoint, 'int', tf.int32), 3) 33 self.assertEqual( 34 test_utils.read_tensor_from_checkpoint(checkpoint, 'str', tf.string), 35 b'test') 36 self.assertListEqual( 37 test_utils.read_tensor_from_checkpoint(checkpoint, 'list', 38 tf.int32).tolist(), [1, 2, 3]) 39 40 def test_read_from_checkpoint_not_found(self): 41 checkpoint = test_utils.create_checkpoint({'int': 3}) 42 with self.assertRaises(Exception): 43 test_utils.read_tensor_from_checkpoint(checkpoint, 'str', tf.string) 44 45 def test_read_from_checkpoint_wrong_type(self): 46 checkpoint = test_utils.create_checkpoint({'int': 3}) 47 with self.assertRaises(Exception): 48 test_utils.read_tensor_from_checkpoint(checkpoint, 'int', tf.string) 49 50 51if __name__ == '__main__': 52 absltest.main() 53