1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC 2*14675a02SAndroid Build Coastguard Worker# 3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License"); 4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License. 5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at 6*14675a02SAndroid Build Coastguard Worker# 7*14675a02SAndroid Build Coastguard Worker# http://www.apache.org/licenses/LICENSE-2.0 8*14675a02SAndroid Build Coastguard Worker# 9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software 10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS, 11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and 13*14675a02SAndroid Build Coastguard Worker# limitations under the License. 14*14675a02SAndroid Build Coastguard Worker"""Tests for test_utils.""" 15*14675a02SAndroid Build Coastguard Worker 16*14675a02SAndroid Build Coastguard Workerfrom absl.testing import absltest 17*14675a02SAndroid Build Coastguard Worker 18*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf 19*14675a02SAndroid Build Coastguard Worker 20*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import test_utils 21*14675a02SAndroid Build Coastguard Worker 22*14675a02SAndroid Build Coastguard Worker 23*14675a02SAndroid Build Coastguard Workerclass TestUtilsTest(absltest.TestCase): 24*14675a02SAndroid Build Coastguard Worker 25*14675a02SAndroid Build Coastguard Worker def test_create_checkpoint(self): 26*14675a02SAndroid Build Coastguard Worker checkpoint = test_utils.create_checkpoint({ 27*14675a02SAndroid Build Coastguard Worker 'int': 3, 28*14675a02SAndroid Build Coastguard Worker 'str': 'test', 29*14675a02SAndroid Build Coastguard Worker 'list': [1, 2, 3], 30*14675a02SAndroid Build Coastguard Worker }) 31*14675a02SAndroid Build Coastguard Worker self.assertEqual( 32*14675a02SAndroid Build Coastguard Worker test_utils.read_tensor_from_checkpoint(checkpoint, 'int', tf.int32), 3) 33*14675a02SAndroid Build Coastguard Worker self.assertEqual( 34*14675a02SAndroid Build Coastguard Worker test_utils.read_tensor_from_checkpoint(checkpoint, 'str', tf.string), 35*14675a02SAndroid Build Coastguard Worker b'test') 36*14675a02SAndroid Build Coastguard Worker self.assertListEqual( 37*14675a02SAndroid Build Coastguard Worker test_utils.read_tensor_from_checkpoint(checkpoint, 'list', 38*14675a02SAndroid Build Coastguard Worker tf.int32).tolist(), [1, 2, 3]) 39*14675a02SAndroid Build Coastguard Worker 40*14675a02SAndroid Build Coastguard Worker def test_read_from_checkpoint_not_found(self): 41*14675a02SAndroid Build Coastguard Worker checkpoint = test_utils.create_checkpoint({'int': 3}) 42*14675a02SAndroid Build Coastguard Worker with self.assertRaises(Exception): 43*14675a02SAndroid Build Coastguard Worker test_utils.read_tensor_from_checkpoint(checkpoint, 'str', tf.string) 44*14675a02SAndroid Build Coastguard Worker 45*14675a02SAndroid Build Coastguard Worker def test_read_from_checkpoint_wrong_type(self): 46*14675a02SAndroid Build Coastguard Worker checkpoint = test_utils.create_checkpoint({'int': 3}) 47*14675a02SAndroid Build Coastguard Worker with self.assertRaises(Exception): 48*14675a02SAndroid Build Coastguard Worker test_utils.read_tensor_from_checkpoint(checkpoint, 'int', tf.string) 49*14675a02SAndroid Build Coastguard Worker 50*14675a02SAndroid Build Coastguard Worker 51*14675a02SAndroid Build Coastguard Workerif __name__ == '__main__': 52*14675a02SAndroid Build Coastguard Worker absltest.main() 53