xref: /aosp_15_r20/external/federated-compute/fcp/aggregation/tensorflow/checkpoint_reader.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/aggregation/tensorflow/checkpoint_reader.h"
18 
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/strings/str_format.h"
25 #include "absl/strings/string_view.h"
26 #include "fcp/aggregation/core/datatype.h"
27 #include "fcp/aggregation/core/tensor.h"
28 #include "fcp/aggregation/tensorflow/converters.h"
29 #include "fcp/base/monitoring.h"
30 #include "tensorflow/c/checkpoint_reader.h"
31 #include "tensorflow/c/tf_status.h"
32 #include "tensorflow/c/tf_status_helper.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/tensor_shape.h"
35 #include "tensorflow/core/framework/types.pb.h"
36 
37 namespace fcp::aggregation::tensorflow {
38 
39 namespace tf = ::tensorflow;
40 
Create(const std::string & filename)41 absl::StatusOr<std::unique_ptr<CheckpointReader>> CheckpointReader::Create(
42     const std::string& filename) {
43   tf::TF_StatusPtr tf_status(TF_NewStatus());
44   auto tf_checkpoint_reader =
45       std::make_unique<tf::checkpoint::CheckpointReader>(filename,
46                                                          tf_status.get());
47   if (TF_GetCode(tf_status.get()) != TF_OK) {
48     return absl::InternalError(
49         absl::StrFormat("Couldn't read checkpoint: %s : %s", filename,
50                         TF_Message(tf_status.get())));
51   }
52 
53   // Populate the DataType map.
54   DataTypeMap data_type_map;
55   for (const auto& [name, tf_dtype] :
56        tf_checkpoint_reader->GetVariableToDataTypeMap()) {
57     FCP_ASSIGN_OR_RETURN(DataType dtype, ConvertDataType(tf_dtype));
58     data_type_map.emplace(name, dtype);
59   }
60 
61   // Populate the TensorShape map.
62   TensorShapeMap shape_map;
63   for (const auto& [name, tf_shape] :
64        tf_checkpoint_reader->GetVariableToShapeMap()) {
65     shape_map.emplace(name, ConvertShape(tf_shape));
66   }
67 
68   return std::unique_ptr<CheckpointReader>(
69       new CheckpointReader(std::move(tf_checkpoint_reader),
70                            std::move(data_type_map), std::move(shape_map)));
71 }
72 
CheckpointReader(std::unique_ptr<tf::checkpoint::CheckpointReader> tensorflow_checkpoint_reader,DataTypeMap data_type_map,TensorShapeMap shape_map)73 CheckpointReader::CheckpointReader(
74     std::unique_ptr<tf::checkpoint::CheckpointReader>
75         tensorflow_checkpoint_reader,
76     DataTypeMap data_type_map, TensorShapeMap shape_map)
77     : tf_checkpoint_reader_(std::move(tensorflow_checkpoint_reader)),
78       data_type_map_(std::move(data_type_map)),
79       shape_map_(std::move(shape_map)) {}
80 
GetTensor(const std::string & name) const81 StatusOr<Tensor> CheckpointReader::GetTensor(const std::string& name) const {
82   std::unique_ptr<tf::Tensor> tensor;
83   const tf::TF_StatusPtr read_status(TF_NewStatus());
84   tf_checkpoint_reader_->GetTensor(name, &tensor, read_status.get());
85   if (TF_GetCode(read_status.get()) != TF_OK) {
86     return absl::NotFoundError(
87         absl::StrFormat("Checkpoint doesn't have tensor %s", name));
88   }
89   return ConvertTensor(std::move(tensor));
90 }
91 
92 }  // namespace fcp::aggregation::tensorflow
93