1 /*
2  * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License").
5  * You may not use this file except in compliance with the License.
6  * A copy of the License is located at
7  *
8  *  http://aws.amazon.com/apache2.0
9  *
10  * or in the "license" file accompanying this file. This file is distributed
11  * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12  * express or implied. See the License for the specific language governing
13  * permissions and limitations under the License.
14  */
15 
16 package software.amazon.awssdk.services.s3.internal.s3express;
17 
18 import static org.assertj.core.api.Assertions.assertThat;
19 import static org.assertj.core.api.Assertions.assertThatThrownBy;
20 import static org.mockito.ArgumentMatchers.any;
21 import static org.mockito.Mockito.mock;
22 import static org.mockito.Mockito.times;
23 import static org.mockito.Mockito.verify;
24 import static org.mockito.Mockito.when;
25 
26 import java.time.Duration;
27 import java.time.Instant;
28 import java.util.ArrayList;
29 import java.util.List;
30 import java.util.function.Function;
31 import org.junit.jupiter.api.BeforeEach;
32 import org.junit.jupiter.api.Test;
33 import org.junit.jupiter.api.extension.ExtendWith;
34 import org.mockito.Mock;
35 import org.mockito.invocation.InvocationOnMock;
36 import org.mockito.junit.jupiter.MockitoExtension;
37 import org.mockito.stubbing.Answer;
38 import software.amazon.awssdk.core.SdkClient;
39 import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity;
40 import software.amazon.awssdk.services.s3.model.CreateSessionResponse;
41 import software.amazon.awssdk.services.s3.model.SessionCredentials;
42 
43 
44 @ExtendWith(MockitoExtension.class)
45 class S3ExpressCacheRefreshTest {
46 
47     private static final String ACCESS_KEY = "accessKeyId";
48     private static final String SECRET_KEY = "secretAccessKey";
49     private static final String SESSION_TOKEN = "sessionToken";
50     private static final S3ExpressIdentityKey KEY = S3ExpressIdentityKey.builder()
51                                                                         .bucket("Bucket-1")
52                                                                         .client(mock(SdkClient.class))
53                                                                         .identity(mock(AwsCredentialsIdentity.class))
54                                                                         .build();
55 
56     @Mock
57     Function<S3ExpressIdentityKey, SessionCredentials> identitySupplier;
58 
59     @BeforeEach
methodSetup()60     public void methodSetup() {
61 
62     }
63 
64     @Test
when_supplierIsAccessedMultipleTimesWithinExpirationTime_NoExtraCallsAreMade()65     void when_supplierIsAccessedMultipleTimesWithinExpirationTime_NoExtraCallsAreMade() {
66         when(identitySupplier.apply(any())).thenAnswer(invocation -> {
67             CreateSessionResponse sessionResponse = createSessionResponse(1, Instant.now().plus(Duration.ofSeconds(5)));
68             return sessionResponse.credentials();
69         });
70 
71         CachedS3ExpressCredentials cache = CachedS3ExpressCredentials.builder(identitySupplier)
72                                                                      .key(KEY)
73                                                                      .prefetchTime(Duration.ofSeconds(1))
74                                                                      .staleTime(Duration.ofMillis(200))
75                                                                      .build();
76 
77         SessionCredentials credentials;
78         credentials = cache.get();
79         credentials = cache.get();
80 
81         verify(identitySupplier, times(1)).apply(KEY);
82         verifyCredentialsInSequence(credentials, 1);
83     }
84 
85     @Test
when_credentialsReachPrefetchRange_credentialsAreRefreshed()86     void when_credentialsReachPrefetchRange_credentialsAreRefreshed() throws InterruptedException {
87         when(identitySupplier.apply(KEY)).thenAnswer(new Answer<SessionCredentials>() {
88             private int i = 0;
89             @Override
90             public SessionCredentials answer(InvocationOnMock invocation) {
91                 i++;
92                 CreateSessionResponse sessionResponse = createSessionResponse(i, Instant.now().plus(Duration.ofSeconds(10)));
93                 return sessionResponse.credentials();
94             }
95         });
96 
97         CachedS3ExpressCredentials cache = CachedS3ExpressCredentials.builder(identitySupplier)
98                                                                          .key(KEY)
99                                                                          .prefetchTime(Duration.ofSeconds(2))
100                                                                          .staleTime(Duration.ofMillis(100))
101                                                                          .build();
102 
103 
104         List<SessionCredentials> sessionCredentials = new ArrayList<>();
105         sessionCredentials.add(cache.get());
106         Thread.sleep(1 * 1000);
107         sessionCredentials.add(cache.get());
108         Thread.sleep(10 * 1000);
109         sessionCredentials.add(cache.get());
110         Thread.sleep(2 * 1000);
111         sessionCredentials.add(cache.get());
112 
113         verify(identitySupplier, times(2)).apply(KEY);
114 
115         String firstCredentialAccessKey = stringWithSequenceNumber(ACCESS_KEY, 1);
116         String secondCredentialAccessKey = stringWithSequenceNumber(ACCESS_KEY, 2);
117         assertThat(sessionCredentials.get(0).accessKeyId()).isEqualTo(firstCredentialAccessKey);
118         assertThat(sessionCredentials.get(1).accessKeyId()).isEqualTo(firstCredentialAccessKey);
119         assertThat(sessionCredentials.get(2).accessKeyId()).isEqualTo(secondCredentialAccessKey);
120         assertThat(sessionCredentials.get(3).accessKeyId()).isEqualTo(secondCredentialAccessKey);
121     }
122 
123     @Test
credentials_getRefreshedMultipleTimes()124     void credentials_getRefreshedMultipleTimes() throws InterruptedException {
125         when(identitySupplier.apply(KEY)).thenAnswer(new Answer<SessionCredentials>() {
126             private int sequenceNumber = 0;
127             @Override
128             public SessionCredentials answer(InvocationOnMock invocation) {
129                 sequenceNumber++;
130                 CreateSessionResponse sessionResponse = createSessionResponse(sequenceNumber,
131                                                                               Instant.now().plus(Duration.ofSeconds(1)));
132                 return sessionResponse.credentials();
133             }
134         });
135         CachedS3ExpressCredentials cache = CachedS3ExpressCredentials.builder(identitySupplier)
136                                                                      .key(KEY)
137                                                                      .prefetchTime(Duration.ofMillis(200))
138                                                                      .staleTime(Duration.ofMillis(40))
139                                                                      .build();
140         SessionCredentials sessionCredentials = null;
141         int numGets = 20;
142         int minimumRefreshesExpectedWithMargin = 15;;
143 
144         for (int i = 0; i < numGets; i++) {
145             sessionCredentials = cache.get();
146             Thread.sleep(1000);
147         }
148         assertThat(sessionCredentials).isNotNull();
149         int responseSequenceNumber = parseSequenceNumber(sessionCredentials.accessKeyId());
150         assertThat(responseSequenceNumber).isGreaterThan(minimumRefreshesExpectedWithMargin);
151     }
152 
153     @Test
when_supplierThrowsException_ExceptionIsPropagated()154     void when_supplierThrowsException_ExceptionIsPropagated() {
155         when(identitySupplier.apply(any())).thenAnswer(invocation -> {
156             throw new RuntimeException("Oops");
157         });
158 
159         CachedS3ExpressCredentials cache = CachedS3ExpressCredentials.builder(identitySupplier).key(KEY).build();
160 
161         assertThatThrownBy(() -> cache.get()).hasMessage("Oops");
162     }
163 
verifyCredentialsInSequence(SessionCredentials actualCredentials, int sequenceNumber)164     private void verifyCredentialsInSequence(SessionCredentials actualCredentials, int sequenceNumber) {
165         assertThat(actualCredentials.accessKeyId()).isEqualTo(stringWithSequenceNumber(ACCESS_KEY, sequenceNumber));
166         assertThat(actualCredentials.secretAccessKey()).isEqualTo(stringWithSequenceNumber(SECRET_KEY, sequenceNumber));
167         assertThat(actualCredentials.sessionToken()).isEqualTo(stringWithSequenceNumber(SESSION_TOKEN, sequenceNumber));
168     }
169 
createSessionResponse(int sequenceNumber, Instant expires)170     private static CreateSessionResponse createSessionResponse(int sequenceNumber, Instant expires) {
171         return CreateSessionResponse.builder()
172                                     .credentials(SessionCredentials.builder()
173                                                                   .accessKeyId(stringWithSequenceNumber(ACCESS_KEY, sequenceNumber))
174                                                                   .secretAccessKey(stringWithSequenceNumber(SECRET_KEY, sequenceNumber))
175                                                                   .sessionToken(stringWithSequenceNumber(SESSION_TOKEN, sequenceNumber))
176                                                                   .expiration(expires)
177                                                                   .build())
178                                     .build();
179     }
180 
stringWithSequenceNumber(String value, int sequenceNumber)181     private static String stringWithSequenceNumber(String value, int sequenceNumber) {
182         return value + "_" + sequenceNumber;
183     }
184 
parseSequenceNumber(String s)185     private Integer parseSequenceNumber(String s) {
186         int sequenceNumIndex = s.lastIndexOf('_');
187         String substring = s.substring(sequenceNumIndex + 1);
188         return Integer.parseInt(substring);
189     }
190 }