xref: /aosp_15_r20/external/federated-compute/fcp/demo/federated_data_source_test.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for federated_data_source."""
15
16from absl.testing import absltest
17import tensorflow as tf
18import tensorflow_federated as tff
19
20from fcp.demo import federated_data_source as fds
21from fcp.protos import plan_pb2
22from fcp.protos.federatedcompute import eligibility_eval_tasks_pb2
23
24_TaskAssignmentMode = (
25    eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode
26)
27
28POPULATION_NAME = 'test/name'
29EXAMPLE_SELECTOR = plan_pb2.ExampleSelector(collection_uri='app://test')
30
31
32class FederatedDataSourceTest(absltest.TestCase):
33
34  def test_invalid_population_name(self):
35    with self.assertRaisesRegex(ValueError, r'population_name must match ".+"'):
36      fds.FederatedDataSource('^invalid^', EXAMPLE_SELECTOR)
37
38  def test_population_name(self):
39    ds = fds.FederatedDataSource(POPULATION_NAME, EXAMPLE_SELECTOR)
40    self.assertEqual(ds.population_name, POPULATION_NAME)
41
42  def test_example_selector(self):
43    ds = fds.FederatedDataSource(POPULATION_NAME, EXAMPLE_SELECTOR)
44    self.assertEqual(ds.example_selector, EXAMPLE_SELECTOR)
45
46  def test_default_task_assignment_mode(self):
47    ds = fds.FederatedDataSource(POPULATION_NAME, EXAMPLE_SELECTOR)
48    self.assertEqual(
49        ds.task_assignment_mode, _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE
50    )
51
52  def test_task_assignment_mode(self):
53    ds = fds.FederatedDataSource(
54        POPULATION_NAME,
55        EXAMPLE_SELECTOR,
56        task_assignment_mode=_TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE,
57    )
58    self.assertEqual(
59        ds.task_assignment_mode,
60        _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE,
61    )
62
63  def test_federated_type(self):
64    ds = fds.FederatedDataSource(POPULATION_NAME, EXAMPLE_SELECTOR)
65    self.assertEqual(
66        ds.federated_type,
67        tff.FederatedType(tff.SequenceType(tf.string), tff.CLIENTS))
68
69  def test_federated_type_nested(self):
70    nested_example_selector = {
71        'a': EXAMPLE_SELECTOR,
72        'b': EXAMPLE_SELECTOR,
73        'c': {
74            '1': EXAMPLE_SELECTOR,
75            '2': EXAMPLE_SELECTOR
76        },
77    }
78    ds = fds.FederatedDataSource(POPULATION_NAME, nested_example_selector)
79    self.assertEqual(
80        ds.federated_type,
81        tff.FederatedType(
82            tff.StructType([
83                ('a', tff.SequenceType(tf.string)),
84                ('b', tff.SequenceType(tf.string)),
85                ('c',
86                 tff.StructType([
87                     ('1', tff.SequenceType(tf.string)),
88                     ('2', tff.SequenceType(tf.string)),
89                 ])),
90            ]), tff.CLIENTS))
91
92  def test_capabilities(self):
93    ds = fds.FederatedDataSource(POPULATION_NAME, EXAMPLE_SELECTOR)
94    self.assertListEqual(ds.capabilities,
95                         [tff.program.Capability.SUPPORTS_REUSE])
96
97  def test_iterator_federated_type(self):
98    ds = fds.FederatedDataSource(POPULATION_NAME, EXAMPLE_SELECTOR)
99    self.assertEqual(ds.iterator().federated_type, ds.federated_type)
100
101  def test_iterator_select(self):
102    ds = fds.FederatedDataSource(
103        POPULATION_NAME,
104        EXAMPLE_SELECTOR,
105        _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE,
106    )
107    self.assertEqual(
108        ds.iterator().select(10),
109        fds.DataSelectionConfig(
110            POPULATION_NAME,
111            EXAMPLE_SELECTOR,
112            _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE,
113            10,
114        ),
115    )
116
117  def test_iterator_select_with_invalid_num_clients(self):
118    ds = fds.FederatedDataSource(POPULATION_NAME, EXAMPLE_SELECTOR)
119    with self.assertRaisesRegex(ValueError, 'num_clients must be positive'):
120      ds.iterator().select(num_clients=None)
121    with self.assertRaisesRegex(ValueError, 'num_clients must be positive'):
122      ds.iterator().select(num_clients=-5)
123    with self.assertRaisesRegex(ValueError, 'num_clients must be positive'):
124      ds.iterator().select(num_clients=0)
125
126
127if __name__ == '__main__':
128  absltest.main()
129