xref: /aosp_15_r20/external/tensorflow/tensorflow/core/lib/io/buffered_inputstream.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/lib/io/buffered_inputstream.h"
17 
18 #include "tensorflow/core/lib/io/random_inputstream.h"
19 
20 namespace tensorflow {
21 namespace io {
22 
BufferedInputStream(InputStreamInterface * input_stream,size_t buffer_bytes,bool owns_input_stream)23 BufferedInputStream::BufferedInputStream(InputStreamInterface* input_stream,
24                                          size_t buffer_bytes,
25                                          bool owns_input_stream)
26     : input_stream_(input_stream),
27       size_(buffer_bytes),
28       owns_input_stream_(owns_input_stream) {
29   buf_.reserve(size_);
30 }
31 
BufferedInputStream(RandomAccessFile * file,size_t buffer_bytes)32 BufferedInputStream::BufferedInputStream(RandomAccessFile* file,
33                                          size_t buffer_bytes)
34     : BufferedInputStream(new RandomAccessInputStream(file), buffer_bytes,
35                           true) {}
36 
~BufferedInputStream()37 BufferedInputStream::~BufferedInputStream() {
38   if (owns_input_stream_) {
39     delete input_stream_;
40   }
41 }
42 
FillBuffer()43 Status BufferedInputStream::FillBuffer() {
44   if (!file_status_.ok()) {
45     pos_ = 0;
46     limit_ = 0;
47     return file_status_;
48   }
49   Status s = input_stream_->ReadNBytes(size_, &buf_);
50   pos_ = 0;
51   limit_ = buf_.size();
52   if (!s.ok()) {
53     file_status_ = s;
54   }
55   return s;
56 }
57 
58 template <typename StringType>
ReadLineHelper(StringType * result,bool include_eol)59 Status BufferedInputStream::ReadLineHelper(StringType* result,
60                                            bool include_eol) {
61   result->clear();
62   Status s;
63   size_t start_pos = pos_;
64   while (true) {
65     if (pos_ == limit_) {
66       result->append(buf_.data() + start_pos, pos_ - start_pos);
67       // Get more data into buffer
68       s = FillBuffer();
69       if (limit_ == 0) {
70         break;
71       }
72       start_pos = pos_;
73     }
74     char c = buf_[pos_];
75     if (c == '\n') {
76       result->append(buf_.data() + start_pos, pos_ - start_pos);
77       if (include_eol) {
78         result->append(1, c);
79       }
80       pos_++;
81       return OkStatus();
82     }
83     // We don't append '\r' to *result
84     if (c == '\r') {
85       result->append(buf_.data() + start_pos, pos_ - start_pos);
86       start_pos = pos_ + 1;
87     }
88     pos_++;
89   }
90   if (errors::IsOutOfRange(s) && !result->empty()) {
91     return OkStatus();
92   }
93   return s;
94 }
95 
ReadNBytes(int64_t bytes_to_read,tstring * result)96 Status BufferedInputStream::ReadNBytes(int64_t bytes_to_read, tstring* result) {
97   if (bytes_to_read < 0) {
98     return errors::InvalidArgument("Can't read a negative number of bytes: ",
99                                    bytes_to_read);
100   }
101   result->clear();
102   if (pos_ == limit_ && !file_status_.ok() && bytes_to_read > 0) {
103     return file_status_;
104   }
105   result->reserve(bytes_to_read);
106 
107   Status s;
108   while (result->size() < static_cast<size_t>(bytes_to_read)) {
109     // Check whether the buffer is fully read or not.
110     if (pos_ == limit_) {
111       s = FillBuffer();
112       // If we didn't read any bytes, we're at the end of the file; break out.
113       if (limit_ == 0) {
114         DCHECK(!s.ok());
115         file_status_ = s;
116         break;
117       }
118     }
119     const int64_t bytes_to_copy =
120         std::min<int64_t>(limit_ - pos_, bytes_to_read - result->size());
121     result->insert(result->size(), buf_, pos_, bytes_to_copy);
122     pos_ += bytes_to_copy;
123   }
124   // Filling the buffer might lead to a situation when we go past the end of
125   // the file leading to an OutOfRange() status return. But we might have
126   // obtained enough data to satisfy the function call. Returning OK then.
127   if (errors::IsOutOfRange(s) &&
128       (result->size() == static_cast<size_t>(bytes_to_read))) {
129     return OkStatus();
130   }
131   return s;
132 }
133 
SkipNBytes(int64_t bytes_to_skip)134 Status BufferedInputStream::SkipNBytes(int64_t bytes_to_skip) {
135   if (bytes_to_skip < 0) {
136     return errors::InvalidArgument("Can only skip forward, not ",
137                                    bytes_to_skip);
138   }
139   if (pos_ + bytes_to_skip < limit_) {
140     // If we aren't skipping too much, then we can just move pos_;
141     pos_ += bytes_to_skip;
142   } else {
143     // Otherwise, we already have read limit_ - pos_, so skip the rest. At this
144     // point we need to get fresh data into the buffer, so reset pos_ and
145     // limit_.
146     Status s = input_stream_->SkipNBytes(bytes_to_skip - (limit_ - pos_));
147     pos_ = 0;
148     limit_ = 0;
149     if (errors::IsOutOfRange(s)) {
150       file_status_ = s;
151     }
152     return s;
153   }
154   return OkStatus();
155 }
156 
Tell() const157 int64_t BufferedInputStream::Tell() const {
158   return input_stream_->Tell() - (limit_ - pos_);
159 }
160 
Seek(int64_t position)161 Status BufferedInputStream::Seek(int64_t position) {
162   if (position < 0) {
163     return errors::InvalidArgument("Seeking to a negative position: ",
164                                    position);
165   }
166 
167   // Position of the buffer's lower limit within file.
168   const int64_t buf_lower_limit = input_stream_->Tell() - limit_;
169   if (position < buf_lower_limit) {
170     // Seek before buffer, reset input stream and skip 'position' bytes.
171     TF_RETURN_IF_ERROR(Reset());
172     return SkipNBytes(position);
173   }
174 
175   if (position < Tell()) {
176     // Seek within buffer before 'pos_'
177     pos_ -= Tell() - position;
178     return OkStatus();
179   }
180 
181   // Seek after 'pos_'
182   return SkipNBytes(position - Tell());
183 }
184 
185 template <typename T>
ReadAll(T * result)186 Status BufferedInputStream::ReadAll(T* result) {
187   result->clear();
188   Status status;
189   while (status.ok()) {
190     status = FillBuffer();
191     if (limit_ == 0) {
192       break;
193     }
194     result->append(buf_);
195     pos_ = limit_;
196   }
197 
198   if (errors::IsOutOfRange(status)) {
199     file_status_ = status;
200     return OkStatus();
201   }
202   return status;
203 }
204 
205 template Status BufferedInputStream::ReadAll<std::string>(std::string* result);
206 template Status BufferedInputStream::ReadAll<tstring>(tstring* result);
207 
Reset()208 Status BufferedInputStream::Reset() {
209   TF_RETURN_IF_ERROR(input_stream_->Reset());
210   pos_ = 0;
211   limit_ = 0;
212   file_status_ = OkStatus();
213   return OkStatus();
214 }
215 
ReadLine(std::string * result)216 Status BufferedInputStream::ReadLine(std::string* result) {
217   return ReadLineHelper(result, false);
218 }
219 
ReadLine(tstring * result)220 Status BufferedInputStream::ReadLine(tstring* result) {
221   return ReadLineHelper(result, false);
222 }
223 
ReadLineAsString()224 std::string BufferedInputStream::ReadLineAsString() {
225   std::string result;
226   ReadLineHelper(&result, true).IgnoreError();
227   return result;
228 }
229 
SkipLine()230 Status BufferedInputStream::SkipLine() {
231   Status s;
232   bool skipped = false;
233   while (true) {
234     if (pos_ == limit_) {
235       // Get more data into buffer
236       s = FillBuffer();
237       if (limit_ == 0) {
238         break;
239       }
240     }
241     char c = buf_[pos_++];
242     skipped = true;
243     if (c == '\n') {
244       return OkStatus();
245     }
246   }
247   if (errors::IsOutOfRange(s) && skipped) {
248     return OkStatus();
249   }
250   return s;
251 }
252 
253 }  // namespace io
254 }  // namespace tensorflow
255