xref: /aosp_15_r20/external/federated-compute/fcp/demo/http_actions_test.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for http_actions."""
15
16import gzip
17import http
18import http.client
19import http.server
20import socket
21import threading
22from unittest import mock
23
24from absl.testing import absltest
25
26from fcp.demo import http_actions
27from fcp.protos.federatedcompute import common_pb2
28from fcp.protos.federatedcompute import eligibility_eval_tasks_pb2
29
30
31class TestService:
32
33  def __init__(self):
34    self.proto_action = mock.Mock()
35    self.get_action = mock.Mock()
36    self.post_action = mock.Mock()
37
38  @http_actions.proto_action(
39      service='google.internal.federatedcompute.v1.EligibilityEvalTasks',
40      method='RequestEligibilityEvalTask')
41  def handle_proto_action(self, *args, **kwargs):
42    return self.proto_action(*args, **kwargs)
43
44  @http_actions.http_action(method='get', pattern='/get/{arg1}/{arg2}')
45  def handle_get_action(self, *args, **kwargs):
46    return self.get_action(*args, **kwargs)
47
48  @http_actions.http_action(method='post', pattern='/post/{arg1}/{arg2}')
49  def handle_post_action(self, *args, **kwargs):
50    return self.post_action(*args, **kwargs)
51
52
53class TestHttpServer(http.server.HTTPServer):
54  pass
55
56
57class HttpActionsTest(absltest.TestCase):
58
59  def setUp(self):
60    super().setUp()
61    self.service = TestService()
62    handler = http_actions.create_handler(self.service)
63    self._httpd = TestHttpServer(('localhost', 0), handler)
64    self._server_thread = threading.Thread(
65        target=self._httpd.serve_forever, daemon=True)
66    self._server_thread.start()
67    self.conn = http.client.HTTPConnection(
68        self._httpd.server_name, port=self._httpd.server_port)
69
70  def tearDown(self):
71    self._httpd.shutdown()
72    self._server_thread.join()
73    self._httpd.server_close()
74    super().tearDown()
75
76  def test_not_found(self):
77    self.conn.request('GET', '/no-match')
78    self.assertEqual(self.conn.getresponse().status, http.HTTPStatus.NOT_FOUND)
79
80  def test_proto_success(self):
81    expected_response = (
82        eligibility_eval_tasks_pb2.EligibilityEvalTaskResponse(
83            session_id='test'))
84    self.service.proto_action.return_value = expected_response
85
86    request = eligibility_eval_tasks_pb2.EligibilityEvalTaskRequest(
87        client_version=common_pb2.ClientVersion(version_code='test123'))
88    self.conn.request(
89        'POST',
90        '/v1/eligibilityevaltasks/test%2Fpopulation:request?%24alt=proto',
91        request.SerializeToString())
92    response = self.conn.getresponse()
93    self.assertEqual(response.status, http.HTTPStatus.OK)
94    response_proto = (
95        eligibility_eval_tasks_pb2.EligibilityEvalTaskResponse.FromString(
96            response.read()))
97    self.assertEqual(response_proto, expected_response)
98    # `population_name` should be set from the URL.
99    request.population_name = 'test/population'
100    self.service.proto_action.assert_called_once_with(request)
101
102  def test_proto_error(self):
103    self.service.proto_action.side_effect = http_actions.HttpError(
104        code=http.HTTPStatus.UNAUTHORIZED)
105
106    self.conn.request(
107        'POST',
108        '/v1/eligibilityevaltasks/test%2Fpopulation:request?%24alt=proto', b'')
109    response = self.conn.getresponse()
110    self.assertEqual(response.status, http.HTTPStatus.UNAUTHORIZED)
111
112  def test_proto_with_invalid_payload(self):
113    self.conn.request(
114        'POST',
115        '/v1/eligibilityevaltasks/test%2Fpopulation:request?%24alt=proto',
116        b'invalid')
117    response = self.conn.getresponse()
118    self.assertEqual(response.status, http.HTTPStatus.BAD_REQUEST)
119
120  def test_proto_with_gzip_encoding(self):
121    self.service.proto_action.return_value = (
122        eligibility_eval_tasks_pb2.EligibilityEvalTaskResponse())
123
124    request = eligibility_eval_tasks_pb2.EligibilityEvalTaskRequest(
125        client_version=common_pb2.ClientVersion(version_code='test123'))
126    self.conn.request('POST',
127                      '/v1/eligibilityevaltasks/test:request?%24alt=proto',
128                      gzip.compress(request.SerializeToString()),
129                      {'Content-Encoding': 'gzip'})
130    self.assertEqual(self.conn.getresponse().status, http.HTTPStatus.OK)
131    request.population_name = 'test'
132    self.service.proto_action.assert_called_once_with(request)
133
134  def test_proto_with_invalid_gzip_encoding(self):
135    self.conn.request('POST',
136                      '/v1/eligibilityevaltasks/test:request?%24alt=proto',
137                      b'invalid', {'Content-Encoding': 'gzip'})
138    response = self.conn.getresponse()
139    self.assertEqual(response.status, http.HTTPStatus.BAD_REQUEST)
140
141  def test_proto_with_unsupport_encoding(self):
142    self.conn.request('POST',
143                      '/v1/eligibilityevaltasks/test:request?%24alt=proto', b'',
144                      {'Content-Encoding': 'compress'})
145    self.assertEqual(self.conn.getresponse().status,
146                     http.HTTPStatus.BAD_REQUEST)
147    self.service.proto_action.assert_not_called()
148
149  def test_get_success(self):
150    self.service.get_action.return_value = http_actions.HttpResponse(
151        body=b'body',
152        headers={
153            'Content-Length': 4,
154            'Content-Type': 'application/x-test',
155        })
156
157    self.conn.request('GET', '/get/foo/bar')
158    response = self.conn.getresponse()
159    self.assertEqual(response.status, http.HTTPStatus.OK)
160    self.assertEqual(response.headers['Content-Length'], '4')
161    self.assertEqual(response.headers['Content-Type'], 'application/x-test')
162    self.assertEqual(response.read(), b'body')
163    self.service.get_action.assert_called_once_with(b'', arg1='foo', arg2='bar')
164
165  def test_get_error(self):
166    self.service.get_action.side_effect = http_actions.HttpError(
167        code=http.HTTPStatus.UNAUTHORIZED)
168
169    self.conn.request('GET', '/get/foo/bar')
170    self.assertEqual(self.conn.getresponse().status,
171                     http.HTTPStatus.UNAUTHORIZED)
172
173  def test_post_success(self):
174    self.service.post_action.return_value = http_actions.HttpResponse(
175        body=b'body',
176        headers={
177            'Content-Length': 4,
178            'Content-Type': 'application/x-test',
179        })
180
181    self.conn.request('POST', '/post/foo/bar', b'request-body')
182    response = self.conn.getresponse()
183    self.assertEqual(response.status, http.HTTPStatus.OK)
184    self.assertEqual(response.headers['Content-Length'], '4')
185    self.assertEqual(response.headers['Content-Type'], 'application/x-test')
186    self.assertEqual(response.read(), b'body')
187    self.service.post_action.assert_called_once_with(
188        b'request-body', arg1='foo', arg2='bar')
189
190  def test_post_error(self):
191    self.service.post_action.side_effect = http_actions.HttpError(
192        code=http.HTTPStatus.UNAUTHORIZED)
193
194    self.conn.request('POST', '/post/foo/bar', b'request-body')
195    self.assertEqual(self.conn.getresponse().status,
196                     http.HTTPStatus.UNAUTHORIZED)
197
198  def test_post_with_gzip_encoding(self):
199    self.service.post_action.return_value = http_actions.HttpResponse(body=b'')
200
201    self.conn.request('POST', '/post/foo/bar', gzip.compress(b'request-body'),
202                      {'Content-Encoding': 'gzip'})
203    self.assertEqual(self.conn.getresponse().status, http.HTTPStatus.OK)
204    self.service.post_action.assert_called_once_with(
205        b'request-body', arg1='foo', arg2='bar')
206
207  def test_post_with_unsupport_encoding(self):
208    self.conn.request('POST', '/post/foo/bar', b'',
209                      {'Content-Encoding': 'compress'})
210    self.assertEqual(self.conn.getresponse().status,
211                     http.HTTPStatus.BAD_REQUEST)
212    self.service.post_action.assert_not_called()
213
214
215if __name__ == '__main__':
216  absltest.main()
217