1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/saved_model/fingerprinting.h"
17
18 #include <algorithm>
19 #include <string>
20 #include <vector>
21
22 #include "absl/container/btree_map.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_split.h"
25 #include "absl/strings/strip.h"
26 #include "tensorflow/cc/saved_model/constants.h"
27 #include "tensorflow/core/framework/attr_value.pb.h"
28 #include "tensorflow/core/framework/function.pb.h"
29 #include "tensorflow/core/framework/op_def.pb.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/framework/versions.pb.h"
32 #include "tensorflow/core/grappler/op_types.h"
33 #include "tensorflow/core/lib/strings/numbers.h"
34 #include "tensorflow/core/lib/strings/proto_serialization.h"
35 #include "tensorflow/core/platform/env.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/fingerprint.h"
38 #include "tensorflow/core/platform/path.h"
39 #include "tensorflow/core/platform/statusor.h"
40 #include "tensorflow/core/protobuf/fingerprint.pb.h"
41 #include "tensorflow/core/protobuf/meta_graph.pb.h"
42 #include "tensorflow/core/protobuf/saved_model.pb.h"
43 #include "tensorflow/core/protobuf/saved_object_graph.pb.h"
44 #include "tensorflow/core/util/tensor_bundle/naming.h"
45
46 namespace tensorflow::fingerprinting {
47
48 // Version of the code that produced the fingerprint.
49 const int kFingerprintProducer = 0;
50 namespace {
51
52 // Returns the suffix UID of `function_name`.
GetSuffixUID(absl::string_view function_name)53 StatusOr<int> GetSuffixUID(absl::string_view function_name) {
54 std::vector<std::string> v = absl::StrSplit(function_name, '_');
55 int uid;
56 if (!strings::safe_strto32(v.back(), &uid)) {
57 return errors::InvalidArgument(absl::StrCat(
58 "Function name: `", function_name, "` does not end in an integer."));
59 }
60 return uid;
61 }
62
63 // This function mutates `graph_def`, changing the names and config_proto's
64 // of the Function nodes.
CanonicalizeNodes(GraphDef * graph_def)65 void CanonicalizeNodes(GraphDef* graph_def) {
66 for (NodeDef& node : *graph_def->mutable_node()) {
67 // Check if this is a function call.
68 if (grappler::IsPartitionedCall(node) ||
69 grappler::IsStatefulPartitionedCall(node)) {
70 // Regularize "f" attribute, the function name for PartitionedCall and
71 // and StatefulPartitionedCall ops, by stripping the suffix UID if it
72 // has one.
73 std::string function_name = node.attr().find("f")->second.func().name();
74 StatusOr<int> uid = GetSuffixUID(function_name);
75 if (uid.ok()) {
76 node.mutable_attr()->find("f")->second.mutable_func()->set_name(
77 std::string(
78 absl::StripSuffix(function_name, std::to_string(*uid))));
79 }
80 // Erase the "config_proto" attribute which contains device-specific
81 // information.
82 node.mutable_attr()->find("config_proto")->second.mutable_s()->erase();
83 }
84 // Erase the value of string constants, which can vary based on platform.
85 if (grappler::IsConstant(node)) {
86 if (node.attr().at("dtype").type() == DT_STRING) {
87 node.mutable_attr()->find("value")->second.clear_value();
88 }
89 }
90 }
91 }
92
93 // Returns the hash of the checkpoint .index file, 0 if there is none.
HashCheckpointIndexFile(absl::string_view model_dir)94 uint64 HashCheckpointIndexFile(absl::string_view model_dir) {
95 std::string meta_filename = MetaFilename(io::JoinPath(
96 model_dir, kSavedModelVariablesDirectory, kSavedModelVariablesFilename));
97 std::string data;
98 Status read_status = ReadFileToString(Env::Default(), meta_filename, &data);
99 if (read_status.ok()) {
100 return tensorflow::Fingerprint64(data);
101 } else {
102 LOG(WARNING) << read_status.error_message();
103 return 0;
104 }
105 }
106
107 } // namespace
108
ComputeHash(const GraphDef & graph_def)109 uint64 ComputeHash(const GraphDef& graph_def) {
110 std::string graph_def_string;
111 SerializeToStringDeterministic(graph_def, &graph_def_string);
112 return tensorflow::Fingerprint64(graph_def_string);
113 }
114
CreateFingerprintDef(const MetaGraphDef & metagraph,absl::string_view export_dir)115 FingerprintDef CreateFingerprintDef(const MetaGraphDef& metagraph,
116 absl::string_view export_dir) {
117 // Create a copy of `metagraph` which will be used and mutated for fingerprint
118 // computation.
119 MetaGraphDef metagraph_copy = metagraph;
120 FingerprintDef fingerprint_def;
121 // Set fingerprint field #1.
122 fingerprint_def.set_graph_def_checksum(
123 ComputeHash(metagraph_copy.graph_def()));
124 // Set fingerprint field #2.
125 CanonicalizeGraphDef(*metagraph_copy.mutable_graph_def());
126 fingerprint_def.set_graph_def_program_hash(
127 ComputeHash(metagraph_copy.graph_def()));
128 // Set fingerprint field #3.
129 fingerprint_def.set_signature_def_hash(
130 RegularizeAndHashSignatureDefs(metagraph_copy.signature_def()));
131 // Set fingerprint field #4.
132 StatusOr<uint64> object_graph_hash =
133 RegularizeAndHashSavedObjectGraph(metagraph_copy.object_graph_def());
134 fingerprint_def.set_saved_object_graph_hash(
135 RegularizeAndHashSavedObjectGraph(metagraph_copy.object_graph_def()));
136 // Set fingerprint field #5.
137 fingerprint_def.set_checkpoint_hash(HashCheckpointIndexFile(export_dir));
138 // Set version of the fingerprint.
139 VersionDef* version = fingerprint_def.mutable_version();
140 version->set_producer(kFingerprintProducer);
141
142 return fingerprint_def;
143 }
144
145 // The GraphDef contains two main sections: a list of nodes and the
146 // FunctionDefLibrary. Canonicalization treats these two sections separately.
CanonicalizeGraphDef(GraphDef & graph_def)147 void CanonicalizeGraphDef(GraphDef& graph_def) {
148 CanonicalizeNodes(&graph_def);
149 // TODO(b/240173815): Complete canonicalization of the FunctionDefLibrary.
150 // For now, we just clear the FunctionDefLibrary.
151 graph_def.mutable_library()->Clear();
152 graph_def.mutable_versions()->Clear();
153 }
154
RegularizeAndHashSignatureDefs(const google::protobuf::Map<std::string,SignatureDef> & signature_def_map)155 uint64 RegularizeAndHashSignatureDefs(
156 const google::protobuf::Map<std::string, SignatureDef>& signature_def_map) {
157 // Sort `signature_def_map`, which is an unordered map from string keys to
158 // SignatureDefs.
159 absl::btree_map<std::string, SignatureDef> sorted_signature_defs;
160 sorted_signature_defs.insert(signature_def_map.begin(),
161 signature_def_map.end());
162 uint64 result_hash = 0;
163 for (const auto& item : sorted_signature_defs) {
164 std::string signature_def_string;
165 SerializeToStringDeterministic(item.second, &signature_def_string);
166 result_hash = FingerprintCat64(
167 result_hash, tensorflow::Fingerprint64(signature_def_string));
168 }
169 return result_hash;
170 }
171
172 // The SavedObjectGraph contains two parts: the list of nodes and the map of
173 // concrete functions. Regularization treats these two parts separately.
RegularizeAndHashSavedObjectGraph(const SavedObjectGraph & object_graph_def)174 uint64 RegularizeAndHashSavedObjectGraph(
175 const SavedObjectGraph& object_graph_def) {
176 // Sort `concrete_functions`, which is an unordered map from function names to
177 // SavedConcreteFunction, using the suffix UID of the function name. Assumes
178 // that the trackable children are listed in a deterministic order during
179 // serialization.
180 absl::btree_map<int, std::string> uid_to_function_names;
181 for (const auto& [name, concrete_function] :
182 object_graph_def.concrete_functions()) {
183 StatusOr<int> uid = GetSuffixUID(name);
184 // All valid function names should end in an UID.
185 if (uid.ok()) {
186 uid_to_function_names.insert({*uid, name});
187 } else {
188 LOG(ERROR) << uid.status().error_message();
189 }
190 }
191 uint64 result_hash = 0;
192 for (const auto& [uid, function_name] : uid_to_function_names) {
193 // Hash the function name (with the UID stripped).
194 result_hash = FingerprintCat64(result_hash,
195 tensorflow::Fingerprint64(absl::StripSuffix(
196 function_name, std::to_string(uid))));
197 // Hash the serialized concrete function.
198 std::string concrete_function_string;
199 SerializeToStringDeterministic(
200 object_graph_def.concrete_functions().at(function_name),
201 &concrete_function_string);
202 result_hash = FingerprintCat64(
203 result_hash, tensorflow::Fingerprint64(concrete_function_string));
204 }
205 // TODO(b/241294832): Complete canonicalization of `object_graph_def.nodes`.
206 return result_hash;
207 }
208 } // namespace tensorflow::fingerprinting
209