xref: /aosp_15_r20/external/tensorflow/tensorflow/cc/saved_model/fingerprinting.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/saved_model/fingerprinting.h"
17 
18 #include <algorithm>
19 #include <string>
20 #include <vector>
21 
22 #include "absl/container/btree_map.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_split.h"
25 #include "absl/strings/strip.h"
26 #include "tensorflow/cc/saved_model/constants.h"
27 #include "tensorflow/core/framework/attr_value.pb.h"
28 #include "tensorflow/core/framework/function.pb.h"
29 #include "tensorflow/core/framework/op_def.pb.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/framework/versions.pb.h"
32 #include "tensorflow/core/grappler/op_types.h"
33 #include "tensorflow/core/lib/strings/numbers.h"
34 #include "tensorflow/core/lib/strings/proto_serialization.h"
35 #include "tensorflow/core/platform/env.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/fingerprint.h"
38 #include "tensorflow/core/platform/path.h"
39 #include "tensorflow/core/platform/statusor.h"
40 #include "tensorflow/core/protobuf/fingerprint.pb.h"
41 #include "tensorflow/core/protobuf/meta_graph.pb.h"
42 #include "tensorflow/core/protobuf/saved_model.pb.h"
43 #include "tensorflow/core/protobuf/saved_object_graph.pb.h"
44 #include "tensorflow/core/util/tensor_bundle/naming.h"
45 
46 namespace tensorflow::fingerprinting {
47 
48 // Version of the code that produced the fingerprint.
49 const int kFingerprintProducer = 0;
50 namespace {
51 
52 // Returns the suffix UID of `function_name`.
GetSuffixUID(absl::string_view function_name)53 StatusOr<int> GetSuffixUID(absl::string_view function_name) {
54   std::vector<std::string> v = absl::StrSplit(function_name, '_');
55   int uid;
56   if (!strings::safe_strto32(v.back(), &uid)) {
57     return errors::InvalidArgument(absl::StrCat(
58         "Function name: `", function_name, "` does not end in an integer."));
59   }
60   return uid;
61 }
62 
63 // This function mutates `graph_def`, changing the names and config_proto's
64 // of the Function nodes.
CanonicalizeNodes(GraphDef * graph_def)65 void CanonicalizeNodes(GraphDef* graph_def) {
66   for (NodeDef& node : *graph_def->mutable_node()) {
67     // Check if this is a function call.
68     if (grappler::IsPartitionedCall(node) ||
69         grappler::IsStatefulPartitionedCall(node)) {
70       // Regularize "f" attribute, the function name for PartitionedCall and
71       // and StatefulPartitionedCall ops, by stripping the suffix UID if it
72       // has one.
73       std::string function_name = node.attr().find("f")->second.func().name();
74       StatusOr<int> uid = GetSuffixUID(function_name);
75       if (uid.ok()) {
76         node.mutable_attr()->find("f")->second.mutable_func()->set_name(
77             std::string(
78                 absl::StripSuffix(function_name, std::to_string(*uid))));
79       }
80       // Erase the "config_proto" attribute which contains device-specific
81       // information.
82       node.mutable_attr()->find("config_proto")->second.mutable_s()->erase();
83     }
84     // Erase the value of string constants, which can vary based on platform.
85     if (grappler::IsConstant(node)) {
86       if (node.attr().at("dtype").type() == DT_STRING) {
87         node.mutable_attr()->find("value")->second.clear_value();
88       }
89     }
90   }
91 }
92 
93 // Returns the hash of the checkpoint .index file, 0 if there is none.
HashCheckpointIndexFile(absl::string_view model_dir)94 uint64 HashCheckpointIndexFile(absl::string_view model_dir) {
95   std::string meta_filename = MetaFilename(io::JoinPath(
96       model_dir, kSavedModelVariablesDirectory, kSavedModelVariablesFilename));
97   std::string data;
98   Status read_status = ReadFileToString(Env::Default(), meta_filename, &data);
99   if (read_status.ok()) {
100     return tensorflow::Fingerprint64(data);
101   } else {
102     LOG(WARNING) << read_status.error_message();
103     return 0;
104   }
105 }
106 
107 }  // namespace
108 
ComputeHash(const GraphDef & graph_def)109 uint64 ComputeHash(const GraphDef& graph_def) {
110   std::string graph_def_string;
111   SerializeToStringDeterministic(graph_def, &graph_def_string);
112   return tensorflow::Fingerprint64(graph_def_string);
113 }
114 
CreateFingerprintDef(const MetaGraphDef & metagraph,absl::string_view export_dir)115 FingerprintDef CreateFingerprintDef(const MetaGraphDef& metagraph,
116                                     absl::string_view export_dir) {
117   // Create a copy of `metagraph` which will be used and mutated for fingerprint
118   // computation.
119   MetaGraphDef metagraph_copy = metagraph;
120   FingerprintDef fingerprint_def;
121   // Set fingerprint field #1.
122   fingerprint_def.set_graph_def_checksum(
123       ComputeHash(metagraph_copy.graph_def()));
124   // Set fingerprint field #2.
125   CanonicalizeGraphDef(*metagraph_copy.mutable_graph_def());
126   fingerprint_def.set_graph_def_program_hash(
127       ComputeHash(metagraph_copy.graph_def()));
128   // Set fingerprint field #3.
129   fingerprint_def.set_signature_def_hash(
130       RegularizeAndHashSignatureDefs(metagraph_copy.signature_def()));
131   // Set fingerprint field #4.
132   StatusOr<uint64> object_graph_hash =
133       RegularizeAndHashSavedObjectGraph(metagraph_copy.object_graph_def());
134   fingerprint_def.set_saved_object_graph_hash(
135       RegularizeAndHashSavedObjectGraph(metagraph_copy.object_graph_def()));
136   // Set fingerprint field #5.
137   fingerprint_def.set_checkpoint_hash(HashCheckpointIndexFile(export_dir));
138   // Set version of the fingerprint.
139   VersionDef* version = fingerprint_def.mutable_version();
140   version->set_producer(kFingerprintProducer);
141 
142   return fingerprint_def;
143 }
144 
145 // The GraphDef contains two main sections: a list of nodes and the
146 // FunctionDefLibrary. Canonicalization treats these two sections separately.
CanonicalizeGraphDef(GraphDef & graph_def)147 void CanonicalizeGraphDef(GraphDef& graph_def) {
148   CanonicalizeNodes(&graph_def);
149   // TODO(b/240173815): Complete canonicalization of the FunctionDefLibrary.
150   // For now, we just clear the FunctionDefLibrary.
151   graph_def.mutable_library()->Clear();
152   graph_def.mutable_versions()->Clear();
153 }
154 
RegularizeAndHashSignatureDefs(const google::protobuf::Map<std::string,SignatureDef> & signature_def_map)155 uint64 RegularizeAndHashSignatureDefs(
156     const google::protobuf::Map<std::string, SignatureDef>& signature_def_map) {
157   // Sort `signature_def_map`, which is an unordered map from string keys to
158   // SignatureDefs.
159   absl::btree_map<std::string, SignatureDef> sorted_signature_defs;
160   sorted_signature_defs.insert(signature_def_map.begin(),
161                                signature_def_map.end());
162   uint64 result_hash = 0;
163   for (const auto& item : sorted_signature_defs) {
164     std::string signature_def_string;
165     SerializeToStringDeterministic(item.second, &signature_def_string);
166     result_hash = FingerprintCat64(
167         result_hash, tensorflow::Fingerprint64(signature_def_string));
168   }
169   return result_hash;
170 }
171 
172 // The SavedObjectGraph contains two parts: the list of nodes and the map of
173 // concrete functions. Regularization treats these two parts separately.
RegularizeAndHashSavedObjectGraph(const SavedObjectGraph & object_graph_def)174 uint64 RegularizeAndHashSavedObjectGraph(
175     const SavedObjectGraph& object_graph_def) {
176   // Sort `concrete_functions`, which is an unordered map from function names to
177   // SavedConcreteFunction, using the suffix UID of the function name. Assumes
178   // that the trackable children are listed in a deterministic order during
179   // serialization.
180   absl::btree_map<int, std::string> uid_to_function_names;
181   for (const auto& [name, concrete_function] :
182        object_graph_def.concrete_functions()) {
183     StatusOr<int> uid = GetSuffixUID(name);
184     // All valid function names should end in an UID.
185     if (uid.ok()) {
186       uid_to_function_names.insert({*uid, name});
187     } else {
188       LOG(ERROR) << uid.status().error_message();
189     }
190   }
191   uint64 result_hash = 0;
192   for (const auto& [uid, function_name] : uid_to_function_names) {
193     // Hash the function name (with the UID stripped).
194     result_hash = FingerprintCat64(result_hash,
195                                    tensorflow::Fingerprint64(absl::StripSuffix(
196                                        function_name, std::to_string(uid))));
197     // Hash the serialized concrete function.
198     std::string concrete_function_string;
199     SerializeToStringDeterministic(
200         object_graph_def.concrete_functions().at(function_name),
201         &concrete_function_string);
202     result_hash = FingerprintCat64(
203         result_hash, tensorflow::Fingerprint64(concrete_function_string));
204   }
205   // TODO(b/241294832): Complete canonicalization of `object_graph_def.nodes`.
206   return result_hash;
207 }
208 }  // namespace tensorflow::fingerprinting
209