1#
2# Copyright 2015 Google Inc.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""The mock module allows easy mocking of apitools clients.
17
18This module allows you to mock out the constructor of a particular apitools
19client, for a specific API and version. Then, when the client is created, it
20will be run against an expected session that you define. This way code that is
21not aware of the testing framework can construct new clients as normal, as long
22as it's all done within the context of a mock.
23"""
24
25import difflib
26import sys
27
28import six
29
30from apitools.base.protorpclite import messages
31from apitools.base.py import base_api
32from apitools.base.py import encoding
33from apitools.base.py import exceptions
34
35
36class Error(Exception):
37
38    """Exceptions for this module."""
39
40
41def _MessagesEqual(msg1, msg2):
42    """Compare two protorpc messages for equality.
43
44    Using python's == operator does not work in all cases, specifically when
45    there is a list involved.
46
47    Args:
48      msg1: protorpc.messages.Message or [protorpc.messages.Message] or number
49          or string, One of the messages to compare.
50      msg2: protorpc.messages.Message or [protorpc.messages.Message] or number
51          or string, One of the messages to compare.
52
53    Returns:
54      If the messages are isomorphic.
55    """
56    if isinstance(msg1, list) and isinstance(msg2, list):
57        if len(msg1) != len(msg2):
58            return False
59        return all(_MessagesEqual(x, y) for x, y in zip(msg1, msg2))
60
61    if (not isinstance(msg1, messages.Message) or
62            not isinstance(msg2, messages.Message)):
63        return msg1 == msg2
64    for field in msg1.all_fields():
65        field1 = getattr(msg1, field.name)
66        field2 = getattr(msg2, field.name)
67        if not _MessagesEqual(field1, field2):
68            return False
69    return True
70
71
72class UnexpectedRequestException(Error):
73
74    def __init__(self, received_call, expected_call):
75        expected_key, expected_request = expected_call
76        received_key, received_request = received_call
77
78        expected_repr = encoding.MessageToRepr(
79            expected_request, multiline=True)
80        received_repr = encoding.MessageToRepr(
81            received_request, multiline=True)
82
83        expected_lines = expected_repr.splitlines()
84        received_lines = received_repr.splitlines()
85
86        diff_lines = difflib.unified_diff(expected_lines, received_lines)
87        diff = '\n'.join(diff_lines)
88
89        if expected_key != received_key:
90            msg = '\n'.join((
91                'expected: {expected_key}({expected_request})',
92                'received: {received_key}({received_request})',
93                '',
94            )).format(
95                expected_key=expected_key,
96                expected_request=expected_repr,
97                received_key=received_key,
98                received_request=received_repr)
99            super(UnexpectedRequestException, self).__init__(msg)
100        else:
101            msg = '\n'.join((
102                'for request to {key},',
103                'expected: {expected_request}',
104                'received: {received_request}',
105                'diff: {diff}',
106                '',
107            )).format(
108                key=expected_key,
109                expected_request=expected_repr,
110                received_request=received_repr,
111                diff=diff)
112            super(UnexpectedRequestException, self).__init__(msg)
113
114
115class ExpectedRequestsException(Error):
116
117    def __init__(self, expected_calls):
118        msg = 'expected:\n'
119        for (key, request) in expected_calls:
120            msg += '{key}({request})\n'.format(
121                key=key,
122                request=encoding.MessageToRepr(request, multiline=True))
123        super(ExpectedRequestsException, self).__init__(msg)
124
125
126class _ExpectedRequestResponse(object):
127
128    """Encapsulation of an expected request and corresponding response."""
129
130    def __init__(self, key, request, response=None, exception=None):
131        self.__key = key
132        self.__request = request
133
134        if response and exception:
135            raise exceptions.ConfigurationValueError(
136                'Should specify at most one of response and exception')
137        if response and isinstance(response, exceptions.Error):
138            raise exceptions.ConfigurationValueError(
139                'Responses should not be an instance of Error')
140        if exception and not isinstance(exception, exceptions.Error):
141            raise exceptions.ConfigurationValueError(
142                'Exceptions must be instances of Error')
143
144        self.__response = response
145        self.__exception = exception
146
147    @property
148    def key(self):
149        return self.__key
150
151    @property
152    def request(self):
153        return self.__request
154
155    def ValidateAndRespond(self, key, request):
156        """Validate that key and request match expectations, and respond if so.
157
158        Args:
159          key: str, Actual key to compare against expectations.
160          request: protorpc.messages.Message or [protorpc.messages.Message]
161            or number or string, Actual request to compare againt expectations
162
163        Raises:
164          UnexpectedRequestException: If key or request dont match
165              expectations.
166          apitools_base.Error: If a non-None exception is specified to
167              be thrown.
168
169        Returns:
170          The response that was specified to be returned.
171
172        """
173        if key != self.__key or not (self.__request == request or
174                                     _MessagesEqual(request, self.__request)):
175            raise UnexpectedRequestException((key, request),
176                                             (self.__key, self.__request))
177
178        if self.__exception:
179            # Can only throw apitools_base.Error.
180            raise self.__exception  # pylint: disable=raising-bad-type
181
182        return self.__response
183
184
185class _MockedMethod(object):
186
187    """A mocked API service method."""
188
189    def __init__(self, key, mocked_client, real_method):
190        self.__name__ = real_method.__name__
191        self.__key = key
192        self.__mocked_client = mocked_client
193        self.__real_method = real_method
194        self.method_config = real_method.method_config
195        config = self.method_config()
196        self.__request_type = getattr(self.__mocked_client.MESSAGES_MODULE,
197                                      config.request_type_name)
198        self.__response_type = getattr(self.__mocked_client.MESSAGES_MODULE,
199                                       config.response_type_name)
200
201    def _TypeCheck(self, msg, is_request):
202        """Ensure the given message is of the expected type of this method.
203
204        Args:
205          msg: The message instance to check.
206          is_request: True to validate against the expected request type,
207             False to validate against the expected response type.
208
209        Raises:
210          exceptions.ConfigurationValueError: If the type of the message was
211             not correct.
212        """
213        if is_request:
214            mode = 'request'
215            real_type = self.__request_type
216        else:
217            mode = 'response'
218            real_type = self.__response_type
219
220        if not isinstance(msg, real_type):
221            raise exceptions.ConfigurationValueError(
222                'Expected {} is not of the correct type for method [{}].\n'
223                '   Required: [{}]\n'
224                '   Given:    [{}]'.format(
225                    mode, self.__key, real_type, type(msg)))
226
227    def Expect(self, request, response=None, exception=None,
228               enable_type_checking=True, **unused_kwargs):
229        """Add an expectation on the mocked method.
230
231        Exactly one of response and exception should be specified.
232
233        Args:
234          request: The request that should be expected
235          response: The response that should be returned or None if
236              exception is provided.
237          exception: An exception that should be thrown, or None.
238          enable_type_checking: When true, the message type of the request
239              and response (if provided) will be checked against the types
240              required by this method.
241        """
242        # TODO(jasmuth): the unused_kwargs provides a placeholder for
243        # future things that can be passed to Expect(), like special
244        # params to the method call.
245
246        # Ensure that the registered request and response mocks actually
247        # match what this method accepts and returns.
248        if enable_type_checking:
249            self._TypeCheck(request, is_request=True)
250            if response:
251                self._TypeCheck(response, is_request=False)
252
253        # pylint: disable=protected-access
254        # Class in same module.
255        self.__mocked_client._request_responses.append(
256            _ExpectedRequestResponse(self.__key,
257                                     request,
258                                     response=response,
259                                     exception=exception))
260        # pylint: enable=protected-access
261
262    def __call__(self, request, **unused_kwargs):
263        # TODO(jasmuth): allow the testing code to expect certain
264        # values in these currently unused_kwargs, especially the
265        # upload parameter used by media-heavy services like bigquery
266        # or bigstore.
267
268        # pylint: disable=protected-access
269        # Class in same module.
270        if self.__mocked_client._request_responses:
271            request_response = self.__mocked_client._request_responses.pop(0)
272        else:
273            raise UnexpectedRequestException(
274                (self.__key, request), (None, None))
275        # pylint: enable=protected-access
276
277        response = request_response.ValidateAndRespond(self.__key, request)
278
279        if response is None and self.__real_method:
280            response = self.__real_method(request)
281            print(encoding.MessageToRepr(
282                response, multiline=True, shortstrings=True))
283            return response
284
285        return response
286
287
288def _MakeMockedService(api_name, collection_name,
289                       mock_client, service, real_service):
290    class MockedService(base_api.BaseApiService):
291        pass
292
293    for method in service.GetMethodsList():
294        real_method = None
295        if real_service:
296            real_method = getattr(real_service, method)
297        setattr(MockedService,
298                method,
299                _MockedMethod(api_name + '.' + collection_name + '.' + method,
300                              mock_client,
301                              real_method))
302    return MockedService
303
304
305class Client(object):
306
307    """Mock an apitools client."""
308
309    def __init__(self, client_class, real_client=None):
310        """Mock an apitools API, given its class.
311
312        Args:
313          client_class: The class for the API. eg, if you
314                from apis.sqladmin import v1beta3
315              then you can pass v1beta3.SqladminV1beta3 to this class
316              and anything within its context will use your mocked
317              version.
318          real_client: apitools Client, The client to make requests
319              against when the expected response is None.
320
321        """
322
323        if not real_client:
324            real_client = client_class(get_credentials=False)
325
326        self.__orig_class = self.__class__
327        self.__client_class = client_class
328        self.__real_service_classes = {}
329        self.__real_client = real_client
330
331        self._request_responses = []
332        self.__real_include_fields = None
333
334    def __enter__(self):
335        return self.Mock()
336
337    def Mock(self):
338        """Stub out the client class with mocked services."""
339        client = self.__real_client or self.__client_class(
340            get_credentials=False)
341
342        class Patched(self.__class__, self.__client_class):
343            pass
344        self.__class__ = Patched
345
346        for name in dir(self.__client_class):
347            service_class = getattr(self.__client_class, name)
348            if not isinstance(service_class, type):
349                continue
350            if not issubclass(service_class, base_api.BaseApiService):
351                continue
352            self.__real_service_classes[name] = service_class
353            # pylint: disable=protected-access
354            collection_name = service_class._NAME
355            # pylint: enable=protected-access
356            api_name = '%s_%s' % (self.__client_class._PACKAGE,
357                                  self.__client_class._URL_VERSION)
358            mocked_service_class = _MakeMockedService(
359                api_name, collection_name, self,
360                service_class,
361                service_class(client) if self.__real_client else None)
362
363            setattr(self.__client_class, name, mocked_service_class)
364
365            setattr(self, collection_name, mocked_service_class(self))
366
367        self.__real_include_fields = self.__client_class.IncludeFields
368        self.__client_class.IncludeFields = self.IncludeFields
369
370        # pylint: disable=attribute-defined-outside-init
371        self._url = client._url
372        self._http = client._http
373
374        return self
375
376    def __exit__(self, exc_type, value, traceback):
377        is_active_exception = value is not None
378        self.Unmock(suppress=is_active_exception)
379        if is_active_exception:
380            six.reraise(exc_type, value, traceback)
381        return True
382
383    def Unmock(self, suppress=False):
384        self.__class__ = self.__orig_class
385        for name, service_class in self.__real_service_classes.items():
386            setattr(self.__client_class, name, service_class)
387            delattr(self, service_class._NAME)
388        self.__real_service_classes = {}
389        del self._url
390        del self._http
391
392        self.__client_class.IncludeFields = self.__real_include_fields
393        self.__real_include_fields = None
394
395        requests = [(rq_rs.key, rq_rs.request)
396                    for rq_rs in self._request_responses]
397        self._request_responses = []
398
399        if requests and not suppress and sys.exc_info()[1] is None:
400            raise ExpectedRequestsException(requests)
401
402    def IncludeFields(self, include_fields):
403        if self.__real_client:
404            return self.__real_include_fields(self.__real_client,
405                                              include_fields)
406