xref: /aosp_15_r20/external/tensorflow/tensorflow/core/util/tensor_slice_reader.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/tensor_slice_reader.h"
17 
18 #include <utility>
19 #include <vector>
20 
21 #include "tensorflow/core/framework/types.pb.h"
22 #include "tensorflow/core/framework/versions.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/io/iterator.h"
25 #include "tensorflow/core/lib/io/table.h"
26 #include "tensorflow/core/lib/io/table_options.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/errors.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/protobuf.h"
31 #include "tensorflow/core/public/version.h"
32 #include "tensorflow/core/util/saved_tensor_slice_util.h"
33 #include "tensorflow/core/util/tensor_slice_util.h"
34 
35 namespace tensorflow {
36 
37 namespace checkpoint {
38 
~Table()39 TensorSliceReader::Table::~Table() {}
40 
41 namespace {
42 class TensorSliceReaderTable : public TensorSliceReader::Table {
43  public:
44   // Takes ownership of 'f'.
TensorSliceReaderTable(RandomAccessFile * f,table::Table * t)45   explicit TensorSliceReaderTable(RandomAccessFile* f, table::Table* t)
46       : file_(f), table_(t) {}
47 
~TensorSliceReaderTable()48   ~TensorSliceReaderTable() override {
49     delete table_;
50     delete file_;
51   }
52 
Get(const string & key,string * value)53   bool Get(const string& key, string* value) override {
54     std::unique_ptr<table::Iterator> iter(table_->NewIterator());
55     iter->Seek(key);
56     if (iter->Valid() && iter->key() == key) {
57       StringPiece v = iter->value();
58       value->assign(v.data(), v.size());
59       return true;
60     } else {
61       return false;
62     }
63   }
64 
65  private:
66   RandomAccessFile* file_;  // Owns.
67   table::Table* table_;
68 };
69 }  // namespace
70 
OpenTableTensorSliceReader(const string & fname,TensorSliceReader::Table ** result)71 Status OpenTableTensorSliceReader(const string& fname,
72                                   TensorSliceReader::Table** result) {
73   *result = nullptr;
74   Env* env = Env::Default();
75   std::unique_ptr<RandomAccessFile> f;
76   Status s = env->NewRandomAccessFile(fname, &f);
77   if (s.ok()) {
78     uint64 file_size;
79     s = env->GetFileSize(fname, &file_size);
80     if (s.ok()) {
81       table::Options options;
82       table::Table* table;
83       s = table::Table::Open(options, f.get(), file_size, &table);
84       if (s.ok()) {
85         *result = new TensorSliceReaderTable(f.release(), table);
86         return OkStatus();
87       } else {
88         s = errors::CreateWithUpdatedMessage(
89             s, strings::StrCat(s.error_message(),
90                                ": perhaps your file is in a different "
91                                "file format and you need to use a "
92                                "different restore operator?"));
93       }
94     }
95   }
96   LOG(WARNING) << "Could not open " << fname << ": " << s;
97   return s;
98 }
99 
TensorSliceReader(const string & filepattern)100 TensorSliceReader::TensorSliceReader(const string& filepattern)
101     : TensorSliceReader(filepattern, OpenTableTensorSliceReader,
102                         kLoadAllShards) {}
103 
TensorSliceReader(const string & filepattern,OpenTableFunction open_function)104 TensorSliceReader::TensorSliceReader(const string& filepattern,
105                                      OpenTableFunction open_function)
106     : TensorSliceReader(filepattern, std::move(open_function), kLoadAllShards) {
107 }
108 
TensorSliceReader(const string & filepattern,OpenTableFunction open_function,int preferred_shard)109 TensorSliceReader::TensorSliceReader(const string& filepattern,
110                                      OpenTableFunction open_function,
111                                      int preferred_shard)
112     : filepattern_(filepattern), open_function_(std::move(open_function)) {
113   VLOG(1) << "TensorSliceReader for " << filepattern;
114   Status s = Env::Default()->GetMatchingPaths(filepattern, &fnames_);
115   if (!s.ok()) {
116     status_ = errors::InvalidArgument(
117         "Unsuccessful TensorSliceReader constructor: "
118         "Failed to get matching files on ",
119         filepattern, ": ", s.ToString());
120     return;
121   }
122   if (fnames_.empty()) {
123     status_ = errors::NotFound(
124         "Unsuccessful TensorSliceReader constructor: "
125         "Failed to find any matching files for ",
126         filepattern);
127     return;
128   }
129   sss_.resize(fnames_.size());
130   for (size_t shard = 0; shard < fnames_.size(); ++shard) {
131     fname_to_index_.insert(std::make_pair(fnames_[shard], shard));
132   }
133   if (preferred_shard == kLoadAllShards || fnames_.size() == 1 ||
134       static_cast<size_t>(preferred_shard) >= fnames_.size()) {
135     LoadAllShards();
136   } else {
137     VLOG(1) << "Loading shard " << preferred_shard << " for " << filepattern_;
138     LoadShard(preferred_shard);
139   }
140 }
141 
LoadShard(int shard) const142 void TensorSliceReader::LoadShard(int shard) const {
143   CHECK_LT(shard, sss_.size());
144   if (sss_[shard] || !status_.ok()) {
145     return;  // Already loaded, or invalid.
146   }
147   string value;
148   SavedTensorSlices sts;
149   const string fname = fnames_[shard];
150   VLOG(1) << "Reading meta data from file " << fname << "...";
151   Table* table;
152   Status s = open_function_(fname, &table);
153   if (!s.ok()) {
154     status_ = errors::DataLoss("Unable to open table file ", fname, ": ",
155                                s.ToString());
156     return;
157   }
158   sss_[shard].reset(table);
159   if (!(table->Get(kSavedTensorSlicesKey, &value) &&
160         ParseProtoUnlimited(&sts, value))) {
161     status_ = errors::Internal(
162         "Failed to find the saved tensor slices at the beginning of the "
163         "checkpoint file: ",
164         fname);
165     return;
166   }
167   status_ = CheckVersions(sts.meta().versions(), TF_CHECKPOINT_VERSION,
168                           TF_CHECKPOINT_VERSION_MIN_PRODUCER, "Checkpoint",
169                           "checkpoint");
170   if (!status_.ok()) return;
171   for (const SavedSliceMeta& ssm : sts.meta().tensor()) {
172     TensorShape ssm_shape;
173     status_ = TensorShape::BuildTensorShapeBase(ssm.shape(), &ssm_shape);
174     if (!status_.ok()) return;
175     for (const TensorSliceProto& tsp : ssm.slice()) {
176       TensorSlice ss_slice;
177       status_ = TensorSlice::BuildTensorSlice(tsp, &ss_slice);
178       if (!status_.ok()) return;
179       status_ = RegisterTensorSlice(ssm.name(), ssm_shape, ssm.type(), fname,
180                                     ss_slice, &tensors_);
181       if (!status_.ok()) return;
182     }
183   }
184 }
185 
LoadAllShards() const186 void TensorSliceReader::LoadAllShards() const {
187   VLOG(1) << "Loading all shards for " << filepattern_;
188   for (size_t i = 0; i < fnames_.size() && status_.ok(); ++i) {
189     LoadShard(i);
190   }
191   all_shards_loaded_ = true;
192 }
193 
FindTensorSlice(const string & name,const TensorSlice & slice,std::vector<std::pair<TensorSlice,string>> * details) const194 const TensorSliceSet* TensorSliceReader::FindTensorSlice(
195     const string& name, const TensorSlice& slice,
196     std::vector<std::pair<TensorSlice, string>>* details) const {
197   const TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name);
198   if (tss && !tss->QueryMeta(slice, details)) {
199     return nullptr;
200   }
201   return tss;
202 }
203 
~TensorSliceReader()204 TensorSliceReader::~TensorSliceReader() {
205   for (auto& temp : tensors_) {
206     delete temp.second;
207   }
208   tensors_.clear();
209 }
210 
HasTensor(const string & name,TensorShape * shape,DataType * type) const211 bool TensorSliceReader::HasTensor(const string& name, TensorShape* shape,
212                                   DataType* type) const {
213   mutex_lock l(mu_);
214   const TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name);
215   if (!tss && !all_shards_loaded_) {
216     VLOG(1) << "Did not find tensor in preferred shard, loading all shards: "
217             << name;
218     LoadAllShards();
219     tss = gtl::FindPtrOrNull(tensors_, name);
220   }
221   if (tss) {
222     if (shape) {
223       *shape = tss->shape();
224     }
225     if (type) {
226       *type = tss->type();
227     }
228     return true;
229   } else {
230     return false;
231   }
232 }
233 
GetTensor(const string & name,std::unique_ptr<tensorflow::Tensor> * out_tensor) const234 Status TensorSliceReader::GetTensor(
235     const string& name, std::unique_ptr<tensorflow::Tensor>* out_tensor) const {
236   DataType type;
237   TensorShape shape;
238   TensorSlice slice;
239   {
240     mutex_lock l(mu_);
241     const TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name);
242     if (tss == nullptr) {
243       return errors::NotFound(name, " not found in checkpoint file");
244     }
245 
246     if (tss->Slices().size() > 1) {
247       // TODO(sherrym): Support multi-slice checkpoints.
248       return errors::Unimplemented("Sliced checkpoints are not supported");
249     }
250 
251     type = tss->type();
252     shape = tss->shape();
253     slice = tss->Slices().begin()->second.slice;
254   }
255 
256   std::unique_ptr<tensorflow::Tensor> t(new tensorflow::Tensor);
257   Status s = tensorflow::Tensor::BuildTensor(type, shape, t.get());
258   if (!s.ok()) return s;
259   bool success = false;
260 
261 #define READER_COPY(dt)                                                  \
262   case dt:                                                               \
263     success = CopySliceData(name, slice,                                 \
264                             t->flat<EnumToDataType<dt>::Type>().data()); \
265     break;
266 
267   switch (type) {
268     READER_COPY(DT_FLOAT);
269     READER_COPY(DT_DOUBLE);
270     READER_COPY(DT_INT32);
271     READER_COPY(DT_UINT8);
272     READER_COPY(DT_INT16);
273     READER_COPY(DT_INT8);
274     READER_COPY(DT_INT64);
275     READER_COPY(DT_STRING);
276     default:
277       return errors::Unimplemented("Data type not supported");
278   }
279 #undef READER_COPY
280 
281   if (!success) {
282     return errors::NotFound(name, " not found in checkpoint file");
283   }
284   std::swap(*out_tensor, t);
285 
286   return OkStatus();
287 }
288 
GetVariableToShapeMap() const289 TensorSliceReader::VarToShapeMap TensorSliceReader::GetVariableToShapeMap()
290     const {
291   VarToShapeMap name_to_shape;
292   if (status().ok()) {
293     for (auto& e : Tensors()) {
294       name_to_shape[e.first] = e.second->shape();
295     }
296   }
297   return name_to_shape;
298 }
299 
300 TensorSliceReader::VarToDataTypeMap
GetVariableToDataTypeMap() const301 TensorSliceReader::GetVariableToDataTypeMap() const {
302   VarToDataTypeMap name_to_dtype;
303   if (status().ok()) {
304     for (auto& e : Tensors()) {
305       name_to_dtype[e.first] = e.second->type();
306     }
307   }
308   return name_to_dtype;
309 }
310 
DebugString() const311 const string TensorSliceReader::DebugString() const {
312   string shape_str;
313   if (status().ok()) {
314     for (const auto& e : Tensors()) {
315       strings::StrAppend(&shape_str, e.first, " (",
316                          DataType_Name(e.second->type()), ") ",
317                          e.second->shape().DebugString());
318       // Indicates if a tensor has more than 1 slice (i.e., it's partitioned).
319       const int num_slices = e.second->Slices().size();
320       if (num_slices > 1) {
321         strings::StrAppend(&shape_str, ", ", num_slices, " slices");
322       }
323       strings::StrAppend(&shape_str, "\n");
324     }
325   }
326   return shape_str;
327 }
328 
329 }  // namespace checkpoint
330 
331 }  // namespace tensorflow
332