1 /* 2 * Copyright 2016 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.auth; 18 19 import static com.google.common.base.Charsets.US_ASCII; 20 import static org.junit.Assert.assertArrayEquals; 21 import static org.junit.Assert.assertEquals; 22 import static org.junit.Assert.assertNull; 23 import static org.junit.Assert.assertTrue; 24 import static org.mockito.ArgumentMatchers.any; 25 import static org.mockito.ArgumentMatchers.eq; 26 import static org.mockito.Mockito.doAnswer; 27 import static org.mockito.Mockito.times; 28 import static org.mockito.Mockito.verify; 29 import static org.mockito.Mockito.when; 30 31 import com.google.auth.Credentials; 32 import com.google.auth.RequestMetadataCallback; 33 import com.google.auth.http.HttpTransportFactory; 34 import com.google.auth.oauth2.AccessToken; 35 import com.google.auth.oauth2.GoogleCredentials; 36 import com.google.auth.oauth2.OAuth2Credentials; 37 import com.google.auth.oauth2.ServiceAccountCredentials; 38 import com.google.common.collect.Iterables; 39 import com.google.common.collect.LinkedListMultimap; 40 import com.google.common.collect.ListMultimap; 41 import com.google.common.collect.Multimaps; 42 import com.google.common.io.BaseEncoding; 43 import io.grpc.Attributes; 44 import io.grpc.CallCredentials; 45 import io.grpc.CallCredentials.MetadataApplier; 46 import io.grpc.Metadata; 47 import io.grpc.MethodDescriptor; 48 import io.grpc.SecurityLevel; 49 import io.grpc.Status; 50 import io.grpc.internal.JsonParser; 51 import io.grpc.testing.TestMethodDescriptors; 52 import java.io.IOException; 53 import java.net.URI; 54 import java.security.KeyPair; 55 import java.security.KeyPairGenerator; 56 import java.util.ArrayList; 57 import java.util.Date; 58 import java.util.List; 59 import java.util.Map; 60 import java.util.concurrent.Executor; 61 import org.junit.After; 62 import org.junit.Before; 63 import org.junit.Rule; 64 import org.junit.Test; 65 import org.junit.runner.RunWith; 66 import org.junit.runners.JUnit4; 67 import org.mockito.ArgumentCaptor; 68 import org.mockito.Captor; 69 import org.mockito.Mock; 70 import org.mockito.Mockito; 71 import org.mockito.invocation.InvocationOnMock; 72 import org.mockito.junit.MockitoJUnit; 73 import org.mockito.junit.MockitoRule; 74 import org.mockito.stubbing.Answer; 75 76 /** 77 * Tests for {@link GoogleAuthLibraryCallCredentials}. 78 */ 79 @RunWith(JUnit4.class) 80 public class GoogleAuthLibraryCallCredentialsTest { 81 82 @Rule 83 public final MockitoRule mocks = MockitoJUnit.rule(); 84 85 private static final Metadata.Key<String> AUTHORIZATION = Metadata.Key.of("Authorization", 86 Metadata.ASCII_STRING_MARSHALLER); 87 private static final Metadata.Key<byte[]> EXTRA_AUTHORIZATION = Metadata.Key.of( 88 "Extra-Authorization-bin", Metadata.BINARY_BYTE_MARSHALLER); 89 90 @Mock 91 private Credentials credentials; 92 93 @Mock 94 private MetadataApplier applier; 95 96 private Executor executor = new Executor() { 97 @Override public void execute(Runnable r) { 98 pendingRunnables.add(r); 99 } 100 }; 101 102 @Captor 103 private ArgumentCaptor<Metadata> headersCaptor; 104 105 @Captor 106 private ArgumentCaptor<Status> statusCaptor; 107 108 private MethodDescriptor<Void, Void> method = MethodDescriptor.<Void, Void>newBuilder() 109 .setType(MethodDescriptor.MethodType.UNKNOWN) 110 .setFullMethodName("a.service/method") 111 .setRequestMarshaller(TestMethodDescriptors.voidMarshaller()) 112 .setResponseMarshaller(TestMethodDescriptors.voidMarshaller()) 113 .build(); 114 private URI expectedUri = URI.create("https://testauthority/a.service"); 115 116 private static final String AUTHORITY = "testauthority"; 117 private static final SecurityLevel SECURITY_LEVEL = SecurityLevel.PRIVACY_AND_INTEGRITY; 118 119 private ArrayList<Runnable> pendingRunnables = new ArrayList<>(); 120 121 @Before setUp()122 public void setUp() throws Exception { 123 doAnswer(new Answer<Void>() { 124 @Override 125 public Void answer(InvocationOnMock invocation) { 126 Credentials mock = (Credentials) invocation.getMock(); 127 URI uri = (URI) invocation.getArguments()[0]; 128 RequestMetadataCallback callback = (RequestMetadataCallback) invocation.getArguments()[2]; 129 Map<String, List<String>> metadata; 130 try { 131 // Default to calling the blocking method, since it is easier to mock 132 metadata = mock.getRequestMetadata(uri); 133 } catch (Exception ex) { 134 callback.onFailure(ex); 135 return null; 136 } 137 callback.onSuccess(metadata); 138 return null; 139 } 140 }).when(credentials).getRequestMetadata( 141 any(URI.class), 142 any(Executor.class), 143 any(RequestMetadataCallback.class)); 144 } 145 146 @After tearDown()147 public void tearDown() { 148 assertEquals(0, pendingRunnables.size()); 149 } 150 151 @Test copyCredentialsToHeaders()152 public void copyCredentialsToHeaders() throws Exception { 153 ListMultimap<String, String> values = LinkedListMultimap.create(); 154 values.put("Authorization", "token1"); 155 values.put("Authorization", "token2"); 156 values.put("Extra-Authorization-bin", "dG9rZW4z"); // bytes "token3" in base64 157 values.put("Extra-Authorization-bin", "dG9rZW40"); // bytes "token4" in base64 158 when(credentials.getRequestMetadata(eq(expectedUri))).thenReturn(Multimaps.asMap(values)); 159 160 GoogleAuthLibraryCallCredentials callCredentials = 161 new GoogleAuthLibraryCallCredentials(credentials); 162 callCredentials.applyRequestMetadata(new RequestInfoImpl(), executor, applier); 163 164 verify(credentials).getRequestMetadata(eq(expectedUri)); 165 verify(applier).apply(headersCaptor.capture()); 166 Metadata headers = headersCaptor.getValue(); 167 Iterable<String> authorization = headers.getAll(AUTHORIZATION); 168 assertArrayEquals(new String[]{"token1", "token2"}, 169 Iterables.toArray(authorization, String.class)); 170 Iterable<byte[]> extraAuthorization = headers.getAll(EXTRA_AUTHORIZATION); 171 assertEquals(2, Iterables.size(extraAuthorization)); 172 assertArrayEquals("token3".getBytes(US_ASCII), Iterables.get(extraAuthorization, 0)); 173 assertArrayEquals("token4".getBytes(US_ASCII), Iterables.get(extraAuthorization, 1)); 174 } 175 176 @Test invalidBase64()177 public void invalidBase64() throws Exception { 178 ListMultimap<String, String> values = LinkedListMultimap.create(); 179 values.put("Extra-Authorization-bin", "dG9rZW4z1"); // invalid base64 180 when(credentials.getRequestMetadata(eq(expectedUri))).thenReturn(Multimaps.asMap(values)); 181 182 GoogleAuthLibraryCallCredentials callCredentials = 183 new GoogleAuthLibraryCallCredentials(credentials); 184 callCredentials.applyRequestMetadata(new RequestInfoImpl(), executor, applier); 185 186 verify(credentials).getRequestMetadata(eq(expectedUri)); 187 verify(applier).fail(statusCaptor.capture()); 188 Status status = statusCaptor.getValue(); 189 assertEquals(Status.Code.UNAUTHENTICATED, status.getCode()); 190 assertEquals(IllegalArgumentException.class, status.getCause().getClass()); 191 } 192 193 @Test credentialsFailsWithIoException()194 public void credentialsFailsWithIoException() throws Exception { 195 Exception exception = new IOException("Broken"); 196 when(credentials.getRequestMetadata(eq(expectedUri))).thenThrow(exception); 197 198 GoogleAuthLibraryCallCredentials callCredentials = 199 new GoogleAuthLibraryCallCredentials(credentials); 200 callCredentials.applyRequestMetadata(new RequestInfoImpl(), executor, applier); 201 202 verify(credentials).getRequestMetadata(eq(expectedUri)); 203 verify(applier).fail(statusCaptor.capture()); 204 Status status = statusCaptor.getValue(); 205 assertEquals(Status.Code.UNAVAILABLE, status.getCode()); 206 assertEquals(exception, status.getCause()); 207 } 208 209 @Test credentialsFailsWithRuntimeException()210 public void credentialsFailsWithRuntimeException() throws Exception { 211 Exception exception = new RuntimeException("Broken"); 212 when(credentials.getRequestMetadata(eq(expectedUri))).thenThrow(exception); 213 214 GoogleAuthLibraryCallCredentials callCredentials = 215 new GoogleAuthLibraryCallCredentials(credentials); 216 callCredentials.applyRequestMetadata(new RequestInfoImpl(), executor, applier); 217 218 verify(credentials).getRequestMetadata(eq(expectedUri)); 219 verify(applier).fail(statusCaptor.capture()); 220 Status status = statusCaptor.getValue(); 221 assertEquals(Status.Code.UNAUTHENTICATED, status.getCode()); 222 assertEquals(exception, status.getCause()); 223 } 224 225 @Test 226 @SuppressWarnings("unchecked") credentialsReturnNullMetadata()227 public void credentialsReturnNullMetadata() throws Exception { 228 ListMultimap<String, String> values = LinkedListMultimap.create(); 229 values.put("Authorization", "token1"); 230 when(credentials.getRequestMetadata(eq(expectedUri))) 231 .thenReturn(null, Multimaps.asMap(values), null); 232 233 GoogleAuthLibraryCallCredentials callCredentials = 234 new GoogleAuthLibraryCallCredentials(credentials); 235 for (int i = 0; i < 3; i++) { 236 callCredentials.applyRequestMetadata(new RequestInfoImpl(), executor, applier); 237 } 238 239 verify(credentials, times(3)).getRequestMetadata(eq(expectedUri)); 240 241 verify(applier, times(3)).apply(headersCaptor.capture()); 242 List<Metadata> headerList = headersCaptor.getAllValues(); 243 assertEquals(3, headerList.size()); 244 245 assertEquals(0, headerList.get(0).keys().size()); 246 247 Iterable<String> authorization = headerList.get(1).getAll(AUTHORIZATION); 248 assertArrayEquals(new String[]{"token1"}, Iterables.toArray(authorization, String.class)); 249 250 assertEquals(0, headerList.get(2).keys().size()); 251 } 252 253 @Test oauth2Credential()254 public void oauth2Credential() { 255 final AccessToken token = new AccessToken("allyourbase", new Date(Long.MAX_VALUE)); 256 OAuth2Credentials credentials = new OAuth2Credentials() { 257 @Override 258 public AccessToken refreshAccessToken() throws IOException { 259 return token; 260 } 261 }; 262 263 GoogleAuthLibraryCallCredentials callCredentials = 264 new GoogleAuthLibraryCallCredentials(credentials); 265 callCredentials.applyRequestMetadata( 266 new RequestInfoImpl(SecurityLevel.NONE), executor, applier); 267 assertEquals(1, runPendingRunnables()); 268 269 verify(applier).apply(headersCaptor.capture()); 270 Metadata headers = headersCaptor.getValue(); 271 Iterable<String> authorization = headers.getAll(AUTHORIZATION); 272 assertArrayEquals(new String[]{"Bearer allyourbase"}, 273 Iterables.toArray(authorization, String.class)); 274 } 275 276 @Test googleCredential_privacyAndIntegrityAllowed()277 public void googleCredential_privacyAndIntegrityAllowed() { 278 final AccessToken token = new AccessToken("allyourbase", new Date(Long.MAX_VALUE)); 279 final Credentials credentials = GoogleCredentials.create(token); 280 281 GoogleAuthLibraryCallCredentials callCredentials = 282 new GoogleAuthLibraryCallCredentials(credentials); 283 callCredentials.applyRequestMetadata( 284 new RequestInfoImpl(SecurityLevel.PRIVACY_AND_INTEGRITY), executor, applier); 285 runPendingRunnables(); 286 287 verify(applier).apply(headersCaptor.capture()); 288 Metadata headers = headersCaptor.getValue(); 289 Iterable<String> authorization = headers.getAll(AUTHORIZATION); 290 assertArrayEquals(new String[]{"Bearer allyourbase"}, 291 Iterables.toArray(authorization, String.class)); 292 } 293 294 @Test googleCredential_integrityDenied()295 public void googleCredential_integrityDenied() { 296 final AccessToken token = new AccessToken("allyourbase", new Date(Long.MAX_VALUE)); 297 final Credentials credentials = GoogleCredentials.create(token); 298 // Anything less than PRIVACY_AND_INTEGRITY should fail 299 300 GoogleAuthLibraryCallCredentials callCredentials = 301 new GoogleAuthLibraryCallCredentials(credentials); 302 callCredentials.applyRequestMetadata( 303 new RequestInfoImpl(SecurityLevel.INTEGRITY), executor, applier); 304 runPendingRunnables(); 305 306 verify(applier).fail(statusCaptor.capture()); 307 Status status = statusCaptor.getValue(); 308 assertEquals(Status.Code.UNAUTHENTICATED, status.getCode()); 309 } 310 311 @Test serviceUri()312 public void serviceUri() throws Exception { 313 GoogleAuthLibraryCallCredentials callCredentials = 314 new GoogleAuthLibraryCallCredentials(credentials); 315 callCredentials.applyRequestMetadata( 316 new RequestInfoImpl("example.com:443"), executor, applier); 317 verify(credentials).getRequestMetadata(eq(new URI("https://example.com/a.service"))); 318 319 callCredentials.applyRequestMetadata( 320 new RequestInfoImpl("example.com:123"), executor, applier); 321 verify(credentials).getRequestMetadata(eq(new URI("https://example.com:123/a.service"))); 322 } 323 324 @Test serviceAccountToJwt()325 public void serviceAccountToJwt() throws Exception { 326 KeyPair pair = KeyPairGenerator.getInstance("RSA").generateKeyPair(); 327 328 HttpTransportFactory factory = Mockito.mock(HttpTransportFactory.class); 329 Mockito.when(factory.create()).thenThrow(new AssertionError()); 330 331 ServiceAccountCredentials credentials = 332 ServiceAccountCredentials.newBuilder() 333 .setClientEmail("[email protected]") 334 .setPrivateKey(pair.getPrivate()) 335 .setPrivateKeyId("test-private-key-id") 336 .setHttpTransportFactory(factory) 337 .build(); 338 339 GoogleAuthLibraryCallCredentials callCredentials = 340 new GoogleAuthLibraryCallCredentials(credentials); 341 callCredentials.applyRequestMetadata(new RequestInfoImpl(), executor, applier); 342 assertEquals(0, runPendingRunnables()); 343 344 verify(applier).apply(headersCaptor.capture()); 345 Metadata headers = headersCaptor.getValue(); 346 String[] authorization = Iterables.toArray(headers.getAll(AUTHORIZATION), String.class); 347 assertEquals(1, authorization.length); 348 assertTrue(authorization[0], authorization[0].startsWith("Bearer ")); 349 // JWT is reasonably long. Normal tokens aren't. 350 assertTrue(authorization[0], authorization[0].length() > 300); 351 } 352 353 @Test oauthClassesNotInClassPath()354 public void oauthClassesNotInClassPath() throws Exception { 355 ListMultimap<String, String> values = LinkedListMultimap.create(); 356 values.put("Authorization", "token1"); 357 when(credentials.getRequestMetadata(eq(expectedUri))).thenReturn(Multimaps.asMap(values)); 358 359 assertNull(GoogleAuthLibraryCallCredentials.createJwtHelperOrNull(null)); 360 GoogleAuthLibraryCallCredentials callCredentials = 361 new GoogleAuthLibraryCallCredentials(credentials, null); 362 callCredentials.applyRequestMetadata(new RequestInfoImpl(), executor, applier); 363 364 verify(credentials).getRequestMetadata(eq(expectedUri)); 365 verify(applier).apply(headersCaptor.capture()); 366 Metadata headers = headersCaptor.getValue(); 367 Iterable<String> authorization = headers.getAll(AUTHORIZATION); 368 assertArrayEquals(new String[]{"token1"}, 369 Iterables.toArray(authorization, String.class)); 370 } 371 372 @Test jwtAccessCredentialsInRequestMetadata()373 public void jwtAccessCredentialsInRequestMetadata() throws Exception { 374 KeyPair pair = KeyPairGenerator.getInstance("RSA").generateKeyPair(); 375 376 ServiceAccountCredentials credentials = 377 ServiceAccountCredentials.newBuilder() 378 .setClientId("test-client") 379 .setClientEmail("[email protected]") 380 .setPrivateKey(pair.getPrivate()) 381 .setPrivateKeyId("test-private-key-id") 382 .setQuotaProjectId("test-quota-project-id") 383 .build(); 384 GoogleAuthLibraryCallCredentials callCredentials = 385 new GoogleAuthLibraryCallCredentials(credentials); 386 callCredentials.applyRequestMetadata(new RequestInfoImpl("example.com:123"), executor, applier); 387 388 verify(applier).apply(headersCaptor.capture()); 389 Metadata headers = headersCaptor.getValue(); 390 String token = 391 Iterables.getOnlyElement(headers.getAll(AUTHORIZATION)).substring("Bearer ".length()); 392 String[] parts = token.split("\\.", 3); 393 String jsonHeader = new String(BaseEncoding.base64Url().decode(parts[0]), US_ASCII); 394 String jsonPayload = new String(BaseEncoding.base64Url().decode(parts[1]), US_ASCII); 395 Map<?, ?> header = (Map<?, ?>) JsonParser.parse(jsonHeader); 396 assertEquals("test-private-key-id", header.get("kid")); 397 Map<?, ?> payload = (Map<?, ?>) JsonParser.parse(jsonPayload); 398 // google-auth-library-java 0.25.2 began stripping the grpc service name from the audience. 399 // Allow tests to pass with both the old and new versions for a while to avoid an atomic upgrade 400 // everywhere google-auth-library-java is used. 401 assertTrue("https://example.com/".equals(payload.get("aud")) 402 || "https://example.com:123/a.service".equals(payload.get("aud"))); 403 assertEquals("[email protected]", payload.get("iss")); 404 assertEquals("[email protected]", payload.get("sub")); 405 406 Metadata.Key<String> quotaProject = Metadata.Key 407 .of("X-Goog-User-Project", Metadata.ASCII_STRING_MARSHALLER); 408 assertEquals("test-quota-project-id", Iterables.getOnlyElement(headers.getAll(quotaProject))); 409 } 410 runPendingRunnables()411 private int runPendingRunnables() { 412 ArrayList<Runnable> savedPendingRunnables = pendingRunnables; 413 pendingRunnables = new ArrayList<>(); 414 for (Runnable r : savedPendingRunnables) { 415 r.run(); 416 } 417 return savedPendingRunnables.size(); 418 } 419 420 private final class RequestInfoImpl extends CallCredentials.RequestInfo { 421 final String authority; 422 final SecurityLevel securityLevel; 423 RequestInfoImpl()424 RequestInfoImpl() { 425 this(AUTHORITY, SECURITY_LEVEL); 426 } 427 RequestInfoImpl(SecurityLevel securityLevel)428 RequestInfoImpl(SecurityLevel securityLevel) { 429 this(AUTHORITY, securityLevel); 430 } 431 RequestInfoImpl(String authority)432 RequestInfoImpl(String authority) { 433 this(authority, SECURITY_LEVEL); 434 } 435 RequestInfoImpl(String authority, SecurityLevel securityLevel)436 RequestInfoImpl(String authority, SecurityLevel securityLevel) { 437 this.authority = authority; 438 this.securityLevel = securityLevel; 439 } 440 441 @Override getMethodDescriptor()442 public MethodDescriptor<?, ?> getMethodDescriptor() { 443 return method; 444 } 445 446 @Override getSecurityLevel()447 public SecurityLevel getSecurityLevel() { 448 return securityLevel; 449 } 450 451 @Override getAuthority()452 public String getAuthority() { 453 return authority; 454 } 455 456 @Override getTransportAttrs()457 public Attributes getTransportAttrs() { 458 return Attributes.EMPTY; 459 } 460 } 461 } 462