1 /*
2 * Copyright 2022 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/aggregation/tensorflow/checkpoint_reader.h"
18
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/strings/str_format.h"
25 #include "absl/strings/string_view.h"
26 #include "fcp/aggregation/core/datatype.h"
27 #include "fcp/aggregation/core/tensor.h"
28 #include "fcp/aggregation/tensorflow/converters.h"
29 #include "fcp/base/monitoring.h"
30 #include "tensorflow/c/checkpoint_reader.h"
31 #include "tensorflow/c/tf_status.h"
32 #include "tensorflow/c/tf_status_helper.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/tensor_shape.h"
35 #include "tensorflow/core/framework/types.pb.h"
36
37 namespace fcp::aggregation::tensorflow {
38
39 namespace tf = ::tensorflow;
40
Create(const std::string & filename)41 absl::StatusOr<std::unique_ptr<CheckpointReader>> CheckpointReader::Create(
42 const std::string& filename) {
43 tf::TF_StatusPtr tf_status(TF_NewStatus());
44 auto tf_checkpoint_reader =
45 std::make_unique<tf::checkpoint::CheckpointReader>(filename,
46 tf_status.get());
47 if (TF_GetCode(tf_status.get()) != TF_OK) {
48 return absl::InternalError(
49 absl::StrFormat("Couldn't read checkpoint: %s : %s", filename,
50 TF_Message(tf_status.get())));
51 }
52
53 // Populate the DataType map.
54 DataTypeMap data_type_map;
55 for (const auto& [name, tf_dtype] :
56 tf_checkpoint_reader->GetVariableToDataTypeMap()) {
57 FCP_ASSIGN_OR_RETURN(DataType dtype, ConvertDataType(tf_dtype));
58 data_type_map.emplace(name, dtype);
59 }
60
61 // Populate the TensorShape map.
62 TensorShapeMap shape_map;
63 for (const auto& [name, tf_shape] :
64 tf_checkpoint_reader->GetVariableToShapeMap()) {
65 shape_map.emplace(name, ConvertShape(tf_shape));
66 }
67
68 return std::unique_ptr<CheckpointReader>(
69 new CheckpointReader(std::move(tf_checkpoint_reader),
70 std::move(data_type_map), std::move(shape_map)));
71 }
72
CheckpointReader(std::unique_ptr<tf::checkpoint::CheckpointReader> tensorflow_checkpoint_reader,DataTypeMap data_type_map,TensorShapeMap shape_map)73 CheckpointReader::CheckpointReader(
74 std::unique_ptr<tf::checkpoint::CheckpointReader>
75 tensorflow_checkpoint_reader,
76 DataTypeMap data_type_map, TensorShapeMap shape_map)
77 : tf_checkpoint_reader_(std::move(tensorflow_checkpoint_reader)),
78 data_type_map_(std::move(data_type_map)),
79 shape_map_(std::move(shape_map)) {}
80
GetTensor(const std::string & name) const81 StatusOr<Tensor> CheckpointReader::GetTensor(const std::string& name) const {
82 std::unique_ptr<tf::Tensor> tensor;
83 const tf::TF_StatusPtr read_status(TF_NewStatus());
84 tf_checkpoint_reader_->GetTensor(name, &tensor, read_status.get());
85 if (TF_GetCode(read_status.get()) != TF_OK) {
86 return absl::NotFoundError(
87 absl::StrFormat("Checkpoint doesn't have tensor %s", name));
88 }
89 return ConvertTensor(std::move(tensor));
90 }
91
92 } // namespace fcp::aggregation::tensorflow
93