1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for http_actions.""" 15 16import gzip 17import http 18import http.client 19import http.server 20import socket 21import threading 22from unittest import mock 23 24from absl.testing import absltest 25 26from fcp.demo import http_actions 27from fcp.protos.federatedcompute import common_pb2 28from fcp.protos.federatedcompute import eligibility_eval_tasks_pb2 29 30 31class TestService: 32 33 def __init__(self): 34 self.proto_action = mock.Mock() 35 self.get_action = mock.Mock() 36 self.post_action = mock.Mock() 37 38 @http_actions.proto_action( 39 service='google.internal.federatedcompute.v1.EligibilityEvalTasks', 40 method='RequestEligibilityEvalTask') 41 def handle_proto_action(self, *args, **kwargs): 42 return self.proto_action(*args, **kwargs) 43 44 @http_actions.http_action(method='get', pattern='/get/{arg1}/{arg2}') 45 def handle_get_action(self, *args, **kwargs): 46 return self.get_action(*args, **kwargs) 47 48 @http_actions.http_action(method='post', pattern='/post/{arg1}/{arg2}') 49 def handle_post_action(self, *args, **kwargs): 50 return self.post_action(*args, **kwargs) 51 52 53class TestHttpServer(http.server.HTTPServer): 54 pass 55 56 57class HttpActionsTest(absltest.TestCase): 58 59 def setUp(self): 60 super().setUp() 61 self.service = TestService() 62 handler = http_actions.create_handler(self.service) 63 self._httpd = TestHttpServer(('localhost', 0), handler) 64 self._server_thread = threading.Thread( 65 target=self._httpd.serve_forever, daemon=True) 66 self._server_thread.start() 67 self.conn = http.client.HTTPConnection( 68 self._httpd.server_name, port=self._httpd.server_port) 69 70 def tearDown(self): 71 self._httpd.shutdown() 72 self._server_thread.join() 73 self._httpd.server_close() 74 super().tearDown() 75 76 def test_not_found(self): 77 self.conn.request('GET', '/no-match') 78 self.assertEqual(self.conn.getresponse().status, http.HTTPStatus.NOT_FOUND) 79 80 def test_proto_success(self): 81 expected_response = ( 82 eligibility_eval_tasks_pb2.EligibilityEvalTaskResponse( 83 session_id='test')) 84 self.service.proto_action.return_value = expected_response 85 86 request = eligibility_eval_tasks_pb2.EligibilityEvalTaskRequest( 87 client_version=common_pb2.ClientVersion(version_code='test123')) 88 self.conn.request( 89 'POST', 90 '/v1/eligibilityevaltasks/test%2Fpopulation:request?%24alt=proto', 91 request.SerializeToString()) 92 response = self.conn.getresponse() 93 self.assertEqual(response.status, http.HTTPStatus.OK) 94 response_proto = ( 95 eligibility_eval_tasks_pb2.EligibilityEvalTaskResponse.FromString( 96 response.read())) 97 self.assertEqual(response_proto, expected_response) 98 # `population_name` should be set from the URL. 99 request.population_name = 'test/population' 100 self.service.proto_action.assert_called_once_with(request) 101 102 def test_proto_error(self): 103 self.service.proto_action.side_effect = http_actions.HttpError( 104 code=http.HTTPStatus.UNAUTHORIZED) 105 106 self.conn.request( 107 'POST', 108 '/v1/eligibilityevaltasks/test%2Fpopulation:request?%24alt=proto', b'') 109 response = self.conn.getresponse() 110 self.assertEqual(response.status, http.HTTPStatus.UNAUTHORIZED) 111 112 def test_proto_with_invalid_payload(self): 113 self.conn.request( 114 'POST', 115 '/v1/eligibilityevaltasks/test%2Fpopulation:request?%24alt=proto', 116 b'invalid') 117 response = self.conn.getresponse() 118 self.assertEqual(response.status, http.HTTPStatus.BAD_REQUEST) 119 120 def test_proto_with_gzip_encoding(self): 121 self.service.proto_action.return_value = ( 122 eligibility_eval_tasks_pb2.EligibilityEvalTaskResponse()) 123 124 request = eligibility_eval_tasks_pb2.EligibilityEvalTaskRequest( 125 client_version=common_pb2.ClientVersion(version_code='test123')) 126 self.conn.request('POST', 127 '/v1/eligibilityevaltasks/test:request?%24alt=proto', 128 gzip.compress(request.SerializeToString()), 129 {'Content-Encoding': 'gzip'}) 130 self.assertEqual(self.conn.getresponse().status, http.HTTPStatus.OK) 131 request.population_name = 'test' 132 self.service.proto_action.assert_called_once_with(request) 133 134 def test_proto_with_invalid_gzip_encoding(self): 135 self.conn.request('POST', 136 '/v1/eligibilityevaltasks/test:request?%24alt=proto', 137 b'invalid', {'Content-Encoding': 'gzip'}) 138 response = self.conn.getresponse() 139 self.assertEqual(response.status, http.HTTPStatus.BAD_REQUEST) 140 141 def test_proto_with_unsupport_encoding(self): 142 self.conn.request('POST', 143 '/v1/eligibilityevaltasks/test:request?%24alt=proto', b'', 144 {'Content-Encoding': 'compress'}) 145 self.assertEqual(self.conn.getresponse().status, 146 http.HTTPStatus.BAD_REQUEST) 147 self.service.proto_action.assert_not_called() 148 149 def test_get_success(self): 150 self.service.get_action.return_value = http_actions.HttpResponse( 151 body=b'body', 152 headers={ 153 'Content-Length': 4, 154 'Content-Type': 'application/x-test', 155 }) 156 157 self.conn.request('GET', '/get/foo/bar') 158 response = self.conn.getresponse() 159 self.assertEqual(response.status, http.HTTPStatus.OK) 160 self.assertEqual(response.headers['Content-Length'], '4') 161 self.assertEqual(response.headers['Content-Type'], 'application/x-test') 162 self.assertEqual(response.read(), b'body') 163 self.service.get_action.assert_called_once_with(b'', arg1='foo', arg2='bar') 164 165 def test_get_error(self): 166 self.service.get_action.side_effect = http_actions.HttpError( 167 code=http.HTTPStatus.UNAUTHORIZED) 168 169 self.conn.request('GET', '/get/foo/bar') 170 self.assertEqual(self.conn.getresponse().status, 171 http.HTTPStatus.UNAUTHORIZED) 172 173 def test_post_success(self): 174 self.service.post_action.return_value = http_actions.HttpResponse( 175 body=b'body', 176 headers={ 177 'Content-Length': 4, 178 'Content-Type': 'application/x-test', 179 }) 180 181 self.conn.request('POST', '/post/foo/bar', b'request-body') 182 response = self.conn.getresponse() 183 self.assertEqual(response.status, http.HTTPStatus.OK) 184 self.assertEqual(response.headers['Content-Length'], '4') 185 self.assertEqual(response.headers['Content-Type'], 'application/x-test') 186 self.assertEqual(response.read(), b'body') 187 self.service.post_action.assert_called_once_with( 188 b'request-body', arg1='foo', arg2='bar') 189 190 def test_post_error(self): 191 self.service.post_action.side_effect = http_actions.HttpError( 192 code=http.HTTPStatus.UNAUTHORIZED) 193 194 self.conn.request('POST', '/post/foo/bar', b'request-body') 195 self.assertEqual(self.conn.getresponse().status, 196 http.HTTPStatus.UNAUTHORIZED) 197 198 def test_post_with_gzip_encoding(self): 199 self.service.post_action.return_value = http_actions.HttpResponse(body=b'') 200 201 self.conn.request('POST', '/post/foo/bar', gzip.compress(b'request-body'), 202 {'Content-Encoding': 'gzip'}) 203 self.assertEqual(self.conn.getresponse().status, http.HTTPStatus.OK) 204 self.service.post_action.assert_called_once_with( 205 b'request-body', arg1='foo', arg2='bar') 206 207 def test_post_with_unsupport_encoding(self): 208 self.conn.request('POST', '/post/foo/bar', b'', 209 {'Content-Encoding': 'compress'}) 210 self.assertEqual(self.conn.getresponse().status, 211 http.HTTPStatus.BAD_REQUEST) 212 self.service.post_action.assert_not_called() 213 214 215if __name__ == '__main__': 216 absltest.main() 217