xref: /aosp_15_r20/external/federated-compute/fcp/demo/federated_computation_test.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for federated_computation."""
15
16from unittest import mock
17
18from absl.testing import absltest
19import tensorflow as tf
20import tensorflow_federated as tff
21
22from fcp.demo import federated_computation as fc
23
24
25@tff.tf_computation(tf.int32, tf.int32)
26def add_values(x, y):
27  return x + y
28
29
30@tff.federated_computation(
31    tff.type_at_server(tf.int32),
32    tff.type_at_clients(tff.SequenceType(tf.string)))
33def count_clients(state, client_data):
34  """Example TFF computation that counts clients."""
35  del client_data
36  client_value = tff.federated_value(1, tff.CLIENTS)
37  aggregated_count = tff.federated_sum(client_value)
38  metrics = tff.federated_value(tff.structure.Struct(()), tff.SERVER)
39  return tff.federated_map(add_values, (state, aggregated_count)), metrics
40
41
42@tff.federated_computation(
43    tff.type_at_server(tf.int32),
44    tff.type_at_clients(tff.SequenceType(tf.string)))
45def count_examples(state, client_data):
46  """Example TFF computation that counts client examples."""
47
48  @tff.tf_computation
49  def client_work(client_data):
50    return client_data.reduce(0, lambda x, _: x + 1)
51
52  client_counts = tff.federated_map(client_work, client_data)
53  aggregated_count = tff.federated_sum(client_counts)
54  metrics = tff.federated_value(tff.structure.Struct(()), tff.SERVER)
55  return tff.federated_map(add_values, (state, aggregated_count)), metrics
56
57
58class FederatedComputationTest(absltest.TestCase):
59
60  def test_invalid_name(self):
61    with self.assertRaisesRegex(ValueError, r'name must match ".+"'):
62      fc.FederatedComputation(count_clients, name='^invalid^')
63
64  def test_incompatible_computation(self):
65    # This function doesn't have the return value structure required for MRF.
66    @tff.federated_computation(tff.type_at_server(tf.int32))
67    def add_one(value):
68      return value + tff.federated_value(1, tff.SERVER)
69
70    with self.assertRaises(TypeError):
71      fc.FederatedComputation(add_one, name='comp')
72
73  @tff.test.with_context(
74      tff.backends.test.create_sync_test_cpp_execution_context
75  )
76  def test_map_reduce_form(self):
77    comp1 = fc.FederatedComputation(count_clients, name='comp1')
78    comp2 = fc.FederatedComputation(count_examples, name='comp2')
79    self.assertNotEqual(comp1.map_reduce_form, comp2.map_reduce_form)
80
81    # While we treat the MRF contents as an implementation detail, we can verify
82    # the invocation results of the corresponding computation.
83    # comp1 should return the number of clients.
84    self.assertEqual(
85        tff.backends.mapreduce.get_computation_for_map_reduce_form(
86            comp1.map_reduce_form
87        )(0, [['', '']] * 3),
88        (3, ()),
89    )
90    # comp2 should return the number of examples across all clients.
91    self.assertEqual(
92        tff.backends.mapreduce.get_computation_for_map_reduce_form(
93            comp2.map_reduce_form)(0, [['', '']] * 3), (6, ()))
94
95  @tff.test.with_context(
96      tff.backends.native.create_sync_local_cpp_execution_context
97  )
98  def test_distribute_aggregate_form(self):
99    comp1 = fc.FederatedComputation(count_clients, name='comp1')
100    comp2 = fc.FederatedComputation(count_examples, name='comp2')
101    self.assertNotEqual(
102        comp1.distribute_aggregate_form, comp2.distribute_aggregate_form
103    )
104
105    # While we treat the DAF contents as an implementation detail, we can verify
106    # the invocation results of the corresponding computation.
107    # comp1 should return the number of clients.
108    self.assertEqual(
109        tff.backends.mapreduce.get_computation_for_distribute_aggregate_form(
110            comp1.distribute_aggregate_form
111        )(0, [['', '']] * 3),
112        (3, ()),
113    )
114    # comp2 should return the number of examples across all clients.
115    self.assertEqual(
116        tff.backends.mapreduce.get_computation_for_distribute_aggregate_form(
117            comp2.distribute_aggregate_form
118        )(0, [['', '']] * 3),
119        (6, ()),
120    )
121
122  def test_wrapped_computation(self):
123    comp = fc.FederatedComputation(count_clients, name='comp')
124    self.assertEqual(comp.wrapped_computation, count_clients)
125
126  def test_name(self):
127    comp = fc.FederatedComputation(count_clients, name='comp')
128    self.assertEqual(comp.name, 'comp')
129
130  def test_type_signature(self):
131    comp = fc.FederatedComputation(count_clients, name='comp')
132    self.assertEqual(comp.type_signature, count_clients.type_signature)
133
134  def test_call(self):
135    comp = fc.FederatedComputation(count_clients, name='comp')
136    ctx = mock.create_autospec(tff.program.FederatedContext, instance=True)
137    ctx.invoke.return_value = 1234
138    with tff.framework.get_context_stack().install(ctx):
139      self.assertEqual(comp(1, 2, 3, kw1='a', kw2='b'), 1234)
140    ctx.invoke.assert_called_once_with(
141        comp,
142        tff.structure.Struct([(None, 1), (None, 2), (None, 3), ('kw1', 'a'),
143                              ('kw2', 'b')]))
144
145  def test_hash(self):
146    comp = fc.FederatedComputation(count_clients, name='comp')
147    # Equivalent objects should have equal hashes.
148    self.assertEqual(
149        hash(comp), hash(fc.FederatedComputation(count_clients, name='comp')))
150    # Different computations or names should produce different hashes.
151    self.assertNotEqual(
152        hash(comp), hash(fc.FederatedComputation(count_clients, name='other')))
153    self.assertNotEqual(
154        hash(comp), hash(fc.FederatedComputation(count_examples, name='comp')))
155
156
157if __name__ == '__main__':
158  absltest.main()
159