1 /*
2  * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License").
5  * You may not use this file except in compliance with the License.
6  * A copy of the License is located at
7  *
8  *  http://aws.amazon.com/apache2.0
9  *
10  * or in the "license" file accompanying this file. This file is distributed
11  * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12  * express or implied. See the License for the specific language governing
13  * permissions and limitations under the License.
14  */
15 
16 package software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding;
17 
18 import static java.util.Arrays.copyOf;
19 import static org.junit.jupiter.api.Assertions.assertArrayEquals;
20 import static org.junit.jupiter.api.Assertions.assertEquals;
21 import static software.amazon.awssdk.http.auth.aws.internal.signer.V4CanonicalRequest.getCanonicalHeadersString;
22 import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerUtils.deriveSigningKey;
23 import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerUtils.hash;
24 import static software.amazon.awssdk.utils.BinaryUtils.toHex;
25 
26 import java.io.ByteArrayInputStream;
27 import java.io.ByteArrayOutputStream;
28 import java.io.IOException;
29 import java.io.InputStream;
30 import java.nio.charset.StandardCharsets;
31 import java.time.Instant;
32 import java.util.Arrays;
33 import java.util.Collections;
34 import java.util.List;
35 import java.util.function.Function;
36 import org.junit.jupiter.api.Test;
37 import org.junit.jupiter.params.ParameterizedTest;
38 import org.junit.jupiter.params.provider.ValueSource;
39 import software.amazon.awssdk.http.auth.aws.internal.signer.CredentialScope;
40 import software.amazon.awssdk.http.auth.aws.internal.signer.RollingSigner;
41 import software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant;
42 import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity;
43 import software.amazon.awssdk.utils.Pair;
44 
45 public class ChunkedEncodedInputStreamTest {
46 
47     @Test
ChunkEncodedInputStream_withBasicParams_returnsEncodedChunks()48     public void ChunkEncodedInputStream_withBasicParams_returnsEncodedChunks() throws IOException {
49         byte[] data = "abcdefghij".getBytes();
50         InputStream payload = new ByteArrayInputStream(data);
51         int chunkSize = 3;
52 
53         ChunkedEncodedInputStream inputStream = ChunkedEncodedInputStream
54             .builder()
55             .inputStream(payload)
56             .chunkSize(chunkSize)
57             .header(chunk -> Integer.toHexString(chunk.remaining()).getBytes())
58             .build();
59 
60         byte[] tmp = new byte[64];
61         int bytesRead = readAll(inputStream, tmp);
62 
63         int expectedBytesRead = 35;
64         byte[] expected = new byte[expectedBytesRead];
65         System.arraycopy(
66             "3\r\nabc\r\n3\r\ndef\r\n3\r\nghi\r\n1\r\nj\r\n0\r\n\r\n".getBytes(),
67             0,
68             expected,
69             0,
70             expectedBytesRead
71         );
72         byte[] actual = copyOf(tmp, bytesRead);
73 
74         assertEquals(expectedBytesRead, bytesRead);
75         assertArrayEquals(expected, actual);
76     }
77 
78     @Test
ChunkEncodedInputStream_withExtensions_returnsEncodedExtendedChunks()79     public void ChunkEncodedInputStream_withExtensions_returnsEncodedExtendedChunks() throws IOException {
80         byte[] data = "abcdefghij".getBytes();
81         InputStream payload = new ByteArrayInputStream(data);
82         int chunkSize = 3;
83 
84         ChunkExtensionProvider helloWorldExt = chunk -> Pair.of(
85             "hello".getBytes(StandardCharsets.UTF_8),
86             "world!".getBytes(StandardCharsets.UTF_8)
87         );
88 
89         ChunkedEncodedInputStream inputStream = ChunkedEncodedInputStream
90             .builder()
91             .inputStream(payload)
92             .chunkSize(chunkSize)
93             .header(chunk -> Integer.toHexString(chunk.remaining()).getBytes())
94             .extensions(Collections.singletonList(helloWorldExt))
95             .build();
96 
97         byte[] tmp = new byte[128];
98         int bytesRead = readAll(inputStream, tmp);
99 
100         int expectedBytesRead = 100;
101         byte[] expected = new byte[expectedBytesRead];
102         System.arraycopy(
103             ("3;hello=world!\r\nabc\r\n3;hello=world!\r\ndef\r\n3;hello=world!\r\nghi\r\n"
104              + "1;hello=world!\r\nj\r\n0;hello=world!\r\n\r\n").getBytes(),
105             0,
106             expected,
107             0,
108             expectedBytesRead
109         );
110         byte[] actual = copyOf(tmp, expected.length);
111 
112         assertEquals(expectedBytesRead, bytesRead);
113         assertArrayEquals(expected, actual);
114     }
115 
116     @Test
ChunkEncodedInputStream_withTrailers_returnsEncodedChunksAndTrailerChunk()117     public void ChunkEncodedInputStream_withTrailers_returnsEncodedChunksAndTrailerChunk() throws IOException {
118         byte[] data = "abcdefghij".getBytes();
119         InputStream payload = new ByteArrayInputStream(data);
120         int chunkSize = 3;
121 
122         TrailerProvider helloWorldTrailer = () -> Pair.of(
123             "hello",
124             Collections.singletonList("world!")
125         );
126 
127         ChunkedEncodedInputStream inputStream = ChunkedEncodedInputStream
128             .builder()
129             .inputStream(payload)
130             .chunkSize(chunkSize)
131             .header(chunk -> Integer.toHexString(chunk.remaining()).getBytes())
132             .trailers(Collections.singletonList(helloWorldTrailer))
133             .build();
134 
135         byte[] tmp = new byte[64];
136         int bytesRead = readAll(inputStream, tmp);
137 
138         int expectedBytesRead = 49;
139         byte[] expected = new byte[expectedBytesRead];
140         System.arraycopy(
141             "3\r\nabc\r\n3\r\ndef\r\n3\r\nghi\r\n1\r\nj\r\n0\r\nhello:world!\r\n\r\n".getBytes(),
142             0,
143             expected,
144             0,
145             expectedBytesRead
146         );
147         byte[] actual = copyOf(tmp, expected.length);
148 
149         assertEquals(expectedBytesRead, bytesRead);
150         assertArrayEquals(expected, actual);
151     }
152 
153     @Test
ChunkEncodedInputStream_withExtensionsAndTrailers_EncodedExtendedChunksAndTrailerChunk()154     public void ChunkEncodedInputStream_withExtensionsAndTrailers_EncodedExtendedChunksAndTrailerChunk() throws IOException {
155         byte[] data = "abcdefghij".getBytes();
156         InputStream payload = new ByteArrayInputStream(data);
157         int chunkSize = 3;
158 
159         ChunkExtensionProvider aExt = chunk -> Pair.of("a".getBytes(StandardCharsets.UTF_8),
160                                                        "1".getBytes(StandardCharsets.UTF_8));
161         ChunkExtensionProvider bExt = chunk -> Pair.of("b".getBytes(StandardCharsets.UTF_8),
162                                                        "2".getBytes(StandardCharsets.UTF_8));
163 
164         TrailerProvider aTrailer = () -> Pair.of("a", Collections.singletonList("1"));
165         TrailerProvider bTrailer = () -> Pair.of("b", Collections.singletonList("2"));
166 
167         ChunkedEncodedInputStream inputStream = ChunkedEncodedInputStream
168             .builder()
169             .inputStream(payload)
170             .chunkSize(chunkSize)
171             .header(chunk -> Integer.toHexString(chunk.remaining()).getBytes())
172             .addExtension(aExt)
173             .addExtension(bExt)
174             .addTrailer(aTrailer)
175             .addTrailer(bTrailer)
176             .build();
177 
178         byte[] tmp = new byte[128];
179         int bytesRead = readAll(inputStream, tmp);
180 
181         int expectedBytesRead = 85;
182         byte[] expected = new byte[expectedBytesRead];
183         System.arraycopy(
184             "3;a=1;b=2\r\nabc\r\n3;a=1;b=2\r\ndef\r\n3;a=1;b=2\r\nghi\r\n1;a=1;b=2\r\nj\r\n0;a=1;b=2\r\na:1\r\nb:2\r\n\r\n".getBytes(),
185             0,
186             expected,
187             0,
188             expectedBytesRead
189         );
190         byte[] actual = copyOf(tmp, expected.length);
191 
192         assertEquals(expectedBytesRead, bytesRead);
193         assertArrayEquals(expected, actual);
194     }
195 
196     @Test
ChunkEncodedInputStream_withAwsParams_returnsAwsSignedAndEncodedChunks()197     public void ChunkEncodedInputStream_withAwsParams_returnsAwsSignedAndEncodedChunks() throws IOException {
198         byte[] data = new byte[65 * 1024];
199         Arrays.fill(data, (byte) 'a');
200         String seedSignature = "106e2a8a18243abcf37539882f36619c00e2dfc72633413f02d3b74544bfeb8e";
201         CredentialScope credentialScope =
202             new CredentialScope("us-east-1", "s3", Instant.parse("2013-05-24T00:00:00Z"));
203         AwsCredentialsIdentity credentials =
204             AwsCredentialsIdentity.create("AKIAIOSFODNN7EXAMPLE", "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY");
205         byte[] signingKey = deriveSigningKey(credentials, credentialScope);
206         InputStream payload = new ByteArrayInputStream(data);
207         int chunkSize = 64 * 1024;
208 
209         RollingSigner signer = new RollingSigner(signingKey, seedSignature);
210 
211         ChunkExtensionProvider ext = chunk -> Pair.of(
212             "chunk-signature".getBytes(StandardCharsets.UTF_8),
213             signer.sign(previousSignature ->
214                             "AWS4-HMAC-SHA256-PAYLOAD" + SignerConstant.LINE_SEPARATOR +
215                             credentialScope.getDatetime() + SignerConstant.LINE_SEPARATOR +
216                             credentialScope.scope() + SignerConstant.LINE_SEPARATOR +
217                             previousSignature + SignerConstant.LINE_SEPARATOR +
218                             toHex(hash("")) + SignerConstant.LINE_SEPARATOR +
219                             toHex(hash(chunk)))
220                   .getBytes(StandardCharsets.UTF_8)
221         );
222 
223         TrailerProvider checksumTrailer = () -> Pair.of(
224             "x-amz-checksum-crc32c",
225             Collections.singletonList("wdBDMA==")
226         );
227 
228         List<Pair<String, List<String>>> trailers = Collections.singletonList(checksumTrailer.get());
229         Function<String, String> template =
230             previousSignature ->
231                 "AWS4-HMAC-SHA256-TRAILER" + SignerConstant.LINE_SEPARATOR +
232                 credentialScope.getDatetime() + SignerConstant.LINE_SEPARATOR +
233                 credentialScope.scope() + SignerConstant.LINE_SEPARATOR +
234                 previousSignature + SignerConstant.LINE_SEPARATOR +
235                 toHex(hash(getCanonicalHeadersString(trailers)));
236 
237         TrailerProvider signatureTrailer = () -> Pair.of(
238             "x-amz-trailer-signature",
239             Collections.singletonList(signer.sign(template))
240         );
241 
242         ChunkedEncodedInputStream inputStream = ChunkedEncodedInputStream
243             .builder()
244             .inputStream(payload)
245             .chunkSize(chunkSize)
246             .header(chunk -> Integer.toHexString(chunk.remaining()).getBytes())
247             .extensions(Collections.singletonList(ext))
248             .trailers(Arrays.asList(checksumTrailer, signatureTrailer))
249             .build();
250 
251         byte[] tmp = new byte[chunkSize * 4];
252         int bytesRead = readAll(inputStream, tmp);
253 
254         int expectedBytesRead = 66946;
255         byte[] actualBytes = copyOf(tmp, expectedBytesRead);
256         ByteArrayOutputStream expected = new ByteArrayOutputStream();
257         expected.write(
258             "10000;chunk-signature=b474d8862b1487a5145d686f57f013e54db672cee1c953b3010fb58501ef5aa2\r\n".getBytes(
259                 StandardCharsets.UTF_8)
260         );
261         expected.write(data, 0, chunkSize);
262         expected.write(
263             "\r\n400;chunk-signature=1c1344b170168f8e65b41376b44b20fe354e373826ccbbe2c1d40a8cae51e5c7\r\n".getBytes(
264                 StandardCharsets.UTF_8)
265         );
266         expected.write(data, chunkSize, 1024);
267         expected.write(
268             "\r\n0;chunk-signature=2ca2aba2005185cf7159c6277faf83795951dd77a3a99e6e65d5c9f85863f992\r\n".getBytes(
269                 StandardCharsets.UTF_8)
270         );
271         expected.write((
272                            "x-amz-checksum-crc32c:wdBDMA==\r\n" +
273                            "x-amz-trailer-signature:ce306fa4cdf73aa89071b78358f0d22ea79c43117314c8ed68017f7d6f91048e\r\n" +
274                            "\r\n").getBytes(StandardCharsets.UTF_8)
275         );
276 
277         assertArrayEquals(expected.toByteArray(), actualBytes);
278         assertEquals(expectedBytesRead, bytesRead);
279     }
280 
281     @ParameterizedTest
282     @ValueSource(ints = {1, 2, 3, 5, 8, 13, 21, 24, 45, 69, 104})
ChunkEncodedInputStream_withVariableChunkSize_shouldCorrectlyChunkData(int chunkSize)283     void ChunkEncodedInputStream_withVariableChunkSize_shouldCorrectlyChunkData(int chunkSize) throws IOException {
284         int size = 100;
285         byte[] data = new byte[size];
286         Arrays.fill(data, (byte) 'a');
287 
288         ChunkedEncodedInputStream inputStream = ChunkedEncodedInputStream
289             .builder()
290             .inputStream(new ByteArrayInputStream(data))
291             .header(chunk -> new byte[] {'0'})
292             .chunkSize(chunkSize)
293             .build();
294 
295         int expectedBytesRead = 0;
296         int numChunks = size / chunkSize;
297 
298         // 0\r\n<data>\r\n
299         expectedBytesRead += numChunks * (5 + chunkSize);
300 
301         if (size % chunkSize != 0) {
302             // 0\r\n\<left-over>\r\n
303             expectedBytesRead += 5 + (size % chunkSize);
304         }
305 
306         // 0\r\n\r\n
307         expectedBytesRead += 5;
308 
309         byte[] tmp = new byte[expectedBytesRead];
310         int bytesRead = readAll(inputStream, tmp);
311 
312         assertEquals(expectedBytesRead, bytesRead);
313     }
314 
readAll(InputStream src, byte[] dst)315     private int readAll(InputStream src, byte[] dst) throws IOException {
316         int read = 0;
317         int offset = 0;
318         while (read >= 0) {
319             read = src.read(dst, offset, dst.length - offset);
320             if (read >= 0) {
321                 offset += read;
322             }
323         }
324         return offset;
325     }
326 }
327