1# Copyright 2016 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import datetime
16import json
17import os
18
19import mock
20
21from google.auth import _helpers
22from google.auth import crypt
23from google.auth import jwt
24from google.auth import transport
25from google.oauth2 import service_account
26
27
28DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
29
30with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh:
31    PRIVATE_KEY_BYTES = fh.read()
32
33with open(os.path.join(DATA_DIR, "public_cert.pem"), "rb") as fh:
34    PUBLIC_CERT_BYTES = fh.read()
35
36with open(os.path.join(DATA_DIR, "other_cert.pem"), "rb") as fh:
37    OTHER_CERT_BYTES = fh.read()
38
39SERVICE_ACCOUNT_JSON_FILE = os.path.join(DATA_DIR, "service_account.json")
40
41with open(SERVICE_ACCOUNT_JSON_FILE, "r") as fh:
42    SERVICE_ACCOUNT_INFO = json.load(fh)
43
44SIGNER = crypt.RSASigner.from_string(PRIVATE_KEY_BYTES, "1")
45
46
47class TestCredentials(object):
48    SERVICE_ACCOUNT_EMAIL = "[email protected]"
49    TOKEN_URI = "https://example.com/oauth2/token"
50
51    @classmethod
52    def make_credentials(cls):
53        return service_account.Credentials(
54            SIGNER, cls.SERVICE_ACCOUNT_EMAIL, cls.TOKEN_URI
55        )
56
57    def test_from_service_account_info(self):
58        credentials = service_account.Credentials.from_service_account_info(
59            SERVICE_ACCOUNT_INFO
60        )
61
62        assert credentials._signer.key_id == SERVICE_ACCOUNT_INFO["private_key_id"]
63        assert credentials.service_account_email == SERVICE_ACCOUNT_INFO["client_email"]
64        assert credentials._token_uri == SERVICE_ACCOUNT_INFO["token_uri"]
65
66    def test_from_service_account_info_args(self):
67        info = SERVICE_ACCOUNT_INFO.copy()
68        scopes = ["email", "profile"]
69        subject = "subject"
70        additional_claims = {"meta": "data"}
71
72        credentials = service_account.Credentials.from_service_account_info(
73            info, scopes=scopes, subject=subject, additional_claims=additional_claims
74        )
75
76        assert credentials.service_account_email == info["client_email"]
77        assert credentials.project_id == info["project_id"]
78        assert credentials._signer.key_id == info["private_key_id"]
79        assert credentials._token_uri == info["token_uri"]
80        assert credentials._scopes == scopes
81        assert credentials._subject == subject
82        assert credentials._additional_claims == additional_claims
83
84    def test_from_service_account_file(self):
85        info = SERVICE_ACCOUNT_INFO.copy()
86
87        credentials = service_account.Credentials.from_service_account_file(
88            SERVICE_ACCOUNT_JSON_FILE
89        )
90
91        assert credentials.service_account_email == info["client_email"]
92        assert credentials.project_id == info["project_id"]
93        assert credentials._signer.key_id == info["private_key_id"]
94        assert credentials._token_uri == info["token_uri"]
95
96    def test_from_service_account_file_args(self):
97        info = SERVICE_ACCOUNT_INFO.copy()
98        scopes = ["email", "profile"]
99        subject = "subject"
100        additional_claims = {"meta": "data"}
101
102        credentials = service_account.Credentials.from_service_account_file(
103            SERVICE_ACCOUNT_JSON_FILE,
104            subject=subject,
105            scopes=scopes,
106            additional_claims=additional_claims,
107        )
108
109        assert credentials.service_account_email == info["client_email"]
110        assert credentials.project_id == info["project_id"]
111        assert credentials._signer.key_id == info["private_key_id"]
112        assert credentials._token_uri == info["token_uri"]
113        assert credentials._scopes == scopes
114        assert credentials._subject == subject
115        assert credentials._additional_claims == additional_claims
116
117    def test_default_state(self):
118        credentials = self.make_credentials()
119        assert not credentials.valid
120        # Expiration hasn't been set yet
121        assert not credentials.expired
122        # Scopes haven't been specified yet
123        assert credentials.requires_scopes
124
125    def test_sign_bytes(self):
126        credentials = self.make_credentials()
127        to_sign = b"123"
128        signature = credentials.sign_bytes(to_sign)
129        assert crypt.verify_signature(to_sign, signature, PUBLIC_CERT_BYTES)
130
131    def test_signer(self):
132        credentials = self.make_credentials()
133        assert isinstance(credentials.signer, crypt.Signer)
134
135    def test_signer_email(self):
136        credentials = self.make_credentials()
137        assert credentials.signer_email == self.SERVICE_ACCOUNT_EMAIL
138
139    def test_create_scoped(self):
140        credentials = self.make_credentials()
141        scopes = ["email", "profile"]
142        credentials = credentials.with_scopes(scopes)
143        assert credentials._scopes == scopes
144
145    def test_with_claims(self):
146        credentials = self.make_credentials()
147        new_credentials = credentials.with_claims({"meep": "moop"})
148        assert new_credentials._additional_claims == {"meep": "moop"}
149
150    def test_with_quota_project(self):
151        credentials = self.make_credentials()
152        new_credentials = credentials.with_quota_project("new-project-456")
153        assert new_credentials.quota_project_id == "new-project-456"
154        hdrs = {}
155        new_credentials.apply(hdrs, token="tok")
156        assert "x-goog-user-project" in hdrs
157
158    def test__with_always_use_jwt_access(self):
159        credentials = self.make_credentials()
160        assert not credentials._always_use_jwt_access
161
162        new_credentials = credentials.with_always_use_jwt_access(True)
163        assert new_credentials._always_use_jwt_access
164
165    def test__make_authorization_grant_assertion(self):
166        credentials = self.make_credentials()
167        token = credentials._make_authorization_grant_assertion()
168        payload = jwt.decode(token, PUBLIC_CERT_BYTES)
169        assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL
170        assert payload["aud"] == service_account._GOOGLE_OAUTH2_TOKEN_ENDPOINT
171
172    def test__make_authorization_grant_assertion_scoped(self):
173        credentials = self.make_credentials()
174        scopes = ["email", "profile"]
175        credentials = credentials.with_scopes(scopes)
176        token = credentials._make_authorization_grant_assertion()
177        payload = jwt.decode(token, PUBLIC_CERT_BYTES)
178        assert payload["scope"] == "email profile"
179
180    def test__make_authorization_grant_assertion_subject(self):
181        credentials = self.make_credentials()
182        subject = "[email protected]"
183        credentials = credentials.with_subject(subject)
184        token = credentials._make_authorization_grant_assertion()
185        payload = jwt.decode(token, PUBLIC_CERT_BYTES)
186        assert payload["sub"] == subject
187
188    def test_apply_with_quota_project_id(self):
189        credentials = service_account.Credentials(
190            SIGNER,
191            self.SERVICE_ACCOUNT_EMAIL,
192            self.TOKEN_URI,
193            quota_project_id="quota-project-123",
194        )
195
196        headers = {}
197        credentials.apply(headers, token="token")
198
199        assert headers["x-goog-user-project"] == "quota-project-123"
200        assert "token" in headers["authorization"]
201
202    def test_apply_with_no_quota_project_id(self):
203        credentials = service_account.Credentials(
204            SIGNER, self.SERVICE_ACCOUNT_EMAIL, self.TOKEN_URI
205        )
206
207        headers = {}
208        credentials.apply(headers, token="token")
209
210        assert "x-goog-user-project" not in headers
211        assert "token" in headers["authorization"]
212
213    @mock.patch("google.auth.jwt.Credentials", instance=True, autospec=True)
214    def test__create_self_signed_jwt(self, jwt):
215        credentials = service_account.Credentials(
216            SIGNER, self.SERVICE_ACCOUNT_EMAIL, self.TOKEN_URI
217        )
218
219        audience = "https://pubsub.googleapis.com"
220        credentials._create_self_signed_jwt(audience)
221        jwt.from_signing_credentials.assert_called_once_with(credentials, audience)
222
223    @mock.patch("google.auth.jwt.Credentials", instance=True, autospec=True)
224    def test__create_self_signed_jwt_with_user_scopes(self, jwt):
225        credentials = service_account.Credentials(
226            SIGNER, self.SERVICE_ACCOUNT_EMAIL, self.TOKEN_URI, scopes=["foo"]
227        )
228
229        audience = "https://pubsub.googleapis.com"
230        credentials._create_self_signed_jwt(audience)
231
232        # JWT should not be created if there are user-defined scopes
233        jwt.from_signing_credentials.assert_not_called()
234
235    @mock.patch("google.auth.jwt.Credentials", instance=True, autospec=True)
236    def test__create_self_signed_jwt_always_use_jwt_access_with_audience(self, jwt):
237        credentials = service_account.Credentials(
238            SIGNER,
239            self.SERVICE_ACCOUNT_EMAIL,
240            self.TOKEN_URI,
241            default_scopes=["bar", "foo"],
242            always_use_jwt_access=True,
243        )
244
245        audience = "https://pubsub.googleapis.com"
246        credentials._create_self_signed_jwt(audience)
247        jwt.from_signing_credentials.assert_called_once_with(credentials, audience)
248
249    @mock.patch("google.auth.jwt.Credentials", instance=True, autospec=True)
250    def test__create_self_signed_jwt_always_use_jwt_access_with_scopes(self, jwt):
251        credentials = service_account.Credentials(
252            SIGNER,
253            self.SERVICE_ACCOUNT_EMAIL,
254            self.TOKEN_URI,
255            scopes=["bar", "foo"],
256            always_use_jwt_access=True,
257        )
258
259        audience = "https://pubsub.googleapis.com"
260        credentials._create_self_signed_jwt(audience)
261        jwt.from_signing_credentials.assert_called_once_with(
262            credentials, None, additional_claims={"scope": "bar foo"}
263        )
264
265    @mock.patch("google.auth.jwt.Credentials", instance=True, autospec=True)
266    def test__create_self_signed_jwt_always_use_jwt_access_with_default_scopes(
267        self, jwt
268    ):
269        credentials = service_account.Credentials(
270            SIGNER,
271            self.SERVICE_ACCOUNT_EMAIL,
272            self.TOKEN_URI,
273            default_scopes=["bar", "foo"],
274            always_use_jwt_access=True,
275        )
276
277        credentials._create_self_signed_jwt(None)
278        jwt.from_signing_credentials.assert_called_once_with(
279            credentials, None, additional_claims={"scope": "bar foo"}
280        )
281
282    @mock.patch("google.auth.jwt.Credentials", instance=True, autospec=True)
283    def test__create_self_signed_jwt_always_use_jwt_access(self, jwt):
284        credentials = service_account.Credentials(
285            SIGNER,
286            self.SERVICE_ACCOUNT_EMAIL,
287            self.TOKEN_URI,
288            always_use_jwt_access=True,
289        )
290
291        credentials._create_self_signed_jwt(None)
292        jwt.from_signing_credentials.assert_not_called()
293
294    @mock.patch("google.oauth2._client.jwt_grant", autospec=True)
295    def test_refresh_success(self, jwt_grant):
296        credentials = self.make_credentials()
297        token = "token"
298        jwt_grant.return_value = (
299            token,
300            _helpers.utcnow() + datetime.timedelta(seconds=500),
301            {},
302        )
303        request = mock.create_autospec(transport.Request, instance=True)
304
305        # Refresh credentials
306        credentials.refresh(request)
307
308        # Check jwt grant call.
309        assert jwt_grant.called
310
311        called_request, token_uri, assertion = jwt_grant.call_args[0]
312        assert called_request == request
313        assert token_uri == credentials._token_uri
314        assert jwt.decode(assertion, PUBLIC_CERT_BYTES)
315        # No further assertion done on the token, as there are separate tests
316        # for checking the authorization grant assertion.
317
318        # Check that the credentials have the token.
319        assert credentials.token == token
320
321        # Check that the credentials are valid (have a token and are not
322        # expired)
323        assert credentials.valid
324
325    @mock.patch("google.oauth2._client.jwt_grant", autospec=True)
326    def test_before_request_refreshes(self, jwt_grant):
327        credentials = self.make_credentials()
328        token = "token"
329        jwt_grant.return_value = (
330            token,
331            _helpers.utcnow() + datetime.timedelta(seconds=500),
332            None,
333        )
334        request = mock.create_autospec(transport.Request, instance=True)
335
336        # Credentials should start as invalid
337        assert not credentials.valid
338
339        # before_request should cause a refresh
340        credentials.before_request(request, "GET", "http://example.com?a=1#3", {})
341
342        # The refresh endpoint should've been called.
343        assert jwt_grant.called
344
345        # Credentials should now be valid.
346        assert credentials.valid
347
348    @mock.patch("google.auth.jwt.Credentials._make_jwt")
349    def test_refresh_with_jwt_credentials(self, make_jwt):
350        credentials = self.make_credentials()
351        credentials._create_self_signed_jwt("https://pubsub.googleapis.com")
352
353        request = mock.create_autospec(transport.Request, instance=True)
354
355        token = "token"
356        expiry = _helpers.utcnow() + datetime.timedelta(seconds=500)
357        make_jwt.return_value = (token, expiry)
358
359        # Credentials should start as invalid
360        assert not credentials.valid
361
362        # before_request should cause a refresh
363        credentials.before_request(request, "GET", "http://example.com?a=1#3", {})
364
365        # Credentials should now be valid.
366        assert credentials.valid
367
368        # Assert make_jwt was called
369        assert make_jwt.called_once()
370
371        assert credentials.token == token
372        assert credentials.expiry == expiry
373
374    @mock.patch("google.oauth2._client.jwt_grant", autospec=True)
375    @mock.patch("google.auth.jwt.Credentials.refresh", autospec=True)
376    def test_refresh_jwt_not_used_for_domain_wide_delegation(
377        self, self_signed_jwt_refresh, jwt_grant
378    ):
379        # Create a domain wide delegation credentials by setting the subject.
380        credentials = service_account.Credentials(
381            SIGNER,
382            self.SERVICE_ACCOUNT_EMAIL,
383            self.TOKEN_URI,
384            always_use_jwt_access=True,
385            subject="subject",
386        )
387        credentials._create_self_signed_jwt("https://pubsub.googleapis.com")
388        jwt_grant.return_value = (
389            "token",
390            _helpers.utcnow() + datetime.timedelta(seconds=500),
391            {},
392        )
393        request = mock.create_autospec(transport.Request, instance=True)
394
395        # Refresh credentials
396        credentials.refresh(request)
397
398        # Make sure we are using jwt_grant and not self signed JWT refresh
399        # method to obtain the token.
400        assert jwt_grant.called
401        assert not self_signed_jwt_refresh.called
402
403
404class TestIDTokenCredentials(object):
405    SERVICE_ACCOUNT_EMAIL = "[email protected]"
406    TOKEN_URI = "https://example.com/oauth2/token"
407    TARGET_AUDIENCE = "https://example.com"
408
409    @classmethod
410    def make_credentials(cls):
411        return service_account.IDTokenCredentials(
412            SIGNER, cls.SERVICE_ACCOUNT_EMAIL, cls.TOKEN_URI, cls.TARGET_AUDIENCE
413        )
414
415    def test_from_service_account_info(self):
416        credentials = service_account.IDTokenCredentials.from_service_account_info(
417            SERVICE_ACCOUNT_INFO, target_audience=self.TARGET_AUDIENCE
418        )
419
420        assert credentials._signer.key_id == SERVICE_ACCOUNT_INFO["private_key_id"]
421        assert credentials.service_account_email == SERVICE_ACCOUNT_INFO["client_email"]
422        assert credentials._token_uri == SERVICE_ACCOUNT_INFO["token_uri"]
423        assert credentials._target_audience == self.TARGET_AUDIENCE
424
425    def test_from_service_account_file(self):
426        info = SERVICE_ACCOUNT_INFO.copy()
427
428        credentials = service_account.IDTokenCredentials.from_service_account_file(
429            SERVICE_ACCOUNT_JSON_FILE, target_audience=self.TARGET_AUDIENCE
430        )
431
432        assert credentials.service_account_email == info["client_email"]
433        assert credentials._signer.key_id == info["private_key_id"]
434        assert credentials._token_uri == info["token_uri"]
435        assert credentials._target_audience == self.TARGET_AUDIENCE
436
437    def test_default_state(self):
438        credentials = self.make_credentials()
439        assert not credentials.valid
440        # Expiration hasn't been set yet
441        assert not credentials.expired
442
443    def test_sign_bytes(self):
444        credentials = self.make_credentials()
445        to_sign = b"123"
446        signature = credentials.sign_bytes(to_sign)
447        assert crypt.verify_signature(to_sign, signature, PUBLIC_CERT_BYTES)
448
449    def test_signer(self):
450        credentials = self.make_credentials()
451        assert isinstance(credentials.signer, crypt.Signer)
452
453    def test_signer_email(self):
454        credentials = self.make_credentials()
455        assert credentials.signer_email == self.SERVICE_ACCOUNT_EMAIL
456
457    def test_with_target_audience(self):
458        credentials = self.make_credentials()
459        new_credentials = credentials.with_target_audience("https://new.example.com")
460        assert new_credentials._target_audience == "https://new.example.com"
461
462    def test_with_quota_project(self):
463        credentials = self.make_credentials()
464        new_credentials = credentials.with_quota_project("project-foo")
465        assert new_credentials._quota_project_id == "project-foo"
466
467    def test__make_authorization_grant_assertion(self):
468        credentials = self.make_credentials()
469        token = credentials._make_authorization_grant_assertion()
470        payload = jwt.decode(token, PUBLIC_CERT_BYTES)
471        assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL
472        assert payload["aud"] == service_account._GOOGLE_OAUTH2_TOKEN_ENDPOINT
473        assert payload["target_audience"] == self.TARGET_AUDIENCE
474
475    @mock.patch("google.oauth2._client.id_token_jwt_grant", autospec=True)
476    def test_refresh_success(self, id_token_jwt_grant):
477        credentials = self.make_credentials()
478        token = "token"
479        id_token_jwt_grant.return_value = (
480            token,
481            _helpers.utcnow() + datetime.timedelta(seconds=500),
482            {},
483        )
484        request = mock.create_autospec(transport.Request, instance=True)
485
486        # Refresh credentials
487        credentials.refresh(request)
488
489        # Check jwt grant call.
490        assert id_token_jwt_grant.called
491
492        called_request, token_uri, assertion = id_token_jwt_grant.call_args[0]
493        assert called_request == request
494        assert token_uri == credentials._token_uri
495        assert jwt.decode(assertion, PUBLIC_CERT_BYTES)
496        # No further assertion done on the token, as there are separate tests
497        # for checking the authorization grant assertion.
498
499        # Check that the credentials have the token.
500        assert credentials.token == token
501
502        # Check that the credentials are valid (have a token and are not
503        # expired)
504        assert credentials.valid
505
506    @mock.patch("google.oauth2._client.id_token_jwt_grant", autospec=True)
507    def test_before_request_refreshes(self, id_token_jwt_grant):
508        credentials = self.make_credentials()
509        token = "token"
510        id_token_jwt_grant.return_value = (
511            token,
512            _helpers.utcnow() + datetime.timedelta(seconds=500),
513            None,
514        )
515        request = mock.create_autospec(transport.Request, instance=True)
516
517        # Credentials should start as invalid
518        assert not credentials.valid
519
520        # before_request should cause a refresh
521        credentials.before_request(request, "GET", "http://example.com?a=1#3", {})
522
523        # The refresh endpoint should've been called.
524        assert id_token_jwt_grant.called
525
526        # Credentials should now be valid.
527        assert credentials.valid
528