1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBEXPORT_LOAD_H_ 16 #define TENSORFLOW_CC_EXPERIMENTAL_LIBEXPORT_LOAD_H_ 17 18 #include <string> 19 20 #include "absl/container/flat_hash_map.h" 21 #include "tensorflow/core/framework/function.pb.h" 22 #include "tensorflow/core/framework/types.pb.h" 23 #include "tensorflow/core/platform/protobuf.h" 24 #include "tensorflow/core/platform/statusor.h" 25 #include "tensorflow/core/protobuf/saved_model.pb.h" 26 #include "tensorflow/core/protobuf/saved_object_graph.pb.h" 27 #include "tensorflow/core/protobuf/trackable_object_graph.pb.h" 28 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" 29 30 namespace tensorflow { 31 namespace libexport { 32 33 // A low-level representation of a SavedModel. 34 // 35 // This class should only ever be a thin wrapper around disk (or other storage) 36 // access for a SavedModel. Higher level functionality should be layered on top 37 // by other functions and classes. 38 // 39 // In the future, this class can also provide a mechanism for automatic version 40 // migration. This will allow the calling code to always work against the most 41 // recent version of SavedModel. 42 class TFPackage { 43 public: 44 // Load a SavedModel, parsing the associated protobuf for later access. 45 static tensorflow::StatusOr<TFPackage> Load(const std::string& path); 46 47 // Reads and returns a checkpoint key associated with a variable. 48 // 49 // The variable is identified by the index in the object graph node list. 50 // 51 // RestoreV2 is the operation that will ultimately be responsible for reading 52 // and restoring the variable(s)' values. Variable values are indexed in the 53 // checkpoint files by "checkpoint keys". These keys along with dtype and 54 // shape / slice information allow RestoreV2 to look up a variable's value in 55 // the SavedModel and restore it into a tensor. 56 tensorflow::StatusOr<std::string> GetVariableCheckpointKey(int index); 57 58 // Retrieves the object graph from the SavedModel. 59 // 60 // For now, we're returning the object graph directly (i.e. the parsed proto) 61 // rather than adding abstraction on top. We may later find we would like an 62 // intermediate abstraction layer to make traversal easier, but for now the 63 // extra complexity doesn't seem justified. Regardless of what we choose, 64 // that logic should live outside this class; this class should continue to 65 // have the clearly-defined, singular responsibility of reading and parsing 66 // the low-level, serialized format. 67 const SavedObjectGraph& GetObjectGraph(); 68 69 // Retrieves a specific GraphDef node by name. 70 // 71 // GraphDef nodes are stored as a repeating list of nodes. At module load 72 // time, a module may have constants that need to be restored. To restore 73 // these constants, they are looked up in the GraphDef's nodes by their name. 74 // Since we may need to load many constants, we create a hash map of these 75 // names to their corresponding nodes at load time in order to look them up 76 // in constant time. 77 tensorflow::StatusOr<const tensorflow::NodeDef*> GetGraphDefNode( 78 std::string name); 79 80 // Returns a list of function defs in the SavedModel. 81 const protobuf::RepeatedPtrField<FunctionDef>& GetFunctionDefs(); 82 83 // Returns a BundleReader for reading variable values. 84 // 85 // This TFPackage retains ownership of the underlying reader. GetVariableReader()86 tensorflow::BundleReader* GetVariableReader() { 87 return variable_reader_.get(); 88 } 89 90 // Returns whether or not we found a valid checkpoint when loading the 91 // package. HasCheckpoint()92 bool HasCheckpoint() { return has_checkpoint_; } 93 94 // Returns the path to the variables file. GetVariablesFilepath()95 const std::string GetVariablesFilepath() { return variables_filepath_; } 96 97 private: 98 SavedModel saved_model_proto_; 99 TrackableObjectGraph trackable_object_graph_; 100 std::unique_ptr<tensorflow::BundleReader> variable_reader_; 101 std::string variables_filepath_; 102 bool has_checkpoint_; 103 absl::flat_hash_map<std::string, const NodeDef*> graph_def_nodes_by_name_; 104 }; 105 106 } // namespace libexport 107 } // namespace tensorflow 108 109 #endif // TENSORFLOW_CC_EXPERIMENTAL_LIBEXPORT_LOAD_H_ 110