xref: /aosp_15_r20/external/tensorflow/tensorflow/cc/experimental/libexport/load.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBEXPORT_LOAD_H_
16 #define TENSORFLOW_CC_EXPERIMENTAL_LIBEXPORT_LOAD_H_
17 
18 #include <string>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "tensorflow/core/framework/function.pb.h"
22 #include "tensorflow/core/framework/types.pb.h"
23 #include "tensorflow/core/platform/protobuf.h"
24 #include "tensorflow/core/platform/statusor.h"
25 #include "tensorflow/core/protobuf/saved_model.pb.h"
26 #include "tensorflow/core/protobuf/saved_object_graph.pb.h"
27 #include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
28 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
29 
30 namespace tensorflow {
31 namespace libexport {
32 
33 // A low-level representation of a SavedModel.
34 //
35 // This class should only ever be a thin wrapper around disk (or other storage)
36 // access for a SavedModel.  Higher level functionality should be layered on top
37 // by other functions and classes.
38 //
39 // In the future, this class can also provide a mechanism for automatic version
40 // migration.  This will allow the calling code to always work against the most
41 // recent version of SavedModel.
42 class TFPackage {
43  public:
44   // Load a SavedModel, parsing the associated protobuf for later access.
45   static tensorflow::StatusOr<TFPackage> Load(const std::string& path);
46 
47   // Reads and returns a checkpoint key associated with a variable.
48   //
49   // The variable is identified by the index in the object graph node list.
50   //
51   // RestoreV2 is the operation that will ultimately be responsible for reading
52   // and restoring the variable(s)' values.  Variable values are indexed in the
53   // checkpoint files by "checkpoint keys".  These keys along with dtype and
54   // shape / slice information allow RestoreV2 to look up a variable's value in
55   // the SavedModel and restore it into a tensor.
56   tensorflow::StatusOr<std::string> GetVariableCheckpointKey(int index);
57 
58   // Retrieves the object graph from the SavedModel.
59   //
60   // For now, we're returning the object graph directly (i.e. the parsed proto)
61   // rather than adding abstraction on top.  We may later find we would like an
62   // intermediate abstraction layer to make traversal easier, but for now the
63   // extra complexity doesn't seem justified.  Regardless of what we choose,
64   // that logic should live outside this class; this class should continue to
65   // have the clearly-defined, singular responsibility of reading and parsing
66   // the low-level, serialized format.
67   const SavedObjectGraph& GetObjectGraph();
68 
69   // Retrieves a specific GraphDef node by name.
70   //
71   // GraphDef nodes are stored as a repeating list of nodes.  At module load
72   // time, a module may have constants that need to be restored.  To restore
73   // these constants, they are looked up in the GraphDef's nodes by their name.
74   // Since we may need to load many constants, we create a hash map of these
75   // names to their corresponding nodes at load time in order to look them up
76   // in constant time.
77   tensorflow::StatusOr<const tensorflow::NodeDef*> GetGraphDefNode(
78       std::string name);
79 
80   // Returns a list of function defs in the SavedModel.
81   const protobuf::RepeatedPtrField<FunctionDef>& GetFunctionDefs();
82 
83   // Returns a BundleReader for reading variable values.
84   //
85   // This TFPackage retains ownership of the underlying reader.
GetVariableReader()86   tensorflow::BundleReader* GetVariableReader() {
87     return variable_reader_.get();
88   }
89 
90   // Returns whether or not we found a valid checkpoint when loading the
91   // package.
HasCheckpoint()92   bool HasCheckpoint() { return has_checkpoint_; }
93 
94   // Returns the path to the variables file.
GetVariablesFilepath()95   const std::string GetVariablesFilepath() { return variables_filepath_; }
96 
97  private:
98   SavedModel saved_model_proto_;
99   TrackableObjectGraph trackable_object_graph_;
100   std::unique_ptr<tensorflow::BundleReader> variable_reader_;
101   std::string variables_filepath_;
102   bool has_checkpoint_;
103   absl::flat_hash_map<std::string, const NodeDef*> graph_def_nodes_by_name_;
104 };
105 
106 }  // namespace libexport
107 }  // namespace tensorflow
108 
109 #endif  // TENSORFLOW_CC_EXPERIMENTAL_LIBEXPORT_LOAD_H_
110