1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for federated_data_source.""" 15 16from absl.testing import absltest 17import tensorflow as tf 18import tensorflow_federated as tff 19 20from fcp.demo import federated_data_source as fds 21from fcp.protos import plan_pb2 22from fcp.protos.federatedcompute import eligibility_eval_tasks_pb2 23 24_TaskAssignmentMode = ( 25 eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode 26) 27 28POPULATION_NAME = 'test/name' 29EXAMPLE_SELECTOR = plan_pb2.ExampleSelector(collection_uri='app://test') 30 31 32class FederatedDataSourceTest(absltest.TestCase): 33 34 def test_invalid_population_name(self): 35 with self.assertRaisesRegex(ValueError, r'population_name must match ".+"'): 36 fds.FederatedDataSource('^invalid^', EXAMPLE_SELECTOR) 37 38 def test_population_name(self): 39 ds = fds.FederatedDataSource(POPULATION_NAME, EXAMPLE_SELECTOR) 40 self.assertEqual(ds.population_name, POPULATION_NAME) 41 42 def test_example_selector(self): 43 ds = fds.FederatedDataSource(POPULATION_NAME, EXAMPLE_SELECTOR) 44 self.assertEqual(ds.example_selector, EXAMPLE_SELECTOR) 45 46 def test_default_task_assignment_mode(self): 47 ds = fds.FederatedDataSource(POPULATION_NAME, EXAMPLE_SELECTOR) 48 self.assertEqual( 49 ds.task_assignment_mode, _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE 50 ) 51 52 def test_task_assignment_mode(self): 53 ds = fds.FederatedDataSource( 54 POPULATION_NAME, 55 EXAMPLE_SELECTOR, 56 task_assignment_mode=_TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE, 57 ) 58 self.assertEqual( 59 ds.task_assignment_mode, 60 _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE, 61 ) 62 63 def test_federated_type(self): 64 ds = fds.FederatedDataSource(POPULATION_NAME, EXAMPLE_SELECTOR) 65 self.assertEqual( 66 ds.federated_type, 67 tff.FederatedType(tff.SequenceType(tf.string), tff.CLIENTS)) 68 69 def test_federated_type_nested(self): 70 nested_example_selector = { 71 'a': EXAMPLE_SELECTOR, 72 'b': EXAMPLE_SELECTOR, 73 'c': { 74 '1': EXAMPLE_SELECTOR, 75 '2': EXAMPLE_SELECTOR 76 }, 77 } 78 ds = fds.FederatedDataSource(POPULATION_NAME, nested_example_selector) 79 self.assertEqual( 80 ds.federated_type, 81 tff.FederatedType( 82 tff.StructType([ 83 ('a', tff.SequenceType(tf.string)), 84 ('b', tff.SequenceType(tf.string)), 85 ('c', 86 tff.StructType([ 87 ('1', tff.SequenceType(tf.string)), 88 ('2', tff.SequenceType(tf.string)), 89 ])), 90 ]), tff.CLIENTS)) 91 92 def test_capabilities(self): 93 ds = fds.FederatedDataSource(POPULATION_NAME, EXAMPLE_SELECTOR) 94 self.assertListEqual(ds.capabilities, 95 [tff.program.Capability.SUPPORTS_REUSE]) 96 97 def test_iterator_federated_type(self): 98 ds = fds.FederatedDataSource(POPULATION_NAME, EXAMPLE_SELECTOR) 99 self.assertEqual(ds.iterator().federated_type, ds.federated_type) 100 101 def test_iterator_select(self): 102 ds = fds.FederatedDataSource( 103 POPULATION_NAME, 104 EXAMPLE_SELECTOR, 105 _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE, 106 ) 107 self.assertEqual( 108 ds.iterator().select(10), 109 fds.DataSelectionConfig( 110 POPULATION_NAME, 111 EXAMPLE_SELECTOR, 112 _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE, 113 10, 114 ), 115 ) 116 117 def test_iterator_select_with_invalid_num_clients(self): 118 ds = fds.FederatedDataSource(POPULATION_NAME, EXAMPLE_SELECTOR) 119 with self.assertRaisesRegex(ValueError, 'num_clients must be positive'): 120 ds.iterator().select(num_clients=None) 121 with self.assertRaisesRegex(ValueError, 'num_clients must be positive'): 122 ds.iterator().select(num_clients=-5) 123 with self.assertRaisesRegex(ValueError, 'num_clients must be positive'): 124 ds.iterator().select(num_clients=0) 125 126 127if __name__ == '__main__': 128 absltest.main() 129