1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Action handlers for file upload and download. 15 16In a production system, download would likely be handled by an external service; 17it's important that uploads are not handled separately to help ensure that 18unaggregated client data is only held ephemerally. 19""" 20 21import contextlib 22import http 23import threading 24from typing import Callable, Iterator, Optional 25import uuid 26 27from fcp.demo import http_actions 28from fcp.protos.federatedcompute import common_pb2 29 30 31class DownloadGroup: 32 """A group of downloadable files.""" 33 34 def __init__(self, prefix: str, add_fn: Callable[[str, bytes, str], None]): 35 self._prefix = prefix 36 self._add_fn = add_fn 37 38 @property 39 def prefix(self) -> str: 40 """The path prefix for all files in this group.""" 41 return self._prefix 42 43 def add(self, 44 name: str, 45 data: bytes, 46 content_type: str = 'application/octet-stream') -> str: 47 """Adds a file to the group. 48 49 Args: 50 name: The name of the new file. 51 data: The bytes to make available. 52 content_type: The content type to include in the response. 53 54 Returns: 55 The full path to the new file. 56 57 Raises: 58 KeyError if a file with that name has already been registered. 59 """ 60 self._add_fn(name, data, content_type) 61 return self._prefix + name 62 63 64class Service: 65 """Implements a service for uploading and downloading data over HTTP.""" 66 67 def __init__(self, forwarding_info: Callable[[], common_pb2.ForwardingInfo]): 68 self._forwarding_info = forwarding_info 69 self._lock = threading.Lock() 70 self._downloads: dict[str, dict[str, http_actions.HttpResponse]] = {} 71 self._uploads: dict[str, Optional[bytes]] = {} 72 73 @contextlib.contextmanager 74 def create_download_group(self) -> Iterator[DownloadGroup]: 75 """Creates a new group of downloadable files. 76 77 Files can be be added to this group using `DownloadGroup.add`. All files in 78 the group will be unregistered when the ContextManager goes out of scope. 79 80 Yields: 81 The download group to which files should be added. 82 """ 83 group = str(uuid.uuid4()) 84 85 def add_file(name: str, data: bytes, content_type: str) -> None: 86 with self._lock: 87 if name in self._downloads[group]: 88 raise KeyError(f'{name} already exists') 89 self._downloads[group][name] = http_actions.HttpResponse( 90 body=data, 91 headers={ 92 'Content-Length': len(data), 93 'Content-Type': content_type, 94 }) 95 96 with self._lock: 97 self._downloads[group] = {} 98 try: 99 yield DownloadGroup( 100 f'{self._forwarding_info().target_uri_prefix}data/{group}/', add_file) 101 finally: 102 with self._lock: 103 del self._downloads[group] 104 105 def register_upload(self) -> str: 106 """Registers a path for single-use upload, returning the resource name.""" 107 name = str(uuid.uuid4()) 108 with self._lock: 109 self._uploads[name] = None 110 return name 111 112 def finalize_upload(self, name: str) -> Optional[bytes]: 113 """Returns the data from an upload, if any.""" 114 with self._lock: 115 return self._uploads.pop(name) 116 117 @http_actions.http_action(method='GET', pattern='/data/{group}/{name}') 118 def download(self, body: bytes, group: str, 119 name: str) -> http_actions.HttpResponse: 120 """Handles a download request.""" 121 del body 122 try: 123 with self._lock: 124 return self._downloads[group][name] 125 except KeyError as e: 126 raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND) from e 127 128 @http_actions.http_action( 129 method='POST', pattern='/upload/v1/media/{name}?upload_protocol=raw') 130 def upload(self, body: bytes, name: str) -> http_actions.HttpResponse: 131 with self._lock: 132 if name not in self._uploads or self._uploads[name] is not None: 133 raise http_actions.HttpError(http.HTTPStatus.UNAUTHORIZED) 134 self._uploads[name] = body 135 return http_actions.HttpResponse(b'') 136