1 /* 2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"). 5 * You may not use this file except in compliance with the License. 6 * A copy of the License is located at 7 * 8 * http://aws.amazon.com/apache2.0 9 * 10 * or in the "license" file accompanying this file. This file is distributed 11 * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 12 * express or implied. See the License for the specific language governing 13 * permissions and limitations under the License. 14 */ 15 16 package software.amazon.awssdk.services.s3.internal.s3express; 17 18 import static org.assertj.core.api.Assertions.assertThat; 19 import static org.assertj.core.api.Assertions.assertThatThrownBy; 20 import static org.mockito.ArgumentMatchers.any; 21 import static org.mockito.Mockito.mock; 22 import static org.mockito.Mockito.times; 23 import static org.mockito.Mockito.verify; 24 import static org.mockito.Mockito.when; 25 26 import java.time.Duration; 27 import java.time.Instant; 28 import java.util.ArrayList; 29 import java.util.List; 30 import java.util.function.Function; 31 import org.junit.jupiter.api.BeforeEach; 32 import org.junit.jupiter.api.Test; 33 import org.junit.jupiter.api.extension.ExtendWith; 34 import org.mockito.Mock; 35 import org.mockito.invocation.InvocationOnMock; 36 import org.mockito.junit.jupiter.MockitoExtension; 37 import org.mockito.stubbing.Answer; 38 import software.amazon.awssdk.core.SdkClient; 39 import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; 40 import software.amazon.awssdk.services.s3.model.CreateSessionResponse; 41 import software.amazon.awssdk.services.s3.model.SessionCredentials; 42 43 44 @ExtendWith(MockitoExtension.class) 45 class S3ExpressCacheRefreshTest { 46 47 private static final String ACCESS_KEY = "accessKeyId"; 48 private static final String SECRET_KEY = "secretAccessKey"; 49 private static final String SESSION_TOKEN = "sessionToken"; 50 private static final S3ExpressIdentityKey KEY = S3ExpressIdentityKey.builder() 51 .bucket("Bucket-1") 52 .client(mock(SdkClient.class)) 53 .identity(mock(AwsCredentialsIdentity.class)) 54 .build(); 55 56 @Mock 57 Function<S3ExpressIdentityKey, SessionCredentials> identitySupplier; 58 59 @BeforeEach methodSetup()60 public void methodSetup() { 61 62 } 63 64 @Test when_supplierIsAccessedMultipleTimesWithinExpirationTime_NoExtraCallsAreMade()65 void when_supplierIsAccessedMultipleTimesWithinExpirationTime_NoExtraCallsAreMade() { 66 when(identitySupplier.apply(any())).thenAnswer(invocation -> { 67 CreateSessionResponse sessionResponse = createSessionResponse(1, Instant.now().plus(Duration.ofSeconds(5))); 68 return sessionResponse.credentials(); 69 }); 70 71 CachedS3ExpressCredentials cache = CachedS3ExpressCredentials.builder(identitySupplier) 72 .key(KEY) 73 .prefetchTime(Duration.ofSeconds(1)) 74 .staleTime(Duration.ofMillis(200)) 75 .build(); 76 77 SessionCredentials credentials; 78 credentials = cache.get(); 79 credentials = cache.get(); 80 81 verify(identitySupplier, times(1)).apply(KEY); 82 verifyCredentialsInSequence(credentials, 1); 83 } 84 85 @Test when_credentialsReachPrefetchRange_credentialsAreRefreshed()86 void when_credentialsReachPrefetchRange_credentialsAreRefreshed() throws InterruptedException { 87 when(identitySupplier.apply(KEY)).thenAnswer(new Answer<SessionCredentials>() { 88 private int i = 0; 89 @Override 90 public SessionCredentials answer(InvocationOnMock invocation) { 91 i++; 92 CreateSessionResponse sessionResponse = createSessionResponse(i, Instant.now().plus(Duration.ofSeconds(10))); 93 return sessionResponse.credentials(); 94 } 95 }); 96 97 CachedS3ExpressCredentials cache = CachedS3ExpressCredentials.builder(identitySupplier) 98 .key(KEY) 99 .prefetchTime(Duration.ofSeconds(2)) 100 .staleTime(Duration.ofMillis(100)) 101 .build(); 102 103 104 List<SessionCredentials> sessionCredentials = new ArrayList<>(); 105 sessionCredentials.add(cache.get()); 106 Thread.sleep(1 * 1000); 107 sessionCredentials.add(cache.get()); 108 Thread.sleep(10 * 1000); 109 sessionCredentials.add(cache.get()); 110 Thread.sleep(2 * 1000); 111 sessionCredentials.add(cache.get()); 112 113 verify(identitySupplier, times(2)).apply(KEY); 114 115 String firstCredentialAccessKey = stringWithSequenceNumber(ACCESS_KEY, 1); 116 String secondCredentialAccessKey = stringWithSequenceNumber(ACCESS_KEY, 2); 117 assertThat(sessionCredentials.get(0).accessKeyId()).isEqualTo(firstCredentialAccessKey); 118 assertThat(sessionCredentials.get(1).accessKeyId()).isEqualTo(firstCredentialAccessKey); 119 assertThat(sessionCredentials.get(2).accessKeyId()).isEqualTo(secondCredentialAccessKey); 120 assertThat(sessionCredentials.get(3).accessKeyId()).isEqualTo(secondCredentialAccessKey); 121 } 122 123 @Test credentials_getRefreshedMultipleTimes()124 void credentials_getRefreshedMultipleTimes() throws InterruptedException { 125 when(identitySupplier.apply(KEY)).thenAnswer(new Answer<SessionCredentials>() { 126 private int sequenceNumber = 0; 127 @Override 128 public SessionCredentials answer(InvocationOnMock invocation) { 129 sequenceNumber++; 130 CreateSessionResponse sessionResponse = createSessionResponse(sequenceNumber, 131 Instant.now().plus(Duration.ofSeconds(1))); 132 return sessionResponse.credentials(); 133 } 134 }); 135 CachedS3ExpressCredentials cache = CachedS3ExpressCredentials.builder(identitySupplier) 136 .key(KEY) 137 .prefetchTime(Duration.ofMillis(200)) 138 .staleTime(Duration.ofMillis(40)) 139 .build(); 140 SessionCredentials sessionCredentials = null; 141 int numGets = 20; 142 int minimumRefreshesExpectedWithMargin = 15;; 143 144 for (int i = 0; i < numGets; i++) { 145 sessionCredentials = cache.get(); 146 Thread.sleep(1000); 147 } 148 assertThat(sessionCredentials).isNotNull(); 149 int responseSequenceNumber = parseSequenceNumber(sessionCredentials.accessKeyId()); 150 assertThat(responseSequenceNumber).isGreaterThan(minimumRefreshesExpectedWithMargin); 151 } 152 153 @Test when_supplierThrowsException_ExceptionIsPropagated()154 void when_supplierThrowsException_ExceptionIsPropagated() { 155 when(identitySupplier.apply(any())).thenAnswer(invocation -> { 156 throw new RuntimeException("Oops"); 157 }); 158 159 CachedS3ExpressCredentials cache = CachedS3ExpressCredentials.builder(identitySupplier).key(KEY).build(); 160 161 assertThatThrownBy(() -> cache.get()).hasMessage("Oops"); 162 } 163 verifyCredentialsInSequence(SessionCredentials actualCredentials, int sequenceNumber)164 private void verifyCredentialsInSequence(SessionCredentials actualCredentials, int sequenceNumber) { 165 assertThat(actualCredentials.accessKeyId()).isEqualTo(stringWithSequenceNumber(ACCESS_KEY, sequenceNumber)); 166 assertThat(actualCredentials.secretAccessKey()).isEqualTo(stringWithSequenceNumber(SECRET_KEY, sequenceNumber)); 167 assertThat(actualCredentials.sessionToken()).isEqualTo(stringWithSequenceNumber(SESSION_TOKEN, sequenceNumber)); 168 } 169 createSessionResponse(int sequenceNumber, Instant expires)170 private static CreateSessionResponse createSessionResponse(int sequenceNumber, Instant expires) { 171 return CreateSessionResponse.builder() 172 .credentials(SessionCredentials.builder() 173 .accessKeyId(stringWithSequenceNumber(ACCESS_KEY, sequenceNumber)) 174 .secretAccessKey(stringWithSequenceNumber(SECRET_KEY, sequenceNumber)) 175 .sessionToken(stringWithSequenceNumber(SESSION_TOKEN, sequenceNumber)) 176 .expiration(expires) 177 .build()) 178 .build(); 179 } 180 stringWithSequenceNumber(String value, int sequenceNumber)181 private static String stringWithSequenceNumber(String value, int sequenceNumber) { 182 return value + "_" + sequenceNumber; 183 } 184 parseSequenceNumber(String s)185 private Integer parseSequenceNumber(String s) { 186 int sequenceNumIndex = s.lastIndexOf('_'); 187 String substring = s.substring(sequenceNumIndex + 1); 188 return Integer.parseInt(substring); 189 } 190 }