1 /* 2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"). 5 * You may not use this file except in compliance with the License. 6 * A copy of the License is located at 7 * 8 * http://aws.amazon.com/apache2.0 9 * 10 * or in the "license" file accompanying this file. This file is distributed 11 * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 12 * express or implied. See the License for the specific language governing 13 * permissions and limitations under the License. 14 */ 15 16 package software.amazon.awssdk.core.async; 17 18 import static org.assertj.core.api.Assertions.assertThat; 19 import static org.assertj.core.api.Assertions.assertThatThrownBy; 20 import static org.junit.jupiter.api.Assertions.assertArrayEquals; 21 import static org.junit.jupiter.api.Assertions.assertEquals; 22 23 import com.google.common.jimfs.Configuration; 24 import com.google.common.jimfs.Jimfs; 25 import io.reactivex.Flowable; 26 import java.io.IOException; 27 import java.nio.ByteBuffer; 28 import java.nio.charset.Charset; 29 import java.nio.charset.StandardCharsets; 30 import java.nio.file.FileSystem; 31 import java.nio.file.Files; 32 import java.nio.file.Path; 33 import java.util.List; 34 import java.util.concurrent.CountDownLatch; 35 import java.util.concurrent.atomic.AtomicReference; 36 import java.util.function.Function; 37 import java.util.stream.Collectors; 38 import org.assertj.core.util.Lists; 39 import org.junit.jupiter.api.Test; 40 import org.junit.jupiter.params.ParameterizedTest; 41 import org.junit.jupiter.params.provider.MethodSource; 42 import org.junit.jupiter.params.provider.ValueSource; 43 import org.reactivestreams.Publisher; 44 import org.reactivestreams.Subscriber; 45 import software.amazon.awssdk.core.internal.util.Mimetype; 46 import software.amazon.awssdk.http.async.SimpleSubscriber; 47 import software.amazon.awssdk.utils.BinaryUtils; 48 49 public class AsyncRequestBodyTest { 50 51 private static final String testString = "Hello!"; 52 private static final Path path; 53 54 static { 55 FileSystem fs = Jimfs.newFileSystem(Configuration.unix()); 56 path = fs.getPath("./test"); 57 try { Files.write(path, testString.getBytes())58 Files.write(path, testString.getBytes()); 59 } catch (IOException e) { 60 e.printStackTrace(); 61 } 62 } 63 64 @ParameterizedTest 65 @MethodSource("contentIntegrityChecks") hasCorrectLength(AsyncRequestBody asyncRequestBody)66 void hasCorrectLength(AsyncRequestBody asyncRequestBody) { 67 assertEquals(testString.length(), asyncRequestBody.contentLength().get()); 68 } 69 70 71 @ParameterizedTest 72 @MethodSource("contentIntegrityChecks") hasCorrectContent(AsyncRequestBody asyncRequestBody)73 void hasCorrectContent(AsyncRequestBody asyncRequestBody) throws InterruptedException { 74 StringBuilder sb = new StringBuilder(); 75 CountDownLatch done = new CountDownLatch(1); 76 77 Subscriber<ByteBuffer> subscriber = new SimpleSubscriber(buffer -> { 78 byte[] bytes = new byte[buffer.remaining()]; 79 buffer.get(bytes); 80 sb.append(new String(bytes, StandardCharsets.UTF_8)); 81 }) { 82 @Override 83 public void onError(Throwable t) { 84 super.onError(t); 85 done.countDown(); 86 } 87 88 @Override 89 public void onComplete() { 90 super.onComplete(); 91 done.countDown(); 92 } 93 }; 94 95 asyncRequestBody.subscribe(subscriber); 96 done.await(); 97 assertEquals(testString, sb.toString()); 98 } 99 contentIntegrityChecks()100 private static AsyncRequestBody[] contentIntegrityChecks() { 101 return new AsyncRequestBody[] { 102 AsyncRequestBody.fromString(testString), 103 AsyncRequestBody.fromFile(path) 104 }; 105 } 106 107 @Test fromBytesCopiesTheProvidedByteArray()108 void fromBytesCopiesTheProvidedByteArray() { 109 byte[] bytes = testString.getBytes(StandardCharsets.UTF_8); 110 byte[] bytesClone = bytes.clone(); 111 112 AsyncRequestBody asyncRequestBody = AsyncRequestBody.fromBytes(bytes); 113 114 for (int i = 0; i < bytes.length; i++) { 115 bytes[i] += 1; 116 } 117 118 AtomicReference<ByteBuffer> publishedBuffer = new AtomicReference<>(); 119 Subscriber<ByteBuffer> subscriber = new SimpleSubscriber(publishedBuffer::set); 120 121 asyncRequestBody.subscribe(subscriber); 122 123 byte[] publishedByteArray = BinaryUtils.copyAllBytesFrom(publishedBuffer.get()); 124 assertArrayEquals(bytesClone, publishedByteArray); 125 } 126 127 @Test fromBytesUnsafeDoesNotCopyTheProvidedByteArray()128 void fromBytesUnsafeDoesNotCopyTheProvidedByteArray() { 129 byte[] bytes = testString.getBytes(StandardCharsets.UTF_8); 130 131 AsyncRequestBody asyncRequestBody = AsyncRequestBody.fromBytesUnsafe(bytes); 132 133 for (int i = 0; i < bytes.length; i++) { 134 bytes[i] += 1; 135 } 136 137 AtomicReference<ByteBuffer> publishedBuffer = new AtomicReference<>(); 138 Subscriber<ByteBuffer> subscriber = new SimpleSubscriber(publishedBuffer::set); 139 140 asyncRequestBody.subscribe(subscriber); 141 142 byte[] publishedByteArray = BinaryUtils.copyAllBytesFrom(publishedBuffer.get()); 143 assertArrayEquals(bytes, publishedByteArray); 144 } 145 146 @ParameterizedTest 147 @MethodSource("safeByteBufferBodyBuilders") safeByteBufferBuildersCopyTheProvidedBuffer(Function<ByteBuffer, AsyncRequestBody> bodyBuilder)148 void safeByteBufferBuildersCopyTheProvidedBuffer(Function<ByteBuffer, AsyncRequestBody> bodyBuilder) { 149 byte[] bytes = testString.getBytes(StandardCharsets.UTF_8); 150 byte[] bytesClone = bytes.clone(); 151 152 AsyncRequestBody asyncRequestBody = bodyBuilder.apply(ByteBuffer.wrap(bytes)); 153 154 for (int i = 0; i < bytes.length; i++) { 155 bytes[i] += 1; 156 } 157 158 AtomicReference<ByteBuffer> publishedBuffer = new AtomicReference<>(); 159 Subscriber<ByteBuffer> subscriber = new SimpleSubscriber(publishedBuffer::set); 160 161 asyncRequestBody.subscribe(subscriber); 162 163 byte[] publishedByteArray = BinaryUtils.copyAllBytesFrom(publishedBuffer.get()); 164 assertArrayEquals(bytesClone, publishedByteArray); 165 } 166 safeByteBufferBodyBuilders()167 private static Function<ByteBuffer, AsyncRequestBody>[] safeByteBufferBodyBuilders() { 168 Function<ByteBuffer, AsyncRequestBody> fromByteBuffer = AsyncRequestBody::fromByteBuffer; 169 Function<ByteBuffer, AsyncRequestBody> fromRemainingByteBuffer = AsyncRequestBody::fromRemainingByteBuffer; 170 Function<ByteBuffer, AsyncRequestBody> fromByteBuffers = AsyncRequestBody::fromByteBuffers; 171 Function<ByteBuffer, AsyncRequestBody> fromRemainingByteBuffers = AsyncRequestBody::fromRemainingByteBuffers; 172 return new Function[] {fromByteBuffer, fromRemainingByteBuffer, fromByteBuffers, fromRemainingByteBuffers}; 173 } 174 175 @ParameterizedTest 176 @MethodSource("unsafeByteBufferBodyBuilders") unsafeByteBufferBuildersDoNotCopyTheProvidedBuffer(Function<ByteBuffer, AsyncRequestBody> bodyBuilder)177 void unsafeByteBufferBuildersDoNotCopyTheProvidedBuffer(Function<ByteBuffer, AsyncRequestBody> bodyBuilder) { 178 byte[] bytes = testString.getBytes(StandardCharsets.UTF_8); 179 180 AsyncRequestBody asyncRequestBody = bodyBuilder.apply(ByteBuffer.wrap(bytes)); 181 182 for (int i = 0; i < bytes.length; i++) { 183 bytes[i] += 1; 184 } 185 186 AtomicReference<ByteBuffer> publishedBuffer = new AtomicReference<>(); 187 Subscriber<ByteBuffer> subscriber = new SimpleSubscriber(publishedBuffer::set); 188 189 asyncRequestBody.subscribe(subscriber); 190 191 byte[] publishedByteArray = BinaryUtils.copyAllBytesFrom(publishedBuffer.get()); 192 assertArrayEquals(bytes, publishedByteArray); 193 } 194 unsafeByteBufferBodyBuilders()195 private static Function<ByteBuffer, AsyncRequestBody>[] unsafeByteBufferBodyBuilders() { 196 Function<ByteBuffer, AsyncRequestBody> fromByteBuffer = AsyncRequestBody::fromByteBufferUnsafe; 197 Function<ByteBuffer, AsyncRequestBody> fromRemainingByteBuffer = AsyncRequestBody::fromRemainingByteBufferUnsafe; 198 Function<ByteBuffer, AsyncRequestBody> fromByteBuffers = AsyncRequestBody::fromByteBuffersUnsafe; 199 Function<ByteBuffer, AsyncRequestBody> fromRemainingByteBuffers = AsyncRequestBody::fromRemainingByteBuffersUnsafe; 200 return new Function[] {fromByteBuffer, fromRemainingByteBuffer, fromByteBuffers, fromRemainingByteBuffers}; 201 } 202 203 @ParameterizedTest 204 @MethodSource("nonRewindingByteBufferBodyBuilders") nonRewindingByteBufferBuildersReadFromTheInputBufferPosition( Function<ByteBuffer, AsyncRequestBody> bodyBuilder)205 void nonRewindingByteBufferBuildersReadFromTheInputBufferPosition( 206 Function<ByteBuffer, AsyncRequestBody> bodyBuilder) { 207 byte[] bytes = testString.getBytes(StandardCharsets.UTF_8); 208 ByteBuffer bb = ByteBuffer.wrap(bytes); 209 int expectedPosition = bytes.length / 2; 210 bb.position(expectedPosition); 211 212 AsyncRequestBody asyncRequestBody = bodyBuilder.apply(bb); 213 214 AtomicReference<ByteBuffer> publishedBuffer = new AtomicReference<>(); 215 Subscriber<ByteBuffer> subscriber = new SimpleSubscriber(publishedBuffer::set); 216 217 asyncRequestBody.subscribe(subscriber); 218 219 int remaining = bb.remaining(); 220 assertEquals(remaining, publishedBuffer.get().remaining()); 221 for (int i = 0; i < remaining; i++) { 222 assertEquals(bb.get(), publishedBuffer.get().get()); 223 } 224 } 225 nonRewindingByteBufferBodyBuilders()226 private static Function<ByteBuffer, AsyncRequestBody>[] nonRewindingByteBufferBodyBuilders() { 227 Function<ByteBuffer, AsyncRequestBody> fromRemainingByteBuffer = AsyncRequestBody::fromRemainingByteBuffer; 228 Function<ByteBuffer, AsyncRequestBody> fromRemainingByteBufferUnsafe = AsyncRequestBody::fromRemainingByteBufferUnsafe; 229 Function<ByteBuffer, AsyncRequestBody> fromRemainingByteBuffers = AsyncRequestBody::fromRemainingByteBuffers; 230 Function<ByteBuffer, AsyncRequestBody> fromRemainingByteBuffersUnsafe = AsyncRequestBody::fromRemainingByteBuffersUnsafe; 231 return new Function[] {fromRemainingByteBuffer, fromRemainingByteBufferUnsafe, fromRemainingByteBuffers, 232 fromRemainingByteBuffersUnsafe}; 233 } 234 235 @ParameterizedTest 236 @MethodSource("safeNonRewindingByteBufferBodyBuilders") safeNonRewindingByteBufferBuildersCopyFromTheInputBufferPosition( Function<ByteBuffer, AsyncRequestBody> bodyBuilder)237 void safeNonRewindingByteBufferBuildersCopyFromTheInputBufferPosition( 238 Function<ByteBuffer, AsyncRequestBody> bodyBuilder) { 239 byte[] bytes = testString.getBytes(StandardCharsets.UTF_8); 240 ByteBuffer bb = ByteBuffer.wrap(bytes); 241 int expectedPosition = bytes.length / 2; 242 bb.position(expectedPosition); 243 244 AsyncRequestBody asyncRequestBody = bodyBuilder.apply(bb); 245 246 AtomicReference<ByteBuffer> publishedBuffer = new AtomicReference<>(); 247 Subscriber<ByteBuffer> subscriber = new SimpleSubscriber(publishedBuffer::set); 248 249 asyncRequestBody.subscribe(subscriber); 250 251 int remaining = bb.remaining(); 252 assertEquals(remaining, publishedBuffer.get().capacity()); 253 for (int i = 0; i < remaining; i++) { 254 assertEquals(bb.get(), publishedBuffer.get().get()); 255 } 256 } 257 safeNonRewindingByteBufferBodyBuilders()258 private static Function<ByteBuffer, AsyncRequestBody>[] safeNonRewindingByteBufferBodyBuilders() { 259 Function<ByteBuffer, AsyncRequestBody> fromRemainingByteBuffer = AsyncRequestBody::fromRemainingByteBuffer; 260 Function<ByteBuffer, AsyncRequestBody> fromRemainingByteBuffers = AsyncRequestBody::fromRemainingByteBuffers; 261 return new Function[] {fromRemainingByteBuffer, fromRemainingByteBuffers}; 262 } 263 264 @ParameterizedTest 265 @MethodSource("rewindingByteBufferBodyBuilders") rewindingByteBufferBuildersDoNotRewindTheInputBuffer(Function<ByteBuffer, AsyncRequestBody> bodyBuilder)266 void rewindingByteBufferBuildersDoNotRewindTheInputBuffer(Function<ByteBuffer, AsyncRequestBody> bodyBuilder) { 267 byte[] bytes = testString.getBytes(StandardCharsets.UTF_8); 268 ByteBuffer bb = ByteBuffer.wrap(bytes); 269 int expectedPosition = bytes.length / 2; 270 bb.position(expectedPosition); 271 272 AsyncRequestBody asyncRequestBody = bodyBuilder.apply(bb); 273 274 Subscriber<ByteBuffer> subscriber = new SimpleSubscriber(buffer -> { 275 }); 276 277 asyncRequestBody.subscribe(subscriber); 278 279 assertEquals(expectedPosition, bb.position()); 280 } 281 282 @ParameterizedTest 283 @MethodSource("rewindingByteBufferBodyBuilders") rewindingByteBufferBuildersReadTheInputBufferFromTheBeginning( Function<ByteBuffer, AsyncRequestBody> bodyBuilder)284 void rewindingByteBufferBuildersReadTheInputBufferFromTheBeginning( 285 Function<ByteBuffer, AsyncRequestBody> bodyBuilder) { 286 byte[] bytes = testString.getBytes(StandardCharsets.UTF_8); 287 ByteBuffer bb = ByteBuffer.wrap(bytes); 288 bb.position(bytes.length / 2); 289 290 AsyncRequestBody asyncRequestBody = bodyBuilder.apply(bb); 291 292 AtomicReference<ByteBuffer> publishedBuffer = new AtomicReference<>(); 293 Subscriber<ByteBuffer> subscriber = new SimpleSubscriber(publishedBuffer::set); 294 295 asyncRequestBody.subscribe(subscriber); 296 297 assertEquals(0, publishedBuffer.get().position()); 298 publishedBuffer.get().rewind(); 299 bb.rewind(); 300 assertEquals(bb, publishedBuffer.get()); 301 } 302 rewindingByteBufferBodyBuilders()303 private static Function<ByteBuffer, AsyncRequestBody>[] rewindingByteBufferBodyBuilders() { 304 Function<ByteBuffer, AsyncRequestBody> fromByteBuffer = AsyncRequestBody::fromByteBuffer; 305 Function<ByteBuffer, AsyncRequestBody> fromByteBufferUnsafe = AsyncRequestBody::fromByteBufferUnsafe; 306 Function<ByteBuffer, AsyncRequestBody> fromByteBuffers = AsyncRequestBody::fromByteBuffers; 307 Function<ByteBuffer, AsyncRequestBody> fromByteBuffersUnsafe = AsyncRequestBody::fromByteBuffersUnsafe; 308 return new Function[] {fromByteBuffer, fromByteBufferUnsafe, fromByteBuffers, fromByteBuffersUnsafe}; 309 } 310 311 @ParameterizedTest 312 @ValueSource(strings = {"US-ASCII", "ISO-8859-1", "UTF-8", "UTF-16BE", "UTF-16LE", "UTF-16"}) charsetsAreConvertedToTheCorrectContentType(Charset charset)313 void charsetsAreConvertedToTheCorrectContentType(Charset charset) { 314 AsyncRequestBody requestBody = AsyncRequestBody.fromString("hello world", charset); 315 assertEquals("text/plain; charset=" + charset.name(), requestBody.contentType()); 316 } 317 318 @Test stringConstructorHasCorrectDefaultContentType()319 void stringConstructorHasCorrectDefaultContentType() { 320 AsyncRequestBody requestBody = AsyncRequestBody.fromString("hello world"); 321 assertEquals("text/plain; charset=UTF-8", requestBody.contentType()); 322 } 323 324 @Test fileConstructorHasCorrectContentType()325 void fileConstructorHasCorrectContentType() { 326 AsyncRequestBody requestBody = AsyncRequestBody.fromFile(path); 327 assertEquals(Mimetype.MIMETYPE_OCTET_STREAM, requestBody.contentType()); 328 } 329 330 @Test bytesArrayConstructorHasCorrectContentType()331 void bytesArrayConstructorHasCorrectContentType() { 332 AsyncRequestBody requestBody = AsyncRequestBody.fromBytes("hello world".getBytes()); 333 assertEquals(Mimetype.MIMETYPE_OCTET_STREAM, requestBody.contentType()); 334 } 335 336 @Test bytesBufferConstructorHasCorrectContentType()337 void bytesBufferConstructorHasCorrectContentType() { 338 ByteBuffer byteBuffer = ByteBuffer.wrap("hello world".getBytes()); 339 AsyncRequestBody requestBody = AsyncRequestBody.fromByteBuffer(byteBuffer); 340 assertEquals(Mimetype.MIMETYPE_OCTET_STREAM, requestBody.contentType()); 341 } 342 343 @Test emptyBytesConstructorHasCorrectContentType()344 void emptyBytesConstructorHasCorrectContentType() { 345 AsyncRequestBody requestBody = AsyncRequestBody.empty(); 346 assertEquals(Mimetype.MIMETYPE_OCTET_STREAM, requestBody.contentType()); 347 } 348 349 @Test publisherConstructorHasCorrectContentType()350 void publisherConstructorHasCorrectContentType() { 351 List<String> requestBodyStrings = Lists.newArrayList("A", "B", "C"); 352 List<ByteBuffer> bodyBytes = requestBodyStrings.stream() 353 .map(s -> ByteBuffer.wrap(s.getBytes(StandardCharsets.UTF_8))) 354 .collect(Collectors.toList()); 355 Publisher<ByteBuffer> bodyPublisher = Flowable.fromIterable(bodyBytes); 356 AsyncRequestBody requestBody = AsyncRequestBody.fromPublisher(bodyPublisher); 357 assertEquals(Mimetype.MIMETYPE_OCTET_STREAM, requestBody.contentType()); 358 } 359 } 360