1 // Copyright 2020 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/read_buffering_stream_socket.h"
6
7 #include <algorithm>
8
9 #include "base/check_op.h"
10 #include "base/notreached.h"
11 #include "net/base/io_buffer.h"
12
13 namespace net {
14
ReadBufferingStreamSocket(std::unique_ptr<StreamSocket> transport)15 ReadBufferingStreamSocket::ReadBufferingStreamSocket(
16 std::unique_ptr<StreamSocket> transport)
17 : WrappedStreamSocket(std::move(transport)) {}
18
19 ReadBufferingStreamSocket::~ReadBufferingStreamSocket() = default;
20
BufferNextRead(int size)21 void ReadBufferingStreamSocket::BufferNextRead(int size) {
22 DCHECK(!user_read_buf_);
23 read_buffer_ = base::MakeRefCounted<GrowableIOBuffer>();
24 read_buffer_->SetCapacity(size);
25 buffer_full_ = false;
26 }
27
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)28 int ReadBufferingStreamSocket::Read(IOBuffer* buf,
29 int buf_len,
30 CompletionOnceCallback callback) {
31 DCHECK(!user_read_buf_);
32 if (!read_buffer_)
33 return transport_->Read(buf, buf_len, std::move(callback));
34 int rv = ReadIfReady(buf, buf_len, std::move(callback));
35 if (rv == ERR_IO_PENDING) {
36 user_read_buf_ = buf;
37 user_read_buf_len_ = buf_len;
38 }
39 return rv;
40 }
41
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)42 int ReadBufferingStreamSocket::ReadIfReady(IOBuffer* buf,
43 int buf_len,
44 CompletionOnceCallback callback) {
45 DCHECK(!user_read_buf_);
46 if (!read_buffer_)
47 return transport_->ReadIfReady(buf, buf_len, std::move(callback));
48
49 if (buffer_full_)
50 return CopyToCaller(buf, buf_len);
51
52 state_ = STATE_READ;
53 int rv = DoLoop(OK);
54 if (rv == OK) {
55 rv = CopyToCaller(buf, buf_len);
56 } else if (rv == ERR_IO_PENDING) {
57 user_read_callback_ = std::move(callback);
58 }
59 return rv;
60 }
61
DoLoop(int result)62 int ReadBufferingStreamSocket::DoLoop(int result) {
63 int rv = result;
64 do {
65 State current_state = state_;
66 state_ = STATE_NONE;
67 switch (current_state) {
68 case STATE_READ:
69 rv = DoRead();
70 break;
71 case STATE_READ_COMPLETE:
72 rv = DoReadComplete(rv);
73 break;
74 case STATE_NONE:
75 default:
76 NOTREACHED() << "Unexpected state: " << current_state;
77 rv = ERR_UNEXPECTED;
78 break;
79 }
80 } while (rv != ERR_IO_PENDING && state_ != STATE_NONE);
81 return rv;
82 }
83
DoRead()84 int ReadBufferingStreamSocket::DoRead() {
85 DCHECK(read_buffer_);
86 DCHECK(!buffer_full_);
87
88 state_ = STATE_READ_COMPLETE;
89 return transport_->Read(
90 read_buffer_.get(), read_buffer_->RemainingCapacity(),
91 base::BindOnce(&ReadBufferingStreamSocket::OnReadCompleted,
92 base::Unretained(this)));
93 }
94
DoReadComplete(int result)95 int ReadBufferingStreamSocket::DoReadComplete(int result) {
96 state_ = STATE_NONE;
97
98 if (result <= 0)
99 return result;
100
101 read_buffer_->set_offset(read_buffer_->offset() + result);
102 if (read_buffer_->RemainingCapacity() > 0) {
103 // Keep reading until |read_buffer_| is full.
104 state_ = STATE_READ;
105 } else {
106 read_buffer_->set_offset(0);
107 buffer_full_ = true;
108 }
109 return OK;
110 }
111
OnReadCompleted(int result)112 void ReadBufferingStreamSocket::OnReadCompleted(int result) {
113 DCHECK_NE(ERR_IO_PENDING, result);
114 DCHECK(user_read_callback_);
115
116 result = DoLoop(result);
117 if (result == ERR_IO_PENDING)
118 return;
119 if (result == OK && user_read_buf_) {
120 // If the user called Read(), return the data to the caller.
121 result = CopyToCaller(user_read_buf_.get(), user_read_buf_len_);
122 user_read_buf_ = nullptr;
123 user_read_buf_len_ = 0;
124 }
125 std::move(user_read_callback_).Run(result);
126 }
127
CopyToCaller(IOBuffer * buf,int buf_len)128 int ReadBufferingStreamSocket::CopyToCaller(IOBuffer* buf, int buf_len) {
129 DCHECK(read_buffer_);
130 DCHECK(buffer_full_);
131
132 buf_len = std::min(buf_len, read_buffer_->RemainingCapacity());
133 memcpy(buf->data(), read_buffer_->data(), buf_len);
134 read_buffer_->set_offset(read_buffer_->offset() + buf_len);
135 if (read_buffer_->RemainingCapacity() == 0) {
136 read_buffer_ = nullptr;
137 buffer_full_ = false;
138 }
139 return buf_len;
140 }
141
142 } // namespace net
143