1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for federated_computation.""" 15 16from unittest import mock 17 18from absl.testing import absltest 19import tensorflow as tf 20import tensorflow_federated as tff 21 22from fcp.demo import federated_computation as fc 23 24 25@tff.tf_computation(tf.int32, tf.int32) 26def add_values(x, y): 27 return x + y 28 29 30@tff.federated_computation( 31 tff.type_at_server(tf.int32), 32 tff.type_at_clients(tff.SequenceType(tf.string))) 33def count_clients(state, client_data): 34 """Example TFF computation that counts clients.""" 35 del client_data 36 client_value = tff.federated_value(1, tff.CLIENTS) 37 aggregated_count = tff.federated_sum(client_value) 38 metrics = tff.federated_value(tff.structure.Struct(()), tff.SERVER) 39 return tff.federated_map(add_values, (state, aggregated_count)), metrics 40 41 42@tff.federated_computation( 43 tff.type_at_server(tf.int32), 44 tff.type_at_clients(tff.SequenceType(tf.string))) 45def count_examples(state, client_data): 46 """Example TFF computation that counts client examples.""" 47 48 @tff.tf_computation 49 def client_work(client_data): 50 return client_data.reduce(0, lambda x, _: x + 1) 51 52 client_counts = tff.federated_map(client_work, client_data) 53 aggregated_count = tff.federated_sum(client_counts) 54 metrics = tff.federated_value(tff.structure.Struct(()), tff.SERVER) 55 return tff.federated_map(add_values, (state, aggregated_count)), metrics 56 57 58class FederatedComputationTest(absltest.TestCase): 59 60 def test_invalid_name(self): 61 with self.assertRaisesRegex(ValueError, r'name must match ".+"'): 62 fc.FederatedComputation(count_clients, name='^invalid^') 63 64 def test_incompatible_computation(self): 65 # This function doesn't have the return value structure required for MRF. 66 @tff.federated_computation(tff.type_at_server(tf.int32)) 67 def add_one(value): 68 return value + tff.federated_value(1, tff.SERVER) 69 70 with self.assertRaises(TypeError): 71 fc.FederatedComputation(add_one, name='comp') 72 73 @tff.test.with_context( 74 tff.backends.test.create_sync_test_cpp_execution_context 75 ) 76 def test_map_reduce_form(self): 77 comp1 = fc.FederatedComputation(count_clients, name='comp1') 78 comp2 = fc.FederatedComputation(count_examples, name='comp2') 79 self.assertNotEqual(comp1.map_reduce_form, comp2.map_reduce_form) 80 81 # While we treat the MRF contents as an implementation detail, we can verify 82 # the invocation results of the corresponding computation. 83 # comp1 should return the number of clients. 84 self.assertEqual( 85 tff.backends.mapreduce.get_computation_for_map_reduce_form( 86 comp1.map_reduce_form 87 )(0, [['', '']] * 3), 88 (3, ()), 89 ) 90 # comp2 should return the number of examples across all clients. 91 self.assertEqual( 92 tff.backends.mapreduce.get_computation_for_map_reduce_form( 93 comp2.map_reduce_form)(0, [['', '']] * 3), (6, ())) 94 95 @tff.test.with_context( 96 tff.backends.native.create_sync_local_cpp_execution_context 97 ) 98 def test_distribute_aggregate_form(self): 99 comp1 = fc.FederatedComputation(count_clients, name='comp1') 100 comp2 = fc.FederatedComputation(count_examples, name='comp2') 101 self.assertNotEqual( 102 comp1.distribute_aggregate_form, comp2.distribute_aggregate_form 103 ) 104 105 # While we treat the DAF contents as an implementation detail, we can verify 106 # the invocation results of the corresponding computation. 107 # comp1 should return the number of clients. 108 self.assertEqual( 109 tff.backends.mapreduce.get_computation_for_distribute_aggregate_form( 110 comp1.distribute_aggregate_form 111 )(0, [['', '']] * 3), 112 (3, ()), 113 ) 114 # comp2 should return the number of examples across all clients. 115 self.assertEqual( 116 tff.backends.mapreduce.get_computation_for_distribute_aggregate_form( 117 comp2.distribute_aggregate_form 118 )(0, [['', '']] * 3), 119 (6, ()), 120 ) 121 122 def test_wrapped_computation(self): 123 comp = fc.FederatedComputation(count_clients, name='comp') 124 self.assertEqual(comp.wrapped_computation, count_clients) 125 126 def test_name(self): 127 comp = fc.FederatedComputation(count_clients, name='comp') 128 self.assertEqual(comp.name, 'comp') 129 130 def test_type_signature(self): 131 comp = fc.FederatedComputation(count_clients, name='comp') 132 self.assertEqual(comp.type_signature, count_clients.type_signature) 133 134 def test_call(self): 135 comp = fc.FederatedComputation(count_clients, name='comp') 136 ctx = mock.create_autospec(tff.program.FederatedContext, instance=True) 137 ctx.invoke.return_value = 1234 138 with tff.framework.get_context_stack().install(ctx): 139 self.assertEqual(comp(1, 2, 3, kw1='a', kw2='b'), 1234) 140 ctx.invoke.assert_called_once_with( 141 comp, 142 tff.structure.Struct([(None, 1), (None, 2), (None, 3), ('kw1', 'a'), 143 ('kw2', 'b')])) 144 145 def test_hash(self): 146 comp = fc.FederatedComputation(count_clients, name='comp') 147 # Equivalent objects should have equal hashes. 148 self.assertEqual( 149 hash(comp), hash(fc.FederatedComputation(count_clients, name='comp'))) 150 # Different computations or names should produce different hashes. 151 self.assertNotEqual( 152 hash(comp), hash(fc.FederatedComputation(count_clients, name='other'))) 153 self.assertNotEqual( 154 hash(comp), hash(fc.FederatedComputation(count_examples, name='comp'))) 155 156 157if __name__ == '__main__': 158 absltest.main() 159