xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/reader_base.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/reader_base.h"
17 
18 #include "tensorflow/core/framework/reader_base.pb.h"
19 #include "tensorflow/core/framework/types.h"
20 #include "tensorflow/core/lib/core/coding.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/core/notification.h"
23 #include "tensorflow/core/lib/core/stringpiece.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26 
27 namespace tensorflow {
28 
29 // ReaderBase ------------------------------------------------------
30 
ReaderBase(const string & name)31 ReaderBase::ReaderBase(const string& name) : name_(name) {}
32 
NumRecordsProduced()33 int64_t ReaderBase::NumRecordsProduced() {
34   mutex_lock lock(mu_);
35   return num_records_produced_;
36 }
37 
NumWorkUnitsCompleted()38 int64_t ReaderBase::NumWorkUnitsCompleted() {
39   mutex_lock lock(mu_);
40   return work_finished_;
41 }
42 
Reset()43 Status ReaderBase::Reset() {
44   mutex_lock lock(mu_);
45   return ResetLocked();
46 }
47 
ResetLocked()48 Status ReaderBase::ResetLocked() {
49   work_started_ = 0;
50   work_finished_ = 0;
51   num_records_produced_ = 0;
52   work_.clear();
53   return OkStatus();
54 }
55 
SerializeState(tstring * state)56 Status ReaderBase::SerializeState(tstring* state) {
57   mutex_lock lock(mu_);
58   return SerializeStateLocked(state);
59 }
60 
SerializeStateLocked(tstring * state)61 Status ReaderBase::SerializeStateLocked(tstring* state) {
62   return errors::Unimplemented("Reader SerializeState");
63 }
64 
RestoreState(const tstring & state)65 Status ReaderBase::RestoreState(const tstring& state) {
66   mutex_lock lock(mu_);
67   Status status = RestoreStateLocked(state);
68   if (!status.ok()) {
69     ResetLocked().IgnoreError();
70   }
71   return status;
72 }
73 
RestoreStateLocked(const tstring & state)74 Status ReaderBase::RestoreStateLocked(const tstring& state) {
75   return errors::Unimplemented("Reader RestoreState");
76 }
77 
ReadUpTo(const int64_t num_records,QueueInterface * queue,std::vector<tstring> * keys,std::vector<tstring> * values,OpKernelContext * context)78 int64_t ReaderBase::ReadUpTo(const int64_t num_records, QueueInterface* queue,
79                              std::vector<tstring>* keys,
80                              std::vector<tstring>* values,
81                              OpKernelContext* context) {
82   mutex_lock lock(mu_);
83   int64_t records_produced_this_call = 0;
84   while (true) {
85     // Records produced by this iteration of the ReadUpToLocked call.
86     int64_t num_records_produced = 0;
87     int64_t remaining = num_records - records_produced_this_call;
88     if (remaining == 0) {
89       return records_produced_this_call;
90     }
91     if (!work_in_progress()) {
92       work_ = GetNextWorkLocked(queue, context);
93       if (!context->status().ok()) {
94         return records_produced_this_call;
95       }
96       Status status = OnWorkStartedLocked();
97       if (status.ok()) {
98         work_started_++;
99       } else {
100         context->SetStatus(status);
101         return records_produced_this_call;
102       }
103     }
104     bool at_end = false;
105 
106     Status status =
107         ReadUpToLocked(remaining, keys, values, &num_records_produced, &at_end);
108     // This call so far.
109     records_produced_this_call += num_records_produced;
110 
111     // In total, over the lifetime of the ReaderBase.
112     num_records_produced_ += num_records_produced;
113 
114     if (!at_end && status.ok() && num_records_produced == 0) {
115       status = errors::Internal(
116           "ReadManyLocked() for ", name(),
117           " must set *at_end=true, *num_produced > 0 or return an error.");
118       context->SetStatus(status);
119       return records_produced_this_call;
120     }
121     if (status.ok() && at_end) {
122       status = OnWorkFinishedLocked();
123       work_finished_ = work_started_;
124       if (records_produced_this_call > 0) {
125         return records_produced_this_call;
126       }
127     }
128     if (!status.ok()) {
129       context->SetStatus(status);
130       return records_produced_this_call;
131     }
132   }
133 }
134 
135 // Default implementation just reads one record at a time.
ReadUpToLocked(int64_t num_records,std::vector<tstring> * keys,std::vector<tstring> * values,int64_t * num_read,bool * at_end)136 Status ReaderBase::ReadUpToLocked(int64_t num_records,
137                                   std::vector<tstring>* keys,
138                                   std::vector<tstring>* values,
139                                   int64_t* num_read, bool* at_end) {
140   bool produced = false;
141   tstring key;
142   tstring value;
143   Status status = ReadLocked(&key, &value, &produced, at_end);
144   if (produced) {
145     keys->push_back(std::move(key));
146     values->push_back(std::move(value));
147     *num_read = 1;
148   } else {
149     *num_read = 0;
150   }
151   return status;
152 }
153 
Read(QueueInterface * queue,tstring * key,tstring * value,OpKernelContext * context)154 void ReaderBase::Read(QueueInterface* queue, tstring* key, tstring* value,
155                       OpKernelContext* context) {
156   mutex_lock lock(mu_);
157   while (true) {
158     if (!work_in_progress()) {
159       work_ = GetNextWorkLocked(queue, context);
160       if (!context->status().ok()) {
161         return;
162       }
163       Status status = OnWorkStartedLocked();
164       if (status.ok()) {
165         work_started_++;
166       } else {
167         context->SetStatus(status);
168         return;
169       }
170     }
171 
172     bool produced = false;
173     bool at_end = false;
174     Status status = ReadLocked(key, value, &produced, &at_end);
175 
176     if (!at_end && status.ok() && !produced) {
177       status = errors::Internal(
178           "ReadLocked() for ", name(),
179           " must set *at_end=true, *produced=true, or return an error.");
180     }
181     if (!status.ok() && produced) {
182       status = errors::Internal("ReadLocked() for ", name(),
183                                 " set *produced=true *and* returned an error: ",
184                                 status.error_message());
185     }
186     if (status.ok() && at_end) {
187       status = OnWorkFinishedLocked();
188       work_finished_ = work_started_;
189     }
190     if (!status.ok()) {
191       context->SetStatus(status);
192       return;
193     }
194     if (produced) {
195       ++num_records_produced_;
196       return;
197     }
198   }
199 }
200 
GetNextWorkLocked(QueueInterface * queue,OpKernelContext * context) const201 string ReaderBase::GetNextWorkLocked(QueueInterface* queue,
202                                      OpKernelContext* context) const {
203   string work;
204   Notification n;
205   queue->TryDequeue(
206       context, [context, &n, &work](const QueueInterface::Tuple& tuple) {
207         if (context->status().ok()) {
208           if (tuple.size() != 1) {
209             context->SetStatus(
210                 errors::InvalidArgument("Expected single component queue"));
211           } else if (tuple[0].dtype() != DT_STRING) {
212             context->SetStatus(errors::InvalidArgument(
213                 "Expected queue with single string component"));
214           } else if (tuple[0].NumElements() != 1) {
215             context->SetStatus(errors::InvalidArgument(
216                 "Expected to dequeue a one-element string tensor"));
217           } else {
218             work = tuple[0].flat<tstring>()(0);
219           }
220         }
221         n.Notify();
222       });
223   n.WaitForNotification();
224   return work;
225 }
226 
SaveBaseState(ReaderBaseState * state) const227 void ReaderBase::SaveBaseState(ReaderBaseState* state) const {
228   state->Clear();
229   state->set_work_started(work_started_);
230   state->set_work_finished(work_finished_);
231   state->set_num_records_produced(num_records_produced_);
232   state->set_current_work(work_.data(), work_.size());
233 }
234 
KeyName(const tstring & key) const235 tstring ReaderBase::KeyName(const tstring& key) const {
236   return strings::StrCat(current_work(), ":", key);
237 }
238 
RestoreBaseState(const ReaderBaseState & state)239 Status ReaderBase::RestoreBaseState(const ReaderBaseState& state) {
240   work_started_ = state.work_started();
241   work_finished_ = state.work_finished();
242   num_records_produced_ = state.num_records_produced();
243   work_ = state.current_work();
244   if (work_started_ < 0 || work_finished_ < 0 || num_records_produced_ < 0) {
245 #if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
246     const string debug_string = "<debug state not available>";
247 #else
248     const string debug_string = state.DebugString();
249 #endif
250     return errors::InvalidArgument(
251         "Unexpected negative value when restoring in ", name(), ": ",
252         debug_string);
253   }
254   if (work_started_ > work_finished_) {
255 #if defined(__ANDROID__) || (__EMSCRIPTEN__)
256     const string debug_string = "<debug state not available>";
257 #else
258     const string debug_string = state.DebugString();
259 #endif
260     return errors::InvalidArgument(
261         "Inconsistent work started vs. finished when restoring in ", name(),
262         ": ", debug_string);
263   }
264   return OkStatus();
265 }
266 
267 }  // namespace tensorflow
268