xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tfrt/saved_model/saved_model.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tfrt/saved_model/saved_model.h"
16 
17 #include <algorithm>
18 #include <cstddef>
19 #include <functional>
20 #include <iterator>
21 #include <memory>
22 #include <optional>
23 #include <string>
24 #include <unordered_set>
25 #include <utility>
26 #include <vector>
27 
28 #include "absl/strings/str_join.h"
29 #include "absl/strings/string_view.h"
30 #include "absl/time/clock.h"
31 #include "absl/time/time.h"
32 #include "absl/types/span.h"
33 #include "tensorflow/cc/saved_model/reader.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
35 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
36 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
37 #include "tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h"
38 #include "tensorflow/compiler/mlir/tfrt/translate/import_model.h"
39 #include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h"
40 #include "tensorflow/compiler/xla/status_macros.h"
41 #include "tensorflow/core/framework/types.h"
42 #include "tensorflow/core/lib/gtl/cleanup.h"
43 #include "tensorflow/core/lib/monitoring/gauge.h"
44 #include "tensorflow/core/platform/errors.h"
45 #include "tensorflow/core/platform/logging.h"
46 #include "tensorflow/core/platform/mutex.h"
47 #include "tensorflow/core/platform/path.h"
48 #include "tensorflow/core/platform/status.h"
49 #include "tensorflow/core/platform/statusor.h"
50 #include "tensorflow/core/profiler/lib/traceme.h"
51 #include "tensorflow/core/protobuf/meta_graph.pb.h"
52 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
53 #include "tensorflow/core/tfrt/graph_executor/graph_executor.h"
54 #include "tensorflow/core/tfrt/mla/mla_utils.h"
55 #include "tensorflow/core/tfrt/runtime/work_queue_interface.h"
56 #include "tensorflow/core/tfrt/saved_model/saved_model_import_input.h"
57 #include "tensorflow/core/tfrt/tpu/tpu_resources.h"  // NOLINT(unused-includes): For tfrt::tpu::TpuModelResource
58 #include "tensorflow/core/tfrt/utils/error_util.h"
59 #include "tensorflow/core/tfrt/utils/fallback_tensor.h"
60 #include "tensorflow/core/tfrt/utils/utils.h"
61 #include "tfrt/bef_executor/bef_file.h"  // from @tf_runtime
62 #include "tfrt/core_runtime/core_runtime.h"  // from @tf_runtime
63 #include "tfrt/host_context/async_value.h"  // from @tf_runtime
64 #include "tfrt/host_context/chain.h"  // from @tf_runtime
65 #include "tfrt/host_context/execution_context.h"  // from @tf_runtime
66 #include "tfrt/host_context/function.h"  // from @tf_runtime
67 #include "tfrt/host_context/host_context.h"  // from @tf_runtime
68 #include "tfrt/host_context/request_deadline_tracker.h"  // from @tf_runtime
69 #include "tfrt/metrics/common_metrics.h"  // from @tf_runtime
70 #include "tfrt/support/ref_count.h"  // from @tf_runtime
71 
72 namespace tensorflow {
73 namespace tfrt_stub {
74 namespace {
75 
76 constexpr absl::string_view kSignatureJoiningDelimiter = "+";
77 constexpr absl::string_view kTensorNameJoiningDelimiter = "-";
78 constexpr absl::string_view kArgumentTypeJoiningDelimiter = "^";
79 
80 using SignatureMap = absl::flat_hash_map<std::string, internal::Signature>;
81 using ::tensorflow::SessionMetadata;
82 using ::tensorflow::StatusOr;
83 
84 struct InitializersAndSignatures {
85   llvm::SmallVector<std::string, 4> initializers;
86   SignatureMap signature_map;
87 };
88 
89 auto* saved_model_read_meta_graph_time_seconds =
90     tensorflow::monitoring::Gauge<int64_t, 1>::New(
91         "/tensorflow/tfrt/saved_model/read_meta_graph_time",
92         "Record the time of reading meta_graph from disk.", "model_name");
93 
94 auto* saved_model_functionalization_time_seconds =
95     tensorflow::monitoring::Gauge<int64_t, 1>::New(
96         "/tensorflow/tfrt/saved_model/functionalization_time",
97         "Record the functionalization time for the savedmodel.", "model_name");
98 
99 auto* saved_model_grappler_time_seconds =
100     tensorflow::monitoring::Gauge<int64_t, 1>::New(
101         "/tensorflow/tfrt/saved_model/grappler_time",
102         "Record the grappler time for the savedmodel.", "model_name");
103 
104 auto* saved_model_mla_check_time_milli_seconds =
105     tensorflow::monitoring::Gauge<int64_t, 1>::New(
106         "/tensorflow/tfrt/saved_model/mla_check_time",
107         "Record the MLA check time for the savedmodel.", "model_name");
108 
109 auto* saved_model_import_time_seconds =
110     tensorflow::monitoring::Gauge<int64_t, 1>::New(
111         "/tensorflow/tfrt/saved_model/import_time",
112         "Record the MLIR import time for the savedmodel.", "model_name");
113 
114 auto* saved_model_compile_time_seconds =
115     tensorflow::monitoring::Gauge<int64_t, 1>::New(
116         "/tensorflow/tfrt/saved_model/compile_time",
117         "Record the compilation time for the savedmodel.", "model_name");
118 
119 auto* saved_model_init_time_seconds =
120     tensorflow::monitoring::Gauge<int64_t, 1>::New(
121         "/tensorflow/tfrt/saved_model/init_time",
122         "Record the initialization time for the savedmodel.", "model_name");
123 
CreateScalarStringTensor(absl::string_view str)124 tensorflow::Tensor CreateScalarStringTensor(absl::string_view str) {
125   return tensorflow::Tensor(tensorflow::tstring(str));
126 }
127 
128 // Create the tensor for the bound input, which can be a variable or an asset.
129 //
130 // TODO(chky): For V2 models, the bound input can also be a resource.
CreateTensorFromBoundInput(mlir::Operation * bound_input,absl::string_view saved_model_dir,absl::flat_hash_map<std::string,tensorflow::Tensor> * variables)131 StatusOr<tensorflow::Tensor> CreateTensorFromBoundInput(
132     mlir::Operation* bound_input, absl::string_view saved_model_dir,
133     absl::flat_hash_map<std::string, tensorflow::Tensor>* variables) {
134   // Assets are files in the saved model directory. We pass their filenames to
135   // functions so that they can be used.
136   if (auto asset = llvm::dyn_cast<mlir::tf_saved_model::AssetOp>(bound_input)) {
137     // The filename in the asset is a relative path. So we prefix it with the
138     // directory path.
139     return CreateScalarStringTensor(
140         tensorflow::io::JoinPath(saved_model_dir, asset.filename().str()));
141   }
142 
143   return tensorflow::errors::Internal(
144       "Failed to create captured tensors: unknown bound input type.");
145 }
146 
GetFunctionSignaturesFromTFSavedModelMLIR(absl::string_view saved_model_dir,mlir::ModuleOp module)147 StatusOr<SignatureMap> GetFunctionSignaturesFromTFSavedModelMLIR(
148     absl::string_view saved_model_dir, mlir::ModuleOp module) {
149   absl::flat_hash_map<std::string, tensorflow::Tensor> variables;
150   SignatureMap signatures;
151 
152   tensorflow::StatusGroup status_group;
153   TF_RETURN_IF_ERROR(tensorflow::MapFunctionSignaturesFromTFSavedModelMLIR(
154       module, [&status_group, &variables, &signatures, saved_model_dir](
155                   const tensorflow::TFRTSavedModelSignatureInfo& sig_info) {
156         auto& signature = signatures[std::string(sig_info.func_name)];
157 
158         auto copy = [](llvm::ArrayRef<llvm::StringRef> src,
159                        std::vector<std::string>* dst) {
160           transform(src, std::back_inserter(*dst),
161                     [](llvm::StringRef x) { return x.str(); });
162         };
163         copy(sig_info.input_names, &signature.input_names);
164         copy(sig_info.output_names, &signature.output_names);
165         copy(sig_info.input_devices, &signature.input_devices);
166 
167         DCHECK(signature.input_specs.empty());
168         signature.input_specs.reserve(sig_info.input_specs.size());
169         for (auto& spec : sig_info.input_specs) {
170           signature.input_specs.push_back(TensorSpec(spec.first, spec.second));
171         }
172 
173         DCHECK(signature.output_specs.empty());
174         signature.output_specs.reserve(sig_info.output_specs.size());
175         for (auto& spec : sig_info.output_specs) {
176           signature.output_specs.push_back(TensorSpec(spec.first, spec.second));
177         }
178 
179         for (auto* bound_input : sig_info.bound_inputs) {
180           auto capture = CreateTensorFromBoundInput(
181               bound_input, saved_model_dir, &variables);
182           if (!capture.ok()) {
183             status_group.Update(capture.status());
184             // Insert a random tensor in case of errors.
185             signature.captures.push_back(tensorflow::Tensor());
186           } else {
187             signature.captures.push_back(*std::move(capture));
188           }
189         }
190       }));
191 
192   if (!status_group.ok()) return status_group.as_concatenated_status();
193 
194   return signatures;
195 }
196 
RunInitializers(const InitializersAndSignatures & initializers_and_signatures,const SessionMetadata & model_metadata,tfrt::BEFFile * bef_file,const Runtime & runtime,tfrt::ResourceContext * resource_context,const FallbackState & fallback_state)197 tensorflow::Status RunInitializers(
198     const InitializersAndSignatures& initializers_and_signatures,
199     const SessionMetadata& model_metadata, tfrt::BEFFile* bef_file,
200     const Runtime& runtime, tfrt::ResourceContext* resource_context,
201     const FallbackState& fallback_state) {
202   auto* host = runtime.core_runtime()->GetHostContext();
203   TF_ASSIGN_OR_RETURN(auto request_info,
204                       SetUpRequestContext(/*run_options=*/{}, model_metadata,
205                                           host, runtime.work_queue(),
206                                           resource_context, fallback_state));
207 
208   tfrt::ExecutionContext exec_ctx(request_info->tfrt_request_context);
209 
210   // Run "_tfrt_fallback_init" first to initialize fallback-specific states. It
211   // is the special function created by compiler, which calls a sequence of
212   // tfrt_fallback_async.createop to create all fallback ops used in this BEF.
213   TF_RETURN_IF_ERROR(
214       RunRuntimeInitializer(exec_ctx, bef_file, "_tfrt_fallback_init"));
215 
216   for (const auto& init : initializers_and_signatures.initializers) {
217     // TODO(b/184771263): Consider using `GraphExecutionRunOnFunction()`
218     // instead.
219 
220     auto* func = bef_file->GetFunction(init);
221     assert(func);
222 
223     const auto& signature = initializers_and_signatures.signature_map.at(init);
224 
225     auto ready_chain = tfrt::GetReadyChain();
226 
227     // The actual arguments are the concat of side-effect chain and assets.
228     llvm::SmallVector<tfrt::AsyncValue*, 1> arguments;
229     auto cleanup = tensorflow::gtl::MakeCleanup([&]() {
230       for (auto* argument : arguments) argument->DropRef();
231     });
232 
233     arguments.push_back(ready_chain.release());
234 
235     for (const auto& capture : signature.captures) {
236       arguments.push_back(
237           tfrt::MakeAvailableAsyncValueRef<FallbackTensor>(capture).release());
238     }
239 
240     assert(arguments.size() == func->argument_types().size());
241 
242     llvm::SmallVector<tfrt::RCReference<tfrt::AsyncValue>, 1> results;
243     results.resize(func->result_types().size());
244     assert(results.size() == 1);
245 
246     func->Execute(exec_ctx, arguments, results);
247 
248     // Wait for the function execution to finish, as well as the side-effects.
249     host->Await(results);
250 
251     if (auto* error = results[0]->GetErrorIfPresent()) {
252       return CreateTfErrorStatus(*error);
253     }
254   }
255 
256   // After we initialized all the resources in the original graph, we can run
257   // the "_tfrt_resource_init" function to set these resources in runtime
258   // states, so that later it can be efficiently retrieved without any locking.
259   TF_RETURN_IF_ERROR(
260       RunRuntimeInitializer(exec_ctx, bef_file, "_tfrt_resource_init"));
261 
262   return OkStatus();
263 }
264 
FindNamesForValidSignatures(const tensorflow::MetaGraphDef & meta_graph_def)265 std::vector<std::string> FindNamesForValidSignatures(
266     const tensorflow::MetaGraphDef& meta_graph_def) {
267   std::vector<std::string> valid_signature_names;
268 
269   auto is_dense_tensor_info = [](const auto& named_tensor_info) {
270     return !named_tensor_info.second.name().empty();
271   };
272 
273   auto is_ref_type_tensor_info = [](const auto& named_tensor_info) {
274     return tensorflow::IsRefType(named_tensor_info.second.dtype());
275   };
276 
277   for (const auto& iter : meta_graph_def.signature_def()) {
278     const auto& sig_key = iter.first;
279     const auto& signature = iter.second;
280     if (!std::all_of(signature.inputs().begin(), signature.inputs().end(),
281                      is_dense_tensor_info) ||
282         !std::all_of(signature.outputs().begin(), signature.outputs().end(),
283                      is_dense_tensor_info)) {
284       LOG(WARNING) << "Unsupported signature with non-dense tensors as "
285                       "input/output. Name: "
286                    << sig_key << "; Signature: " << signature.DebugString();
287       continue;
288     }
289     if (std::any_of(signature.inputs().begin(), signature.inputs().end(),
290                     is_ref_type_tensor_info) ||
291         std::any_of(signature.outputs().begin(), signature.outputs().end(),
292                     is_ref_type_tensor_info)) {
293       LOG(WARNING) << "Unsupported signature with ref type tensors as "
294                       "input/output. Name: "
295                    << sig_key << "; Signature: " << signature.DebugString();
296       continue;
297     }
298     valid_signature_names.push_back(sig_key);
299   }
300   return valid_signature_names;
301 }
302 
ImportSavedModel(mlir::MLIRContext * context,const tensorflow::MetaGraphDef & meta_graph_def,const FallbackState & fallback_state,std::string saved_model_dir,bool import_user_signatures,bool run_placer_grappler_on_functions,bool enable_tfrt_gpu)303 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ImportSavedModel(
304     mlir::MLIRContext* context, const tensorflow::MetaGraphDef& meta_graph_def,
305     const FallbackState& fallback_state, std::string saved_model_dir,
306     bool import_user_signatures, bool run_placer_grappler_on_functions,
307     bool enable_tfrt_gpu) {
308   std::vector<std::string> signature_names;
309   if (import_user_signatures) {
310     signature_names = FindNamesForValidSignatures(meta_graph_def);
311     if (signature_names.empty())
312       LOG(WARNING) << "No valid signature found for model: " << saved_model_dir;
313   }
314 
315   // TfrtSavedModelMLIRImportInput basically implements the graph processing
316   // logic (eg. Placer and Grappler) used in DirectSession, which apply graph
317   // transformations on each subgraphs (ie. signatures). It is reusing the
318   // code path in DirectSession to avoid problems caused by different behavior
319   // in a different code path. And it is injected to the MLIR importer so that
320   // the importer can import the transformed graph instead of the original
321   // graph.
322   TF_ASSIGN_OR_RETURN(auto import_input,
323                       TfrtSavedModelMLIRImportInput::Create(
324                           fallback_state, &meta_graph_def, /*debug_info=*/{},
325                           run_placer_grappler_on_functions, enable_tfrt_gpu));
326 
327   TF_ASSIGN_OR_RETURN(
328       auto module,
329       tensorflow::ConvertSavedModelV1ToMlirLite(
330           import_input,
331           /*exported_names=*/absl::MakeSpan(signature_names), context));
332 
333   LOG(INFO) << "TFRT ImportSavedModel: Functionalization took "
334             << absl::ToInt64Milliseconds(
335                    import_input.GetFunctionalizationDuration())
336             << " ms.";
337   LOG(INFO) << "TFRT ImportSavedModel: Grappler took "
338             << absl::ToInt64Milliseconds(import_input.GetGrapplerDuration())
339             << " ms.";
340 
341   saved_model_functionalization_time_seconds->GetCell(saved_model_dir)
342       ->Set(absl::ToInt64Seconds(import_input.GetFunctionalizationDuration()));
343 
344   saved_model_grappler_time_seconds->GetCell(saved_model_dir)
345       ->Set(absl::ToInt64Seconds(import_input.GetGrapplerDuration()));
346 
347   return module;
348 }
349 
GetInitializersAndSignatures(mlir::ModuleOp module,absl::string_view saved_model_dir)350 StatusOr<InitializersAndSignatures> GetInitializersAndSignatures(
351     mlir::ModuleOp module, absl::string_view saved_model_dir) {
352   InitializersAndSignatures result;
353   TF_ASSIGN_OR_RETURN(
354       result.signature_map,
355       GetFunctionSignaturesFromTFSavedModelMLIR(saved_model_dir, module));
356   for (auto session_initializer_name :
357        mlir::tf_saved_model::GetSessionInitializerExportedName(module)) {
358     result.initializers.push_back(session_initializer_name.str());
359   }
360   return result;
361 }
362 
InitSavedModel(const InitializersAndSignatures & initializers_and_signatures,tfrt::BEFFile * bef_file,const SavedModel::Options & options,tfrt::ResourceContext * resource_context,const FallbackState & fallback_state)363 tensorflow::Status InitSavedModel(
364     const InitializersAndSignatures& initializers_and_signatures,
365     tfrt::BEFFile* bef_file, const SavedModel::Options& options,
366     tfrt::ResourceContext* resource_context,
367     const FallbackState& fallback_state) {
368   TF_RETURN_IF_ERROR(
369       RunInitializers(initializers_and_signatures,
370                       options.graph_execution_options.model_metadata, bef_file,
371                       *options.graph_execution_options.runtime,
372                       resource_context, fallback_state));
373 
374   return OkStatus();
375 }
376 
377 }  // namespace
378 
~SavedModel()379 SavedModel::~SavedModel() {}
380 
GetHostContext() const381 tfrt::HostContext* SavedModel::GetHostContext() const {
382   return runtime_->core_runtime()->GetHostContext();
383 }
384 
385 namespace {
386 
387 // Gets the signatures from `signature_defs` and inserts them into `signatures`.
GetSignaturesFromSignatureDef(SignatureMap & signatures,const google::protobuf::Map<std::string,tensorflow::SignatureDef> & signature_defs,const SavedModel::Options & options)388 void GetSignaturesFromSignatureDef(
389     SignatureMap& signatures,
390     const google::protobuf::Map<std::string, tensorflow::SignatureDef>& signature_defs,
391     const SavedModel::Options& options) {
392   for (const auto& p : signature_defs) {
393     const std::string& signature_name = p.first;
394     const tensorflow::SignatureDef& signature_def = p.second;
395     DCHECK(signatures.find(signature_name) == signatures.end());
396     auto& signature = signatures[signature_name];
397 
398     signature.input_names.reserve(signature_def.inputs().size());
399     signature.input_specs.reserve(signature_def.inputs().size());
400     for (const auto& p : signature_def.inputs()) {
401       const std::string& input_tensor_name = p.first;
402       const tensorflow::TensorInfo& tensor_info = p.second;
403       signature.input_names.push_back(input_tensor_name);
404       signature.input_specs.push_back(
405           TensorSpec(tensor_info.dtype(), tensor_info.tensor_shape()));
406     }
407 
408     signature.input_devices = std::vector<std::string>(
409         signature_def.inputs().size(),
410         options.graph_execution_options.compile_options.default_device);
411 
412     signature.output_names.reserve(signature_def.outputs().size());
413     signature.output_specs.reserve(signature_def.outputs().size());
414     for (const auto& p : signature_def.outputs()) {
415       const std::string& output_tensor_name = p.first;
416       const tensorflow::TensorInfo& tensor_info = p.second;
417       signature.output_names.push_back(output_tensor_name);
418       signature.output_specs.push_back(
419           TensorSpec(tensor_info.dtype(), tensor_info.tensor_shape()));
420     }
421   }
422 }
423 
UpdateCompileOptions(SavedModel::Options & options)424 void UpdateCompileOptions(SavedModel::Options& options) {
425   // Disable DecomposeResourceOpsPass for now, as DecomposeResourceGather does
426   // not work well with GPU (b/232819415).
427   if (options.graph_execution_options.enable_tfrt_gpu) {
428     options.graph_execution_options.compile_options.decompose_resource_ops =
429         false;
430   }
431 }
432 
ReadSavedModel(absl::string_view saved_model_dir,const std::unordered_set<std::string> & tags)433 StatusOr<tensorflow::MetaGraphDef> ReadSavedModel(
434     absl::string_view saved_model_dir,
435     const std::unordered_set<std::string>& tags) {
436   LOG(INFO) << "TFRT reading v1 savedmodel: " << saved_model_dir;
437   const auto read_start_time = absl::Now();
438 
439   tensorflow::MetaGraphDef meta_graph_def;
440   TF_RETURN_IF_ERROR(tensorflow::ReadMetaGraphDefFromSavedModel(
441       std::string(saved_model_dir), tags, &meta_graph_def));
442 
443   const auto read_meta_graph_duration = absl::Now() - read_start_time;
444   saved_model_read_meta_graph_time_seconds
445       ->GetCell(std::string(saved_model_dir))
446       ->Set(absl::ToInt64Seconds(read_meta_graph_duration));
447   LOG(INFO) << "TFRT finished reading meta graph. Took "
448             << absl::ToInt64Milliseconds(read_meta_graph_duration) << " ms.";
449   return std::move(meta_graph_def);
450 }
451 
452 }  // namespace
453 
LoadSavedModel(Options options,absl::string_view saved_model_dir,const std::unordered_set<std::string> & tags,tensorflow::Status * status)454 std::unique_ptr<SavedModel> SavedModelImpl::LoadSavedModel(
455     Options options, absl::string_view saved_model_dir,
456     const std::unordered_set<std::string>& tags, tensorflow::Status* status) {
457   std::string saved_model_dir_str = "unused";
458   if (options.maybe_load_from_mla) {
459     const auto mla_check_start_time = absl::Now();
460     const bool is_mla = IsMlarchive(saved_model_dir);
461     const auto mla_check_duration = absl::Now() - mla_check_start_time;
462     saved_model_mla_check_time_milli_seconds
463         ->GetCell(std::string(saved_model_dir))
464         ->Set(absl::ToInt64Milliseconds(mla_check_duration));
465     LOG(INFO) << "TFRT finished checking MLA. Took "
466               << absl::ToInt64Milliseconds(mla_check_duration) << " ms.";
467     if (is_mla) {
468       LOG(INFO) << "TFRT got an MLArchive dir: " << saved_model_dir
469                 << ". Continuing to find the actual saved_model_dir in it.";
470       const auto statusor_saved_model_dir =
471           GetSavedModelDirFromMlaDir(saved_model_dir);
472       if (!statusor_saved_model_dir.ok()) {
473         *status = statusor_saved_model_dir.status();
474         return nullptr;
475       }
476       saved_model_dir_str = *statusor_saved_model_dir;
477       saved_model_dir = saved_model_dir_str;
478       LOG(INFO) << "TFRT found from MLArchive a saved model: "
479                 << saved_model_dir;
480     }  // Not an MLA; `saved_model_dir` is ready to use.
481   }
482 
483   auto meta_graph_def = ReadSavedModel(saved_model_dir, tags);
484   if (!meta_graph_def.ok()) {
485     *status = meta_graph_def.status();
486     return nullptr;
487   }
488 
489   LOG(INFO) << "TFRT loading v1 savedmodel: " << saved_model_dir;
490   tfrt::metrics::AddTFRTVersionMetric();
491 
492   UpdateTpuTargetByBridgeCompatibility(options.graph_execution_options,
493                                        meta_graph_def->graph_def());
494   UpdateCompileOptions(options);
495 
496   auto saved_model =
497       [&]() -> tensorflow::StatusOr<std::unique_ptr<SavedModel>> {
498     mlir::MLIRContext context;
499 
500     const bool lazy_loading_enabled =
501         meta_graph_def->signature_def_size() > options.lazy_loading_threshold;
502 
503     // Step 1: Import saved model from a proto to an MLIR module.
504     const auto import_start_time = absl::Now();
505     auto session_options =
506         CreateDefaultSessionOptions(options.graph_execution_options);
507     // Set optimize_for_static_graph to true since we won't extend the graph
508     // later. If optimize_for_static_graph is set to false, FallbackState will
509     // keep an extra unused copy of the graph, which unnecessarily consumes
510     // memory.
511     session_options.config.mutable_experimental()
512         ->set_optimize_for_static_graph(true);
513 
514     // Creating the fallback_state using the original function def library
515     // without applying placer or grappler, it is OK for now because it's only
516     // used for captured functions in certain tf.data ops
517     const auto& fdef_lib = meta_graph_def->graph_def().library();
518     ASSIGN_OR_RETURN_IN_IMPORT(
519         auto fallback_state, FallbackState::Create(session_options, fdef_lib));
520     ASSIGN_OR_RETURN_IN_IMPORT(
521         auto mlir_module,
522         ImportSavedModel(
523             &context, *meta_graph_def, *fallback_state,
524             std::string(saved_model_dir),
525             /*import_user_signatures=*/!lazy_loading_enabled,
526             options.graph_execution_options.run_placer_grappler_on_functions,
527             options.graph_execution_options.enable_tfrt_gpu));
528 
529     const auto import_duration = absl::Now() - import_start_time;
530     saved_model_import_time_seconds->GetCell(std::string(saved_model_dir))
531         ->Set(absl::ToInt64Seconds(import_duration));
532     LOG(INFO) << "TFRT finished importing savedmodel. Took "
533               << absl::ToInt64Milliseconds(import_duration) << " ms.";
534 
535     // Step 2: Compile the MLIR module from TF dialect to TFRT dialect (in BEF).
536     const auto compile_start_time = absl::Now();
537     ASSIGN_OR_RETURN_IN_COMPILE(
538         auto initializers_and_signatures,
539         GetInitializersAndSignatures(mlir_module.get(), saved_model_dir));
540     // If lazy loading is enabled, the user signatures are not exported via MLIR
541     // module, so we need to get them from the proto.
542     // TODO(b/187228559): Unify the code paths for populating the signature map.
543     if (lazy_loading_enabled) {
544       GetSignaturesFromSignatureDef(initializers_and_signatures.signature_map,
545                                     meta_graph_def->signature_def(), options);
546     }
547     tfrt::BefBuffer bef;
548     RETURN_IF_ERROR_IN_COMPILE(tensorflow::ConvertTfMlirToBef(
549         options.graph_execution_options.compile_options, mlir_module.get(),
550         &bef));
551 
552     const auto compile_duration = absl::Now() - compile_start_time;
553     saved_model_compile_time_seconds->GetCell(std::string(saved_model_dir))
554         ->Set(absl::ToInt64Seconds(compile_duration));
555     LOG(INFO) << "TFRT finished compiling savedmodel. Took "
556               << absl::ToInt64Milliseconds(compile_duration) << " ms.";
557 
558     // Step 3: Initialize runtime states using special BEF functions.
559     const auto init_start_time = absl::Now();
560     ASSIGN_OR_RETURN_IN_INIT(
561         auto bef_file, tfrt::CreateBefFileFromBefBuffer(
562                            *options.graph_execution_options.runtime, bef));
563 
564     auto tpu_model_resource = std::make_unique<tfrt::tpu::TpuModelResource>();
565     auto resource_context = CreateResourceContext(
566         *options.graph_execution_options.runtime, tpu_model_resource.get(),
567         options.graph_execution_options.compile_options.tpu_target);
568     RETURN_IF_ERROR_IN_INIT(
569         InitSavedModel(initializers_and_signatures, bef_file.get(), options,
570                        resource_context.get(), *fallback_state));
571 
572     const auto init_duration = absl::Now() - init_start_time;
573     saved_model_init_time_seconds->GetCell(std::string(saved_model_dir))
574         ->Set(absl::ToInt64Seconds(init_duration));
575     LOG(INFO) << "TFRT finished initializing savedmodel. Took "
576               << absl::ToInt64Milliseconds(init_duration) << " ms.";
577 
578     ASSIGN_OR_RETURN_WITH_STAGE_INFO(
579         "graph_executor creation", auto graph_executor,
580         GraphExecutor::Create(options.graph_execution_options, *fallback_state,
581                               tpu_model_resource.get(),
582                               std::move(*meta_graph_def->mutable_graph_def())));
583 
584     // Finally, create the saved model.
585     return {std::make_unique<SavedModelImpl>(
586         std::move(options), *std::move(meta_graph_def), std::move(bef),
587         std::move(bef_file),
588         std::move(initializers_and_signatures.signature_map),
589         std::move(fallback_state), std::move(tpu_model_resource),
590         std::move(resource_context), std::move(graph_executor))};
591   }();
592 
593   if (!saved_model.ok()) {
594     *status = saved_model.status();
595     return nullptr;
596   }
597   *status = OkStatus();
598   return *std::move(saved_model);
599 }
600 
SavedModelImpl(Options options,tensorflow::MetaGraphDef meta_graph_def,tfrt::BefBuffer bef,tfrt::RCReference<tfrt::BEFFile> bef_file,SignatureMap signatures,std::unique_ptr<FallbackState> fallback_state,std::unique_ptr<tfrt::tpu::TpuModelResource> tpu_model_resource,std::unique_ptr<tfrt::ResourceContext> resource_context,std::unique_ptr<GraphExecutor> graph_executor)601 SavedModelImpl::SavedModelImpl(
602     Options options, tensorflow::MetaGraphDef meta_graph_def,
603     tfrt::BefBuffer bef, tfrt::RCReference<tfrt::BEFFile> bef_file,
604     SignatureMap signatures, std::unique_ptr<FallbackState> fallback_state,
605     std::unique_ptr<tfrt::tpu::TpuModelResource> tpu_model_resource,
606     std::unique_ptr<tfrt::ResourceContext> resource_context,
607     std::unique_ptr<GraphExecutor> graph_executor)
608     : SavedModel(options.graph_execution_options.runtime),
609       options_(std::move(options)),
610       meta_graph_def_(std::move(meta_graph_def)),
611       bef_(std::move(bef)),
612       bef_file_(std::move(bef_file)),
613       req_deadline_tracker_(
614           options.graph_execution_options.runtime->core_runtime()
615               ->GetHostContext()),
616       signatures_(std::move(signatures)),
617       fallback_state_(std::move(fallback_state)),
618       tpu_model_resource_(std::move(tpu_model_resource)),
619       resource_context_(std::move(resource_context)),
620       graph_executor_(std::move(graph_executor)),
621       lazy_loading_enabled_(meta_graph_def_.signature_def_size() >
622                             options.lazy_loading_threshold) {}
623 
624 SavedModelImpl::~SavedModelImpl() = default;
625 
GetFunctionNames() const626 std::vector<std::string> SavedModelImpl::GetFunctionNames() const {
627   std::vector<std::string> result;
628   for (const auto& entry : signatures_) {
629     result.push_back(entry.first);
630   }
631   return result;
632 }
633 
GetMetaGraphDef() const634 const tensorflow::MetaGraphDef& SavedModelImpl::GetMetaGraphDef() const {
635   return meta_graph_def_;
636 }
637 
GetFunctionMetadata(absl::string_view func_name) const638 std::optional<FunctionMetadata> SavedModelImpl::GetFunctionMetadata(
639     absl::string_view func_name) const {
640   auto iter = signatures_.find(func_name);
641   if (iter == signatures_.end()) return std::nullopt;
642   return FunctionMetadata(&iter->second);
643 }
644 
645 namespace {
IsInputSpecsCorrect(absl::string_view name,const internal::Signature & signature,absl::Span<const tensorflow::Tensor> inputs)646 tensorflow::Status IsInputSpecsCorrect(
647     absl::string_view name, const internal::Signature& signature,
648     absl::Span<const tensorflow::Tensor> inputs) {
649   TF_RET_CHECK(signature.input_specs.size() == inputs.size())
650       << "signature " << name
651       << " input size is wrong, expected: " << signature.input_specs.size()
652       << ", actual: " << inputs.size();
653   for (size_t i = 0; i < inputs.size(); ++i) {
654     const auto& expected_input_spec = signature.input_specs[i];
655     TF_RET_CHECK(expected_input_spec.dtype == inputs[i].dtype())
656         << "signature " << name
657         << " input dtype is wrong, expected: " << expected_input_spec.dtype
658         << ", actual: " << inputs[i].dtype();
659     TF_RET_CHECK(expected_input_spec.shape.IsCompatibleWith(inputs[i].shape()))
660         << "signature " << name
661         << " input shape is wrong, expected : " << expected_input_spec.shape
662         << ", actual: " << inputs[i].shape();
663   }
664   return OkStatus();
665 }
666 }  // namespace
667 
Run(const RunOptions & run_options,absl::string_view name,absl::Span<const tensorflow::Tensor> inputs,std::vector<tensorflow::Tensor> * outputs)668 tensorflow::Status SavedModelImpl::Run(
669     const RunOptions& run_options, absl::string_view name,
670     absl::Span<const tensorflow::Tensor> inputs,
671     std::vector<tensorflow::Tensor>* outputs) {
672   TF_RET_CHECK(outputs) << "outputs must be provided";
673   outputs->clear();
674 
675   auto sig_iter = signatures_.find(name);
676   TF_RET_CHECK(sig_iter != signatures_.end())
677       << "failed to find signature " << name << " in the graph";
678   if (run_options.validate_input_specs) {
679     TF_RETURN_IF_ERROR(IsInputSpecsCorrect(name, sig_iter->second, inputs));
680   }
681   if (run_options.validate_input_specs_dry_run) {
682     const auto status = IsInputSpecsCorrect(name, sig_iter->second, inputs);
683     if (!status.ok()) {
684       LOG(ERROR) << "TFRT input specs validation failed: "
685                  << status.error_message();
686     }
687   }
688   std::vector<tensorflow::Tensor> captures;
689   for (const auto& capture : sig_iter->second.captures) {
690     captures.push_back(capture);
691   }
692 
693   const tfrt::Function* func;
694   tfrt::ResourceContext* resource_context;
695   if (lazy_loading_enabled_) {
696     // If lazy loading is enabled, no signature is loaded into `bef_file_`, so
697     // we need to find the BEF from the cache or create one.
698     TF_ASSIGN_OR_RETURN(const LoadingResult& loading_result,
699                         GetOrCreateLoadingResult({std::string(name)}));
700     func = loading_result.bef_file->GetFunction(
701         tensorflow::kImportModelDefaultGraphFuncName);
702     resource_context = loading_result.resource_context.get();
703   } else {
704     func = bef_file_->GetFunction({name.data(), name.size()});
705     resource_context = resource_context_.get();
706   }
707   DCHECK(func);
708 
709   return GraphExecutionRunOnFunction(options_.graph_execution_options,
710                                      run_options, name, *func, inputs, captures,
711                                      outputs, resource_context, runtime(),
712                                      *fallback_state_, req_deadline_tracker_);
713 }
714 
715 struct SavedModelImpl::JoinedSignature {
716   // A unique name for the joined signature.
717   std::string name;
718   // The feed nodes for the corresponding inputs, but they might not be in the
719   // original order and if there are more than one original inputs mapped to the
720   // same feed node, only one is picked here.
721   tensorflow::GraphImportConfig::InputArrays input_nodes;
722   // The fetch nodes for the outputs, which should be in the original order.
723   std::vector<std::string> output_nodes;
724   // The target nodes that should be run but not returned as outputs.
725   std::vector<std::string> target_nodes;
726 };
727 
RunMultipleSignatures(const RunOptions & run_options,absl::Span<const std::string> names,absl::Span<const std::vector<tensorflow::Tensor>> multi_inputs,std::vector<std::vector<tensorflow::Tensor>> * multi_outputs)728 tensorflow::Status SavedModelImpl::RunMultipleSignatures(
729     const RunOptions& run_options, absl::Span<const std::string> names,
730     absl::Span<const std::vector<tensorflow::Tensor>> multi_inputs,
731     std::vector<std::vector<tensorflow::Tensor>>* multi_outputs) {
732   TF_RET_CHECK(names.size() == multi_inputs.size())
733       << "the sizes of names and inputs should be the same";
734   TF_RET_CHECK(multi_outputs) << "outputs must be provided";
735   multi_outputs->clear();
736 
737   // Due to possible overlapping of feed nodes among user-specified inputs,
738   // `JoinSignatures()` will deduplicate against fetch tensor names and produce
739   // the desired inputs in a new order. The same dedup logic is used here to
740   // generate the flattened input values in the same order.
741   //
742   // Note that we don't need to do any deduplicating nor reordering for the
743   // fetch nodes.
744   //
745   // TODO(tfrt-devs): Consider refactoring JoinSignatures so that we don't have
746   // the implicit requirement that the same dedup logic must be used here and in
747   // JoinSignatures().
748   std::vector<std::pair<std::string /*tensor_name*/, tensorflow::Tensor>>
749       flat_inputs;
750   std::vector<std::string> flat_output_names;
751   absl::flat_hash_set<std::string> visited_feed_tensor_names;
752 
753   const auto& signature_defs = meta_graph_def_.signature_def();
754   for (int i = 0; i < names.size(); ++i) {
755     const auto& signature_name = names[i];
756     const auto& input_tensors = multi_inputs[i];
757     auto sig_iter = signature_defs.find(signature_name);
758 
759     // Early out if any signature can't be found.
760     TF_RET_CHECK(sig_iter != signature_defs.end())
761         << "failed to find signature in the graph";
762     const auto& signature_def = sig_iter->second;
763 
764     // `signatures_` keeps the user-specified input names that is in the same
765     // order as `input_tensors`.
766     const auto& signature = signatures_.at(signature_name);
767     const auto& input_names = signature.input_names;
768     if (run_options.validate_input_specs) {
769       TF_RETURN_IF_ERROR(
770           IsInputSpecsCorrect(signature_name, signature, input_tensors));
771     }
772     if (run_options.validate_input_specs_dry_run) {
773       const auto status =
774           IsInputSpecsCorrect(signature_name, signature, input_tensors);
775       if (!status.ok()) {
776         LOG(ERROR) << "TFRT input specs validation failed: "
777                    << status.error_message();
778       }
779     }
780     DCHECK(signature.captures.empty());
781 
782     TF_RET_CHECK(input_tensors.size() == signature_def.inputs().size())
783         << "Incorrect input size for signature: " << signature_name
784         << ": expected " << signature_def.inputs().size() << ", but got "
785         << input_tensors.size();
786     DCHECK_EQ(input_names.size(), signature_def.inputs().size());
787 
788     // Then we find out the corresponding tensor names (ie.
789     // node_name:output_idx) for the inputs using the SignatureDef proto.
790     //
791     // TODO(tfrt-devs): Consider including tensor names in `signatures_` as
792     // well, so that only `signatures_` is used here.
793     for (int j = 0; j < input_tensors.size(); ++j) {
794       const auto& tensor_info = signature_def.inputs().at(input_names[j]);
795 
796       // TODO(b/184675681): Support other encoding cases.
797       //
798       // TODO(b/184679394): Add unit test for this check.
799       TF_RET_CHECK(tensor_info.encoding_case() == tensorflow::TensorInfo::kName)
800           << "Only dense tensor is supported, but got encoding case "
801           << tensor_info.encoding_case();
802 
803       const auto& tensor_name = tensor_info.name();
804 
805       // Skip if we have visited the feed tensor. Otherwise, marked it as
806       // visited and put it in the `flat_inputs`. Note that the following code
807       // uses the same logic as in JoinSignatures() to deduplicate inputs with
808       // the feed tensor names, and generates the flat inputs in the same order.
809       if (visited_feed_tensor_names.contains(tensor_name)) continue;
810       visited_feed_tensor_names.insert(tensor_name);
811       flat_inputs.push_back(std::make_pair(tensor_name, input_tensors[j]));
812     }
813 
814     for (const auto& output_key : signature.output_names) {
815       const auto& tensor_info = signature_def.outputs().at(output_key);
816 
817       VLOG(1) << "Importing Signature Output: output_key = " << output_key
818               << ", tensor_info = " << tensor_info.DebugString();
819 
820       TF_RET_CHECK(tensor_info.encoding_case() == tensorflow::TensorInfo::kName)
821           << "Only dense tensor is supported, but got encoding case "
822           << tensor_info.encoding_case();
823 
824       flat_output_names.push_back(tensor_info.name());
825     }
826   }
827 
828   std::vector<tensorflow::Tensor> flat_outputs;
829 
830   TF_RETURN_IF_ERROR(
831       graph_executor_->Run(run_options, flat_inputs, flat_output_names,
832                            /*target_tensor_names=*/{}, &flat_outputs));
833 
834   // The outputs of the compiled function are in the user-specified order,
835   // though they are flattened. So we just need to regroup the outputs for each
836   // signature using the number of outputs of it.
837   multi_outputs->resize(names.size());
838   auto cur = flat_outputs.begin();
839   for (size_t i = 0; i < names.size(); ++i) {
840     const auto& signature_name = names[i];
841     const size_t len = signature_defs.at(signature_name).outputs().size();
842     std::move(cur, cur + len, std::back_inserter(multi_outputs->at(i)));
843     cur += len;
844     DCHECK_LE(std::distance(flat_outputs.begin(), cur), flat_outputs.size());
845   }
846   return OkStatus();
847 }
848 
849 tensorflow::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
ImportSubgraph(mlir::MLIRContext * context,const tensorflow::GraphImportConfig::InputArrays & input_nodes,const std::vector<std::string> & output_nodes,const std::vector<std::string> & target_nodes)850 SavedModelImpl::ImportSubgraph(
851     mlir::MLIRContext* context,
852     const tensorflow::GraphImportConfig::InputArrays& input_nodes,
853     const std::vector<std::string>& output_nodes,
854     const std::vector<std::string>& target_nodes) {
855   tensorflow::GraphImportConfig graph_import_config;
856   graph_import_config.prune_unused_nodes = true;
857   graph_import_config.enable_shape_inference = false;
858   graph_import_config.inputs = input_nodes;
859   graph_import_config.outputs = output_nodes;
860   graph_import_config.control_outputs = target_nodes;
861 
862   // Optimize the graph.
863   TF_ASSIGN_OR_RETURN(
864       auto optimization_result,
865       graph_executor_->graph_execution_state().CreateOptimizedGraph(
866           graph_import_config));
867 
868   // Convert the optimized graph to an MLIR module.
869   return tensorflow::ConvertGraphToMlir(
870       *optimization_result.graph, /*debug_info=*/{},
871       optimization_result.graph->flib_def(), graph_import_config, context);
872 }
873 
RunByTensorNames(const RunOptions & run_options,absl::Span<const std::pair<std::string,tensorflow::Tensor>> inputs,absl::Span<const std::string> output_tensor_names,absl::Span<const std::string> target_node_names,std::vector<tensorflow::Tensor> * outputs)874 tensorflow::Status SavedModelImpl::RunByTensorNames(
875     const RunOptions& run_options,
876     absl::Span<const std::pair<std::string, tensorflow::Tensor>> inputs,
877     absl::Span<const std::string> output_tensor_names,
878     absl::Span<const std::string> target_node_names,
879     std::vector<tensorflow::Tensor>* outputs) {
880   // TODO(b/192498110): Validate input type.
881 
882   return graph_executor_->Run(run_options, inputs, output_tensor_names,
883                               target_node_names, outputs);
884 }
885 
886 namespace {
887 
888 using JoinedSignature = SavedModelImpl::JoinedSignature;
889 
890 // Returns a joined signature with the signatures in `names`. For inputs, as
891 // their corresponding nodes may overlap, we deduplicate them by the nodes so
892 // the order of inputs for the joined signature would be different from the
893 // original order. For outputs, overlapping is fine so we only flatten it in the
894 // original order.
JoinSignatures(absl::Span<const std::string> names,const SignatureMap & signature_map,const tensorflow::protobuf::Map<std::string,tensorflow::SignatureDef> & signature_def_map)895 StatusOr<JoinedSignature> JoinSignatures(
896     absl::Span<const std::string> names, const SignatureMap& signature_map,
897     const tensorflow::protobuf::Map<std::string, tensorflow::SignatureDef>&
898         signature_def_map) {
899   // Join all the names, all the inputs, and all the outputs.
900   JoinedSignature joined_signature;
901   joined_signature.name = absl::StrJoin(names, kSignatureJoiningDelimiter);
902 
903   // Keep the feed tensor names visited.
904   absl::flat_hash_set<std::string> visited_feed_tensor_names;
905 
906   for (const auto& name : names) {
907     const auto& signature_def = signature_def_map.at(name);
908 
909     // For inputs, we deduplicate possible overlapping feed nodes and create the
910     // new input array.
911     for (const auto& iter : signature_def.inputs()) {
912       const auto& tensor_info = iter.second;
913 
914       // Skip if this feed node is already visited.
915       if (visited_feed_tensor_names.contains(tensor_info.name())) continue;
916 
917       // Otherwise, we parse its tensor info and collect it for later
918       // compilation.
919       visited_feed_tensor_names.insert(tensor_info.name());
920 
921       // TODO(b/184675681): Support other encoding cases.
922       //
923       // TODO(b/184679394): Add unit test for this check.
924       TF_RET_CHECK(tensor_info.encoding_case() == tensorflow::TensorInfo::kName)
925           << "Only dense tensor is supported, but got encoding case "
926           << tensor_info.encoding_case();
927 
928       VLOG(1) << "Importing Signature Input: input_key = " << iter.first
929               << ", tensor_info = " << tensor_info.DebugString();
930 
931       tensorflow::ArrayInfo array_info;
932       array_info.imported_dtype = tensor_info.dtype();
933 
934       if (tensor_info.has_tensor_shape()) {
935         array_info.shape = tensor_info.tensor_shape();
936       } else {
937         // If there is no tensor shape in the tensor info, conservatively set
938         // unknown_rank to true.
939         array_info.shape.set_unknown_rank(true);
940       }
941 
942       joined_signature.input_nodes.insert(
943           std::pair<std::string, tensorflow::ArrayInfo>(tensor_info.name(),
944                                                         std::move(array_info)));
945     }
946 
947     // For outputs, we simply flatten them in the original order, as it is fine
948     // to have duplicated fetch nodes.
949     const internal::Signature& signature = signature_map.at(name);
950     for (const auto& output_key : signature.output_names) {
951       const auto& tensor_info = signature_def.outputs().at(output_key);
952 
953       VLOG(1) << "Importing Signature Output: output_key = " << output_key
954               << ", tensor_info = " << tensor_info.DebugString();
955 
956       TF_RET_CHECK(tensor_info.encoding_case() == tensorflow::TensorInfo::kName)
957           << "Only dense tensor is supported, but got encoding case "
958           << tensor_info.encoding_case();
959 
960       joined_signature.output_nodes.push_back(tensor_info.name());
961     }
962   }
963 
964   return joined_signature;
965 }
966 
967 }  // namespace
968 
969 // TODO(b/216379787): Reuse `GraphExecutor::LoadClientGraph()`.
970 StatusOr<std::reference_wrapper<const SavedModelImpl::LoadingResult>>
LoadJoinedSignature(const JoinedSignature & joined_signature)971 SavedModelImpl::LoadJoinedSignature(const JoinedSignature& joined_signature) {
972   // Step 1: Import the combined subgraph from proto to an MLIR module.
973   mlir::MLIRContext context;
974   ASSIGN_OR_RETURN_IN_IMPORT(
975       auto module, ImportSubgraph(&context, joined_signature.input_nodes,
976                                   joined_signature.output_nodes,
977                                   joined_signature.target_nodes));
978 
979   // Step 2: Compile the MLIR module from TF dialect to TFRT dialect (in BEF).
980   auto loading_result = std::make_unique<LoadingResult>();
981   loading_result->name = joined_signature.name;
982   loading_result->resource_context = CreateResourceContext(
983       runtime(), tpu_model_resource_.get(),
984       options_.graph_execution_options.compile_options.tpu_target);
985 
986   RETURN_IF_ERROR_IN_COMPILE(tensorflow::ConvertTfMlirToBef(
987       options_.graph_execution_options.compile_options, module.get(),
988       &loading_result->bef));
989 
990   // Step 3: Initialize runtime states using special BEF functions.
991   ASSIGN_OR_RETURN_IN_INIT(
992       loading_result->bef_file,
993       tfrt::CreateBefFileFromBefBuffer(
994           *options_.graph_execution_options.runtime, loading_result->bef));
995   RETURN_IF_ERROR_IN_INIT(RunInitializers(
996       /*initializers_and_signatures=*/{},
997       options_.graph_execution_options.model_metadata,
998       loading_result->bef_file.get(), *options_.graph_execution_options.runtime,
999       loading_result->resource_context.get(), *fallback_state_));
1000 
1001   // Store loading_result in cache.
1002   const auto* loading_result_ptr = loading_result.get();
1003   loading_result_cache_[joined_signature.name] = std::move(loading_result);
1004   return {*loading_result_ptr};
1005 }
1006 
1007 StatusOr<std::reference_wrapper<const SavedModelImpl::LoadingResult>>
GetOrCreateLoadingResult(absl::Span<const std::string> names)1008 SavedModelImpl::GetOrCreateLoadingResult(absl::Span<const std::string> names) {
1009   const auto joined_name = absl::StrJoin(names, kSignatureJoiningDelimiter);
1010   tensorflow::mutex_lock l(loading_result_cache_mu_);
1011   const auto iter = loading_result_cache_.find(joined_name);
1012   if (iter != loading_result_cache_.end()) return {*iter->second};
1013 
1014   TF_ASSIGN_OR_RETURN(
1015       const auto joined_signature,
1016       JoinSignatures(names, signatures_, meta_graph_def_.signature_def()));
1017 
1018   return LoadJoinedSignature(joined_signature);
1019 }
1020 
1021 }  // namespace tfrt_stub
1022 }  // namespace tensorflow
1023