1# Copyright 2021 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import tensorflow as tf 16 17from fcp.protos import federated_api_pb2 18from fcp.tensorflow.task_eligibility_info_ops import create_task_eligibility_info 19 20 21class TaskEligibilityInfoOpsTest(tf.test.TestCase): 22 23 def test_create_task_eligibility_info_succeeds(self): 24 # Run the op and parse its result into the expected proto type. 25 actual_serialized_value = create_task_eligibility_info( 26 version=555, 27 task_names=['foo_task', 'bar_task'], 28 task_weights=[123.456, 789.012]) 29 tf.debugging.assert_scalar(actual_serialized_value) 30 tf.debugging.assert_type(actual_serialized_value, tf.string) 31 32 actual_value = federated_api_pb2.TaskEligibilityInfo() 33 # Note: the .numpy() call converts the string tensor to a Python string we 34 # can parse the proto from. 35 actual_value.ParseFromString(actual_serialized_value.numpy()) 36 37 # Ensure the resulting proto contains the expected data. 38 expected_value = federated_api_pb2.TaskEligibilityInfo( 39 version=555, 40 task_weights=[ 41 federated_api_pb2.TaskWeight(task_name='foo_task', weight=123.456), 42 federated_api_pb2.TaskWeight(task_name='bar_task', weight=789.012) 43 ]) 44 assert actual_value == expected_value 45 46 def test_create_task_eligibility_info_empty_task_list_succeeds(self): 47 """Tests that an empty `task_names` input is allowed & handled correctly.""" 48 actual_serialized_value = create_task_eligibility_info( 49 version=555, task_names=[], task_weights=[]) 50 actual_value = federated_api_pb2.TaskEligibilityInfo() 51 actual_value.ParseFromString(actual_serialized_value.numpy()) 52 53 # Ensure the resulting proto contains the expected data. 54 expected_value = federated_api_pb2.TaskEligibilityInfo(version=555) 55 assert actual_value == expected_value 56 57 def test_create_task_eligibility_info_non_scalar_version_raises_error(self): 58 with self.assertRaises(tf.errors.InvalidArgumentError): 59 create_task_eligibility_info( 60 version=[555], task_names=['foo_task'], task_weights=[123.456]) 61 62 def test_create_task_eligibility_info_non_vector_task_names_list_raises_error( 63 self): 64 with self.assertRaises(tf.errors.InvalidArgumentError): 65 create_task_eligibility_info( 66 version=555, task_names=[['foo_task']], task_weights=[123.456]) 67 68 def test_create_task_eligibility_info_non_vector_task_weights_list_raises_error( 69 self): 70 with self.assertRaises(tf.errors.InvalidArgumentError): 71 create_task_eligibility_info( 72 version=555, task_names=['foo_task'], task_weights=[[123.456]]) 73 74 def test_create_task_eligibility_info_differing_names_weights_length_raises_error( 75 self): 76 with self.assertRaises(tf.errors.InvalidArgumentError): 77 create_task_eligibility_info( 78 version=555, task_names=['foo_task', 'bar_task'], task_weights=[123]) 79 80 def test_create_task_eligibility_info_invalid_task_names_type_raises_error( 81 self): 82 with self.assertRaises(TypeError): 83 create_task_eligibility_info( 84 version=555, task_names=[111, 222], task_weights=[123.456, 789.012]) 85 86 def test_create_task_eligibility_info_invalid_task_weights_type_raises_error( 87 self): 88 with self.assertRaises(TypeError): 89 create_task_eligibility_info( 90 version=555, 91 task_names=['foo_task', 'bar_task'], 92 task_weights=['hello', 'world']) 93 94 95if __name__ == '__main__': 96 tf.test.main() 97