1# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import datetime
16import json
17import os
18import pickle
19import sys
20
21import mock
22import pytest
23
24from google.auth import _helpers
25from google.auth import exceptions
26from google.oauth2 import _credentials_async as _credentials_async
27from google.oauth2 import credentials
28from tests.oauth2 import test_credentials
29
30
31class TestCredentials:
32
33    TOKEN_URI = "https://example.com/oauth2/token"
34    REFRESH_TOKEN = "refresh_token"
35    CLIENT_ID = "client_id"
36    CLIENT_SECRET = "client_secret"
37
38    @classmethod
39    def make_credentials(cls):
40        return _credentials_async.Credentials(
41            token=None,
42            refresh_token=cls.REFRESH_TOKEN,
43            token_uri=cls.TOKEN_URI,
44            client_id=cls.CLIENT_ID,
45            client_secret=cls.CLIENT_SECRET,
46            enable_reauth_refresh=True,
47        )
48
49    def test_default_state(self):
50        credentials = self.make_credentials()
51        assert not credentials.valid
52        # Expiration hasn't been set yet
53        assert not credentials.expired
54        # Scopes aren't required for these credentials
55        assert not credentials.requires_scopes
56        # Test properties
57        assert credentials.refresh_token == self.REFRESH_TOKEN
58        assert credentials.token_uri == self.TOKEN_URI
59        assert credentials.client_id == self.CLIENT_ID
60        assert credentials.client_secret == self.CLIENT_SECRET
61
62    @mock.patch("google.oauth2._reauth_async.refresh_grant", autospec=True)
63    @mock.patch(
64        "google.auth._helpers.utcnow",
65        return_value=datetime.datetime.min + _helpers.REFRESH_THRESHOLD,
66    )
67    @pytest.mark.asyncio
68    async def test_refresh_success(self, unused_utcnow, refresh_grant):
69        token = "token"
70        expiry = _helpers.utcnow() + datetime.timedelta(seconds=500)
71        grant_response = {"id_token": mock.sentinel.id_token}
72        rapt_token = "rapt_token"
73        refresh_grant.return_value = (
74            # Access token
75            token,
76            # New refresh token
77            None,
78            # Expiry,
79            expiry,
80            # Extra data
81            grant_response,
82            # Rapt token
83            rapt_token,
84        )
85
86        request = mock.AsyncMock(spec=["transport.Request"])
87        creds = self.make_credentials()
88
89        # Refresh credentials
90        await creds.refresh(request)
91
92        # Check jwt grant call.
93        refresh_grant.assert_called_with(
94            request,
95            self.TOKEN_URI,
96            self.REFRESH_TOKEN,
97            self.CLIENT_ID,
98            self.CLIENT_SECRET,
99            None,
100            None,
101            True,
102        )
103
104        # Check that the credentials have the token and expiry
105        assert creds.token == token
106        assert creds.expiry == expiry
107        assert creds.id_token == mock.sentinel.id_token
108        assert creds.rapt_token == rapt_token
109
110        # Check that the credentials are valid (have a token and are not
111        # expired)
112        assert creds.valid
113
114    @pytest.mark.asyncio
115    async def test_refresh_no_refresh_token(self):
116        request = mock.AsyncMock(spec=["transport.Request"])
117        credentials_ = _credentials_async.Credentials(token=None, refresh_token=None)
118
119        with pytest.raises(exceptions.RefreshError, match="necessary fields"):
120            await credentials_.refresh(request)
121
122        request.assert_not_called()
123
124    @mock.patch("google.oauth2._reauth_async.refresh_grant", autospec=True)
125    @mock.patch(
126        "google.auth._helpers.utcnow",
127        return_value=datetime.datetime.min + _helpers.REFRESH_THRESHOLD,
128    )
129    @pytest.mark.asyncio
130    async def test_credentials_with_scopes_requested_refresh_success(
131        self, unused_utcnow, refresh_grant
132    ):
133        scopes = ["email", "profile"]
134        token = "token"
135        expiry = _helpers.utcnow() + datetime.timedelta(seconds=500)
136        grant_response = {"id_token": mock.sentinel.id_token}
137        rapt_token = "rapt_token"
138        refresh_grant.return_value = (
139            # Access token
140            token,
141            # New refresh token
142            None,
143            # Expiry,
144            expiry,
145            # Extra data
146            grant_response,
147            # Rapt token
148            rapt_token,
149        )
150
151        request = mock.AsyncMock(spec=["transport.Request"])
152        creds = _credentials_async.Credentials(
153            token=None,
154            refresh_token=self.REFRESH_TOKEN,
155            token_uri=self.TOKEN_URI,
156            client_id=self.CLIENT_ID,
157            client_secret=self.CLIENT_SECRET,
158            scopes=scopes,
159            rapt_token="old_rapt_token",
160        )
161
162        # Refresh credentials
163        await creds.refresh(request)
164
165        # Check jwt grant call.
166        refresh_grant.assert_called_with(
167            request,
168            self.TOKEN_URI,
169            self.REFRESH_TOKEN,
170            self.CLIENT_ID,
171            self.CLIENT_SECRET,
172            scopes,
173            "old_rapt_token",
174            False,
175        )
176
177        # Check that the credentials have the token and expiry
178        assert creds.token == token
179        assert creds.expiry == expiry
180        assert creds.id_token == mock.sentinel.id_token
181        assert creds.has_scopes(scopes)
182        assert creds.rapt_token == rapt_token
183
184        # Check that the credentials are valid (have a token and are not
185        # expired.)
186        assert creds.valid
187
188    @mock.patch("google.oauth2._reauth_async.refresh_grant", autospec=True)
189    @mock.patch(
190        "google.auth._helpers.utcnow",
191        return_value=datetime.datetime.min + _helpers.REFRESH_THRESHOLD,
192    )
193    @pytest.mark.asyncio
194    async def test_credentials_with_scopes_returned_refresh_success(
195        self, unused_utcnow, refresh_grant
196    ):
197        scopes = ["email", "profile"]
198        token = "token"
199        expiry = _helpers.utcnow() + datetime.timedelta(seconds=500)
200        grant_response = {"id_token": mock.sentinel.id_token, "scope": " ".join(scopes)}
201        rapt_token = "rapt_token"
202        refresh_grant.return_value = (
203            # Access token
204            token,
205            # New refresh token
206            None,
207            # Expiry,
208            expiry,
209            # Extra data
210            grant_response,
211            # Rapt token
212            rapt_token,
213        )
214
215        request = mock.AsyncMock(spec=["transport.Request"])
216        creds = _credentials_async.Credentials(
217            token=None,
218            refresh_token=self.REFRESH_TOKEN,
219            token_uri=self.TOKEN_URI,
220            client_id=self.CLIENT_ID,
221            client_secret=self.CLIENT_SECRET,
222            scopes=scopes,
223        )
224
225        # Refresh credentials
226        await creds.refresh(request)
227
228        # Check jwt grant call.
229        refresh_grant.assert_called_with(
230            request,
231            self.TOKEN_URI,
232            self.REFRESH_TOKEN,
233            self.CLIENT_ID,
234            self.CLIENT_SECRET,
235            scopes,
236            None,
237            False,
238        )
239
240        # Check that the credentials have the token and expiry
241        assert creds.token == token
242        assert creds.expiry == expiry
243        assert creds.id_token == mock.sentinel.id_token
244        assert creds.has_scopes(scopes)
245        assert creds.rapt_token == rapt_token
246
247        # Check that the credentials are valid (have a token and are not
248        # expired.)
249        assert creds.valid
250
251    @mock.patch("google.oauth2._reauth_async.refresh_grant", autospec=True)
252    @mock.patch(
253        "google.auth._helpers.utcnow",
254        return_value=datetime.datetime.min + _helpers.REFRESH_THRESHOLD,
255    )
256    @pytest.mark.asyncio
257    async def test_credentials_with_scopes_refresh_failure_raises_refresh_error(
258        self, unused_utcnow, refresh_grant
259    ):
260        scopes = ["email", "profile"]
261        scopes_returned = ["email"]
262        token = "token"
263        expiry = _helpers.utcnow() + datetime.timedelta(seconds=500)
264        grant_response = {
265            "id_token": mock.sentinel.id_token,
266            "scope": " ".join(scopes_returned),
267        }
268        rapt_token = "rapt_token"
269        refresh_grant.return_value = (
270            # Access token
271            token,
272            # New refresh token
273            None,
274            # Expiry,
275            expiry,
276            # Extra data
277            grant_response,
278            # Rapt token
279            rapt_token,
280        )
281
282        request = mock.AsyncMock(spec=["transport.Request"])
283        creds = _credentials_async.Credentials(
284            token=None,
285            refresh_token=self.REFRESH_TOKEN,
286            token_uri=self.TOKEN_URI,
287            client_id=self.CLIENT_ID,
288            client_secret=self.CLIENT_SECRET,
289            scopes=scopes,
290            rapt_token=None,
291        )
292
293        # Refresh credentials
294        with pytest.raises(
295            exceptions.RefreshError, match="Not all requested scopes were granted"
296        ):
297            await creds.refresh(request)
298
299        # Check jwt grant call.
300        refresh_grant.assert_called_with(
301            request,
302            self.TOKEN_URI,
303            self.REFRESH_TOKEN,
304            self.CLIENT_ID,
305            self.CLIENT_SECRET,
306            scopes,
307            None,
308            False,
309        )
310
311        # Check that the credentials have the token and expiry
312        assert creds.token == token
313        assert creds.expiry == expiry
314        assert creds.id_token == mock.sentinel.id_token
315        assert creds.has_scopes(scopes)
316
317        # Check that the credentials are valid (have a token and are not
318        # expired.)
319        assert creds.valid
320
321    def test_apply_with_quota_project_id(self):
322        creds = _credentials_async.Credentials(
323            token="token",
324            refresh_token=self.REFRESH_TOKEN,
325            token_uri=self.TOKEN_URI,
326            client_id=self.CLIENT_ID,
327            client_secret=self.CLIENT_SECRET,
328            quota_project_id="quota-project-123",
329        )
330
331        headers = {}
332        creds.apply(headers)
333        assert headers["x-goog-user-project"] == "quota-project-123"
334
335    def test_apply_with_no_quota_project_id(self):
336        creds = _credentials_async.Credentials(
337            token="token",
338            refresh_token=self.REFRESH_TOKEN,
339            token_uri=self.TOKEN_URI,
340            client_id=self.CLIENT_ID,
341            client_secret=self.CLIENT_SECRET,
342        )
343
344        headers = {}
345        creds.apply(headers)
346        assert "x-goog-user-project" not in headers
347
348    def test_with_quota_project(self):
349        creds = _credentials_async.Credentials(
350            token="token",
351            refresh_token=self.REFRESH_TOKEN,
352            token_uri=self.TOKEN_URI,
353            client_id=self.CLIENT_ID,
354            client_secret=self.CLIENT_SECRET,
355            quota_project_id="quota-project-123",
356        )
357
358        new_creds = creds.with_quota_project("new-project-456")
359        assert new_creds.quota_project_id == "new-project-456"
360        headers = {}
361        creds.apply(headers)
362        assert "x-goog-user-project" in headers
363
364    def test_from_authorized_user_info(self):
365        info = test_credentials.AUTH_USER_INFO.copy()
366
367        creds = _credentials_async.Credentials.from_authorized_user_info(info)
368        assert creds.client_secret == info["client_secret"]
369        assert creds.client_id == info["client_id"]
370        assert creds.refresh_token == info["refresh_token"]
371        assert creds.token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT
372        assert creds.scopes is None
373
374        scopes = ["email", "profile"]
375        creds = _credentials_async.Credentials.from_authorized_user_info(info, scopes)
376        assert creds.client_secret == info["client_secret"]
377        assert creds.client_id == info["client_id"]
378        assert creds.refresh_token == info["refresh_token"]
379        assert creds.token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT
380        assert creds.scopes == scopes
381
382    def test_from_authorized_user_file(self):
383        info = test_credentials.AUTH_USER_INFO.copy()
384
385        creds = _credentials_async.Credentials.from_authorized_user_file(
386            test_credentials.AUTH_USER_JSON_FILE
387        )
388        assert creds.client_secret == info["client_secret"]
389        assert creds.client_id == info["client_id"]
390        assert creds.refresh_token == info["refresh_token"]
391        assert creds.token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT
392        assert creds.scopes is None
393
394        scopes = ["email", "profile"]
395        creds = _credentials_async.Credentials.from_authorized_user_file(
396            test_credentials.AUTH_USER_JSON_FILE, scopes
397        )
398        assert creds.client_secret == info["client_secret"]
399        assert creds.client_id == info["client_id"]
400        assert creds.refresh_token == info["refresh_token"]
401        assert creds.token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT
402        assert creds.scopes == scopes
403
404    def test_to_json(self):
405        info = test_credentials.AUTH_USER_INFO.copy()
406        creds = _credentials_async.Credentials.from_authorized_user_info(info)
407
408        # Test with no `strip` arg
409        json_output = creds.to_json()
410        json_asdict = json.loads(json_output)
411        assert json_asdict.get("token") == creds.token
412        assert json_asdict.get("refresh_token") == creds.refresh_token
413        assert json_asdict.get("token_uri") == creds.token_uri
414        assert json_asdict.get("client_id") == creds.client_id
415        assert json_asdict.get("scopes") == creds.scopes
416        assert json_asdict.get("client_secret") == creds.client_secret
417
418        # Test with a `strip` arg
419        json_output = creds.to_json(strip=["client_secret"])
420        json_asdict = json.loads(json_output)
421        assert json_asdict.get("token") == creds.token
422        assert json_asdict.get("refresh_token") == creds.refresh_token
423        assert json_asdict.get("token_uri") == creds.token_uri
424        assert json_asdict.get("client_id") == creds.client_id
425        assert json_asdict.get("scopes") == creds.scopes
426        assert json_asdict.get("client_secret") is None
427
428    def test_pickle_and_unpickle(self):
429        creds = self.make_credentials()
430        unpickled = pickle.loads(pickle.dumps(creds))
431
432        # make sure attributes aren't lost during pickling
433        assert list(creds.__dict__).sort() == list(unpickled.__dict__).sort()
434
435        for attr in list(creds.__dict__):
436            assert getattr(creds, attr) == getattr(unpickled, attr)
437
438    def test_pickle_with_missing_attribute(self):
439        creds = self.make_credentials()
440
441        # remove an optional attribute before pickling
442        # this mimics a pickle created with a previous class definition with
443        # fewer attributes
444        del creds.__dict__["_quota_project_id"]
445
446        unpickled = pickle.loads(pickle.dumps(creds))
447
448        # Attribute should be initialized by `__setstate__`
449        assert unpickled.quota_project_id is None
450
451    # pickles are not compatible across versions
452    @pytest.mark.skipif(
453        sys.version_info < (3, 5),
454        reason="pickle file can only be loaded with Python >= 3.5",
455    )
456    def test_unpickle_old_credentials_pickle(self):
457        # make sure a credentials file pickled with an older
458        # library version (google-auth==1.5.1) can be unpickled
459        with open(
460            os.path.join(test_credentials.DATA_DIR, "old_oauth_credentials_py3.pickle"),
461            "rb",
462        ) as f:
463            credentials = pickle.load(f)
464            assert credentials.quota_project_id is None
465
466
467class TestUserAccessTokenCredentials(object):
468    def test_instance(self):
469        cred = _credentials_async.UserAccessTokenCredentials()
470        assert cred._account is None
471
472        cred = cred.with_account("account")
473        assert cred._account == "account"
474
475    @mock.patch("google.auth._cloud_sdk.get_auth_access_token", autospec=True)
476    def test_refresh(self, get_auth_access_token):
477        get_auth_access_token.return_value = "access_token"
478        cred = _credentials_async.UserAccessTokenCredentials()
479        cred.refresh(None)
480        assert cred.token == "access_token"
481
482    def test_with_quota_project(self):
483        cred = _credentials_async.UserAccessTokenCredentials()
484        quota_project_cred = cred.with_quota_project("project-foo")
485
486        assert quota_project_cred._quota_project_id == "project-foo"
487        assert quota_project_cred._account == cred._account
488
489    @mock.patch(
490        "google.oauth2._credentials_async.UserAccessTokenCredentials.apply",
491        autospec=True,
492    )
493    @mock.patch(
494        "google.oauth2._credentials_async.UserAccessTokenCredentials.refresh",
495        autospec=True,
496    )
497    def test_before_request(self, refresh, apply):
498        cred = _credentials_async.UserAccessTokenCredentials()
499        cred.before_request(mock.Mock(), "GET", "https://example.com", {})
500        refresh.assert_called()
501        apply.assert_called()
502