# Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for federated_context.""" import http import http.client import socket import threading import unittest from unittest import mock from absl.testing import absltest import attr import tensorflow as tf import tensorflow_federated as tff from fcp.artifact_building import artifact_constants from fcp.artifact_building import federated_compute_plan_builder from fcp.artifact_building import plan_utils from fcp.artifact_building import variable_helpers from fcp.demo import federated_computation from fcp.demo import federated_context from fcp.demo import federated_data_source from fcp.demo import server from fcp.demo import test_utils from fcp.protos import plan_pb2 ADDRESS_FAMILY = socket.AddressFamily.AF_INET POPULATION_NAME = 'test/population' DATA_SOURCE = federated_data_source.FederatedDataSource( POPULATION_NAME, plan_pb2.ExampleSelector(collection_uri='app:/test')) @tff.tf_computation(tf.int32) def add_one(x): return x + 1 @tff.federated_computation( tff.type_at_server(tf.int32), tff.type_at_clients(tff.SequenceType(tf.string))) def count_clients(state, client_data): """Example TFF computation that counts clients.""" del client_data num_clients = tff.federated_sum(tff.federated_value(1, tff.CLIENTS)) non_state = tff.federated_value((), tff.SERVER) return state + num_clients, non_state @tff.federated_computation( tff.type_at_server(tff.StructType([('foo', tf.int32), ('bar', tf.int32)])), tff.type_at_clients(tff.SequenceType(tf.string)), ) def irregular_arrays(state, client_data): """Example TFF computation that returns irregular data.""" del client_data num_clients = tff.federated_sum(tff.federated_value(1, tff.CLIENTS)) non_state = tff.federated_value(1, tff.SERVER) return state, non_state + num_clients @attr.s(eq=False, frozen=True, slots=True) class TestClass: """An attrs class.""" field_one = attr.ib() field_two = attr.ib() @tff.tf_computation def init(): return TestClass(field_one=1, field_two=2) attrs_type = init.type_signature.result @tff.federated_computation( tff.type_at_server(attrs_type), tff.type_at_clients(tff.SequenceType(tf.string)), ) def attrs_computation(state, client_data): """Example TFF computation that returns an attrs class.""" del client_data num_clients = tff.federated_sum(tff.federated_value(1, tff.CLIENTS)) non_state = tff.federated_value(1, tff.SERVER) return state, non_state + num_clients def build_result_checkpoint(state: int) -> bytes: """Helper function to build a result checkpoint for `count_clients`.""" var_names = variable_helpers.variable_names_from_type( count_clients.type_signature.result[0], name=artifact_constants.SERVER_STATE_VAR_PREFIX) return test_utils.create_checkpoint({var_names[0]: state}) class FederatedContextTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase): def test_invalid_population_name(self): with self.assertRaisesRegex(ValueError, 'population_name must match ".+"'): federated_context.FederatedContext( '^^invalid^^', address_family=ADDRESS_FAMILY) @mock.patch.object(server.InProcessServer, 'shutdown', autospec=True) @mock.patch.object(server.InProcessServer, 'serve_forever', autospec=True) def test_context_management(self, serve_forever, shutdown): started = threading.Event() serve_forever.side_effect = lambda *args, **kwargs: started.set() ctx = federated_context.FederatedContext( POPULATION_NAME, address_family=ADDRESS_FAMILY) self.assertFalse(started.is_set()) shutdown.assert_not_called() with ctx: self.assertTrue(started.wait(0.5)) shutdown.assert_not_called() shutdown.assert_called_once() def test_http(self): with federated_context.FederatedContext( POPULATION_NAME, address_family=ADDRESS_FAMILY) as ctx: conn = http.client.HTTPConnection('localhost', port=ctx.server_port) conn.request('GET', '/does-not-exist') self.assertEqual(conn.getresponse().status, http.HTTPStatus.NOT_FOUND) def test_invoke_non_federated_with_base_context(self): base_context = tff.backends.native.create_sync_local_cpp_execution_context() ctx = federated_context.FederatedContext( POPULATION_NAME, address_family=ADDRESS_FAMILY, base_context=base_context) with tff.framework.get_context_stack().install(ctx): self.assertEqual(add_one(3), 4) def test_invoke_non_federated_without_base_context(self): ctx = federated_context.FederatedContext( POPULATION_NAME, address_family=ADDRESS_FAMILY) with tff.framework.get_context_stack().install(ctx): with self.assertRaisesRegex(TypeError, 'computation must be a FederatedComputation'): add_one(3) def test_invoke_with_invalid_state_type(self): comp = federated_computation.FederatedComputation(count_clients, name='x') ctx = federated_context.FederatedContext( POPULATION_NAME, address_family=ADDRESS_FAMILY) with tff.framework.get_context_stack().install(ctx): with self.assertRaisesRegex( TypeError, r'arg\[0\] must be a value or structure of values' ): comp(plan_pb2.Plan(), DATA_SOURCE.iterator().select(1)) def test_invoke_with_invalid_data_source_type(self): comp = federated_computation.FederatedComputation(count_clients, name='x') ctx = federated_context.FederatedContext( POPULATION_NAME, address_family=ADDRESS_FAMILY) with tff.framework.get_context_stack().install(ctx): with self.assertRaisesRegex( TypeError, r'arg\[1\] must be the result of ' r'FederatedDataSource.iterator\(\).select\(\)'): comp(0, plan_pb2.Plan()) def test_invoke_succeeds_with_structure_state_type(self): comp = federated_computation.FederatedComputation( irregular_arrays, name='x' ) ctx = federated_context.FederatedContext( POPULATION_NAME, address_family=ADDRESS_FAMILY ) with tff.framework.get_context_stack().install(ctx): state = {'foo': (3, 1), 'bar': (4, 5, 6)} comp(state, DATA_SOURCE.iterator().select(1)) def test_invoke_succeeds_with_attrs_state_type(self): comp = federated_computation.FederatedComputation( attrs_computation, name='x' ) ctx = federated_context.FederatedContext( POPULATION_NAME, address_family=ADDRESS_FAMILY ) with tff.framework.get_context_stack().install(ctx): state = TestClass(field_one=1, field_two=2) comp(state, DATA_SOURCE.iterator().select(1)) def test_invoke_with_mismatched_population_names(self): comp = federated_computation.FederatedComputation(count_clients, name='x') ds = federated_data_source.FederatedDataSource('other/name', DATA_SOURCE.example_selector) ctx = federated_context.FederatedContext( POPULATION_NAME, address_family=ADDRESS_FAMILY) with tff.framework.get_context_stack().install(ctx): with self.assertRaisesRegex( ValueError, 'FederatedDataSource and FederatedContext ' 'population_names must match'): comp(0, ds.iterator().select(1)) @mock.patch.object(server.InProcessServer, 'run_computation', autospec=True) async def test_invoke_success(self, run_computation): run_computation.return_value = build_result_checkpoint(7) comp = federated_computation.FederatedComputation(count_clients, name='x') ctx = federated_context.FederatedContext( POPULATION_NAME, address_family=ADDRESS_FAMILY) release_manager = tff.program.MemoryReleaseManager() with tff.framework.get_context_stack().install(ctx): state, _ = comp(3, DATA_SOURCE.iterator().select(10)) await release_manager.release( state, tff.type_at_server(tf.int32), key='result') self.assertEqual(release_manager.values()['result'][0], 7) run_computation.assert_called_once_with( mock.ANY, comp.name, mock.ANY, mock.ANY, DATA_SOURCE.task_assignment_mode, 10, ) plan = run_computation.call_args.args[2] self.assertIsInstance(plan, plan_pb2.Plan) self.assertNotEmpty(plan.client_tflite_graph_bytes) input_var_names = variable_helpers.variable_names_from_type( count_clients.type_signature.parameter[0], name=artifact_constants.SERVER_STATE_VAR_PREFIX) self.assertLen(input_var_names, 1) self.assertEqual( test_utils.read_tensor_from_checkpoint( run_computation.call_args.args[3], input_var_names[0], tf.int32), 3) @mock.patch.object(server.InProcessServer, 'run_computation', autospec=True) async def test_invoke_with_value_reference(self, run_computation): run_computation.side_effect = [ build_result_checkpoint(1234), build_result_checkpoint(5678) ] comp = federated_computation.FederatedComputation(count_clients, name='x') ctx = federated_context.FederatedContext( POPULATION_NAME, address_family=ADDRESS_FAMILY) release_manager = tff.program.MemoryReleaseManager() with tff.framework.get_context_stack().install(ctx): state, _ = comp(3, DATA_SOURCE.iterator().select(10)) state, _ = comp(state, DATA_SOURCE.iterator().select(10)) await release_manager.release( state, tff.type_at_server(tf.int32), key='result') self.assertEqual(release_manager.values()['result'][0], 5678) input_var_names = variable_helpers.variable_names_from_type( count_clients.type_signature.parameter[0], name=artifact_constants.SERVER_STATE_VAR_PREFIX) self.assertLen(input_var_names, 1) # The second invocation should be passed the value returned by the first # invocation. self.assertEqual(run_computation.call_count, 2) self.assertEqual( test_utils.read_tensor_from_checkpoint( run_computation.call_args.args[3], input_var_names[0], tf.int32), 1234) async def test_invoke_without_input_state(self): comp = federated_computation.FederatedComputation(count_clients, name='x') ctx = federated_context.FederatedContext( POPULATION_NAME, address_family=ADDRESS_FAMILY) with tff.framework.get_context_stack().install(ctx): with self.assertRaisesRegex( TypeError, r'arg\[0\] must be a value or structure of values' ): comp(None, DATA_SOURCE.iterator().select(1)) @mock.patch.object(server.InProcessServer, 'run_computation', autospec=True) async def test_invoke_with_run_computation_error(self, run_computation): run_computation.side_effect = ValueError('message') comp = federated_computation.FederatedComputation(count_clients, name='x') ctx = federated_context.FederatedContext( POPULATION_NAME, address_family=ADDRESS_FAMILY) release_manager = tff.program.MemoryReleaseManager() with tff.framework.get_context_stack().install(ctx): state, _ = comp(0, DATA_SOURCE.iterator().select(10)) with self.assertRaisesRegex(ValueError, 'message'): await release_manager.release( state, tff.type_at_server(tf.int32), key='result') class FederatedContextPlanCachingTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): await super().asyncSetUp() @tff.federated_computation( tff.type_at_server(tf.int32), tff.type_at_clients(tff.SequenceType(tf.string))) def identity(state, client_data): del client_data return state, tff.federated_value((), tff.SERVER) self.count_clients_comp1 = federated_computation.FederatedComputation( count_clients, name='count_clients1') self.count_clients_comp2 = federated_computation.FederatedComputation( count_clients, name='count_clients2') self.identity_comp = federated_computation.FederatedComputation( identity, name='identity') self.data_source1 = federated_data_source.FederatedDataSource( POPULATION_NAME, plan_pb2.ExampleSelector(collection_uri='app:/1')) self.data_source2 = federated_data_source.FederatedDataSource( POPULATION_NAME, plan_pb2.ExampleSelector(collection_uri='app:/2')) self.run_computation = self.enter_context( mock.patch.object( server.InProcessServer, 'run_computation', autospec=True)) self.run_computation.return_value = build_result_checkpoint(0) self.build_plan = self.enter_context( mock.patch.object( federated_compute_plan_builder, 'build_plan', autospec=True)) self.build_plan.return_value = plan_pb2.Plan() self.generate_and_add_flat_buffer_to_plan = self.enter_context( mock.patch.object( plan_utils, 'generate_and_add_flat_buffer_to_plan', autospec=True)) self.generate_and_add_flat_buffer_to_plan.side_effect = lambda plan: plan self.enter_context(tff.framework.get_context_stack().install( federated_context.FederatedContext( POPULATION_NAME, address_family=ADDRESS_FAMILY))) self.release_manager = tff.program.MemoryReleaseManager() # Run (and therefore cache) count_clients_comp1 with data_source1. await self.release_manager.release( self.count_clients_comp1(0, self.data_source1.iterator().select(1)), self.count_clients_comp1.type_signature.result, key='result') self.build_plan.assert_called_once() self.assertEqual(self.build_plan.call_args.args[0], self.count_clients_comp1.map_reduce_form) self.assertEqual( self.build_plan.call_args.args[1], self.count_clients_comp1.distribute_aggregate_form, ) self.assertEqual( self.build_plan.call_args.args[2].example_selector_proto, self.data_source1.example_selector, ) self.run_computation.assert_called_once() self.build_plan.reset_mock() self.run_computation.reset_mock() async def test_reuse_with_repeat_computation(self): await self.release_manager.release( self.count_clients_comp1(0, self.data_source1.iterator().select(1)), self.count_clients_comp1.type_signature.result, key='result') self.build_plan.assert_not_called() self.run_computation.assert_called_once() async def test_reuse_with_changed_num_clients(self): await self.release_manager.release( self.count_clients_comp1(0, self.data_source1.iterator().select(10)), self.count_clients_comp1.type_signature.result, key='result') self.build_plan.assert_not_called() self.run_computation.assert_called_once() async def test_reuse_with_changed_initial_state(self): await self.release_manager.release( self.count_clients_comp1(3, self.data_source1.iterator().select(1)), self.count_clients_comp1.type_signature.result, key='result') self.build_plan.assert_not_called() self.run_computation.assert_called_once() async def test_reuse_with_equivalent_map_reduce_form(self): await self.release_manager.release( self.count_clients_comp2(0, self.data_source1.iterator().select(1)), self.count_clients_comp2.type_signature.result, key='result') self.build_plan.assert_not_called() self.run_computation.assert_called_once() async def test_rebuild_with_different_computation(self): await self.release_manager.release( self.identity_comp(0, self.data_source1.iterator().select(1)), self.identity_comp.type_signature.result, key='result') self.build_plan.assert_called_once() self.assertEqual(self.build_plan.call_args.args[0], self.identity_comp.map_reduce_form) self.assertEqual( self.build_plan.call_args.args[1], self.identity_comp.distribute_aggregate_form, ) self.assertEqual( self.build_plan.call_args.args[2].example_selector_proto, self.data_source1.example_selector, ) self.run_computation.assert_called_once() async def test_rebuild_with_different_data_source(self): await self.release_manager.release( self.count_clients_comp1(0, self.data_source2.iterator().select(1)), self.count_clients_comp1.type_signature.result, key='result') self.build_plan.assert_called_once() self.assertEqual(self.build_plan.call_args.args[0], self.count_clients_comp1.map_reduce_form) self.assertEqual( self.build_plan.call_args.args[1], self.count_clients_comp1.distribute_aggregate_form, ) self.assertEqual( self.build_plan.call_args.args[2].example_selector_proto, self.data_source2.example_selector, ) self.run_computation.assert_called_once() if __name__ == '__main__': absltest.main()