xref: /aosp_15_r20/external/tensorflow/tensorflow/cc/saved_model/reader.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/saved_model/reader.h"
17 
18 #include <unordered_set>
19 
20 #include "absl/memory/memory.h"
21 #include "tensorflow/cc/saved_model/constants.h"
22 #include "tensorflow/cc/saved_model/metrics.h"
23 #include "tensorflow/cc/saved_model/util.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/function.pb.h"
26 #include "tensorflow/core/framework/graph.pb.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/tensor.pb.h"
29 #include "tensorflow/core/lib/io/path.h"
30 #include "tensorflow/core/lib/strings/str_util.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 #include "tensorflow/core/platform/env.h"
33 #include "tensorflow/core/platform/file_system_helper.h"
34 #include "tensorflow/core/platform/statusor.h"
35 #include "tensorflow/core/protobuf/saved_model.pb.h"
36 #include "tensorflow/core/util/tensor_bundle/byte_swap.h"
37 
38 namespace tensorflow {
39 namespace {
40 
41 // Reads the SavedModel proto from saved_model.pb in `export_dir`.
42 // Returns a failure status when the SavedModel file does not exist.
ReadSavedModel(absl::string_view export_dir,SavedModel * saved_model_proto)43 Status ReadSavedModel(absl::string_view export_dir,
44                       SavedModel* saved_model_proto) {
45   LOG(INFO) << "Reading SavedModel from: " << export_dir;
46 
47   const std::string saved_model_pb_path =
48       io::JoinPath(export_dir, kSavedModelFilenamePb);
49 
50   TF_ASSIGN_OR_RETURN(
51       bool saved_model_pb_exists,
52       internal::FileExists(Env::Default(), saved_model_pb_path));
53   if (saved_model_pb_exists) {
54     Status result =
55         ReadBinaryProto(Env::Default(), saved_model_pb_path, saved_model_proto);
56     if (result.ok()) {
57       metrics::SavedModelRead(saved_model::GetWriteVersion(*saved_model_proto))
58           .IncrementBy(1);
59     }
60     return result;
61   }
62   const std::string saved_model_pbtxt_path =
63       io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
64   TF_ASSIGN_OR_RETURN(
65       bool saved_model_pbtxt_exists,
66       internal::FileExists(Env::Default(), saved_model_pbtxt_path));
67   if (saved_model_pbtxt_exists) {
68     Status result = ReadTextProto(Env::Default(), saved_model_pbtxt_path,
69                                   saved_model_proto);
70     if (result.ok()) {
71       metrics::SavedModelRead(saved_model::GetWriteVersion(*saved_model_proto))
72           .IncrementBy(1);
73     }
74     return result;
75   }
76   return Status(
77       error::Code::NOT_FOUND,
78       strings::StrCat("Could not find SavedModel .pb or .pbtxt at supplied "
79                       "export directory path: ",
80                       export_dir,
81                       ". Check that "
82                       "the directory exists and that you have the right "
83                       "permissions for accessing it."));
84 }
85 
FindMetaGraphDef(const std::unordered_set<string> & tags,SavedModel * saved_model_proto,MetaGraphDef * meta_graph_def)86 Status FindMetaGraphDef(const std::unordered_set<string>& tags,
87                         SavedModel* saved_model_proto,
88                         MetaGraphDef* meta_graph_def) {
89   LOG(INFO) << "Reading meta graph with tags { " << absl::StrJoin(tags, " ")
90             << " }";
91   for (MetaGraphDef& graph_def : *saved_model_proto->mutable_meta_graphs()) {
92     // Get tags from the graph_def.
93     std::unordered_set<string> graph_tags;
94     for (const string& tag : graph_def.meta_info_def().tags()) {
95       graph_tags.insert(tag);
96     }
97     // Match with the set of tags provided.
98     if (graph_tags == tags) {
99       *meta_graph_def = std::move(graph_def);
100       // Correct the endiness of Tensor content on big-endian system
101       if (!port::kLittleEndian) {
102         TF_RETURN_IF_ERROR(ByteSwapTensorContent(meta_graph_def));
103       }
104       return OkStatus();
105     }
106   }
107   return Status(
108       error::Code::NOT_FOUND,
109       strings::StrCat(
110           "Could not find meta graph def matching supplied tags: { ",
111           absl::StrJoin(tags, " "),
112           " }. To inspect available tag-sets in the SavedModel, please "
113           "use the SavedModel CLI: `saved_model_cli`"));
114 }
115 }  // namespace
116 
ReadMetaGraphDefFromSavedModel(const string & export_dir,const std::unordered_set<string> & tags,MetaGraphDef * const meta_graph_def)117 Status ReadMetaGraphDefFromSavedModel(const string& export_dir,
118                                       const std::unordered_set<string>& tags,
119                                       MetaGraphDef* const meta_graph_def) {
120   SavedModel saved_model_proto;
121   TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto));
122   TF_RETURN_IF_ERROR(
123       FindMetaGraphDef(tags, &saved_model_proto, meta_graph_def));
124   return OkStatus();
125 }
126 
ReadSavedModelDebugInfoIfPresent(const string & export_dir,std::unique_ptr<GraphDebugInfo> * debug_info_proto)127 Status ReadSavedModelDebugInfoIfPresent(
128     const string& export_dir,
129     std::unique_ptr<GraphDebugInfo>* debug_info_proto) {
130   LOG(INFO) << "Reading SavedModel debug info (if present) from: "
131             << export_dir;
132 
133   const string debug_info_pb_path =
134       io::JoinPath(export_dir, "debug", "saved_model_debug_info.pb");
135   TF_ASSIGN_OR_RETURN(bool debug_info_pb_exists,
136                       internal::FileExists(Env::Default(), debug_info_pb_path));
137   if (debug_info_pb_exists) {
138     GraphDebugInfo debug_info;
139     TF_RETURN_IF_ERROR(
140         ReadBinaryProto(Env::Default(), debug_info_pb_path, &debug_info));
141     *debug_info_proto =
142         absl::make_unique<GraphDebugInfo>(std::move(debug_info));
143   }
144   return OkStatus();
145 }
146 
147 }  // namespace tensorflow
148