1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/cc/experimental/libtf/module.h"
16
17 #include <string>
18
19 #include "tensorflow/core/platform/errors.h"
20 #include "tensorflow/core/protobuf/saved_object_graph.pb.h"
21 namespace tf {
22 namespace libtf {
23 namespace impl {
24
25 using tensorflow::libexport::TFPackage;
26 using tf::libtf::runtime::Runtime;
27
28 // TODO(danielellis): Fill in with implementations.
29
30 // Builds a vector of runtime representations of `SavedObject`s from a
31 // SavedModel. These are returned as a flat list. The full hierarchy building
32 // and initialization should be done in a later pass.
BuildObjects(TFPackage & tf_package)33 tensorflow::StatusOr<std::vector<Handle>> BuildObjects(TFPackage& tf_package) {
34 std::vector<Handle> objects;
35 const tensorflow::SavedObjectGraph object_graph = tf_package.GetObjectGraph();
36 for (auto& node : object_graph.nodes()) {
37 if (node.kind_case() == tensorflow::SavedObject::kUserObject) {
38 tensorflow::StatusOr<Handle> result = BuildSavedUserObject(node);
39 if (result.ok()) {
40 objects.push_back(*result);
41 } else {
42 return result.status();
43 }
44 }
45 }
46 return objects;
47 }
48
BuildSavedUserObject(tensorflow::SavedObject saved_object_proto)49 tensorflow::StatusOr<Handle> BuildSavedUserObject(
50 tensorflow::SavedObject saved_object_proto) {
51 if (saved_object_proto.kind_case() != tensorflow::SavedObject::kUserObject) {
52 return tensorflow::errors::InvalidArgument("Not a UserObject.");
53 }
54
55 std::string identifier = saved_object_proto.user_object().identifier();
56 if (identifier == "trackable_list_wrapper") {
57 tf::libtf::List user_list;
58 // TODO(b/191267013): Populate with values.
59 return user_list;
60 }
61 if (identifier == "trackable_dict_wrapper") {
62 tf::libtf::Dictionary user_dict;
63 // TODO(b/191267013): Populate with values.
64 return user_dict;
65 }
66 if (identifier == "signature_map") {
67 tf::libtf::Dictionary signature_map;
68 // TODO(b/191267013): Populate with values.
69 return signature_map;
70 }
71 if (identifier == "_generic_user_object") {
72 tf::libtf::Dictionary user_object;
73 // TODO(b/191267013): Populate with values.
74 return user_object;
75 }
76 return tensorflow::errors::Unimplemented(absl::StrCat(
77 "UserObject with identifier '", identifier, "' not implemented."));
78 }
79
80 // Register all available concrete functions from a SavedModel into a runtime.
RegisterConcreteFunctions(Runtime runtime,TFPackage tf_package)81 tensorflow::Status RegisterConcreteFunctions(Runtime runtime,
82 TFPackage tf_package) {
83 return tensorflow::errors::Unimplemented("Not implemented.");
84 }
85
86 // Initialize any variables found in the SavedModel and attach them to the
87 // appropriate object representation in the runtime.
InitializeVariables(Runtime runtime,TFPackage tf_package,std::vector<Handle> objects)88 tensorflow::Status InitializeVariables(Runtime runtime, TFPackage tf_package,
89 std::vector<Handle> objects) {
90 return tensorflow::errors::Unimplemented("Not implemented.");
91 }
92
93 // Register concrete functions with their associated polymorphic functions.
SetupPolymorphicFunctions(Runtime runtime,TFPackage tf_package,std::vector<Handle> objects)94 tensorflow::Status SetupPolymorphicFunctions(Runtime runtime,
95 TFPackage tf_package,
96 std::vector<Handle> objects) {
97 return tensorflow::errors::Unimplemented("Not implemented.");
98 }
99
100 // Register any captures with their associated higher-level functions.
SetupFunctionCaptures(Runtime runtime,TFPackage tf_package,std::vector<Handle> objects)101 tensorflow::Status SetupFunctionCaptures(Runtime runtime, TFPackage tf_package,
102 std::vector<Handle> objects) {
103 return tensorflow::errors::Unimplemented("Not implemented.");
104 }
105
106 // Takes a flat list of Handles and builds them into the hierarchical
107 // representation defined by the SavedModel.
BuildObjectHierarchy(TFPackage tf_package,std::vector<Handle> objects)108 tensorflow::StatusOr<Handle> BuildObjectHierarchy(TFPackage tf_package,
109 std::vector<Handle> objects) {
110 return tensorflow::errors::Unimplemented("Not implemented.");
111 }
112
BuildProgram(Runtime runtime,TFPackage & tf_package)113 tensorflow::StatusOr<Handle> BuildProgram(Runtime runtime,
114 TFPackage& tf_package) {
115 return tensorflow::errors::Unimplemented("Not implemented.");
116 }
117
118 } // namespace impl
119 } // namespace libtf
120 } // namespace tf
121