1 /*
2  * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License").
5  * You may not use this file except in compliance with the License.
6  * A copy of the License is located at
7  *
8  *  http://aws.amazon.com/apache2.0
9  *
10  * or in the "license" file accompanying this file. This file is distributed
11  * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12  * express or implied. See the License for the specific language governing
13  * permissions and limitations under the License.
14  */
15 
16 package software.amazon.awssdk.http.nio.netty.internal;
17 
18 import static java.util.stream.Collectors.groupingBy;
19 import static java.util.stream.Collectors.mapping;
20 import static software.amazon.awssdk.http.nio.netty.internal.ChannelAttributeKey.CHANNEL_DIAGNOSTICS;
21 import static software.amazon.awssdk.http.nio.netty.internal.ChannelAttributeKey.EXECUTE_FUTURE_KEY;
22 import static software.amazon.awssdk.http.nio.netty.internal.ChannelAttributeKey.KEEP_ALIVE;
23 import static software.amazon.awssdk.http.nio.netty.internal.ChannelAttributeKey.REQUEST_CONTEXT_KEY;
24 import static software.amazon.awssdk.http.nio.netty.internal.ChannelAttributeKey.RESPONSE_COMPLETE_KEY;
25 import static software.amazon.awssdk.http.nio.netty.internal.ChannelAttributeKey.RESPONSE_CONTENT_LENGTH;
26 import static software.amazon.awssdk.http.nio.netty.internal.ChannelAttributeKey.RESPONSE_DATA_READ;
27 import static software.amazon.awssdk.http.nio.netty.internal.ChannelAttributeKey.RESPONSE_STATUS_CODE;
28 import static software.amazon.awssdk.http.nio.netty.internal.ChannelAttributeKey.STREAMING_COMPLETE_KEY;
29 import static software.amazon.awssdk.http.nio.netty.internal.utils.ExceptionHandlingUtils.tryCatch;
30 import static software.amazon.awssdk.http.nio.netty.internal.utils.ExceptionHandlingUtils.tryCatchFinally;
31 
32 import io.netty.buffer.ByteBuf;
33 import io.netty.channel.Channel;
34 import io.netty.channel.ChannelHandler.Sharable;
35 import io.netty.channel.ChannelHandlerContext;
36 import io.netty.channel.SimpleChannelInboundHandler;
37 import io.netty.handler.codec.http.FullHttpResponse;
38 import io.netty.handler.codec.http.HttpContent;
39 import io.netty.handler.codec.http.HttpHeaderNames;
40 import io.netty.handler.codec.http.HttpHeaders;
41 import io.netty.handler.codec.http.HttpObject;
42 import io.netty.handler.codec.http.HttpResponse;
43 import io.netty.handler.codec.http.HttpResponseStatus;
44 import io.netty.handler.codec.http.HttpUtil;
45 import io.netty.util.ReferenceCountUtil;
46 import java.io.IOException;
47 import java.nio.ByteBuffer;
48 import java.util.List;
49 import java.util.Map;
50 import java.util.Optional;
51 import java.util.concurrent.CompletableFuture;
52 import java.util.concurrent.atomic.AtomicBoolean;
53 import java.util.function.Supplier;
54 import java.util.stream.Collectors;
55 import org.reactivestreams.Publisher;
56 import org.reactivestreams.Subscriber;
57 import org.reactivestreams.Subscription;
58 import software.amazon.awssdk.annotations.SdkInternalApi;
59 import software.amazon.awssdk.http.HttpStatusFamily;
60 import software.amazon.awssdk.http.Protocol;
61 import software.amazon.awssdk.http.SdkCancellationException;
62 import software.amazon.awssdk.http.SdkHttpFullResponse;
63 import software.amazon.awssdk.http.SdkHttpMethod;
64 import software.amazon.awssdk.http.SdkHttpResponse;
65 import software.amazon.awssdk.http.async.SdkAsyncHttpResponseHandler;
66 import software.amazon.awssdk.http.nio.netty.internal.http2.Http2ResetSendingSubscription;
67 import software.amazon.awssdk.http.nio.netty.internal.nrs.HttpStreamsClientHandler;
68 import software.amazon.awssdk.http.nio.netty.internal.nrs.StreamedHttpResponse;
69 import software.amazon.awssdk.http.nio.netty.internal.utils.ChannelUtils;
70 import software.amazon.awssdk.http.nio.netty.internal.utils.NettyClientLogger;
71 import software.amazon.awssdk.http.nio.netty.internal.utils.NettyUtils;
72 import software.amazon.awssdk.utils.FunctionalUtils.UnsafeRunnable;
73 import software.amazon.awssdk.utils.async.DelegatingSubscription;
74 
75 @Sharable
76 @SdkInternalApi
77 public class ResponseHandler extends SimpleChannelInboundHandler<HttpObject> {
78 
79     private static final NettyClientLogger log = NettyClientLogger.getLogger(ResponseHandler.class);
80 
81     private static final ResponseHandler INSTANCE = new ResponseHandler();
82 
ResponseHandler()83     private ResponseHandler() {
84     }
85 
86     @Override
channelRead0(ChannelHandlerContext channelContext, HttpObject msg)87     protected void channelRead0(ChannelHandlerContext channelContext, HttpObject msg) throws Exception {
88         RequestContext requestContext = channelContext.channel().attr(REQUEST_CONTEXT_KEY).get();
89 
90         if (msg instanceof HttpResponse) {
91             HttpResponse response = (HttpResponse) msg;
92             SdkHttpResponse sdkResponse = SdkHttpFullResponse.builder()
93                                                              .headers(fromNettyHeaders(response.headers()))
94                                                              .statusCode(response.status().code())
95                                                              .statusText(response.status().reasonPhrase())
96                                                              .build();
97             channelContext.channel().attr(RESPONSE_STATUS_CODE).set(response.status().code());
98             channelContext.channel().attr(RESPONSE_CONTENT_LENGTH).set(responseContentLength(response));
99             channelContext.channel().attr(KEEP_ALIVE).set(shouldKeepAlive(response));
100             ChannelUtils.getAttribute(channelContext.channel(), CHANNEL_DIAGNOSTICS)
101                         .ifPresent(ChannelDiagnostics::incrementResponseCount);
102             requestContext.handler().onHeaders(sdkResponse);
103         }
104 
105         CompletableFuture<Void> ef = executeFuture(channelContext);
106         if (msg instanceof StreamedHttpResponse) {
107             requestContext.handler().onStream(
108                     new DataCountingPublisher(channelContext,
109                                               new PublisherAdapter((StreamedHttpResponse) msg, channelContext,
110                                                                    requestContext, ef)));
111         } else if (msg instanceof FullHttpResponse) {
112             ByteBuf fullContent = null;
113             try {
114                 // Be prepared to take care of (ignore) a trailing LastHttpResponse
115                 // from the HttpClientCodec if there is one.
116                 channelContext.pipeline().replace(HttpStreamsClientHandler.class,
117                                                   channelContext.name() + "-LastHttpContentSwallower",
118                                                   LastHttpContentSwallower.getInstance());
119 
120                 fullContent = ((FullHttpResponse) msg).content();
121                 ByteBuffer bb = copyToByteBuffer(fullContent);
122                 requestContext.handler().onStream(new DataCountingPublisher(channelContext,
123                                                                             new FullResponseContentPublisher(channelContext,
124                                                                                                              bb, ef)));
125 
126                 try {
127                     validateResponseContentLength(channelContext);
128                     finalizeResponse(requestContext, channelContext);
129                 } catch (IOException e) {
130                     exceptionCaught(channelContext, e);
131                 }
132             } finally {
133                 Optional.ofNullable(fullContent).ifPresent(ByteBuf::release);
134             }
135         }
136     }
137 
responseContentLength(HttpResponse response)138     private Long responseContentLength(HttpResponse response) {
139         String length = response.headers().get(HttpHeaderNames.CONTENT_LENGTH);
140         if (length == null) {
141             return null;
142         }
143 
144         return Long.parseLong(length);
145     }
146 
validateResponseContentLength(ChannelHandlerContext ctx)147     private static void validateResponseContentLength(ChannelHandlerContext ctx) throws IOException {
148         if (!shouldValidateResponseContentLength(ctx)) {
149             return;
150         }
151 
152         Long contentLengthHeader = ctx.channel().attr(RESPONSE_CONTENT_LENGTH).get();
153         Long actualContentLength = ctx.channel().attr(RESPONSE_DATA_READ).get();
154 
155         if (contentLengthHeader == null) {
156             return;
157         }
158 
159         if (actualContentLength == null) {
160             actualContentLength = 0L;
161         }
162 
163         if (actualContentLength.equals(contentLengthHeader)) {
164             return;
165         }
166 
167         throw new IOException("Response had content-length of " + contentLengthHeader + " bytes, but only received "
168                               + actualContentLength + " bytes before the connection was closed.");
169     }
170 
shouldValidateResponseContentLength(ChannelHandlerContext ctx)171     private static boolean shouldValidateResponseContentLength(ChannelHandlerContext ctx) {
172         RequestContext requestContext = ctx.channel().attr(REQUEST_CONTEXT_KEY).get();
173 
174         // HEAD requests may return Content-Length without a payload
175         if (requestContext.executeRequest().request().method() == SdkHttpMethod.HEAD) {
176             return false;
177         }
178 
179         // 304 responses may contain Content-Length without a payload
180         Integer responseStatusCode = ctx.channel().attr(RESPONSE_STATUS_CODE).get();
181         if (responseStatusCode != null && responseStatusCode == HttpResponseStatus.NOT_MODIFIED.code()) {
182             return false;
183         }
184 
185         return true;
186     }
187 
188 
189     /**
190      * Finalize the response by completing the execute future and release the channel pool being used.
191      *
192      * @param requestContext the request context
193      * @param channelContext the channel context
194      */
finalizeResponse(RequestContext requestContext, ChannelHandlerContext channelContext)195     private static void finalizeResponse(RequestContext requestContext, ChannelHandlerContext channelContext) {
196         channelContext.channel().attr(RESPONSE_COMPLETE_KEY).set(true);
197 
198         executeFuture(channelContext).complete(null);
199         if (!channelContext.channel().attr(KEEP_ALIVE).get()) {
200             closeAndRelease(channelContext);
201         } else {
202             requestContext.channelPool().release(channelContext.channel());
203         }
204     }
205 
shouldKeepAlive(HttpResponse response)206     private boolean shouldKeepAlive(HttpResponse response) {
207         if (HttpStatusFamily.of(response.status().code()) == HttpStatusFamily.SERVER_ERROR) {
208             return false;
209         }
210         return HttpUtil.isKeepAlive(response);
211     }
212 
213     @Override
exceptionCaught(ChannelHandlerContext ctx, Throwable cause)214     public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
215         RequestContext requestContext = ctx.channel().attr(REQUEST_CONTEXT_KEY).get();
216         log.debug(ctx.channel(), () -> "Exception processing request: " + requestContext.executeRequest().request(), cause);
217         Throwable throwable = NettyUtils.decorateException(ctx.channel(), cause);
218         executeFuture(ctx).completeExceptionally(throwable);
219         runAndLogError(ctx.channel(), () -> "Fail to execute SdkAsyncHttpResponseHandler#onError",
220                        () -> requestContext.handler().onError(throwable));
221         runAndLogError(ctx.channel(), () -> "Could not release channel back to the pool", () -> closeAndRelease(ctx));
222     }
223 
224     @Override
channelInactive(ChannelHandlerContext handlerCtx)225     public void channelInactive(ChannelHandlerContext handlerCtx) {
226         notifyIfResponseNotCompleted(handlerCtx);
227     }
228 
getInstance()229     public static ResponseHandler getInstance() {
230         return INSTANCE;
231     }
232 
233     /**
234      * Close the channel and release it back into the pool.
235      *
236      * @param ctx Context for channel
237      */
closeAndRelease(ChannelHandlerContext ctx)238     private static void closeAndRelease(ChannelHandlerContext ctx) {
239         Channel channel = ctx.channel();
240         channel.attr(KEEP_ALIVE).set(false);
241         RequestContext requestContext = channel.attr(REQUEST_CONTEXT_KEY).get();
242         ctx.close();
243         requestContext.channelPool().release(channel);
244     }
245 
246     /**
247      * Runs a given {@link UnsafeRunnable} and logs an error without throwing.
248      *
249      * @param errorMsg Message to log with exception thrown.
250      * @param runnable Action to perform.
251      */
runAndLogError(Channel ch, Supplier<String> errorMsg, UnsafeRunnable runnable)252     private static void runAndLogError(Channel ch, Supplier<String> errorMsg, UnsafeRunnable runnable) {
253         try {
254             runnable.run();
255         } catch (Exception e) {
256             log.error(ch, errorMsg, e);
257         }
258     }
259 
fromNettyHeaders(HttpHeaders headers)260     private static Map<String, List<String>> fromNettyHeaders(HttpHeaders headers) {
261         return headers.entries().stream()
262                 .collect(groupingBy(Map.Entry::getKey,
263                         mapping(Map.Entry::getValue, Collectors.toList())));
264     }
265 
copyToByteBuffer(ByteBuf byteBuf)266     private static ByteBuffer copyToByteBuffer(ByteBuf byteBuf) {
267         ByteBuffer bb = ByteBuffer.allocate(byteBuf.readableBytes());
268         byteBuf.getBytes(byteBuf.readerIndex(), bb);
269         bb.flip();
270         return bb;
271     }
272 
executeFuture(ChannelHandlerContext ctx)273     private static CompletableFuture<Void> executeFuture(ChannelHandlerContext ctx) {
274         return ctx.channel().attr(EXECUTE_FUTURE_KEY).get();
275     }
276 
277     static class PublisherAdapter implements Publisher<ByteBuffer> {
278         private final StreamedHttpResponse response;
279         private final ChannelHandlerContext channelContext;
280         private final RequestContext requestContext;
281         private final CompletableFuture<Void> executeFuture;
282         private final AtomicBoolean isDone = new AtomicBoolean(false);
283 
PublisherAdapter(StreamedHttpResponse response, ChannelHandlerContext channelContext, RequestContext requestContext, CompletableFuture<Void> executeFuture)284         PublisherAdapter(StreamedHttpResponse response, ChannelHandlerContext channelContext,
285                          RequestContext requestContext, CompletableFuture<Void> executeFuture) {
286             this.response = response;
287             this.channelContext = channelContext;
288             this.requestContext = requestContext;
289             this.executeFuture = executeFuture;
290         }
291 
292         @Override
subscribe(Subscriber<? super ByteBuffer> subscriber)293         public void subscribe(Subscriber<? super ByteBuffer> subscriber) {
294             response.subscribe(new Subscriber<HttpContent>() {
295                 @Override
296                 public void onSubscribe(Subscription subscription) {
297                     subscriber.onSubscribe(new OnCancelSubscription(resolveSubscription(subscription),
298                                                                     this::onCancel));
299                 }
300 
301                 private Subscription resolveSubscription(Subscription subscription) {
302                     // For HTTP2 we send a RST_STREAM frame on cancel to stop the service from sending more data
303                     if (ChannelAttributeKey.getProtocolNow(channelContext.channel()) == Protocol.HTTP2) {
304                         return new Http2ResetSendingSubscription(channelContext, subscription);
305                     } else {
306                         return subscription;
307                     }
308                 }
309 
310                 private void onCancel() {
311                     if (!isDone.compareAndSet(false, true)) {
312                         return;
313                     }
314                     try {
315                         SdkCancellationException e = new SdkCancellationException(
316                                 "Subscriber cancelled before all events were published");
317                         log.warn(channelContext.channel(), () -> "Subscriber cancelled before all events were published");
318                         executeFuture.completeExceptionally(e);
319                     } finally {
320                         runAndLogError(channelContext.channel(), () -> "Could not release channel back to the pool",
321                             () -> closeAndRelease(channelContext));
322                     }
323                 }
324 
325                 @Override
326                 public void onNext(HttpContent httpContent) {
327                     // isDone may be true if the subscriber cancelled
328                     if (isDone.get()) {
329                         ReferenceCountUtil.release(httpContent);
330                         return;
331                     }
332 
333                     // Needed to prevent use-after-free bug if the subscriber's onNext is asynchronous
334                     ByteBuffer byteBuffer =
335                         tryCatchFinally(() -> copyToByteBuffer(httpContent.content()),
336                                         this::onError,
337                                         httpContent::release);
338 
339 
340                     //As per reactive-streams rule 2.13, we should not call subscriber#onError when
341                     //exception is thrown from subscriber#onNext
342                     if (byteBuffer != null) {
343                         tryCatch(() -> subscriber.onNext(byteBuffer),
344                                  this::notifyError);
345                     }
346                 }
347 
348                 @Override
349                 public void onError(Throwable t) {
350                     if (!isDone.compareAndSet(false, true)) {
351                         return;
352                     }
353                     try {
354                         runAndLogError(channelContext.channel(),
355                                        () -> String.format("Subscriber %s threw an exception in onError.", subscriber),
356                                        () -> subscriber.onError(t));
357                         notifyError(t);
358                     } finally {
359                         runAndLogError(channelContext.channel(), () -> "Could not release channel back to the pool",
360                             () -> closeAndRelease(channelContext));
361                     }
362                 }
363 
364                 @Override
365                 public void onComplete() {
366                     // For HTTP/2 it's possible to get an onComplete after we cancel due to the channel becoming
367                     // inactive. We guard against that here and just ignore the signal (see HandlerPublisher)
368                     if (!isDone.compareAndSet(false, true)) {
369                         return;
370                     }
371 
372                     try {
373                         validateResponseContentLength(channelContext);
374                         try {
375                             runAndLogError(channelContext.channel(),
376                                            () -> String.format("Subscriber %s threw an exception in onComplete.", subscriber),
377                                            subscriber::onComplete);
378                         } finally {
379                             finalizeResponse(requestContext, channelContext);
380                         }
381                     } catch (IOException e) {
382                         notifyError(e);
383                         runAndLogError(channelContext.channel(), () -> "Could not release channel back to the pool",
384                                        () -> closeAndRelease(channelContext));
385                     }
386                 }
387 
388                 private void notifyError(Throwable throwable) {
389                     SdkAsyncHttpResponseHandler handler = requestContext.handler();
390                     runAndLogError(channelContext.channel(),
391                                    () -> String.format("SdkAsyncHttpResponseHandler %s threw an exception in onError.", handler),
392                                    () -> handler.onError(throwable));
393                     executeFuture.completeExceptionally(throwable);
394                 }
395 
396             });
397         }
398     }
399 
400     /**
401      * Decorator around a {@link Subscription} to notify if a cancellation occurs.
402      */
403     private static class OnCancelSubscription extends DelegatingSubscription {
404 
405         private final Runnable onCancel;
406 
OnCancelSubscription(Subscription subscription, Runnable onCancel)407         private OnCancelSubscription(Subscription subscription, Runnable onCancel) {
408             super(subscription);
409             this.onCancel = onCancel;
410         }
411 
412         @Override
cancel()413         public void cancel() {
414             onCancel.run();
415             super.cancel();
416         }
417     }
418 
419     static class FullResponseContentPublisher implements Publisher<ByteBuffer> {
420         private final ChannelHandlerContext channelContext;
421         private final ByteBuffer fullContent;
422         private final CompletableFuture<Void> executeFuture;
423         private boolean running = true;
424         private Subscriber<? super ByteBuffer> subscriber;
425 
FullResponseContentPublisher(ChannelHandlerContext channelContext, ByteBuffer fullContent, CompletableFuture<Void> executeFuture)426         FullResponseContentPublisher(ChannelHandlerContext channelContext, ByteBuffer fullContent,
427                                      CompletableFuture<Void> executeFuture) {
428             this.channelContext = channelContext;
429             this.fullContent = fullContent;
430             this.executeFuture = executeFuture;
431         }
432 
433         @Override
subscribe(Subscriber<? super ByteBuffer> subscriber)434         public void subscribe(Subscriber<? super ByteBuffer> subscriber) {
435             if (this.subscriber != null) {
436                 subscriber.onComplete();
437                 return;
438             }
439             this.subscriber = subscriber;
440             channelContext.channel().attr(ChannelAttributeKey.SUBSCRIBER_KEY)
441                     .set(subscriber);
442 
443             subscriber.onSubscribe(new Subscription() {
444                 @Override
445                 public void request(long l) {
446                     if (running) {
447                         running = false;
448                         if (l <= 0) {
449                             subscriber.onError(new IllegalArgumentException("Demand must be positive!"));
450                         } else {
451                             if (fullContent.hasRemaining()) {
452                                 subscriber.onNext(fullContent);
453                             }
454                             subscriber.onComplete();
455                             executeFuture.complete(null);
456                         }
457                     }
458                 }
459 
460                 @Override
461                 public void cancel() {
462                     running = false;
463                 }
464             });
465 
466         }
467     }
468 
notifyIfResponseNotCompleted(ChannelHandlerContext handlerCtx)469     private void notifyIfResponseNotCompleted(ChannelHandlerContext handlerCtx) {
470         RequestContext requestCtx = handlerCtx.channel().attr(REQUEST_CONTEXT_KEY).get();
471         Boolean responseCompleted = handlerCtx.channel().attr(RESPONSE_COMPLETE_KEY).get();
472         Boolean isStreamingComplete = handlerCtx.channel().attr(STREAMING_COMPLETE_KEY).get();
473         handlerCtx.channel().attr(KEEP_ALIVE).set(false);
474 
475         if (!Boolean.TRUE.equals(responseCompleted) && !Boolean.TRUE.equals(isStreamingComplete)) {
476             IOException err = new IOException(NettyUtils.closedChannelMessage(handlerCtx.channel()));
477             runAndLogError(handlerCtx.channel(), () -> "Fail to execute SdkAsyncHttpResponseHandler#onError",
478                            () -> requestCtx.handler().onError(err));
479             executeFuture(handlerCtx).completeExceptionally(err);
480             runAndLogError(handlerCtx.channel(), () -> "Could not release channel", () -> closeAndRelease(handlerCtx));
481         }
482     }
483 
484     private static final class DataCountingPublisher implements Publisher<ByteBuffer> {
485         private final ChannelHandlerContext ctx;
486         private final Publisher<ByteBuffer> delegate;
487 
DataCountingPublisher(ChannelHandlerContext ctx, Publisher<ByteBuffer> delegate)488         private DataCountingPublisher(ChannelHandlerContext ctx, Publisher<ByteBuffer> delegate) {
489             this.ctx = ctx;
490             this.delegate = delegate;
491         }
492 
493         @Override
subscribe(Subscriber<? super ByteBuffer> subscriber)494         public void subscribe(Subscriber<? super ByteBuffer> subscriber) {
495             delegate.subscribe(new Subscriber<ByteBuffer>() {
496                 @Override
497                 public void onSubscribe(Subscription subscription) {
498                     subscriber.onSubscribe(subscription);
499                 }
500 
501                 @Override
502                 public void onNext(ByteBuffer byteBuffer) {
503                     Long responseDataSoFar = ctx.channel().attr(RESPONSE_DATA_READ).get();
504                     if (responseDataSoFar == null) {
505                         responseDataSoFar = 0L;
506                     }
507 
508                     ctx.channel().attr(RESPONSE_DATA_READ).set(responseDataSoFar + byteBuffer.remaining());
509                     subscriber.onNext(byteBuffer);
510                 }
511 
512                 @Override
513                 public void onError(Throwable throwable) {
514                     subscriber.onError(throwable);
515                 }
516 
517                 @Override
518                 public void onComplete() {
519                     subscriber.onComplete();
520                 }
521             });
522         }
523     }
524 }