1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""An in-process federated compute server.""" 15 16import contextlib 17import gzip 18import http.server 19import socket 20import socketserver 21import ssl 22from typing import Optional 23 24from absl import logging 25 26from fcp.demo import aggregations 27from fcp.demo import eligibility_eval_tasks 28from fcp.demo import http_actions 29from fcp.demo import media 30from fcp.demo import plan_utils 31from fcp.demo import task_assignments 32from fcp.protos import plan_pb2 33from fcp.protos.federatedcompute import common_pb2 34from fcp.protos.federatedcompute import eligibility_eval_tasks_pb2 35 36_TaskAssignmentMode = ( 37 eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode 38) 39 40# Template for file name for federated select slices. See 41# `FederatedSelectUriInfo.uri_template` for the meaning of the "{served_at_id}" 42# and "{key_base10}" substrings. 43_FEDERATED_SELECT_NAME_TEMPLATE = '{served_at_id}_{key_base10}' 44 45# Content type used for serialized and compressed Plan messages. 46_PLAN_CONTENT_TYPE = 'application/x-protobuf+gzip' 47 48# Content type used for serialzied and compressed TensorFlow checkpoints. 49_CHECKPOINT_CONTENT_TYPE = 'application/octet-stream+gzip' 50 51 52class InProcessServer(socketserver.ThreadingMixIn, http.server.HTTPServer): 53 """An in-process HTTP server implementing the Federated Compute protocol.""" 54 55 def __init__(self, 56 *, 57 population_name: str, 58 host: str, 59 port: int, 60 address_family: Optional[socket.AddressFamily] = None): 61 self._media_service = media.Service(self._get_forwarding_info) 62 self._aggregations_service = aggregations.Service(self._get_forwarding_info, 63 self._media_service) 64 self._task_assignments_service = task_assignments.Service( 65 population_name, self._get_forwarding_info, self._aggregations_service) 66 self._eligibility_eval_tasks_service = eligibility_eval_tasks.Service( 67 population_name, self._get_forwarding_info 68 ) 69 handler = http_actions.create_handler( 70 self._media_service, 71 self._aggregations_service, 72 self._task_assignments_service, 73 self._eligibility_eval_tasks_service, 74 ) 75 if address_family is not None: 76 self.address_family = address_family 77 http.server.HTTPServer.__init__(self, (host, port), handler) 78 79 async def run_computation( 80 self, 81 task_name: str, 82 plan: plan_pb2.Plan, 83 server_checkpoint: bytes, 84 task_assignment_mode: _TaskAssignmentMode, 85 number_of_clients: int, 86 ) -> bytes: 87 """Runs a computation, returning the resulting checkpoint. 88 89 If there's already a computation in progress, the new computation will 90 not start until the previous one has completed (either successfully or not). 91 92 Args: 93 task_name: The name of the task. 94 plan: The Plan proto containing the client and server computations. 95 server_checkpoint: The starting server checkpoint. 96 task_assignment_mode: The task assignment mode to use for the computation. 97 number_of_clients: The minimum number of clients to include. 98 99 Returns: 100 A TensorFlow checkpoint containing the aggregated results. 101 """ 102 requirements = aggregations.AggregationRequirements( 103 minimum_clients_in_server_published_aggregate=number_of_clients, 104 plan=plan) 105 session_id = self._aggregations_service.create_session(requirements) 106 with contextlib.ExitStack() as stack: 107 stack.callback( 108 lambda: self._aggregations_service.abort_session(session_id)) 109 with plan_utils.Session(plan, server_checkpoint) as session: 110 with self._media_service.create_download_group() as group: 111 plan_url = group.add( 112 'plan', 113 gzip.compress(session.client_plan), 114 content_type=_PLAN_CONTENT_TYPE, 115 ) 116 checkpoint_url = group.add( 117 'checkpoint', 118 gzip.compress(session.client_checkpoint), 119 content_type=_CHECKPOINT_CONTENT_TYPE, 120 ) 121 for served_at_id, slices in session.slices.items(): 122 for i, slice_data in enumerate(slices): 123 group.add( 124 _FEDERATED_SELECT_NAME_TEMPLATE.format( 125 served_at_id=served_at_id, key_base10=str(i) 126 ), 127 gzip.compress(slice_data), 128 content_type=_CHECKPOINT_CONTENT_TYPE, 129 ) 130 self._eligibility_eval_tasks_service.add_task( 131 task_name, task_assignment_mode 132 ) 133 self._task_assignments_service.add_task( 134 task_name, 135 task_assignment_mode, 136 session_id, 137 common_pb2.Resource(uri=plan_url), 138 common_pb2.Resource(uri=checkpoint_url), 139 group.prefix + _FEDERATED_SELECT_NAME_TEMPLATE, 140 ) 141 try: 142 status = await self._aggregations_service.wait( 143 session_id, 144 num_inputs_aggregated_and_included=number_of_clients) 145 if status.status != aggregations.AggregationStatus.PENDING: 146 raise ValueError('Aggregation failed.') 147 finally: 148 self._task_assignments_service.remove_task(session_id) 149 self._eligibility_eval_tasks_service.remove_task(task_name) 150 151 stack.pop_all() 152 status, intermedia_update = ( 153 self._aggregations_service.complete_session(session_id)) 154 if (status.status != aggregations.AggregationStatus.COMPLETED or 155 intermedia_update is None): 156 raise ValueError('Aggregation failed.') 157 logging.debug('%s aggregation complete: %s', task_name, status) 158 return session.finalize(intermedia_update) 159 160 def _get_forwarding_info(self) -> common_pb2.ForwardingInfo: 161 protocol = 'https' if isinstance(self.socket, ssl.SSLSocket) else 'http' 162 return common_pb2.ForwardingInfo( 163 target_uri_prefix=( 164 f'{protocol}://{self.server_name}:{self.server_port}/')) 165