1# 2# Copyright 2015 Google Inc. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16"""The mock module allows easy mocking of apitools clients. 17 18This module allows you to mock out the constructor of a particular apitools 19client, for a specific API and version. Then, when the client is created, it 20will be run against an expected session that you define. This way code that is 21not aware of the testing framework can construct new clients as normal, as long 22as it's all done within the context of a mock. 23""" 24 25import difflib 26import sys 27 28import six 29 30from apitools.base.protorpclite import messages 31from apitools.base.py import base_api 32from apitools.base.py import encoding 33from apitools.base.py import exceptions 34 35 36class Error(Exception): 37 38 """Exceptions for this module.""" 39 40 41def _MessagesEqual(msg1, msg2): 42 """Compare two protorpc messages for equality. 43 44 Using python's == operator does not work in all cases, specifically when 45 there is a list involved. 46 47 Args: 48 msg1: protorpc.messages.Message or [protorpc.messages.Message] or number 49 or string, One of the messages to compare. 50 msg2: protorpc.messages.Message or [protorpc.messages.Message] or number 51 or string, One of the messages to compare. 52 53 Returns: 54 If the messages are isomorphic. 55 """ 56 if isinstance(msg1, list) and isinstance(msg2, list): 57 if len(msg1) != len(msg2): 58 return False 59 return all(_MessagesEqual(x, y) for x, y in zip(msg1, msg2)) 60 61 if (not isinstance(msg1, messages.Message) or 62 not isinstance(msg2, messages.Message)): 63 return msg1 == msg2 64 for field in msg1.all_fields(): 65 field1 = getattr(msg1, field.name) 66 field2 = getattr(msg2, field.name) 67 if not _MessagesEqual(field1, field2): 68 return False 69 return True 70 71 72class UnexpectedRequestException(Error): 73 74 def __init__(self, received_call, expected_call): 75 expected_key, expected_request = expected_call 76 received_key, received_request = received_call 77 78 expected_repr = encoding.MessageToRepr( 79 expected_request, multiline=True) 80 received_repr = encoding.MessageToRepr( 81 received_request, multiline=True) 82 83 expected_lines = expected_repr.splitlines() 84 received_lines = received_repr.splitlines() 85 86 diff_lines = difflib.unified_diff(expected_lines, received_lines) 87 diff = '\n'.join(diff_lines) 88 89 if expected_key != received_key: 90 msg = '\n'.join(( 91 'expected: {expected_key}({expected_request})', 92 'received: {received_key}({received_request})', 93 '', 94 )).format( 95 expected_key=expected_key, 96 expected_request=expected_repr, 97 received_key=received_key, 98 received_request=received_repr) 99 super(UnexpectedRequestException, self).__init__(msg) 100 else: 101 msg = '\n'.join(( 102 'for request to {key},', 103 'expected: {expected_request}', 104 'received: {received_request}', 105 'diff: {diff}', 106 '', 107 )).format( 108 key=expected_key, 109 expected_request=expected_repr, 110 received_request=received_repr, 111 diff=diff) 112 super(UnexpectedRequestException, self).__init__(msg) 113 114 115class ExpectedRequestsException(Error): 116 117 def __init__(self, expected_calls): 118 msg = 'expected:\n' 119 for (key, request) in expected_calls: 120 msg += '{key}({request})\n'.format( 121 key=key, 122 request=encoding.MessageToRepr(request, multiline=True)) 123 super(ExpectedRequestsException, self).__init__(msg) 124 125 126class _ExpectedRequestResponse(object): 127 128 """Encapsulation of an expected request and corresponding response.""" 129 130 def __init__(self, key, request, response=None, exception=None): 131 self.__key = key 132 self.__request = request 133 134 if response and exception: 135 raise exceptions.ConfigurationValueError( 136 'Should specify at most one of response and exception') 137 if response and isinstance(response, exceptions.Error): 138 raise exceptions.ConfigurationValueError( 139 'Responses should not be an instance of Error') 140 if exception and not isinstance(exception, exceptions.Error): 141 raise exceptions.ConfigurationValueError( 142 'Exceptions must be instances of Error') 143 144 self.__response = response 145 self.__exception = exception 146 147 @property 148 def key(self): 149 return self.__key 150 151 @property 152 def request(self): 153 return self.__request 154 155 def ValidateAndRespond(self, key, request): 156 """Validate that key and request match expectations, and respond if so. 157 158 Args: 159 key: str, Actual key to compare against expectations. 160 request: protorpc.messages.Message or [protorpc.messages.Message] 161 or number or string, Actual request to compare againt expectations 162 163 Raises: 164 UnexpectedRequestException: If key or request dont match 165 expectations. 166 apitools_base.Error: If a non-None exception is specified to 167 be thrown. 168 169 Returns: 170 The response that was specified to be returned. 171 172 """ 173 if key != self.__key or not (self.__request == request or 174 _MessagesEqual(request, self.__request)): 175 raise UnexpectedRequestException((key, request), 176 (self.__key, self.__request)) 177 178 if self.__exception: 179 # Can only throw apitools_base.Error. 180 raise self.__exception # pylint: disable=raising-bad-type 181 182 return self.__response 183 184 185class _MockedMethod(object): 186 187 """A mocked API service method.""" 188 189 def __init__(self, key, mocked_client, real_method): 190 self.__name__ = real_method.__name__ 191 self.__key = key 192 self.__mocked_client = mocked_client 193 self.__real_method = real_method 194 self.method_config = real_method.method_config 195 config = self.method_config() 196 self.__request_type = getattr(self.__mocked_client.MESSAGES_MODULE, 197 config.request_type_name) 198 self.__response_type = getattr(self.__mocked_client.MESSAGES_MODULE, 199 config.response_type_name) 200 201 def _TypeCheck(self, msg, is_request): 202 """Ensure the given message is of the expected type of this method. 203 204 Args: 205 msg: The message instance to check. 206 is_request: True to validate against the expected request type, 207 False to validate against the expected response type. 208 209 Raises: 210 exceptions.ConfigurationValueError: If the type of the message was 211 not correct. 212 """ 213 if is_request: 214 mode = 'request' 215 real_type = self.__request_type 216 else: 217 mode = 'response' 218 real_type = self.__response_type 219 220 if not isinstance(msg, real_type): 221 raise exceptions.ConfigurationValueError( 222 'Expected {} is not of the correct type for method [{}].\n' 223 ' Required: [{}]\n' 224 ' Given: [{}]'.format( 225 mode, self.__key, real_type, type(msg))) 226 227 def Expect(self, request, response=None, exception=None, 228 enable_type_checking=True, **unused_kwargs): 229 """Add an expectation on the mocked method. 230 231 Exactly one of response and exception should be specified. 232 233 Args: 234 request: The request that should be expected 235 response: The response that should be returned or None if 236 exception is provided. 237 exception: An exception that should be thrown, or None. 238 enable_type_checking: When true, the message type of the request 239 and response (if provided) will be checked against the types 240 required by this method. 241 """ 242 # TODO(jasmuth): the unused_kwargs provides a placeholder for 243 # future things that can be passed to Expect(), like special 244 # params to the method call. 245 246 # Ensure that the registered request and response mocks actually 247 # match what this method accepts and returns. 248 if enable_type_checking: 249 self._TypeCheck(request, is_request=True) 250 if response: 251 self._TypeCheck(response, is_request=False) 252 253 # pylint: disable=protected-access 254 # Class in same module. 255 self.__mocked_client._request_responses.append( 256 _ExpectedRequestResponse(self.__key, 257 request, 258 response=response, 259 exception=exception)) 260 # pylint: enable=protected-access 261 262 def __call__(self, request, **unused_kwargs): 263 # TODO(jasmuth): allow the testing code to expect certain 264 # values in these currently unused_kwargs, especially the 265 # upload parameter used by media-heavy services like bigquery 266 # or bigstore. 267 268 # pylint: disable=protected-access 269 # Class in same module. 270 if self.__mocked_client._request_responses: 271 request_response = self.__mocked_client._request_responses.pop(0) 272 else: 273 raise UnexpectedRequestException( 274 (self.__key, request), (None, None)) 275 # pylint: enable=protected-access 276 277 response = request_response.ValidateAndRespond(self.__key, request) 278 279 if response is None and self.__real_method: 280 response = self.__real_method(request) 281 print(encoding.MessageToRepr( 282 response, multiline=True, shortstrings=True)) 283 return response 284 285 return response 286 287 288def _MakeMockedService(api_name, collection_name, 289 mock_client, service, real_service): 290 class MockedService(base_api.BaseApiService): 291 pass 292 293 for method in service.GetMethodsList(): 294 real_method = None 295 if real_service: 296 real_method = getattr(real_service, method) 297 setattr(MockedService, 298 method, 299 _MockedMethod(api_name + '.' + collection_name + '.' + method, 300 mock_client, 301 real_method)) 302 return MockedService 303 304 305class Client(object): 306 307 """Mock an apitools client.""" 308 309 def __init__(self, client_class, real_client=None): 310 """Mock an apitools API, given its class. 311 312 Args: 313 client_class: The class for the API. eg, if you 314 from apis.sqladmin import v1beta3 315 then you can pass v1beta3.SqladminV1beta3 to this class 316 and anything within its context will use your mocked 317 version. 318 real_client: apitools Client, The client to make requests 319 against when the expected response is None. 320 321 """ 322 323 if not real_client: 324 real_client = client_class(get_credentials=False) 325 326 self.__orig_class = self.__class__ 327 self.__client_class = client_class 328 self.__real_service_classes = {} 329 self.__real_client = real_client 330 331 self._request_responses = [] 332 self.__real_include_fields = None 333 334 def __enter__(self): 335 return self.Mock() 336 337 def Mock(self): 338 """Stub out the client class with mocked services.""" 339 client = self.__real_client or self.__client_class( 340 get_credentials=False) 341 342 class Patched(self.__class__, self.__client_class): 343 pass 344 self.__class__ = Patched 345 346 for name in dir(self.__client_class): 347 service_class = getattr(self.__client_class, name) 348 if not isinstance(service_class, type): 349 continue 350 if not issubclass(service_class, base_api.BaseApiService): 351 continue 352 self.__real_service_classes[name] = service_class 353 # pylint: disable=protected-access 354 collection_name = service_class._NAME 355 # pylint: enable=protected-access 356 api_name = '%s_%s' % (self.__client_class._PACKAGE, 357 self.__client_class._URL_VERSION) 358 mocked_service_class = _MakeMockedService( 359 api_name, collection_name, self, 360 service_class, 361 service_class(client) if self.__real_client else None) 362 363 setattr(self.__client_class, name, mocked_service_class) 364 365 setattr(self, collection_name, mocked_service_class(self)) 366 367 self.__real_include_fields = self.__client_class.IncludeFields 368 self.__client_class.IncludeFields = self.IncludeFields 369 370 # pylint: disable=attribute-defined-outside-init 371 self._url = client._url 372 self._http = client._http 373 374 return self 375 376 def __exit__(self, exc_type, value, traceback): 377 is_active_exception = value is not None 378 self.Unmock(suppress=is_active_exception) 379 if is_active_exception: 380 six.reraise(exc_type, value, traceback) 381 return True 382 383 def Unmock(self, suppress=False): 384 self.__class__ = self.__orig_class 385 for name, service_class in self.__real_service_classes.items(): 386 setattr(self.__client_class, name, service_class) 387 delattr(self, service_class._NAME) 388 self.__real_service_classes = {} 389 del self._url 390 del self._http 391 392 self.__client_class.IncludeFields = self.__real_include_fields 393 self.__real_include_fields = None 394 395 requests = [(rq_rs.key, rq_rs.request) 396 for rq_rs in self._request_responses] 397 self._request_responses = [] 398 399 if requests and not suppress and sys.exc_info()[1] is None: 400 raise ExpectedRequestsException(requests) 401 402 def IncludeFields(self, include_fields): 403 if self.__real_client: 404 return self.__real_include_fields(self.__real_client, 405 include_fields) 406