1 /*
2  * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License").
5  * You may not use this file except in compliance with the License.
6  * A copy of the License is located at
7  *
8  *  http://aws.amazon.com/apache2.0
9  *
10  * or in the "license" file accompanying this file. This file is distributed
11  * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12  * express or implied. See the License for the specific language governing
13  * permissions and limitations under the License.
14  */
15 
16 package software.amazon.awssdk.core.async;
17 
18 import static org.assertj.core.api.Assertions.assertThat;
19 import static org.assertj.core.api.Assertions.assertThatThrownBy;
20 import static org.junit.jupiter.api.Assertions.assertArrayEquals;
21 import static org.junit.jupiter.api.Assertions.assertEquals;
22 
23 import com.google.common.jimfs.Configuration;
24 import com.google.common.jimfs.Jimfs;
25 import io.reactivex.Flowable;
26 import java.io.IOException;
27 import java.nio.ByteBuffer;
28 import java.nio.charset.Charset;
29 import java.nio.charset.StandardCharsets;
30 import java.nio.file.FileSystem;
31 import java.nio.file.Files;
32 import java.nio.file.Path;
33 import java.util.List;
34 import java.util.concurrent.CountDownLatch;
35 import java.util.concurrent.atomic.AtomicReference;
36 import java.util.function.Function;
37 import java.util.stream.Collectors;
38 import org.assertj.core.util.Lists;
39 import org.junit.jupiter.api.Test;
40 import org.junit.jupiter.params.ParameterizedTest;
41 import org.junit.jupiter.params.provider.MethodSource;
42 import org.junit.jupiter.params.provider.ValueSource;
43 import org.reactivestreams.Publisher;
44 import org.reactivestreams.Subscriber;
45 import software.amazon.awssdk.core.internal.util.Mimetype;
46 import software.amazon.awssdk.http.async.SimpleSubscriber;
47 import software.amazon.awssdk.utils.BinaryUtils;
48 
49 public class AsyncRequestBodyTest {
50 
51     private static final String testString = "Hello!";
52     private static final Path path;
53 
54     static {
55         FileSystem fs = Jimfs.newFileSystem(Configuration.unix());
56         path = fs.getPath("./test");
57         try {
Files.write(path, testString.getBytes())58             Files.write(path, testString.getBytes());
59         } catch (IOException e) {
60             e.printStackTrace();
61         }
62     }
63 
64     @ParameterizedTest
65     @MethodSource("contentIntegrityChecks")
hasCorrectLength(AsyncRequestBody asyncRequestBody)66     void hasCorrectLength(AsyncRequestBody asyncRequestBody) {
67         assertEquals(testString.length(), asyncRequestBody.contentLength().get());
68     }
69 
70 
71     @ParameterizedTest
72     @MethodSource("contentIntegrityChecks")
hasCorrectContent(AsyncRequestBody asyncRequestBody)73     void hasCorrectContent(AsyncRequestBody asyncRequestBody) throws InterruptedException {
74         StringBuilder sb = new StringBuilder();
75         CountDownLatch done = new CountDownLatch(1);
76 
77         Subscriber<ByteBuffer> subscriber = new SimpleSubscriber(buffer -> {
78             byte[] bytes = new byte[buffer.remaining()];
79             buffer.get(bytes);
80             sb.append(new String(bytes, StandardCharsets.UTF_8));
81         }) {
82             @Override
83             public void onError(Throwable t) {
84                 super.onError(t);
85                 done.countDown();
86             }
87 
88             @Override
89             public void onComplete() {
90                 super.onComplete();
91                 done.countDown();
92             }
93         };
94 
95         asyncRequestBody.subscribe(subscriber);
96         done.await();
97         assertEquals(testString, sb.toString());
98     }
99 
contentIntegrityChecks()100     private static AsyncRequestBody[] contentIntegrityChecks() {
101         return new AsyncRequestBody[] {
102             AsyncRequestBody.fromString(testString),
103             AsyncRequestBody.fromFile(path)
104         };
105     }
106 
107     @Test
fromBytesCopiesTheProvidedByteArray()108     void fromBytesCopiesTheProvidedByteArray() {
109         byte[] bytes = testString.getBytes(StandardCharsets.UTF_8);
110         byte[] bytesClone = bytes.clone();
111 
112         AsyncRequestBody asyncRequestBody = AsyncRequestBody.fromBytes(bytes);
113 
114         for (int i = 0; i < bytes.length; i++) {
115             bytes[i] += 1;
116         }
117 
118         AtomicReference<ByteBuffer> publishedBuffer = new AtomicReference<>();
119         Subscriber<ByteBuffer> subscriber = new SimpleSubscriber(publishedBuffer::set);
120 
121         asyncRequestBody.subscribe(subscriber);
122 
123         byte[] publishedByteArray = BinaryUtils.copyAllBytesFrom(publishedBuffer.get());
124         assertArrayEquals(bytesClone, publishedByteArray);
125     }
126 
127     @Test
fromBytesUnsafeDoesNotCopyTheProvidedByteArray()128     void fromBytesUnsafeDoesNotCopyTheProvidedByteArray() {
129         byte[] bytes = testString.getBytes(StandardCharsets.UTF_8);
130 
131         AsyncRequestBody asyncRequestBody = AsyncRequestBody.fromBytesUnsafe(bytes);
132 
133         for (int i = 0; i < bytes.length; i++) {
134             bytes[i] += 1;
135         }
136 
137         AtomicReference<ByteBuffer> publishedBuffer = new AtomicReference<>();
138         Subscriber<ByteBuffer> subscriber = new SimpleSubscriber(publishedBuffer::set);
139 
140         asyncRequestBody.subscribe(subscriber);
141 
142         byte[] publishedByteArray = BinaryUtils.copyAllBytesFrom(publishedBuffer.get());
143         assertArrayEquals(bytes, publishedByteArray);
144     }
145 
146     @ParameterizedTest
147     @MethodSource("safeByteBufferBodyBuilders")
safeByteBufferBuildersCopyTheProvidedBuffer(Function<ByteBuffer, AsyncRequestBody> bodyBuilder)148     void safeByteBufferBuildersCopyTheProvidedBuffer(Function<ByteBuffer, AsyncRequestBody> bodyBuilder) {
149         byte[] bytes = testString.getBytes(StandardCharsets.UTF_8);
150         byte[] bytesClone = bytes.clone();
151 
152         AsyncRequestBody asyncRequestBody = bodyBuilder.apply(ByteBuffer.wrap(bytes));
153 
154         for (int i = 0; i < bytes.length; i++) {
155             bytes[i] += 1;
156         }
157 
158         AtomicReference<ByteBuffer> publishedBuffer = new AtomicReference<>();
159         Subscriber<ByteBuffer> subscriber = new SimpleSubscriber(publishedBuffer::set);
160 
161         asyncRequestBody.subscribe(subscriber);
162 
163         byte[] publishedByteArray = BinaryUtils.copyAllBytesFrom(publishedBuffer.get());
164         assertArrayEquals(bytesClone, publishedByteArray);
165     }
166 
safeByteBufferBodyBuilders()167     private static Function<ByteBuffer, AsyncRequestBody>[] safeByteBufferBodyBuilders() {
168         Function<ByteBuffer, AsyncRequestBody> fromByteBuffer = AsyncRequestBody::fromByteBuffer;
169         Function<ByteBuffer, AsyncRequestBody> fromRemainingByteBuffer = AsyncRequestBody::fromRemainingByteBuffer;
170         Function<ByteBuffer, AsyncRequestBody> fromByteBuffers = AsyncRequestBody::fromByteBuffers;
171         Function<ByteBuffer, AsyncRequestBody> fromRemainingByteBuffers = AsyncRequestBody::fromRemainingByteBuffers;
172         return new Function[] {fromByteBuffer, fromRemainingByteBuffer, fromByteBuffers, fromRemainingByteBuffers};
173     }
174 
175     @ParameterizedTest
176     @MethodSource("unsafeByteBufferBodyBuilders")
unsafeByteBufferBuildersDoNotCopyTheProvidedBuffer(Function<ByteBuffer, AsyncRequestBody> bodyBuilder)177     void unsafeByteBufferBuildersDoNotCopyTheProvidedBuffer(Function<ByteBuffer, AsyncRequestBody> bodyBuilder) {
178         byte[] bytes = testString.getBytes(StandardCharsets.UTF_8);
179 
180         AsyncRequestBody asyncRequestBody = bodyBuilder.apply(ByteBuffer.wrap(bytes));
181 
182         for (int i = 0; i < bytes.length; i++) {
183             bytes[i] += 1;
184         }
185 
186         AtomicReference<ByteBuffer> publishedBuffer = new AtomicReference<>();
187         Subscriber<ByteBuffer> subscriber = new SimpleSubscriber(publishedBuffer::set);
188 
189         asyncRequestBody.subscribe(subscriber);
190 
191         byte[] publishedByteArray = BinaryUtils.copyAllBytesFrom(publishedBuffer.get());
192         assertArrayEquals(bytes, publishedByteArray);
193     }
194 
unsafeByteBufferBodyBuilders()195     private static Function<ByteBuffer, AsyncRequestBody>[] unsafeByteBufferBodyBuilders() {
196         Function<ByteBuffer, AsyncRequestBody> fromByteBuffer = AsyncRequestBody::fromByteBufferUnsafe;
197         Function<ByteBuffer, AsyncRequestBody> fromRemainingByteBuffer = AsyncRequestBody::fromRemainingByteBufferUnsafe;
198         Function<ByteBuffer, AsyncRequestBody> fromByteBuffers = AsyncRequestBody::fromByteBuffersUnsafe;
199         Function<ByteBuffer, AsyncRequestBody> fromRemainingByteBuffers = AsyncRequestBody::fromRemainingByteBuffersUnsafe;
200         return new Function[] {fromByteBuffer, fromRemainingByteBuffer, fromByteBuffers, fromRemainingByteBuffers};
201     }
202 
203     @ParameterizedTest
204     @MethodSource("nonRewindingByteBufferBodyBuilders")
nonRewindingByteBufferBuildersReadFromTheInputBufferPosition( Function<ByteBuffer, AsyncRequestBody> bodyBuilder)205     void nonRewindingByteBufferBuildersReadFromTheInputBufferPosition(
206         Function<ByteBuffer, AsyncRequestBody> bodyBuilder) {
207         byte[] bytes = testString.getBytes(StandardCharsets.UTF_8);
208         ByteBuffer bb = ByteBuffer.wrap(bytes);
209         int expectedPosition = bytes.length / 2;
210         bb.position(expectedPosition);
211 
212         AsyncRequestBody asyncRequestBody = bodyBuilder.apply(bb);
213 
214         AtomicReference<ByteBuffer> publishedBuffer = new AtomicReference<>();
215         Subscriber<ByteBuffer> subscriber = new SimpleSubscriber(publishedBuffer::set);
216 
217         asyncRequestBody.subscribe(subscriber);
218 
219         int remaining = bb.remaining();
220         assertEquals(remaining, publishedBuffer.get().remaining());
221         for (int i = 0; i < remaining; i++) {
222             assertEquals(bb.get(), publishedBuffer.get().get());
223         }
224     }
225 
nonRewindingByteBufferBodyBuilders()226     private static Function<ByteBuffer, AsyncRequestBody>[] nonRewindingByteBufferBodyBuilders() {
227         Function<ByteBuffer, AsyncRequestBody> fromRemainingByteBuffer = AsyncRequestBody::fromRemainingByteBuffer;
228         Function<ByteBuffer, AsyncRequestBody> fromRemainingByteBufferUnsafe = AsyncRequestBody::fromRemainingByteBufferUnsafe;
229         Function<ByteBuffer, AsyncRequestBody> fromRemainingByteBuffers = AsyncRequestBody::fromRemainingByteBuffers;
230         Function<ByteBuffer, AsyncRequestBody> fromRemainingByteBuffersUnsafe = AsyncRequestBody::fromRemainingByteBuffersUnsafe;
231         return new Function[] {fromRemainingByteBuffer, fromRemainingByteBufferUnsafe, fromRemainingByteBuffers,
232                                fromRemainingByteBuffersUnsafe};
233     }
234 
235     @ParameterizedTest
236     @MethodSource("safeNonRewindingByteBufferBodyBuilders")
safeNonRewindingByteBufferBuildersCopyFromTheInputBufferPosition( Function<ByteBuffer, AsyncRequestBody> bodyBuilder)237     void safeNonRewindingByteBufferBuildersCopyFromTheInputBufferPosition(
238         Function<ByteBuffer, AsyncRequestBody> bodyBuilder) {
239         byte[] bytes = testString.getBytes(StandardCharsets.UTF_8);
240         ByteBuffer bb = ByteBuffer.wrap(bytes);
241         int expectedPosition = bytes.length / 2;
242         bb.position(expectedPosition);
243 
244         AsyncRequestBody asyncRequestBody = bodyBuilder.apply(bb);
245 
246         AtomicReference<ByteBuffer> publishedBuffer = new AtomicReference<>();
247         Subscriber<ByteBuffer> subscriber = new SimpleSubscriber(publishedBuffer::set);
248 
249         asyncRequestBody.subscribe(subscriber);
250 
251         int remaining = bb.remaining();
252         assertEquals(remaining, publishedBuffer.get().capacity());
253         for (int i = 0; i < remaining; i++) {
254             assertEquals(bb.get(), publishedBuffer.get().get());
255         }
256     }
257 
safeNonRewindingByteBufferBodyBuilders()258     private static Function<ByteBuffer, AsyncRequestBody>[] safeNonRewindingByteBufferBodyBuilders() {
259         Function<ByteBuffer, AsyncRequestBody> fromRemainingByteBuffer = AsyncRequestBody::fromRemainingByteBuffer;
260         Function<ByteBuffer, AsyncRequestBody> fromRemainingByteBuffers = AsyncRequestBody::fromRemainingByteBuffers;
261         return new Function[] {fromRemainingByteBuffer, fromRemainingByteBuffers};
262     }
263 
264     @ParameterizedTest
265     @MethodSource("rewindingByteBufferBodyBuilders")
rewindingByteBufferBuildersDoNotRewindTheInputBuffer(Function<ByteBuffer, AsyncRequestBody> bodyBuilder)266     void rewindingByteBufferBuildersDoNotRewindTheInputBuffer(Function<ByteBuffer, AsyncRequestBody> bodyBuilder) {
267         byte[] bytes = testString.getBytes(StandardCharsets.UTF_8);
268         ByteBuffer bb = ByteBuffer.wrap(bytes);
269         int expectedPosition = bytes.length / 2;
270         bb.position(expectedPosition);
271 
272         AsyncRequestBody asyncRequestBody = bodyBuilder.apply(bb);
273 
274         Subscriber<ByteBuffer> subscriber = new SimpleSubscriber(buffer -> {
275         });
276 
277         asyncRequestBody.subscribe(subscriber);
278 
279         assertEquals(expectedPosition, bb.position());
280     }
281 
282     @ParameterizedTest
283     @MethodSource("rewindingByteBufferBodyBuilders")
rewindingByteBufferBuildersReadTheInputBufferFromTheBeginning( Function<ByteBuffer, AsyncRequestBody> bodyBuilder)284     void rewindingByteBufferBuildersReadTheInputBufferFromTheBeginning(
285         Function<ByteBuffer, AsyncRequestBody> bodyBuilder) {
286         byte[] bytes = testString.getBytes(StandardCharsets.UTF_8);
287         ByteBuffer bb = ByteBuffer.wrap(bytes);
288         bb.position(bytes.length / 2);
289 
290         AsyncRequestBody asyncRequestBody = bodyBuilder.apply(bb);
291 
292         AtomicReference<ByteBuffer> publishedBuffer = new AtomicReference<>();
293         Subscriber<ByteBuffer> subscriber = new SimpleSubscriber(publishedBuffer::set);
294 
295         asyncRequestBody.subscribe(subscriber);
296 
297         assertEquals(0, publishedBuffer.get().position());
298         publishedBuffer.get().rewind();
299         bb.rewind();
300         assertEquals(bb, publishedBuffer.get());
301     }
302 
rewindingByteBufferBodyBuilders()303     private static Function<ByteBuffer, AsyncRequestBody>[] rewindingByteBufferBodyBuilders() {
304         Function<ByteBuffer, AsyncRequestBody> fromByteBuffer = AsyncRequestBody::fromByteBuffer;
305         Function<ByteBuffer, AsyncRequestBody> fromByteBufferUnsafe = AsyncRequestBody::fromByteBufferUnsafe;
306         Function<ByteBuffer, AsyncRequestBody> fromByteBuffers = AsyncRequestBody::fromByteBuffers;
307         Function<ByteBuffer, AsyncRequestBody> fromByteBuffersUnsafe = AsyncRequestBody::fromByteBuffersUnsafe;
308         return new Function[] {fromByteBuffer, fromByteBufferUnsafe, fromByteBuffers, fromByteBuffersUnsafe};
309     }
310 
311     @ParameterizedTest
312     @ValueSource(strings = {"US-ASCII", "ISO-8859-1", "UTF-8", "UTF-16BE", "UTF-16LE", "UTF-16"})
charsetsAreConvertedToTheCorrectContentType(Charset charset)313     void charsetsAreConvertedToTheCorrectContentType(Charset charset) {
314         AsyncRequestBody requestBody = AsyncRequestBody.fromString("hello world", charset);
315         assertEquals("text/plain; charset=" + charset.name(), requestBody.contentType());
316     }
317 
318     @Test
stringConstructorHasCorrectDefaultContentType()319     void stringConstructorHasCorrectDefaultContentType() {
320         AsyncRequestBody requestBody = AsyncRequestBody.fromString("hello world");
321         assertEquals("text/plain; charset=UTF-8", requestBody.contentType());
322     }
323 
324     @Test
fileConstructorHasCorrectContentType()325     void fileConstructorHasCorrectContentType() {
326         AsyncRequestBody requestBody = AsyncRequestBody.fromFile(path);
327         assertEquals(Mimetype.MIMETYPE_OCTET_STREAM, requestBody.contentType());
328     }
329 
330     @Test
bytesArrayConstructorHasCorrectContentType()331     void bytesArrayConstructorHasCorrectContentType() {
332         AsyncRequestBody requestBody = AsyncRequestBody.fromBytes("hello world".getBytes());
333         assertEquals(Mimetype.MIMETYPE_OCTET_STREAM, requestBody.contentType());
334     }
335 
336     @Test
bytesBufferConstructorHasCorrectContentType()337     void bytesBufferConstructorHasCorrectContentType() {
338         ByteBuffer byteBuffer = ByteBuffer.wrap("hello world".getBytes());
339         AsyncRequestBody requestBody = AsyncRequestBody.fromByteBuffer(byteBuffer);
340         assertEquals(Mimetype.MIMETYPE_OCTET_STREAM, requestBody.contentType());
341     }
342 
343     @Test
emptyBytesConstructorHasCorrectContentType()344     void emptyBytesConstructorHasCorrectContentType() {
345         AsyncRequestBody requestBody = AsyncRequestBody.empty();
346         assertEquals(Mimetype.MIMETYPE_OCTET_STREAM, requestBody.contentType());
347     }
348 
349     @Test
publisherConstructorHasCorrectContentType()350     void publisherConstructorHasCorrectContentType() {
351         List<String> requestBodyStrings = Lists.newArrayList("A", "B", "C");
352         List<ByteBuffer> bodyBytes = requestBodyStrings.stream()
353                                                        .map(s -> ByteBuffer.wrap(s.getBytes(StandardCharsets.UTF_8)))
354                                                        .collect(Collectors.toList());
355         Publisher<ByteBuffer> bodyPublisher = Flowable.fromIterable(bodyBytes);
356         AsyncRequestBody requestBody = AsyncRequestBody.fromPublisher(bodyPublisher);
357         assertEquals(Mimetype.MIMETYPE_OCTET_STREAM, requestBody.contentType());
358     }
359 }
360