xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/task_eligibility_info_ops_test.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2021 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import tensorflow as tf
16
17from fcp.protos import federated_api_pb2
18from fcp.tensorflow.task_eligibility_info_ops import create_task_eligibility_info
19
20
21class TaskEligibilityInfoOpsTest(tf.test.TestCase):
22
23  def test_create_task_eligibility_info_succeeds(self):
24    # Run the op and parse its result into the expected proto type.
25    actual_serialized_value = create_task_eligibility_info(
26        version=555,
27        task_names=['foo_task', 'bar_task'],
28        task_weights=[123.456, 789.012])
29    tf.debugging.assert_scalar(actual_serialized_value)
30    tf.debugging.assert_type(actual_serialized_value, tf.string)
31
32    actual_value = federated_api_pb2.TaskEligibilityInfo()
33    # Note: the .numpy() call converts the string tensor to a Python string we
34    # can parse the proto from.
35    actual_value.ParseFromString(actual_serialized_value.numpy())
36
37    # Ensure the resulting proto contains the expected data.
38    expected_value = federated_api_pb2.TaskEligibilityInfo(
39        version=555,
40        task_weights=[
41            federated_api_pb2.TaskWeight(task_name='foo_task', weight=123.456),
42            federated_api_pb2.TaskWeight(task_name='bar_task', weight=789.012)
43        ])
44    assert actual_value == expected_value
45
46  def test_create_task_eligibility_info_empty_task_list_succeeds(self):
47    """Tests that an empty `task_names` input is allowed & handled correctly."""
48    actual_serialized_value = create_task_eligibility_info(
49        version=555, task_names=[], task_weights=[])
50    actual_value = federated_api_pb2.TaskEligibilityInfo()
51    actual_value.ParseFromString(actual_serialized_value.numpy())
52
53    # Ensure the resulting proto contains the expected data.
54    expected_value = federated_api_pb2.TaskEligibilityInfo(version=555)
55    assert actual_value == expected_value
56
57  def test_create_task_eligibility_info_non_scalar_version_raises_error(self):
58    with self.assertRaises(tf.errors.InvalidArgumentError):
59      create_task_eligibility_info(
60          version=[555], task_names=['foo_task'], task_weights=[123.456])
61
62  def test_create_task_eligibility_info_non_vector_task_names_list_raises_error(
63      self):
64    with self.assertRaises(tf.errors.InvalidArgumentError):
65      create_task_eligibility_info(
66          version=555, task_names=[['foo_task']], task_weights=[123.456])
67
68  def test_create_task_eligibility_info_non_vector_task_weights_list_raises_error(
69      self):
70    with self.assertRaises(tf.errors.InvalidArgumentError):
71      create_task_eligibility_info(
72          version=555, task_names=['foo_task'], task_weights=[[123.456]])
73
74  def test_create_task_eligibility_info_differing_names_weights_length_raises_error(
75      self):
76    with self.assertRaises(tf.errors.InvalidArgumentError):
77      create_task_eligibility_info(
78          version=555, task_names=['foo_task', 'bar_task'], task_weights=[123])
79
80  def test_create_task_eligibility_info_invalid_task_names_type_raises_error(
81      self):
82    with self.assertRaises(TypeError):
83      create_task_eligibility_info(
84          version=555, task_names=[111, 222], task_weights=[123.456, 789.012])
85
86  def test_create_task_eligibility_info_invalid_task_weights_type_raises_error(
87      self):
88    with self.assertRaises(TypeError):
89      create_task_eligibility_info(
90          version=555,
91          task_names=['foo_task', 'bar_task'],
92          task_weights=['hello', 'world'])
93
94
95if __name__ == '__main__':
96  tf.test.main()
97