1 /*
2  * Copyright 2016 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.auth;
18 
19 import static com.google.common.base.Charsets.US_ASCII;
20 import static org.junit.Assert.assertArrayEquals;
21 import static org.junit.Assert.assertEquals;
22 import static org.junit.Assert.assertNull;
23 import static org.junit.Assert.assertTrue;
24 import static org.mockito.ArgumentMatchers.any;
25 import static org.mockito.ArgumentMatchers.eq;
26 import static org.mockito.Mockito.doAnswer;
27 import static org.mockito.Mockito.times;
28 import static org.mockito.Mockito.verify;
29 import static org.mockito.Mockito.when;
30 
31 import com.google.auth.Credentials;
32 import com.google.auth.RequestMetadataCallback;
33 import com.google.auth.http.HttpTransportFactory;
34 import com.google.auth.oauth2.AccessToken;
35 import com.google.auth.oauth2.GoogleCredentials;
36 import com.google.auth.oauth2.OAuth2Credentials;
37 import com.google.auth.oauth2.ServiceAccountCredentials;
38 import com.google.common.collect.Iterables;
39 import com.google.common.collect.LinkedListMultimap;
40 import com.google.common.collect.ListMultimap;
41 import com.google.common.collect.Multimaps;
42 import com.google.common.io.BaseEncoding;
43 import io.grpc.Attributes;
44 import io.grpc.CallCredentials;
45 import io.grpc.CallCredentials.MetadataApplier;
46 import io.grpc.Metadata;
47 import io.grpc.MethodDescriptor;
48 import io.grpc.SecurityLevel;
49 import io.grpc.Status;
50 import io.grpc.internal.JsonParser;
51 import io.grpc.testing.TestMethodDescriptors;
52 import java.io.IOException;
53 import java.net.URI;
54 import java.security.KeyPair;
55 import java.security.KeyPairGenerator;
56 import java.util.ArrayList;
57 import java.util.Date;
58 import java.util.List;
59 import java.util.Map;
60 import java.util.concurrent.Executor;
61 import org.junit.After;
62 import org.junit.Before;
63 import org.junit.Rule;
64 import org.junit.Test;
65 import org.junit.runner.RunWith;
66 import org.junit.runners.JUnit4;
67 import org.mockito.ArgumentCaptor;
68 import org.mockito.Captor;
69 import org.mockito.Mock;
70 import org.mockito.Mockito;
71 import org.mockito.invocation.InvocationOnMock;
72 import org.mockito.junit.MockitoJUnit;
73 import org.mockito.junit.MockitoRule;
74 import org.mockito.stubbing.Answer;
75 
76 /**
77  * Tests for {@link GoogleAuthLibraryCallCredentials}.
78  */
79 @RunWith(JUnit4.class)
80 public class GoogleAuthLibraryCallCredentialsTest {
81 
82   @Rule
83   public final MockitoRule mocks = MockitoJUnit.rule();
84 
85   private static final Metadata.Key<String> AUTHORIZATION = Metadata.Key.of("Authorization",
86       Metadata.ASCII_STRING_MARSHALLER);
87   private static final Metadata.Key<byte[]> EXTRA_AUTHORIZATION = Metadata.Key.of(
88       "Extra-Authorization-bin", Metadata.BINARY_BYTE_MARSHALLER);
89 
90   @Mock
91   private Credentials credentials;
92 
93   @Mock
94   private MetadataApplier applier;
95 
96   private Executor executor = new Executor() {
97     @Override public void execute(Runnable r) {
98       pendingRunnables.add(r);
99     }
100   };
101 
102   @Captor
103   private ArgumentCaptor<Metadata> headersCaptor;
104 
105   @Captor
106   private ArgumentCaptor<Status> statusCaptor;
107 
108   private MethodDescriptor<Void, Void> method = MethodDescriptor.<Void, Void>newBuilder()
109       .setType(MethodDescriptor.MethodType.UNKNOWN)
110       .setFullMethodName("a.service/method")
111       .setRequestMarshaller(TestMethodDescriptors.voidMarshaller())
112       .setResponseMarshaller(TestMethodDescriptors.voidMarshaller())
113       .build();
114   private URI expectedUri = URI.create("https://testauthority/a.service");
115 
116   private static final String AUTHORITY = "testauthority";
117   private static final SecurityLevel SECURITY_LEVEL = SecurityLevel.PRIVACY_AND_INTEGRITY;
118 
119   private ArrayList<Runnable> pendingRunnables = new ArrayList<>();
120 
121   @Before
setUp()122   public void setUp() throws Exception {
123     doAnswer(new Answer<Void>() {
124       @Override
125       public Void answer(InvocationOnMock invocation) {
126         Credentials mock = (Credentials) invocation.getMock();
127         URI uri = (URI) invocation.getArguments()[0];
128         RequestMetadataCallback callback = (RequestMetadataCallback) invocation.getArguments()[2];
129         Map<String, List<String>> metadata;
130         try {
131           // Default to calling the blocking method, since it is easier to mock
132           metadata = mock.getRequestMetadata(uri);
133         } catch (Exception ex) {
134           callback.onFailure(ex);
135           return null;
136         }
137         callback.onSuccess(metadata);
138         return null;
139       }
140     }).when(credentials).getRequestMetadata(
141         any(URI.class),
142         any(Executor.class),
143         any(RequestMetadataCallback.class));
144   }
145 
146   @After
tearDown()147   public void tearDown() {
148     assertEquals(0, pendingRunnables.size());
149   }
150 
151   @Test
copyCredentialsToHeaders()152   public void copyCredentialsToHeaders() throws Exception {
153     ListMultimap<String, String> values = LinkedListMultimap.create();
154     values.put("Authorization", "token1");
155     values.put("Authorization", "token2");
156     values.put("Extra-Authorization-bin", "dG9rZW4z");  // bytes "token3" in base64
157     values.put("Extra-Authorization-bin", "dG9rZW40");  // bytes "token4" in base64
158     when(credentials.getRequestMetadata(eq(expectedUri))).thenReturn(Multimaps.asMap(values));
159 
160     GoogleAuthLibraryCallCredentials callCredentials =
161         new GoogleAuthLibraryCallCredentials(credentials);
162     callCredentials.applyRequestMetadata(new RequestInfoImpl(), executor, applier);
163 
164     verify(credentials).getRequestMetadata(eq(expectedUri));
165     verify(applier).apply(headersCaptor.capture());
166     Metadata headers = headersCaptor.getValue();
167     Iterable<String> authorization = headers.getAll(AUTHORIZATION);
168     assertArrayEquals(new String[]{"token1", "token2"},
169         Iterables.toArray(authorization, String.class));
170     Iterable<byte[]> extraAuthorization = headers.getAll(EXTRA_AUTHORIZATION);
171     assertEquals(2, Iterables.size(extraAuthorization));
172     assertArrayEquals("token3".getBytes(US_ASCII), Iterables.get(extraAuthorization, 0));
173     assertArrayEquals("token4".getBytes(US_ASCII), Iterables.get(extraAuthorization, 1));
174   }
175 
176   @Test
invalidBase64()177   public void invalidBase64() throws Exception {
178     ListMultimap<String, String> values = LinkedListMultimap.create();
179     values.put("Extra-Authorization-bin", "dG9rZW4z1");  // invalid base64
180     when(credentials.getRequestMetadata(eq(expectedUri))).thenReturn(Multimaps.asMap(values));
181 
182     GoogleAuthLibraryCallCredentials callCredentials =
183         new GoogleAuthLibraryCallCredentials(credentials);
184     callCredentials.applyRequestMetadata(new RequestInfoImpl(), executor, applier);
185 
186     verify(credentials).getRequestMetadata(eq(expectedUri));
187     verify(applier).fail(statusCaptor.capture());
188     Status status = statusCaptor.getValue();
189     assertEquals(Status.Code.UNAUTHENTICATED, status.getCode());
190     assertEquals(IllegalArgumentException.class, status.getCause().getClass());
191   }
192 
193   @Test
credentialsFailsWithIoException()194   public void credentialsFailsWithIoException() throws Exception {
195     Exception exception = new IOException("Broken");
196     when(credentials.getRequestMetadata(eq(expectedUri))).thenThrow(exception);
197 
198     GoogleAuthLibraryCallCredentials callCredentials =
199         new GoogleAuthLibraryCallCredentials(credentials);
200     callCredentials.applyRequestMetadata(new RequestInfoImpl(), executor, applier);
201 
202     verify(credentials).getRequestMetadata(eq(expectedUri));
203     verify(applier).fail(statusCaptor.capture());
204     Status status = statusCaptor.getValue();
205     assertEquals(Status.Code.UNAVAILABLE, status.getCode());
206     assertEquals(exception, status.getCause());
207   }
208 
209   @Test
credentialsFailsWithRuntimeException()210   public void credentialsFailsWithRuntimeException() throws Exception {
211     Exception exception = new RuntimeException("Broken");
212     when(credentials.getRequestMetadata(eq(expectedUri))).thenThrow(exception);
213 
214     GoogleAuthLibraryCallCredentials callCredentials =
215         new GoogleAuthLibraryCallCredentials(credentials);
216     callCredentials.applyRequestMetadata(new RequestInfoImpl(), executor, applier);
217 
218     verify(credentials).getRequestMetadata(eq(expectedUri));
219     verify(applier).fail(statusCaptor.capture());
220     Status status = statusCaptor.getValue();
221     assertEquals(Status.Code.UNAUTHENTICATED, status.getCode());
222     assertEquals(exception, status.getCause());
223   }
224 
225   @Test
226   @SuppressWarnings("unchecked")
credentialsReturnNullMetadata()227   public void credentialsReturnNullMetadata() throws Exception {
228     ListMultimap<String, String> values = LinkedListMultimap.create();
229     values.put("Authorization", "token1");
230     when(credentials.getRequestMetadata(eq(expectedUri)))
231         .thenReturn(null, Multimaps.asMap(values), null);
232 
233     GoogleAuthLibraryCallCredentials callCredentials =
234         new GoogleAuthLibraryCallCredentials(credentials);
235     for (int i = 0; i < 3; i++) {
236       callCredentials.applyRequestMetadata(new RequestInfoImpl(), executor, applier);
237     }
238 
239     verify(credentials, times(3)).getRequestMetadata(eq(expectedUri));
240 
241     verify(applier, times(3)).apply(headersCaptor.capture());
242     List<Metadata> headerList = headersCaptor.getAllValues();
243     assertEquals(3, headerList.size());
244 
245     assertEquals(0, headerList.get(0).keys().size());
246 
247     Iterable<String> authorization = headerList.get(1).getAll(AUTHORIZATION);
248     assertArrayEquals(new String[]{"token1"}, Iterables.toArray(authorization, String.class));
249 
250     assertEquals(0, headerList.get(2).keys().size());
251   }
252 
253   @Test
oauth2Credential()254   public void oauth2Credential() {
255     final AccessToken token = new AccessToken("allyourbase", new Date(Long.MAX_VALUE));
256     OAuth2Credentials credentials = new OAuth2Credentials() {
257       @Override
258       public AccessToken refreshAccessToken() throws IOException {
259         return token;
260       }
261     };
262 
263     GoogleAuthLibraryCallCredentials callCredentials =
264         new GoogleAuthLibraryCallCredentials(credentials);
265     callCredentials.applyRequestMetadata(
266         new RequestInfoImpl(SecurityLevel.NONE), executor, applier);
267     assertEquals(1, runPendingRunnables());
268 
269     verify(applier).apply(headersCaptor.capture());
270     Metadata headers = headersCaptor.getValue();
271     Iterable<String> authorization = headers.getAll(AUTHORIZATION);
272     assertArrayEquals(new String[]{"Bearer allyourbase"},
273         Iterables.toArray(authorization, String.class));
274   }
275 
276   @Test
googleCredential_privacyAndIntegrityAllowed()277   public void googleCredential_privacyAndIntegrityAllowed() {
278     final AccessToken token = new AccessToken("allyourbase", new Date(Long.MAX_VALUE));
279     final Credentials credentials = GoogleCredentials.create(token);
280 
281     GoogleAuthLibraryCallCredentials callCredentials =
282         new GoogleAuthLibraryCallCredentials(credentials);
283     callCredentials.applyRequestMetadata(
284         new RequestInfoImpl(SecurityLevel.PRIVACY_AND_INTEGRITY), executor, applier);
285     runPendingRunnables();
286 
287     verify(applier).apply(headersCaptor.capture());
288     Metadata headers = headersCaptor.getValue();
289     Iterable<String> authorization = headers.getAll(AUTHORIZATION);
290     assertArrayEquals(new String[]{"Bearer allyourbase"},
291         Iterables.toArray(authorization, String.class));
292   }
293 
294   @Test
googleCredential_integrityDenied()295   public void googleCredential_integrityDenied() {
296     final AccessToken token = new AccessToken("allyourbase", new Date(Long.MAX_VALUE));
297     final Credentials credentials = GoogleCredentials.create(token);
298     // Anything less than PRIVACY_AND_INTEGRITY should fail
299 
300     GoogleAuthLibraryCallCredentials callCredentials =
301         new GoogleAuthLibraryCallCredentials(credentials);
302     callCredentials.applyRequestMetadata(
303         new RequestInfoImpl(SecurityLevel.INTEGRITY), executor, applier);
304     runPendingRunnables();
305 
306     verify(applier).fail(statusCaptor.capture());
307     Status status = statusCaptor.getValue();
308     assertEquals(Status.Code.UNAUTHENTICATED, status.getCode());
309   }
310 
311   @Test
serviceUri()312   public void serviceUri() throws Exception {
313     GoogleAuthLibraryCallCredentials callCredentials =
314         new GoogleAuthLibraryCallCredentials(credentials);
315     callCredentials.applyRequestMetadata(
316         new RequestInfoImpl("example.com:443"), executor, applier);
317     verify(credentials).getRequestMetadata(eq(new URI("https://example.com/a.service")));
318 
319     callCredentials.applyRequestMetadata(
320         new RequestInfoImpl("example.com:123"), executor, applier);
321     verify(credentials).getRequestMetadata(eq(new URI("https://example.com:123/a.service")));
322   }
323 
324   @Test
serviceAccountToJwt()325   public void serviceAccountToJwt() throws Exception {
326     KeyPair pair = KeyPairGenerator.getInstance("RSA").generateKeyPair();
327 
328     HttpTransportFactory factory = Mockito.mock(HttpTransportFactory.class);
329     Mockito.when(factory.create()).thenThrow(new AssertionError());
330 
331     ServiceAccountCredentials credentials =
332         ServiceAccountCredentials.newBuilder()
333             .setClientEmail("[email protected]")
334             .setPrivateKey(pair.getPrivate())
335             .setPrivateKeyId("test-private-key-id")
336             .setHttpTransportFactory(factory)
337             .build();
338 
339     GoogleAuthLibraryCallCredentials callCredentials =
340         new GoogleAuthLibraryCallCredentials(credentials);
341     callCredentials.applyRequestMetadata(new RequestInfoImpl(), executor, applier);
342     assertEquals(0, runPendingRunnables());
343 
344     verify(applier).apply(headersCaptor.capture());
345     Metadata headers = headersCaptor.getValue();
346     String[] authorization = Iterables.toArray(headers.getAll(AUTHORIZATION), String.class);
347     assertEquals(1, authorization.length);
348     assertTrue(authorization[0], authorization[0].startsWith("Bearer "));
349     // JWT is reasonably long. Normal tokens aren't.
350     assertTrue(authorization[0], authorization[0].length() > 300);
351   }
352 
353   @Test
oauthClassesNotInClassPath()354   public void oauthClassesNotInClassPath() throws Exception {
355     ListMultimap<String, String> values = LinkedListMultimap.create();
356     values.put("Authorization", "token1");
357     when(credentials.getRequestMetadata(eq(expectedUri))).thenReturn(Multimaps.asMap(values));
358 
359     assertNull(GoogleAuthLibraryCallCredentials.createJwtHelperOrNull(null));
360     GoogleAuthLibraryCallCredentials callCredentials =
361         new GoogleAuthLibraryCallCredentials(credentials, null);
362     callCredentials.applyRequestMetadata(new RequestInfoImpl(), executor, applier);
363 
364     verify(credentials).getRequestMetadata(eq(expectedUri));
365     verify(applier).apply(headersCaptor.capture());
366     Metadata headers = headersCaptor.getValue();
367     Iterable<String> authorization = headers.getAll(AUTHORIZATION);
368     assertArrayEquals(new String[]{"token1"},
369         Iterables.toArray(authorization, String.class));
370   }
371 
372   @Test
jwtAccessCredentialsInRequestMetadata()373   public void jwtAccessCredentialsInRequestMetadata() throws Exception {
374     KeyPair pair = KeyPairGenerator.getInstance("RSA").generateKeyPair();
375 
376     ServiceAccountCredentials credentials =
377         ServiceAccountCredentials.newBuilder()
378             .setClientId("test-client")
379             .setClientEmail("[email protected]")
380             .setPrivateKey(pair.getPrivate())
381             .setPrivateKeyId("test-private-key-id")
382             .setQuotaProjectId("test-quota-project-id")
383             .build();
384     GoogleAuthLibraryCallCredentials callCredentials =
385         new GoogleAuthLibraryCallCredentials(credentials);
386     callCredentials.applyRequestMetadata(new RequestInfoImpl("example.com:123"), executor, applier);
387 
388     verify(applier).apply(headersCaptor.capture());
389     Metadata headers = headersCaptor.getValue();
390     String token =
391         Iterables.getOnlyElement(headers.getAll(AUTHORIZATION)).substring("Bearer ".length());
392     String[] parts = token.split("\\.", 3);
393     String jsonHeader = new String(BaseEncoding.base64Url().decode(parts[0]), US_ASCII);
394     String jsonPayload = new String(BaseEncoding.base64Url().decode(parts[1]), US_ASCII);
395     Map<?, ?> header = (Map<?, ?>) JsonParser.parse(jsonHeader);
396     assertEquals("test-private-key-id", header.get("kid"));
397     Map<?, ?> payload = (Map<?, ?>) JsonParser.parse(jsonPayload);
398     // google-auth-library-java 0.25.2 began stripping the grpc service name from the audience.
399     // Allow tests to pass with both the old and new versions for a while to avoid an atomic upgrade
400     // everywhere google-auth-library-java is used.
401     assertTrue("https://example.com/".equals(payload.get("aud"))
402         || "https://example.com:123/a.service".equals(payload.get("aud")));
403     assertEquals("[email protected]", payload.get("iss"));
404     assertEquals("[email protected]", payload.get("sub"));
405 
406     Metadata.Key<String> quotaProject = Metadata.Key
407         .of("X-Goog-User-Project", Metadata.ASCII_STRING_MARSHALLER);
408     assertEquals("test-quota-project-id", Iterables.getOnlyElement(headers.getAll(quotaProject)));
409   }
410 
runPendingRunnables()411   private int runPendingRunnables() {
412     ArrayList<Runnable> savedPendingRunnables = pendingRunnables;
413     pendingRunnables = new ArrayList<>();
414     for (Runnable r : savedPendingRunnables) {
415       r.run();
416     }
417     return savedPendingRunnables.size();
418   }
419 
420   private final class RequestInfoImpl extends CallCredentials.RequestInfo {
421     final String authority;
422     final SecurityLevel securityLevel;
423 
RequestInfoImpl()424     RequestInfoImpl() {
425       this(AUTHORITY, SECURITY_LEVEL);
426     }
427 
RequestInfoImpl(SecurityLevel securityLevel)428     RequestInfoImpl(SecurityLevel securityLevel) {
429       this(AUTHORITY, securityLevel);
430     }
431 
RequestInfoImpl(String authority)432     RequestInfoImpl(String authority) {
433       this(authority, SECURITY_LEVEL);
434     }
435 
RequestInfoImpl(String authority, SecurityLevel securityLevel)436     RequestInfoImpl(String authority, SecurityLevel securityLevel) {
437       this.authority = authority;
438       this.securityLevel = securityLevel;
439     }
440 
441     @Override
getMethodDescriptor()442     public MethodDescriptor<?, ?> getMethodDescriptor() {
443       return method;
444     }
445 
446     @Override
getSecurityLevel()447     public SecurityLevel getSecurityLevel() {
448       return securityLevel;
449     }
450 
451     @Override
getAuthority()452     public String getAuthority() {
453       return authority;
454     }
455 
456     @Override
getTransportAttrs()457     public Attributes getTransportAttrs() {
458       return Attributes.EMPTY;
459     }
460   }
461 }
462