xref: /aosp_15_r20/external/tensorflow/tensorflow/cc/saved_model/loader.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/saved_model/loader.h"
17 
18 #include <unordered_set>
19 
20 #include "tensorflow/cc/saved_model/constants.h"
21 #include "tensorflow/cc/saved_model/loader_util.h"
22 #include "tensorflow/cc/saved_model/metrics.h"
23 #include "tensorflow/cc/saved_model/reader.h"
24 #include "tensorflow/cc/saved_model/util.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/function.pb.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/op_def.pb.h"
29 #include "tensorflow/core/framework/tensor.pb.h"
30 #include "tensorflow/core/lib/io/path.h"
31 #include "tensorflow/core/lib/monitoring/counter.h"
32 #include "tensorflow/core/lib/monitoring/sampler.h"
33 #include "tensorflow/core/lib/strings/str_util.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/platform/env.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/file_system_helper.h"
38 #include "tensorflow/core/platform/statusor.h"
39 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
40 #include "tensorflow/core/protobuf/meta_graph.pb.h"
41 #include "tensorflow/core/protobuf/saver.pb.h"
42 #include "tensorflow/core/public/session.h"
43 #include "tensorflow/core/public/session_options.h"
44 #include "tensorflow/core/util/tensor_bundle/naming.h"
45 
46 namespace tensorflow {
47 namespace {
48 
49 auto* load_attempt_count = monitoring::Counter<2>::New(
50     "/tensorflow/cc/saved_model/load_attempt_count",
51     "The number of times a SavedModel was successfully loaded.", "model_path",
52     "status");
53 auto* load_latency = monitoring::Counter<1>::New(
54     "/tensorflow/cc/saved_model/load_latency",
55     "Latency in microseconds for SavedModels that were successfully loaded.",
56     "model_path");
57 auto* load_latency_by_stage = monitoring::Sampler<2>::New(
58     {
59         "/tensorflow/cc/saved_model/load_latency_by_stage",  // metric name
60         "Distribution of wall time spent (in microseconds) in each stage "
61         "(restore graph from disk, run init graph op, etc) when loading the "
62         "model",
63         "model_path",
64         "stage",
65     },
66     // Scale of 10, power of 1.8 with bucket count 33 (~20 minutes).
67     monitoring::Buckets::Exponential(10, 1.8, 33));
68 
69 constexpr char kLoadAttemptFail[] = "fail";
70 constexpr char kLoadAttemptSuccess[] = "success";
71 // `tensorflow::LoadSavedModel` API label.
72 constexpr char kCCLoadLabel[] = "cc_load";
73 
GetLatencyMicroseconds(const uint64 start_microseconds)74 uint64 GetLatencyMicroseconds(const uint64 start_microseconds) {
75   const uint64 end_microseconds = EnvTime::NowMicros();
76   // Avoid clock skew.
77   if (end_microseconds < start_microseconds) return 0;
78   return end_microseconds - start_microseconds;
79 }
80 
81 // Ensure that constant tensors loaded from the saved model have valid shape.
82 // Also ensure that constant nodes have a value assigned to them.
83 // TODO(b/154763635): this is temporary and will be replaced with a better audit
ValidateNode(const NodeDef & node)84 static Status ValidateNode(const NodeDef& node) {
85   const auto node_iterator = node.attr().find("value");
86   if (node_iterator != node.attr().end()) {
87     AttrValue node_value = node_iterator->second;
88     if (node_value.has_tensor()) {
89       const PartialTensorShape node_shape(node_value.tensor().tensor_shape());
90       if (node_shape.num_elements() < 0) {
91         return errors::FailedPrecondition(
92             "Saved model contains node \"", node.name(), "\" (op \"", node.op(),
93             "\") which initializes from a tensor with ",
94             node_shape.num_elements(), " elements");
95       }
96     }
97   } else if (node.op() == "Const") {
98     return errors::FailedPrecondition(
99         "Saved model contains node \"", node.name(),
100         "\" which is a constant tensor but no value has been provided");
101   }
102   return OkStatus();
103 }
104 
ValidateFunctionNotRecursive(const FunctionDef & function)105 static Status ValidateFunctionNotRecursive(const FunctionDef& function) {
106   const auto& function_name = function.signature().name();
107   for (const auto& node : function.node_def()) {
108     if (node.op() == function_name) {
109       return errors::FailedPrecondition(
110           "Function ", function_name,
111           " is self recursive and TensorFlow does not support this scenario.");
112     }
113   }
114 
115   return OkStatus();
116 }
117 
ValidateSavedTensors(const GraphDef & graph_def)118 static Status ValidateSavedTensors(const GraphDef& graph_def) {
119   for (const auto& node : graph_def.node()) {
120     TF_RETURN_IF_ERROR(ValidateNode(node));
121   }
122 
123   if (graph_def.has_library()) {
124     const FunctionDefLibrary& library = graph_def.library();
125     for (const auto& function : library.function()) {
126       for (const auto& node : function.node_def()) {
127         TF_RETURN_IF_ERROR(ValidateNode(node));
128       }
129 
130       // Also check that there is no recursivity in the library
131       TF_RETURN_IF_ERROR(ValidateFunctionNotRecursive(function));
132     }
133   }
134 
135   return OkStatus();
136 }
137 
CreateStringTensor(const string & value)138 Tensor CreateStringTensor(const string& value) {
139   Tensor tensor(DT_STRING, TensorShape({}));
140   tensor.scalar<tstring>()() = value;
141   return tensor;
142 }
143 
AddAssetsTensorsToInputs(const StringPiece export_dir,const std::vector<AssetFileDef> & asset_file_defs,std::vector<std::pair<string,Tensor>> * inputs)144 void AddAssetsTensorsToInputs(const StringPiece export_dir,
145                               const std::vector<AssetFileDef>& asset_file_defs,
146                               std::vector<std::pair<string, Tensor>>* inputs) {
147   if (asset_file_defs.empty()) {
148     return;
149   }
150   for (auto& asset_file_def : asset_file_defs) {
151     Tensor assets_file_path_tensor = CreateStringTensor(io::JoinPath(
152         export_dir, kSavedModelAssetsDirectory, asset_file_def.filename()));
153     inputs->push_back(
154         {asset_file_def.tensor_info().name(), assets_file_path_tensor});
155   }
156 }
157 
158 // Like Session::Run(), but uses the Make/Run/ReleaseCallable() API to avoid
159 // leaving behind non-GC'ed state.
160 //
161 // Detailed motivation behind this approach, from ashankar@:
162 //
163 // Each call to Session::Run() that identifies a new subgraph (based on feeds
164 // and fetches) creates some datastructures that live as long as the session
165 // (the partitioned graph, associated executors etc.).
166 //
167 // A pathological case of this would be if say the initialization op
168 // (main_op/legacy_init_op) involves the use of a large constant. Then we
169 // allocate memory for that large constant that will just stick around till the
170 // session dies. With this Callable mechanism, that memory will be released
171 // right after ReleaseCallable returns.
172 //
173 // However, the resource manager state remains.
RunOnce(const RunOptions & run_options,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs,RunMetadata * run_metadata,Session * session)174 Status RunOnce(const RunOptions& run_options,
175                const std::vector<std::pair<string, Tensor>>& inputs,
176                const std::vector<string>& output_tensor_names,
177                const std::vector<string>& target_node_names,
178                std::vector<Tensor>* outputs, RunMetadata* run_metadata,
179                Session* session) {
180   CallableOptions callable_options;
181   std::vector<Tensor> feed_tensors;
182   *callable_options.mutable_run_options() = run_options;
183   for (const auto& input : inputs) {
184     const string& name = input.first;
185     const Tensor& tensor = input.second;
186     callable_options.add_feed(name);
187     feed_tensors.push_back(tensor);
188   }
189   for (const string& output_tensor_name : output_tensor_names) {
190     callable_options.add_fetch(output_tensor_name);
191   }
192   for (const string& target_node_name : target_node_names) {
193     callable_options.add_target(target_node_name);
194   }
195 
196   Session::CallableHandle callable_handle;
197   TF_RETURN_IF_ERROR(session->MakeCallable(callable_options, &callable_handle));
198   const Status run_status = session->RunCallable(callable_handle, feed_tensors,
199                                                  outputs, run_metadata);
200   // Be sure to call ReleaseCallable() regardless of the outcome of
201   // RunCallable().
202   session->ReleaseCallable(callable_handle).IgnoreError();
203   return run_status;
204 }
205 
206 // RunInitOp will return OK if the initialization op was run successfully.
207 // An empty init_op_name indicates that there are no init ops to run.
RunInitOp(const RunOptions & run_options,const string & export_dir,const MetaGraphDef & meta_graph_def,const std::vector<AssetFileDef> & asset_file_defs,Session * session,const string & init_op_name)208 Status RunInitOp(const RunOptions& run_options, const string& export_dir,
209                  const MetaGraphDef& meta_graph_def,
210                  const std::vector<AssetFileDef>& asset_file_defs,
211                  Session* session, const string& init_op_name) {
212   if (!init_op_name.empty()) {
213     LOG(INFO) << "Running initialization op on SavedModel bundle at path: "
214               << export_dir;
215     std::vector<std::pair<string, Tensor>> inputs;
216     AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
217     RunMetadata run_metadata;
218     return RunOnce(run_options, inputs, {}, {init_op_name},
219                    nullptr /* outputs */, &run_metadata, session);
220   }
221   return OkStatus();
222 }
223 
RunRestore(const RunOptions & run_options,const string & export_dir,const StringPiece restore_op_name,const StringPiece variable_filename_const_op_name,const std::vector<AssetFileDef> & asset_file_defs,Session * session)224 Status RunRestore(const RunOptions& run_options, const string& export_dir,
225                   const StringPiece restore_op_name,
226                   const StringPiece variable_filename_const_op_name,
227                   const std::vector<AssetFileDef>& asset_file_defs,
228                   Session* session) {
229   LOG(INFO) << "Restoring SavedModel bundle.";
230   // Find path to variables to be restored in export directory.
231   const string variables_directory =
232       io::JoinPath(export_dir, kSavedModelVariablesDirectory);
233   // Check for saver checkpoints in v2 format. Models exported in the checkpoint
234   // v2 format will have a variables.index file. The corresponding
235   // variables are stored in the variables.data-?????-of-????? files.
236   const string variables_index_path = io::JoinPath(
237       variables_directory, MetaFilename(kSavedModelVariablesFilename));
238   TF_ASSIGN_OR_RETURN(
239       bool variables_index_exists,
240       internal::FileExists(Env::Default(), variables_index_path));
241   if (!variables_index_exists) {
242     LOG(INFO) << "The specified SavedModel has no variables; no checkpoints "
243                  "were restored. File does not exist: "
244               << variables_index_path;
245     return OkStatus();
246   }
247   const string variables_path =
248       io::JoinPath(variables_directory, kSavedModelVariablesFilename);
249 
250   // Add variables to the graph.
251   Tensor variables_path_tensor(DT_STRING, TensorShape({}));
252   variables_path_tensor.scalar<tstring>()() = variables_path;
253 
254   std::vector<std::pair<string, Tensor>> inputs = {
255       {string(variable_filename_const_op_name), variables_path_tensor}};
256 
257   AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
258 
259   RunMetadata run_metadata;
260   return RunOnce(run_options, inputs, {}, {string(restore_op_name)},
261                  nullptr /* outputs */, &run_metadata, session);
262 }
263 
264 }  // namespace
265 
~SavedModelBundleInterface()266 SavedModelBundleInterface::~SavedModelBundleInterface() {}
267 
LoadMetagraphIntoSession(const SessionOptions & session_options,const MetaGraphDef & meta_graph,std::unique_ptr<Session> * session)268 Status LoadMetagraphIntoSession(const SessionOptions& session_options,
269                                 const MetaGraphDef& meta_graph,
270                                 std::unique_ptr<Session>* session) {
271   Session* session_p = nullptr;
272   TF_RETURN_IF_ERROR(NewSession(session_options, &session_p));
273   session->reset(session_p);
274   TF_RETURN_IF_ERROR(ValidateSavedTensors(meta_graph.graph_def()));
275   return (*session)->Create(meta_graph.graph_def());
276 }
277 
LoadSavedModelInternal(const SessionOptions & session_options,const RunOptions & run_options,const string & export_dir,const std::unordered_set<string> & tags,SavedModelBundle * const bundle)278 Status LoadSavedModelInternal(const SessionOptions& session_options,
279                               const RunOptions& run_options,
280                               const string& export_dir,
281                               const std::unordered_set<string>& tags,
282                               SavedModelBundle* const bundle) {
283   TF_RETURN_IF_ERROR(ReadMetaGraphDefFromSavedModel(export_dir, tags,
284                                                     &bundle->meta_graph_def));
285   TF_RETURN_IF_ERROR(
286       ReadSavedModelDebugInfoIfPresent(export_dir, &bundle->debug_info));
287   TF_RETURN_IF_ERROR(LoadMetagraphIntoSession(
288       session_options, bundle->meta_graph_def, &bundle->session));
289   TF_RETURN_IF_ERROR(RestoreSession(run_options, bundle->meta_graph_def,
290                                     export_dir, &bundle->session));
291   return OkStatus();
292 }
293 
LoadSavedModel(const SessionOptions & session_options,const RunOptions & run_options,const string & export_dir,const std::unordered_set<string> & tags,SavedModelBundle * const bundle)294 Status LoadSavedModel(const SessionOptions& session_options,
295                       const RunOptions& run_options, const string& export_dir,
296                       const std::unordered_set<string>& tags,
297                       SavedModelBundle* const bundle) {
298   metrics::SavedModelReadApi(kCCLoadLabel).IncrementBy(1);
299 
300   // TODO(robson): Add tests for the counters.
301   const uint64 start_microseconds = Env::Default()->NowMicros();
302   const Status status = LoadSavedModelInternal(session_options, run_options,
303                                                export_dir, tags, bundle);
304   auto log_and_count = [&](const string& status_str) {
305     LOG(INFO) << "SavedModel load for tags { " << absl::StrJoin(tags, " ")
306               << " }; Status: " << status_str << ": " << status << ". Took "
307               << GetLatencyMicroseconds(start_microseconds) << " microseconds.";
308     load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1);
309   };
310   if (status.ok()) {
311     log_and_count(kLoadAttemptSuccess);
312   } else {
313     log_and_count(kLoadAttemptFail);
314   }
315   load_latency->GetCell(export_dir)
316       ->IncrementBy(GetLatencyMicroseconds(start_microseconds));
317   return status;
318 }
319 
320 namespace {
321 // Session wrapper that prevents calls to Session::Create(), Session::Extend(),
322 // and the deprecated partial-run methods.
323 //
324 // Limiting the available methods on a returned Session gives us the option
325 // to replace the Session with a cut-down implementation, without breaking any
326 // users.
327 class LiteSessionWrapper : public Session {
328  public:
LiteSessionWrapper(std::unique_ptr<Session> wrapped)329   explicit LiteSessionWrapper(std::unique_ptr<Session> wrapped)
330       : wrapped_(std::move(wrapped)) {}
331 
Create(const GraphDef & graph)332   Status Create(const GraphDef& graph) override {
333     return errors::Unimplemented("Session::Create()");
334   }
Create(GraphDef && graph)335   Status Create(GraphDef&& graph) override {
336     return errors::Unimplemented("Session::Create()");
337   }
338 
Extend(const GraphDef & graph)339   Status Extend(const GraphDef& graph) override {
340     return errors::Unimplemented("Session::Extend()");
341   }
Extend(GraphDef && graph)342   Status Extend(GraphDef&& graph) override {
343     return errors::Unimplemented("Session::Extend()");
344   }
345 
Run(const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs)346   Status Run(const std::vector<std::pair<string, Tensor>>& inputs,
347              const std::vector<string>& output_tensor_names,
348              const std::vector<string>& target_node_names,
349              std::vector<Tensor>* outputs) override {
350     return wrapped_->Run(inputs, output_tensor_names, target_node_names,
351                          outputs);
352   }
353 
Create(const RunOptions & run_options,const GraphDef & graph)354   Status Create(const RunOptions& run_options, const GraphDef& graph) override {
355     return errors::Unimplemented("Session::Create()");
356   }
Extend(const RunOptions & run_options,const GraphDef & graph)357   Status Extend(const RunOptions& run_options, const GraphDef& graph) override {
358     return errors::Unimplemented("Session::Extend()");
359   }
Create(const RunOptions & run_options,GraphDef && graph)360   Status Create(const RunOptions& run_options, GraphDef&& graph) override {
361     return errors::Unimplemented("Session::Create()");
362   }
Extend(const RunOptions & run_options,GraphDef && graph)363   Status Extend(const RunOptions& run_options, GraphDef&& graph) override {
364     return errors::Unimplemented("Session::Extend()");
365   }
Close(const RunOptions & run_options)366   Status Close(const RunOptions& run_options) override {
367     return wrapped_->Close(run_options);
368   }
369 
Run(const RunOptions & run_options,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs,RunMetadata * run_metadata)370   Status Run(const RunOptions& run_options,
371              const std::vector<std::pair<string, Tensor>>& inputs,
372              const std::vector<string>& output_tensor_names,
373              const std::vector<string>& target_node_names,
374              std::vector<Tensor>* outputs, RunMetadata* run_metadata) override {
375     return wrapped_->Run(run_options, inputs, output_tensor_names,
376                          target_node_names, outputs, run_metadata);
377   }
378 
PRunSetup(const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<string> & target_nodes,string * handle)379   Status PRunSetup(const std::vector<string>& input_names,
380                    const std::vector<string>& output_names,
381                    const std::vector<string>& target_nodes,
382                    string* handle) override {
383     return errors::Unimplemented("Session::PRunSetup()");
384   }
385 
PRun(const string & handle,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_names,std::vector<Tensor> * outputs)386   Status PRun(const string& handle,
387               const std::vector<std::pair<string, Tensor>>& inputs,
388               const std::vector<string>& output_names,
389               std::vector<Tensor>* outputs) override {
390     return errors::Unimplemented("Session::PRun()");
391   }
392 
ListDevices(std::vector<DeviceAttributes> * response)393   Status ListDevices(std::vector<DeviceAttributes>* response) override {
394     return wrapped_->ListDevices(response);
395   }
396 
Close()397   Status Close() override { return wrapped_->Close(); }
398 
MakeCallable(const CallableOptions & callable_options,CallableHandle * out_handle)399   Status MakeCallable(const CallableOptions& callable_options,
400                       CallableHandle* out_handle) override {
401     return wrapped_->MakeCallable(callable_options, out_handle);
402   }
403 
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata)404   Status RunCallable(CallableHandle handle,
405                      const std::vector<Tensor>& feed_tensors,
406                      std::vector<Tensor>* fetch_tensors,
407                      RunMetadata* run_metadata) override {
408     return wrapped_->RunCallable(handle, feed_tensors, fetch_tensors,
409                                  run_metadata);
410   }
411 
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata,const thread::ThreadPoolOptions & threadpool_options)412   Status RunCallable(
413       CallableHandle handle, const std::vector<Tensor>& feed_tensors,
414       std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata,
415       const thread::ThreadPoolOptions& threadpool_options) override {
416     return wrapped_->RunCallable(handle, feed_tensors, fetch_tensors,
417                                  run_metadata, threadpool_options);
418   }
419 
ReleaseCallable(CallableHandle handle)420   Status ReleaseCallable(CallableHandle handle) override {
421     return wrapped_->ReleaseCallable(handle);
422   }
423 
424  private:
425   const std::unique_ptr<Session> wrapped_;
426 };
427 }  // namespace
428 
RestoreSession(const RunOptions & run_options,const MetaGraphDef & meta_graph,const string & export_dir,std::unique_ptr<Session> * session)429 Status RestoreSession(const RunOptions& run_options,
430                       const MetaGraphDef& meta_graph, const string& export_dir,
431                       std::unique_ptr<Session>* session) {
432   const uint64 read_start_microseconds = Env::Default()->NowMicros();
433   std::vector<AssetFileDef> asset_file_defs;
434   TF_RETURN_IF_ERROR(internal::GetAssetFileDefs(meta_graph, &asset_file_defs));
435   if (meta_graph.has_saver_def()) {
436     TF_RETURN_IF_ERROR(RunRestore(run_options, export_dir,
437                                   meta_graph.saver_def().restore_op_name(),
438                                   meta_graph.saver_def().filename_tensor_name(),
439                                   asset_file_defs, session->get()));
440   }
441   // Record walltime spent in restoring graph from disk, but postpone metric
442   // increments until graph init finishes.
443   const uint64 restore_graph_walltime =
444       GetLatencyMicroseconds(read_start_microseconds);
445 
446   const uint64 graph_init_start_microseconds = Env::Default()->NowMicros();
447   string init_op_name;
448   TF_RETURN_IF_ERROR(
449       internal::GetInitOp(export_dir, meta_graph, &init_op_name));
450   TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, meta_graph,
451                                asset_file_defs, session->get(), init_op_name));
452   load_latency_by_stage->GetCell(export_dir, "restore_graph")
453       ->Add(restore_graph_walltime);
454   // Record wall time spent in init op.
455   load_latency_by_stage->GetCell(export_dir, "init_graph")
456       ->Add(GetLatencyMicroseconds(graph_init_start_microseconds));
457   return OkStatus();
458 }
459 
LoadSavedModel(const SessionOptions & session_options,const RunOptions & run_options,const string & export_dir,const std::unordered_set<string> & tags,SavedModelBundleLite * const bundle)460 Status LoadSavedModel(const SessionOptions& session_options,
461                       const RunOptions& run_options, const string& export_dir,
462                       const std::unordered_set<string>& tags,
463                       SavedModelBundleLite* const bundle) {
464   SavedModelBundle legacy_bundle;
465   SessionOptions rewritten_options(session_options);
466   // We disallow calls to Session::Extend() on the returned session, so we can
467   // reduce memory consumption by not storing the original GraphDef.
468   rewritten_options.config.mutable_experimental()
469       ->set_optimize_for_static_graph(true);
470   // Disallowing the `RunOptions.output_partition_graphs` option (typically used
471   // in debugging and tests) allows us to reduce memory consumption further by
472   // not storing the rewritten subgraph for each signature.
473   rewritten_options.config.mutable_experimental()
474       ->set_disable_output_partition_graphs(true);
475   // TODO(mrry): Consider specializing the session creation to reduce peak
476   // RAM consumption by using `Session::Create(GraphDef&&)`.
477   TF_RETURN_IF_ERROR(LoadSavedModel(rewritten_options, run_options, export_dir,
478                                     tags, &legacy_bundle));
479   *bundle = SavedModelBundleLite(
480       absl::make_unique<LiteSessionWrapper>(std::move(legacy_bundle.session)),
481       std::move(*legacy_bundle.meta_graph_def.mutable_signature_def()));
482   return OkStatus();
483 }
484 
MaybeSavedModelDirectory(const string & export_dir)485 bool MaybeSavedModelDirectory(const string& export_dir) {
486   const string saved_model_pb_path =
487       io::JoinPath(export_dir, kSavedModelFilenamePb);
488   const string saved_model_pbtxt_path =
489       io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
490   return Env::Default()->FileExists(saved_model_pb_path).ok() ||
491          Env::Default()->FileExists(saved_model_pbtxt_path).ok();
492 }
493 
494 }  // namespace tensorflow
495