xref: /aosp_15_r20/external/federated-compute/fcp/demo/server.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""An in-process federated compute server."""
15
16import contextlib
17import gzip
18import http.server
19import socket
20import socketserver
21import ssl
22from typing import Optional
23
24from absl import logging
25
26from fcp.demo import aggregations
27from fcp.demo import eligibility_eval_tasks
28from fcp.demo import http_actions
29from fcp.demo import media
30from fcp.demo import plan_utils
31from fcp.demo import task_assignments
32from fcp.protos import plan_pb2
33from fcp.protos.federatedcompute import common_pb2
34from fcp.protos.federatedcompute import eligibility_eval_tasks_pb2
35
36_TaskAssignmentMode = (
37    eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode
38)
39
40# Template for file name for federated select slices. See
41# `FederatedSelectUriInfo.uri_template` for the meaning of the "{served_at_id}"
42# and "{key_base10}" substrings.
43_FEDERATED_SELECT_NAME_TEMPLATE = '{served_at_id}_{key_base10}'
44
45# Content type used for serialized and compressed Plan messages.
46_PLAN_CONTENT_TYPE = 'application/x-protobuf+gzip'
47
48# Content type used for serialzied and compressed TensorFlow checkpoints.
49_CHECKPOINT_CONTENT_TYPE = 'application/octet-stream+gzip'
50
51
52class InProcessServer(socketserver.ThreadingMixIn, http.server.HTTPServer):
53  """An in-process HTTP server implementing the Federated Compute protocol."""
54
55  def __init__(self,
56               *,
57               population_name: str,
58               host: str,
59               port: int,
60               address_family: Optional[socket.AddressFamily] = None):
61    self._media_service = media.Service(self._get_forwarding_info)
62    self._aggregations_service = aggregations.Service(self._get_forwarding_info,
63                                                      self._media_service)
64    self._task_assignments_service = task_assignments.Service(
65        population_name, self._get_forwarding_info, self._aggregations_service)
66    self._eligibility_eval_tasks_service = eligibility_eval_tasks.Service(
67        population_name, self._get_forwarding_info
68    )
69    handler = http_actions.create_handler(
70        self._media_service,
71        self._aggregations_service,
72        self._task_assignments_service,
73        self._eligibility_eval_tasks_service,
74    )
75    if address_family is not None:
76      self.address_family = address_family
77    http.server.HTTPServer.__init__(self, (host, port), handler)
78
79  async def run_computation(
80      self,
81      task_name: str,
82      plan: plan_pb2.Plan,
83      server_checkpoint: bytes,
84      task_assignment_mode: _TaskAssignmentMode,
85      number_of_clients: int,
86  ) -> bytes:
87    """Runs a computation, returning the resulting checkpoint.
88
89    If there's already a computation in progress, the new computation will
90    not start until the previous one has completed (either successfully or not).
91
92    Args:
93      task_name: The name of the task.
94      plan: The Plan proto containing the client and server computations.
95      server_checkpoint: The starting server checkpoint.
96      task_assignment_mode: The task assignment mode to use for the computation.
97      number_of_clients: The minimum number of clients to include.
98
99    Returns:
100      A TensorFlow checkpoint containing the aggregated results.
101    """
102    requirements = aggregations.AggregationRequirements(
103        minimum_clients_in_server_published_aggregate=number_of_clients,
104        plan=plan)
105    session_id = self._aggregations_service.create_session(requirements)
106    with contextlib.ExitStack() as stack:
107      stack.callback(
108          lambda: self._aggregations_service.abort_session(session_id))
109      with plan_utils.Session(plan, server_checkpoint) as session:
110        with self._media_service.create_download_group() as group:
111          plan_url = group.add(
112              'plan',
113              gzip.compress(session.client_plan),
114              content_type=_PLAN_CONTENT_TYPE,
115          )
116          checkpoint_url = group.add(
117              'checkpoint',
118              gzip.compress(session.client_checkpoint),
119              content_type=_CHECKPOINT_CONTENT_TYPE,
120          )
121          for served_at_id, slices in session.slices.items():
122            for i, slice_data in enumerate(slices):
123              group.add(
124                  _FEDERATED_SELECT_NAME_TEMPLATE.format(
125                      served_at_id=served_at_id, key_base10=str(i)
126                  ),
127                  gzip.compress(slice_data),
128                  content_type=_CHECKPOINT_CONTENT_TYPE,
129              )
130          self._eligibility_eval_tasks_service.add_task(
131              task_name, task_assignment_mode
132          )
133          self._task_assignments_service.add_task(
134              task_name,
135              task_assignment_mode,
136              session_id,
137              common_pb2.Resource(uri=plan_url),
138              common_pb2.Resource(uri=checkpoint_url),
139              group.prefix + _FEDERATED_SELECT_NAME_TEMPLATE,
140          )
141          try:
142            status = await self._aggregations_service.wait(
143                session_id,
144                num_inputs_aggregated_and_included=number_of_clients)
145            if status.status != aggregations.AggregationStatus.PENDING:
146              raise ValueError('Aggregation failed.')
147          finally:
148            self._task_assignments_service.remove_task(session_id)
149            self._eligibility_eval_tasks_service.remove_task(task_name)
150
151          stack.pop_all()
152          status, intermedia_update = (
153              self._aggregations_service.complete_session(session_id))
154          if (status.status != aggregations.AggregationStatus.COMPLETED or
155              intermedia_update is None):
156            raise ValueError('Aggregation failed.')
157        logging.debug('%s aggregation complete: %s', task_name, status)
158        return session.finalize(intermedia_update)
159
160  def _get_forwarding_info(self) -> common_pb2.ForwardingInfo:
161    protocol = 'https' if isinstance(self.socket, ssl.SSLSocket) else 'http'
162    return common_pb2.ForwardingInfo(
163        target_uri_prefix=(
164            f'{protocol}://{self.server_name}:{self.server_port}/'))
165