1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tfrt/saved_model/saved_model.h"
16
17 #include <algorithm>
18 #include <cstddef>
19 #include <functional>
20 #include <iterator>
21 #include <memory>
22 #include <optional>
23 #include <string>
24 #include <unordered_set>
25 #include <utility>
26 #include <vector>
27
28 #include "absl/strings/str_join.h"
29 #include "absl/strings/string_view.h"
30 #include "absl/time/clock.h"
31 #include "absl/time/time.h"
32 #include "absl/types/span.h"
33 #include "tensorflow/cc/saved_model/reader.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
35 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
36 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
37 #include "tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h"
38 #include "tensorflow/compiler/mlir/tfrt/translate/import_model.h"
39 #include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h"
40 #include "tensorflow/compiler/xla/status_macros.h"
41 #include "tensorflow/core/framework/types.h"
42 #include "tensorflow/core/lib/gtl/cleanup.h"
43 #include "tensorflow/core/lib/monitoring/gauge.h"
44 #include "tensorflow/core/platform/errors.h"
45 #include "tensorflow/core/platform/logging.h"
46 #include "tensorflow/core/platform/mutex.h"
47 #include "tensorflow/core/platform/path.h"
48 #include "tensorflow/core/platform/status.h"
49 #include "tensorflow/core/platform/statusor.h"
50 #include "tensorflow/core/profiler/lib/traceme.h"
51 #include "tensorflow/core/protobuf/meta_graph.pb.h"
52 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
53 #include "tensorflow/core/tfrt/graph_executor/graph_executor.h"
54 #include "tensorflow/core/tfrt/mla/mla_utils.h"
55 #include "tensorflow/core/tfrt/runtime/work_queue_interface.h"
56 #include "tensorflow/core/tfrt/saved_model/saved_model_import_input.h"
57 #include "tensorflow/core/tfrt/tpu/tpu_resources.h" // NOLINT(unused-includes): For tfrt::tpu::TpuModelResource
58 #include "tensorflow/core/tfrt/utils/error_util.h"
59 #include "tensorflow/core/tfrt/utils/fallback_tensor.h"
60 #include "tensorflow/core/tfrt/utils/utils.h"
61 #include "tfrt/bef_executor/bef_file.h" // from @tf_runtime
62 #include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime
63 #include "tfrt/host_context/async_value.h" // from @tf_runtime
64 #include "tfrt/host_context/chain.h" // from @tf_runtime
65 #include "tfrt/host_context/execution_context.h" // from @tf_runtime
66 #include "tfrt/host_context/function.h" // from @tf_runtime
67 #include "tfrt/host_context/host_context.h" // from @tf_runtime
68 #include "tfrt/host_context/request_deadline_tracker.h" // from @tf_runtime
69 #include "tfrt/metrics/common_metrics.h" // from @tf_runtime
70 #include "tfrt/support/ref_count.h" // from @tf_runtime
71
72 namespace tensorflow {
73 namespace tfrt_stub {
74 namespace {
75
76 constexpr absl::string_view kSignatureJoiningDelimiter = "+";
77 constexpr absl::string_view kTensorNameJoiningDelimiter = "-";
78 constexpr absl::string_view kArgumentTypeJoiningDelimiter = "^";
79
80 using SignatureMap = absl::flat_hash_map<std::string, internal::Signature>;
81 using ::tensorflow::SessionMetadata;
82 using ::tensorflow::StatusOr;
83
84 struct InitializersAndSignatures {
85 llvm::SmallVector<std::string, 4> initializers;
86 SignatureMap signature_map;
87 };
88
89 auto* saved_model_read_meta_graph_time_seconds =
90 tensorflow::monitoring::Gauge<int64_t, 1>::New(
91 "/tensorflow/tfrt/saved_model/read_meta_graph_time",
92 "Record the time of reading meta_graph from disk.", "model_name");
93
94 auto* saved_model_functionalization_time_seconds =
95 tensorflow::monitoring::Gauge<int64_t, 1>::New(
96 "/tensorflow/tfrt/saved_model/functionalization_time",
97 "Record the functionalization time for the savedmodel.", "model_name");
98
99 auto* saved_model_grappler_time_seconds =
100 tensorflow::monitoring::Gauge<int64_t, 1>::New(
101 "/tensorflow/tfrt/saved_model/grappler_time",
102 "Record the grappler time for the savedmodel.", "model_name");
103
104 auto* saved_model_mla_check_time_milli_seconds =
105 tensorflow::monitoring::Gauge<int64_t, 1>::New(
106 "/tensorflow/tfrt/saved_model/mla_check_time",
107 "Record the MLA check time for the savedmodel.", "model_name");
108
109 auto* saved_model_import_time_seconds =
110 tensorflow::monitoring::Gauge<int64_t, 1>::New(
111 "/tensorflow/tfrt/saved_model/import_time",
112 "Record the MLIR import time for the savedmodel.", "model_name");
113
114 auto* saved_model_compile_time_seconds =
115 tensorflow::monitoring::Gauge<int64_t, 1>::New(
116 "/tensorflow/tfrt/saved_model/compile_time",
117 "Record the compilation time for the savedmodel.", "model_name");
118
119 auto* saved_model_init_time_seconds =
120 tensorflow::monitoring::Gauge<int64_t, 1>::New(
121 "/tensorflow/tfrt/saved_model/init_time",
122 "Record the initialization time for the savedmodel.", "model_name");
123
CreateScalarStringTensor(absl::string_view str)124 tensorflow::Tensor CreateScalarStringTensor(absl::string_view str) {
125 return tensorflow::Tensor(tensorflow::tstring(str));
126 }
127
128 // Create the tensor for the bound input, which can be a variable or an asset.
129 //
130 // TODO(chky): For V2 models, the bound input can also be a resource.
CreateTensorFromBoundInput(mlir::Operation * bound_input,absl::string_view saved_model_dir,absl::flat_hash_map<std::string,tensorflow::Tensor> * variables)131 StatusOr<tensorflow::Tensor> CreateTensorFromBoundInput(
132 mlir::Operation* bound_input, absl::string_view saved_model_dir,
133 absl::flat_hash_map<std::string, tensorflow::Tensor>* variables) {
134 // Assets are files in the saved model directory. We pass their filenames to
135 // functions so that they can be used.
136 if (auto asset = llvm::dyn_cast<mlir::tf_saved_model::AssetOp>(bound_input)) {
137 // The filename in the asset is a relative path. So we prefix it with the
138 // directory path.
139 return CreateScalarStringTensor(
140 tensorflow::io::JoinPath(saved_model_dir, asset.filename().str()));
141 }
142
143 return tensorflow::errors::Internal(
144 "Failed to create captured tensors: unknown bound input type.");
145 }
146
GetFunctionSignaturesFromTFSavedModelMLIR(absl::string_view saved_model_dir,mlir::ModuleOp module)147 StatusOr<SignatureMap> GetFunctionSignaturesFromTFSavedModelMLIR(
148 absl::string_view saved_model_dir, mlir::ModuleOp module) {
149 absl::flat_hash_map<std::string, tensorflow::Tensor> variables;
150 SignatureMap signatures;
151
152 tensorflow::StatusGroup status_group;
153 TF_RETURN_IF_ERROR(tensorflow::MapFunctionSignaturesFromTFSavedModelMLIR(
154 module, [&status_group, &variables, &signatures, saved_model_dir](
155 const tensorflow::TFRTSavedModelSignatureInfo& sig_info) {
156 auto& signature = signatures[std::string(sig_info.func_name)];
157
158 auto copy = [](llvm::ArrayRef<llvm::StringRef> src,
159 std::vector<std::string>* dst) {
160 transform(src, std::back_inserter(*dst),
161 [](llvm::StringRef x) { return x.str(); });
162 };
163 copy(sig_info.input_names, &signature.input_names);
164 copy(sig_info.output_names, &signature.output_names);
165 copy(sig_info.input_devices, &signature.input_devices);
166
167 DCHECK(signature.input_specs.empty());
168 signature.input_specs.reserve(sig_info.input_specs.size());
169 for (auto& spec : sig_info.input_specs) {
170 signature.input_specs.push_back(TensorSpec(spec.first, spec.second));
171 }
172
173 DCHECK(signature.output_specs.empty());
174 signature.output_specs.reserve(sig_info.output_specs.size());
175 for (auto& spec : sig_info.output_specs) {
176 signature.output_specs.push_back(TensorSpec(spec.first, spec.second));
177 }
178
179 for (auto* bound_input : sig_info.bound_inputs) {
180 auto capture = CreateTensorFromBoundInput(
181 bound_input, saved_model_dir, &variables);
182 if (!capture.ok()) {
183 status_group.Update(capture.status());
184 // Insert a random tensor in case of errors.
185 signature.captures.push_back(tensorflow::Tensor());
186 } else {
187 signature.captures.push_back(*std::move(capture));
188 }
189 }
190 }));
191
192 if (!status_group.ok()) return status_group.as_concatenated_status();
193
194 return signatures;
195 }
196
RunInitializers(const InitializersAndSignatures & initializers_and_signatures,const SessionMetadata & model_metadata,tfrt::BEFFile * bef_file,const Runtime & runtime,tfrt::ResourceContext * resource_context,const FallbackState & fallback_state)197 tensorflow::Status RunInitializers(
198 const InitializersAndSignatures& initializers_and_signatures,
199 const SessionMetadata& model_metadata, tfrt::BEFFile* bef_file,
200 const Runtime& runtime, tfrt::ResourceContext* resource_context,
201 const FallbackState& fallback_state) {
202 auto* host = runtime.core_runtime()->GetHostContext();
203 TF_ASSIGN_OR_RETURN(auto request_info,
204 SetUpRequestContext(/*run_options=*/{}, model_metadata,
205 host, runtime.work_queue(),
206 resource_context, fallback_state));
207
208 tfrt::ExecutionContext exec_ctx(request_info->tfrt_request_context);
209
210 // Run "_tfrt_fallback_init" first to initialize fallback-specific states. It
211 // is the special function created by compiler, which calls a sequence of
212 // tfrt_fallback_async.createop to create all fallback ops used in this BEF.
213 TF_RETURN_IF_ERROR(
214 RunRuntimeInitializer(exec_ctx, bef_file, "_tfrt_fallback_init"));
215
216 for (const auto& init : initializers_and_signatures.initializers) {
217 // TODO(b/184771263): Consider using `GraphExecutionRunOnFunction()`
218 // instead.
219
220 auto* func = bef_file->GetFunction(init);
221 assert(func);
222
223 const auto& signature = initializers_and_signatures.signature_map.at(init);
224
225 auto ready_chain = tfrt::GetReadyChain();
226
227 // The actual arguments are the concat of side-effect chain and assets.
228 llvm::SmallVector<tfrt::AsyncValue*, 1> arguments;
229 auto cleanup = tensorflow::gtl::MakeCleanup([&]() {
230 for (auto* argument : arguments) argument->DropRef();
231 });
232
233 arguments.push_back(ready_chain.release());
234
235 for (const auto& capture : signature.captures) {
236 arguments.push_back(
237 tfrt::MakeAvailableAsyncValueRef<FallbackTensor>(capture).release());
238 }
239
240 assert(arguments.size() == func->argument_types().size());
241
242 llvm::SmallVector<tfrt::RCReference<tfrt::AsyncValue>, 1> results;
243 results.resize(func->result_types().size());
244 assert(results.size() == 1);
245
246 func->Execute(exec_ctx, arguments, results);
247
248 // Wait for the function execution to finish, as well as the side-effects.
249 host->Await(results);
250
251 if (auto* error = results[0]->GetErrorIfPresent()) {
252 return CreateTfErrorStatus(*error);
253 }
254 }
255
256 // After we initialized all the resources in the original graph, we can run
257 // the "_tfrt_resource_init" function to set these resources in runtime
258 // states, so that later it can be efficiently retrieved without any locking.
259 TF_RETURN_IF_ERROR(
260 RunRuntimeInitializer(exec_ctx, bef_file, "_tfrt_resource_init"));
261
262 return OkStatus();
263 }
264
FindNamesForValidSignatures(const tensorflow::MetaGraphDef & meta_graph_def)265 std::vector<std::string> FindNamesForValidSignatures(
266 const tensorflow::MetaGraphDef& meta_graph_def) {
267 std::vector<std::string> valid_signature_names;
268
269 auto is_dense_tensor_info = [](const auto& named_tensor_info) {
270 return !named_tensor_info.second.name().empty();
271 };
272
273 auto is_ref_type_tensor_info = [](const auto& named_tensor_info) {
274 return tensorflow::IsRefType(named_tensor_info.second.dtype());
275 };
276
277 for (const auto& iter : meta_graph_def.signature_def()) {
278 const auto& sig_key = iter.first;
279 const auto& signature = iter.second;
280 if (!std::all_of(signature.inputs().begin(), signature.inputs().end(),
281 is_dense_tensor_info) ||
282 !std::all_of(signature.outputs().begin(), signature.outputs().end(),
283 is_dense_tensor_info)) {
284 LOG(WARNING) << "Unsupported signature with non-dense tensors as "
285 "input/output. Name: "
286 << sig_key << "; Signature: " << signature.DebugString();
287 continue;
288 }
289 if (std::any_of(signature.inputs().begin(), signature.inputs().end(),
290 is_ref_type_tensor_info) ||
291 std::any_of(signature.outputs().begin(), signature.outputs().end(),
292 is_ref_type_tensor_info)) {
293 LOG(WARNING) << "Unsupported signature with ref type tensors as "
294 "input/output. Name: "
295 << sig_key << "; Signature: " << signature.DebugString();
296 continue;
297 }
298 valid_signature_names.push_back(sig_key);
299 }
300 return valid_signature_names;
301 }
302
ImportSavedModel(mlir::MLIRContext * context,const tensorflow::MetaGraphDef & meta_graph_def,const FallbackState & fallback_state,std::string saved_model_dir,bool import_user_signatures,bool run_placer_grappler_on_functions,bool enable_tfrt_gpu)303 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ImportSavedModel(
304 mlir::MLIRContext* context, const tensorflow::MetaGraphDef& meta_graph_def,
305 const FallbackState& fallback_state, std::string saved_model_dir,
306 bool import_user_signatures, bool run_placer_grappler_on_functions,
307 bool enable_tfrt_gpu) {
308 std::vector<std::string> signature_names;
309 if (import_user_signatures) {
310 signature_names = FindNamesForValidSignatures(meta_graph_def);
311 if (signature_names.empty())
312 LOG(WARNING) << "No valid signature found for model: " << saved_model_dir;
313 }
314
315 // TfrtSavedModelMLIRImportInput basically implements the graph processing
316 // logic (eg. Placer and Grappler) used in DirectSession, which apply graph
317 // transformations on each subgraphs (ie. signatures). It is reusing the
318 // code path in DirectSession to avoid problems caused by different behavior
319 // in a different code path. And it is injected to the MLIR importer so that
320 // the importer can import the transformed graph instead of the original
321 // graph.
322 TF_ASSIGN_OR_RETURN(auto import_input,
323 TfrtSavedModelMLIRImportInput::Create(
324 fallback_state, &meta_graph_def, /*debug_info=*/{},
325 run_placer_grappler_on_functions, enable_tfrt_gpu));
326
327 TF_ASSIGN_OR_RETURN(
328 auto module,
329 tensorflow::ConvertSavedModelV1ToMlirLite(
330 import_input,
331 /*exported_names=*/absl::MakeSpan(signature_names), context));
332
333 LOG(INFO) << "TFRT ImportSavedModel: Functionalization took "
334 << absl::ToInt64Milliseconds(
335 import_input.GetFunctionalizationDuration())
336 << " ms.";
337 LOG(INFO) << "TFRT ImportSavedModel: Grappler took "
338 << absl::ToInt64Milliseconds(import_input.GetGrapplerDuration())
339 << " ms.";
340
341 saved_model_functionalization_time_seconds->GetCell(saved_model_dir)
342 ->Set(absl::ToInt64Seconds(import_input.GetFunctionalizationDuration()));
343
344 saved_model_grappler_time_seconds->GetCell(saved_model_dir)
345 ->Set(absl::ToInt64Seconds(import_input.GetGrapplerDuration()));
346
347 return module;
348 }
349
GetInitializersAndSignatures(mlir::ModuleOp module,absl::string_view saved_model_dir)350 StatusOr<InitializersAndSignatures> GetInitializersAndSignatures(
351 mlir::ModuleOp module, absl::string_view saved_model_dir) {
352 InitializersAndSignatures result;
353 TF_ASSIGN_OR_RETURN(
354 result.signature_map,
355 GetFunctionSignaturesFromTFSavedModelMLIR(saved_model_dir, module));
356 for (auto session_initializer_name :
357 mlir::tf_saved_model::GetSessionInitializerExportedName(module)) {
358 result.initializers.push_back(session_initializer_name.str());
359 }
360 return result;
361 }
362
InitSavedModel(const InitializersAndSignatures & initializers_and_signatures,tfrt::BEFFile * bef_file,const SavedModel::Options & options,tfrt::ResourceContext * resource_context,const FallbackState & fallback_state)363 tensorflow::Status InitSavedModel(
364 const InitializersAndSignatures& initializers_and_signatures,
365 tfrt::BEFFile* bef_file, const SavedModel::Options& options,
366 tfrt::ResourceContext* resource_context,
367 const FallbackState& fallback_state) {
368 TF_RETURN_IF_ERROR(
369 RunInitializers(initializers_and_signatures,
370 options.graph_execution_options.model_metadata, bef_file,
371 *options.graph_execution_options.runtime,
372 resource_context, fallback_state));
373
374 return OkStatus();
375 }
376
377 } // namespace
378
~SavedModel()379 SavedModel::~SavedModel() {}
380
GetHostContext() const381 tfrt::HostContext* SavedModel::GetHostContext() const {
382 return runtime_->core_runtime()->GetHostContext();
383 }
384
385 namespace {
386
387 // Gets the signatures from `signature_defs` and inserts them into `signatures`.
GetSignaturesFromSignatureDef(SignatureMap & signatures,const google::protobuf::Map<std::string,tensorflow::SignatureDef> & signature_defs,const SavedModel::Options & options)388 void GetSignaturesFromSignatureDef(
389 SignatureMap& signatures,
390 const google::protobuf::Map<std::string, tensorflow::SignatureDef>& signature_defs,
391 const SavedModel::Options& options) {
392 for (const auto& p : signature_defs) {
393 const std::string& signature_name = p.first;
394 const tensorflow::SignatureDef& signature_def = p.second;
395 DCHECK(signatures.find(signature_name) == signatures.end());
396 auto& signature = signatures[signature_name];
397
398 signature.input_names.reserve(signature_def.inputs().size());
399 signature.input_specs.reserve(signature_def.inputs().size());
400 for (const auto& p : signature_def.inputs()) {
401 const std::string& input_tensor_name = p.first;
402 const tensorflow::TensorInfo& tensor_info = p.second;
403 signature.input_names.push_back(input_tensor_name);
404 signature.input_specs.push_back(
405 TensorSpec(tensor_info.dtype(), tensor_info.tensor_shape()));
406 }
407
408 signature.input_devices = std::vector<std::string>(
409 signature_def.inputs().size(),
410 options.graph_execution_options.compile_options.default_device);
411
412 signature.output_names.reserve(signature_def.outputs().size());
413 signature.output_specs.reserve(signature_def.outputs().size());
414 for (const auto& p : signature_def.outputs()) {
415 const std::string& output_tensor_name = p.first;
416 const tensorflow::TensorInfo& tensor_info = p.second;
417 signature.output_names.push_back(output_tensor_name);
418 signature.output_specs.push_back(
419 TensorSpec(tensor_info.dtype(), tensor_info.tensor_shape()));
420 }
421 }
422 }
423
UpdateCompileOptions(SavedModel::Options & options)424 void UpdateCompileOptions(SavedModel::Options& options) {
425 // Disable DecomposeResourceOpsPass for now, as DecomposeResourceGather does
426 // not work well with GPU (b/232819415).
427 if (options.graph_execution_options.enable_tfrt_gpu) {
428 options.graph_execution_options.compile_options.decompose_resource_ops =
429 false;
430 }
431 }
432
ReadSavedModel(absl::string_view saved_model_dir,const std::unordered_set<std::string> & tags)433 StatusOr<tensorflow::MetaGraphDef> ReadSavedModel(
434 absl::string_view saved_model_dir,
435 const std::unordered_set<std::string>& tags) {
436 LOG(INFO) << "TFRT reading v1 savedmodel: " << saved_model_dir;
437 const auto read_start_time = absl::Now();
438
439 tensorflow::MetaGraphDef meta_graph_def;
440 TF_RETURN_IF_ERROR(tensorflow::ReadMetaGraphDefFromSavedModel(
441 std::string(saved_model_dir), tags, &meta_graph_def));
442
443 const auto read_meta_graph_duration = absl::Now() - read_start_time;
444 saved_model_read_meta_graph_time_seconds
445 ->GetCell(std::string(saved_model_dir))
446 ->Set(absl::ToInt64Seconds(read_meta_graph_duration));
447 LOG(INFO) << "TFRT finished reading meta graph. Took "
448 << absl::ToInt64Milliseconds(read_meta_graph_duration) << " ms.";
449 return std::move(meta_graph_def);
450 }
451
452 } // namespace
453
LoadSavedModel(Options options,absl::string_view saved_model_dir,const std::unordered_set<std::string> & tags,tensorflow::Status * status)454 std::unique_ptr<SavedModel> SavedModelImpl::LoadSavedModel(
455 Options options, absl::string_view saved_model_dir,
456 const std::unordered_set<std::string>& tags, tensorflow::Status* status) {
457 std::string saved_model_dir_str = "unused";
458 if (options.maybe_load_from_mla) {
459 const auto mla_check_start_time = absl::Now();
460 const bool is_mla = IsMlarchive(saved_model_dir);
461 const auto mla_check_duration = absl::Now() - mla_check_start_time;
462 saved_model_mla_check_time_milli_seconds
463 ->GetCell(std::string(saved_model_dir))
464 ->Set(absl::ToInt64Milliseconds(mla_check_duration));
465 LOG(INFO) << "TFRT finished checking MLA. Took "
466 << absl::ToInt64Milliseconds(mla_check_duration) << " ms.";
467 if (is_mla) {
468 LOG(INFO) << "TFRT got an MLArchive dir: " << saved_model_dir
469 << ". Continuing to find the actual saved_model_dir in it.";
470 const auto statusor_saved_model_dir =
471 GetSavedModelDirFromMlaDir(saved_model_dir);
472 if (!statusor_saved_model_dir.ok()) {
473 *status = statusor_saved_model_dir.status();
474 return nullptr;
475 }
476 saved_model_dir_str = *statusor_saved_model_dir;
477 saved_model_dir = saved_model_dir_str;
478 LOG(INFO) << "TFRT found from MLArchive a saved model: "
479 << saved_model_dir;
480 } // Not an MLA; `saved_model_dir` is ready to use.
481 }
482
483 auto meta_graph_def = ReadSavedModel(saved_model_dir, tags);
484 if (!meta_graph_def.ok()) {
485 *status = meta_graph_def.status();
486 return nullptr;
487 }
488
489 LOG(INFO) << "TFRT loading v1 savedmodel: " << saved_model_dir;
490 tfrt::metrics::AddTFRTVersionMetric();
491
492 UpdateTpuTargetByBridgeCompatibility(options.graph_execution_options,
493 meta_graph_def->graph_def());
494 UpdateCompileOptions(options);
495
496 auto saved_model =
497 [&]() -> tensorflow::StatusOr<std::unique_ptr<SavedModel>> {
498 mlir::MLIRContext context;
499
500 const bool lazy_loading_enabled =
501 meta_graph_def->signature_def_size() > options.lazy_loading_threshold;
502
503 // Step 1: Import saved model from a proto to an MLIR module.
504 const auto import_start_time = absl::Now();
505 auto session_options =
506 CreateDefaultSessionOptions(options.graph_execution_options);
507 // Set optimize_for_static_graph to true since we won't extend the graph
508 // later. If optimize_for_static_graph is set to false, FallbackState will
509 // keep an extra unused copy of the graph, which unnecessarily consumes
510 // memory.
511 session_options.config.mutable_experimental()
512 ->set_optimize_for_static_graph(true);
513
514 // Creating the fallback_state using the original function def library
515 // without applying placer or grappler, it is OK for now because it's only
516 // used for captured functions in certain tf.data ops
517 const auto& fdef_lib = meta_graph_def->graph_def().library();
518 ASSIGN_OR_RETURN_IN_IMPORT(
519 auto fallback_state, FallbackState::Create(session_options, fdef_lib));
520 ASSIGN_OR_RETURN_IN_IMPORT(
521 auto mlir_module,
522 ImportSavedModel(
523 &context, *meta_graph_def, *fallback_state,
524 std::string(saved_model_dir),
525 /*import_user_signatures=*/!lazy_loading_enabled,
526 options.graph_execution_options.run_placer_grappler_on_functions,
527 options.graph_execution_options.enable_tfrt_gpu));
528
529 const auto import_duration = absl::Now() - import_start_time;
530 saved_model_import_time_seconds->GetCell(std::string(saved_model_dir))
531 ->Set(absl::ToInt64Seconds(import_duration));
532 LOG(INFO) << "TFRT finished importing savedmodel. Took "
533 << absl::ToInt64Milliseconds(import_duration) << " ms.";
534
535 // Step 2: Compile the MLIR module from TF dialect to TFRT dialect (in BEF).
536 const auto compile_start_time = absl::Now();
537 ASSIGN_OR_RETURN_IN_COMPILE(
538 auto initializers_and_signatures,
539 GetInitializersAndSignatures(mlir_module.get(), saved_model_dir));
540 // If lazy loading is enabled, the user signatures are not exported via MLIR
541 // module, so we need to get them from the proto.
542 // TODO(b/187228559): Unify the code paths for populating the signature map.
543 if (lazy_loading_enabled) {
544 GetSignaturesFromSignatureDef(initializers_and_signatures.signature_map,
545 meta_graph_def->signature_def(), options);
546 }
547 tfrt::BefBuffer bef;
548 RETURN_IF_ERROR_IN_COMPILE(tensorflow::ConvertTfMlirToBef(
549 options.graph_execution_options.compile_options, mlir_module.get(),
550 &bef));
551
552 const auto compile_duration = absl::Now() - compile_start_time;
553 saved_model_compile_time_seconds->GetCell(std::string(saved_model_dir))
554 ->Set(absl::ToInt64Seconds(compile_duration));
555 LOG(INFO) << "TFRT finished compiling savedmodel. Took "
556 << absl::ToInt64Milliseconds(compile_duration) << " ms.";
557
558 // Step 3: Initialize runtime states using special BEF functions.
559 const auto init_start_time = absl::Now();
560 ASSIGN_OR_RETURN_IN_INIT(
561 auto bef_file, tfrt::CreateBefFileFromBefBuffer(
562 *options.graph_execution_options.runtime, bef));
563
564 auto tpu_model_resource = std::make_unique<tfrt::tpu::TpuModelResource>();
565 auto resource_context = CreateResourceContext(
566 *options.graph_execution_options.runtime, tpu_model_resource.get(),
567 options.graph_execution_options.compile_options.tpu_target);
568 RETURN_IF_ERROR_IN_INIT(
569 InitSavedModel(initializers_and_signatures, bef_file.get(), options,
570 resource_context.get(), *fallback_state));
571
572 const auto init_duration = absl::Now() - init_start_time;
573 saved_model_init_time_seconds->GetCell(std::string(saved_model_dir))
574 ->Set(absl::ToInt64Seconds(init_duration));
575 LOG(INFO) << "TFRT finished initializing savedmodel. Took "
576 << absl::ToInt64Milliseconds(init_duration) << " ms.";
577
578 ASSIGN_OR_RETURN_WITH_STAGE_INFO(
579 "graph_executor creation", auto graph_executor,
580 GraphExecutor::Create(options.graph_execution_options, *fallback_state,
581 tpu_model_resource.get(),
582 std::move(*meta_graph_def->mutable_graph_def())));
583
584 // Finally, create the saved model.
585 return {std::make_unique<SavedModelImpl>(
586 std::move(options), *std::move(meta_graph_def), std::move(bef),
587 std::move(bef_file),
588 std::move(initializers_and_signatures.signature_map),
589 std::move(fallback_state), std::move(tpu_model_resource),
590 std::move(resource_context), std::move(graph_executor))};
591 }();
592
593 if (!saved_model.ok()) {
594 *status = saved_model.status();
595 return nullptr;
596 }
597 *status = OkStatus();
598 return *std::move(saved_model);
599 }
600
SavedModelImpl(Options options,tensorflow::MetaGraphDef meta_graph_def,tfrt::BefBuffer bef,tfrt::RCReference<tfrt::BEFFile> bef_file,SignatureMap signatures,std::unique_ptr<FallbackState> fallback_state,std::unique_ptr<tfrt::tpu::TpuModelResource> tpu_model_resource,std::unique_ptr<tfrt::ResourceContext> resource_context,std::unique_ptr<GraphExecutor> graph_executor)601 SavedModelImpl::SavedModelImpl(
602 Options options, tensorflow::MetaGraphDef meta_graph_def,
603 tfrt::BefBuffer bef, tfrt::RCReference<tfrt::BEFFile> bef_file,
604 SignatureMap signatures, std::unique_ptr<FallbackState> fallback_state,
605 std::unique_ptr<tfrt::tpu::TpuModelResource> tpu_model_resource,
606 std::unique_ptr<tfrt::ResourceContext> resource_context,
607 std::unique_ptr<GraphExecutor> graph_executor)
608 : SavedModel(options.graph_execution_options.runtime),
609 options_(std::move(options)),
610 meta_graph_def_(std::move(meta_graph_def)),
611 bef_(std::move(bef)),
612 bef_file_(std::move(bef_file)),
613 req_deadline_tracker_(
614 options.graph_execution_options.runtime->core_runtime()
615 ->GetHostContext()),
616 signatures_(std::move(signatures)),
617 fallback_state_(std::move(fallback_state)),
618 tpu_model_resource_(std::move(tpu_model_resource)),
619 resource_context_(std::move(resource_context)),
620 graph_executor_(std::move(graph_executor)),
621 lazy_loading_enabled_(meta_graph_def_.signature_def_size() >
622 options.lazy_loading_threshold) {}
623
624 SavedModelImpl::~SavedModelImpl() = default;
625
GetFunctionNames() const626 std::vector<std::string> SavedModelImpl::GetFunctionNames() const {
627 std::vector<std::string> result;
628 for (const auto& entry : signatures_) {
629 result.push_back(entry.first);
630 }
631 return result;
632 }
633
GetMetaGraphDef() const634 const tensorflow::MetaGraphDef& SavedModelImpl::GetMetaGraphDef() const {
635 return meta_graph_def_;
636 }
637
GetFunctionMetadata(absl::string_view func_name) const638 std::optional<FunctionMetadata> SavedModelImpl::GetFunctionMetadata(
639 absl::string_view func_name) const {
640 auto iter = signatures_.find(func_name);
641 if (iter == signatures_.end()) return std::nullopt;
642 return FunctionMetadata(&iter->second);
643 }
644
645 namespace {
IsInputSpecsCorrect(absl::string_view name,const internal::Signature & signature,absl::Span<const tensorflow::Tensor> inputs)646 tensorflow::Status IsInputSpecsCorrect(
647 absl::string_view name, const internal::Signature& signature,
648 absl::Span<const tensorflow::Tensor> inputs) {
649 TF_RET_CHECK(signature.input_specs.size() == inputs.size())
650 << "signature " << name
651 << " input size is wrong, expected: " << signature.input_specs.size()
652 << ", actual: " << inputs.size();
653 for (size_t i = 0; i < inputs.size(); ++i) {
654 const auto& expected_input_spec = signature.input_specs[i];
655 TF_RET_CHECK(expected_input_spec.dtype == inputs[i].dtype())
656 << "signature " << name
657 << " input dtype is wrong, expected: " << expected_input_spec.dtype
658 << ", actual: " << inputs[i].dtype();
659 TF_RET_CHECK(expected_input_spec.shape.IsCompatibleWith(inputs[i].shape()))
660 << "signature " << name
661 << " input shape is wrong, expected : " << expected_input_spec.shape
662 << ", actual: " << inputs[i].shape();
663 }
664 return OkStatus();
665 }
666 } // namespace
667
Run(const RunOptions & run_options,absl::string_view name,absl::Span<const tensorflow::Tensor> inputs,std::vector<tensorflow::Tensor> * outputs)668 tensorflow::Status SavedModelImpl::Run(
669 const RunOptions& run_options, absl::string_view name,
670 absl::Span<const tensorflow::Tensor> inputs,
671 std::vector<tensorflow::Tensor>* outputs) {
672 TF_RET_CHECK(outputs) << "outputs must be provided";
673 outputs->clear();
674
675 auto sig_iter = signatures_.find(name);
676 TF_RET_CHECK(sig_iter != signatures_.end())
677 << "failed to find signature " << name << " in the graph";
678 if (run_options.validate_input_specs) {
679 TF_RETURN_IF_ERROR(IsInputSpecsCorrect(name, sig_iter->second, inputs));
680 }
681 if (run_options.validate_input_specs_dry_run) {
682 const auto status = IsInputSpecsCorrect(name, sig_iter->second, inputs);
683 if (!status.ok()) {
684 LOG(ERROR) << "TFRT input specs validation failed: "
685 << status.error_message();
686 }
687 }
688 std::vector<tensorflow::Tensor> captures;
689 for (const auto& capture : sig_iter->second.captures) {
690 captures.push_back(capture);
691 }
692
693 const tfrt::Function* func;
694 tfrt::ResourceContext* resource_context;
695 if (lazy_loading_enabled_) {
696 // If lazy loading is enabled, no signature is loaded into `bef_file_`, so
697 // we need to find the BEF from the cache or create one.
698 TF_ASSIGN_OR_RETURN(const LoadingResult& loading_result,
699 GetOrCreateLoadingResult({std::string(name)}));
700 func = loading_result.bef_file->GetFunction(
701 tensorflow::kImportModelDefaultGraphFuncName);
702 resource_context = loading_result.resource_context.get();
703 } else {
704 func = bef_file_->GetFunction({name.data(), name.size()});
705 resource_context = resource_context_.get();
706 }
707 DCHECK(func);
708
709 return GraphExecutionRunOnFunction(options_.graph_execution_options,
710 run_options, name, *func, inputs, captures,
711 outputs, resource_context, runtime(),
712 *fallback_state_, req_deadline_tracker_);
713 }
714
715 struct SavedModelImpl::JoinedSignature {
716 // A unique name for the joined signature.
717 std::string name;
718 // The feed nodes for the corresponding inputs, but they might not be in the
719 // original order and if there are more than one original inputs mapped to the
720 // same feed node, only one is picked here.
721 tensorflow::GraphImportConfig::InputArrays input_nodes;
722 // The fetch nodes for the outputs, which should be in the original order.
723 std::vector<std::string> output_nodes;
724 // The target nodes that should be run but not returned as outputs.
725 std::vector<std::string> target_nodes;
726 };
727
RunMultipleSignatures(const RunOptions & run_options,absl::Span<const std::string> names,absl::Span<const std::vector<tensorflow::Tensor>> multi_inputs,std::vector<std::vector<tensorflow::Tensor>> * multi_outputs)728 tensorflow::Status SavedModelImpl::RunMultipleSignatures(
729 const RunOptions& run_options, absl::Span<const std::string> names,
730 absl::Span<const std::vector<tensorflow::Tensor>> multi_inputs,
731 std::vector<std::vector<tensorflow::Tensor>>* multi_outputs) {
732 TF_RET_CHECK(names.size() == multi_inputs.size())
733 << "the sizes of names and inputs should be the same";
734 TF_RET_CHECK(multi_outputs) << "outputs must be provided";
735 multi_outputs->clear();
736
737 // Due to possible overlapping of feed nodes among user-specified inputs,
738 // `JoinSignatures()` will deduplicate against fetch tensor names and produce
739 // the desired inputs in a new order. The same dedup logic is used here to
740 // generate the flattened input values in the same order.
741 //
742 // Note that we don't need to do any deduplicating nor reordering for the
743 // fetch nodes.
744 //
745 // TODO(tfrt-devs): Consider refactoring JoinSignatures so that we don't have
746 // the implicit requirement that the same dedup logic must be used here and in
747 // JoinSignatures().
748 std::vector<std::pair<std::string /*tensor_name*/, tensorflow::Tensor>>
749 flat_inputs;
750 std::vector<std::string> flat_output_names;
751 absl::flat_hash_set<std::string> visited_feed_tensor_names;
752
753 const auto& signature_defs = meta_graph_def_.signature_def();
754 for (int i = 0; i < names.size(); ++i) {
755 const auto& signature_name = names[i];
756 const auto& input_tensors = multi_inputs[i];
757 auto sig_iter = signature_defs.find(signature_name);
758
759 // Early out if any signature can't be found.
760 TF_RET_CHECK(sig_iter != signature_defs.end())
761 << "failed to find signature in the graph";
762 const auto& signature_def = sig_iter->second;
763
764 // `signatures_` keeps the user-specified input names that is in the same
765 // order as `input_tensors`.
766 const auto& signature = signatures_.at(signature_name);
767 const auto& input_names = signature.input_names;
768 if (run_options.validate_input_specs) {
769 TF_RETURN_IF_ERROR(
770 IsInputSpecsCorrect(signature_name, signature, input_tensors));
771 }
772 if (run_options.validate_input_specs_dry_run) {
773 const auto status =
774 IsInputSpecsCorrect(signature_name, signature, input_tensors);
775 if (!status.ok()) {
776 LOG(ERROR) << "TFRT input specs validation failed: "
777 << status.error_message();
778 }
779 }
780 DCHECK(signature.captures.empty());
781
782 TF_RET_CHECK(input_tensors.size() == signature_def.inputs().size())
783 << "Incorrect input size for signature: " << signature_name
784 << ": expected " << signature_def.inputs().size() << ", but got "
785 << input_tensors.size();
786 DCHECK_EQ(input_names.size(), signature_def.inputs().size());
787
788 // Then we find out the corresponding tensor names (ie.
789 // node_name:output_idx) for the inputs using the SignatureDef proto.
790 //
791 // TODO(tfrt-devs): Consider including tensor names in `signatures_` as
792 // well, so that only `signatures_` is used here.
793 for (int j = 0; j < input_tensors.size(); ++j) {
794 const auto& tensor_info = signature_def.inputs().at(input_names[j]);
795
796 // TODO(b/184675681): Support other encoding cases.
797 //
798 // TODO(b/184679394): Add unit test for this check.
799 TF_RET_CHECK(tensor_info.encoding_case() == tensorflow::TensorInfo::kName)
800 << "Only dense tensor is supported, but got encoding case "
801 << tensor_info.encoding_case();
802
803 const auto& tensor_name = tensor_info.name();
804
805 // Skip if we have visited the feed tensor. Otherwise, marked it as
806 // visited and put it in the `flat_inputs`. Note that the following code
807 // uses the same logic as in JoinSignatures() to deduplicate inputs with
808 // the feed tensor names, and generates the flat inputs in the same order.
809 if (visited_feed_tensor_names.contains(tensor_name)) continue;
810 visited_feed_tensor_names.insert(tensor_name);
811 flat_inputs.push_back(std::make_pair(tensor_name, input_tensors[j]));
812 }
813
814 for (const auto& output_key : signature.output_names) {
815 const auto& tensor_info = signature_def.outputs().at(output_key);
816
817 VLOG(1) << "Importing Signature Output: output_key = " << output_key
818 << ", tensor_info = " << tensor_info.DebugString();
819
820 TF_RET_CHECK(tensor_info.encoding_case() == tensorflow::TensorInfo::kName)
821 << "Only dense tensor is supported, but got encoding case "
822 << tensor_info.encoding_case();
823
824 flat_output_names.push_back(tensor_info.name());
825 }
826 }
827
828 std::vector<tensorflow::Tensor> flat_outputs;
829
830 TF_RETURN_IF_ERROR(
831 graph_executor_->Run(run_options, flat_inputs, flat_output_names,
832 /*target_tensor_names=*/{}, &flat_outputs));
833
834 // The outputs of the compiled function are in the user-specified order,
835 // though they are flattened. So we just need to regroup the outputs for each
836 // signature using the number of outputs of it.
837 multi_outputs->resize(names.size());
838 auto cur = flat_outputs.begin();
839 for (size_t i = 0; i < names.size(); ++i) {
840 const auto& signature_name = names[i];
841 const size_t len = signature_defs.at(signature_name).outputs().size();
842 std::move(cur, cur + len, std::back_inserter(multi_outputs->at(i)));
843 cur += len;
844 DCHECK_LE(std::distance(flat_outputs.begin(), cur), flat_outputs.size());
845 }
846 return OkStatus();
847 }
848
849 tensorflow::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
ImportSubgraph(mlir::MLIRContext * context,const tensorflow::GraphImportConfig::InputArrays & input_nodes,const std::vector<std::string> & output_nodes,const std::vector<std::string> & target_nodes)850 SavedModelImpl::ImportSubgraph(
851 mlir::MLIRContext* context,
852 const tensorflow::GraphImportConfig::InputArrays& input_nodes,
853 const std::vector<std::string>& output_nodes,
854 const std::vector<std::string>& target_nodes) {
855 tensorflow::GraphImportConfig graph_import_config;
856 graph_import_config.prune_unused_nodes = true;
857 graph_import_config.enable_shape_inference = false;
858 graph_import_config.inputs = input_nodes;
859 graph_import_config.outputs = output_nodes;
860 graph_import_config.control_outputs = target_nodes;
861
862 // Optimize the graph.
863 TF_ASSIGN_OR_RETURN(
864 auto optimization_result,
865 graph_executor_->graph_execution_state().CreateOptimizedGraph(
866 graph_import_config));
867
868 // Convert the optimized graph to an MLIR module.
869 return tensorflow::ConvertGraphToMlir(
870 *optimization_result.graph, /*debug_info=*/{},
871 optimization_result.graph->flib_def(), graph_import_config, context);
872 }
873
RunByTensorNames(const RunOptions & run_options,absl::Span<const std::pair<std::string,tensorflow::Tensor>> inputs,absl::Span<const std::string> output_tensor_names,absl::Span<const std::string> target_node_names,std::vector<tensorflow::Tensor> * outputs)874 tensorflow::Status SavedModelImpl::RunByTensorNames(
875 const RunOptions& run_options,
876 absl::Span<const std::pair<std::string, tensorflow::Tensor>> inputs,
877 absl::Span<const std::string> output_tensor_names,
878 absl::Span<const std::string> target_node_names,
879 std::vector<tensorflow::Tensor>* outputs) {
880 // TODO(b/192498110): Validate input type.
881
882 return graph_executor_->Run(run_options, inputs, output_tensor_names,
883 target_node_names, outputs);
884 }
885
886 namespace {
887
888 using JoinedSignature = SavedModelImpl::JoinedSignature;
889
890 // Returns a joined signature with the signatures in `names`. For inputs, as
891 // their corresponding nodes may overlap, we deduplicate them by the nodes so
892 // the order of inputs for the joined signature would be different from the
893 // original order. For outputs, overlapping is fine so we only flatten it in the
894 // original order.
JoinSignatures(absl::Span<const std::string> names,const SignatureMap & signature_map,const tensorflow::protobuf::Map<std::string,tensorflow::SignatureDef> & signature_def_map)895 StatusOr<JoinedSignature> JoinSignatures(
896 absl::Span<const std::string> names, const SignatureMap& signature_map,
897 const tensorflow::protobuf::Map<std::string, tensorflow::SignatureDef>&
898 signature_def_map) {
899 // Join all the names, all the inputs, and all the outputs.
900 JoinedSignature joined_signature;
901 joined_signature.name = absl::StrJoin(names, kSignatureJoiningDelimiter);
902
903 // Keep the feed tensor names visited.
904 absl::flat_hash_set<std::string> visited_feed_tensor_names;
905
906 for (const auto& name : names) {
907 const auto& signature_def = signature_def_map.at(name);
908
909 // For inputs, we deduplicate possible overlapping feed nodes and create the
910 // new input array.
911 for (const auto& iter : signature_def.inputs()) {
912 const auto& tensor_info = iter.second;
913
914 // Skip if this feed node is already visited.
915 if (visited_feed_tensor_names.contains(tensor_info.name())) continue;
916
917 // Otherwise, we parse its tensor info and collect it for later
918 // compilation.
919 visited_feed_tensor_names.insert(tensor_info.name());
920
921 // TODO(b/184675681): Support other encoding cases.
922 //
923 // TODO(b/184679394): Add unit test for this check.
924 TF_RET_CHECK(tensor_info.encoding_case() == tensorflow::TensorInfo::kName)
925 << "Only dense tensor is supported, but got encoding case "
926 << tensor_info.encoding_case();
927
928 VLOG(1) << "Importing Signature Input: input_key = " << iter.first
929 << ", tensor_info = " << tensor_info.DebugString();
930
931 tensorflow::ArrayInfo array_info;
932 array_info.imported_dtype = tensor_info.dtype();
933
934 if (tensor_info.has_tensor_shape()) {
935 array_info.shape = tensor_info.tensor_shape();
936 } else {
937 // If there is no tensor shape in the tensor info, conservatively set
938 // unknown_rank to true.
939 array_info.shape.set_unknown_rank(true);
940 }
941
942 joined_signature.input_nodes.insert(
943 std::pair<std::string, tensorflow::ArrayInfo>(tensor_info.name(),
944 std::move(array_info)));
945 }
946
947 // For outputs, we simply flatten them in the original order, as it is fine
948 // to have duplicated fetch nodes.
949 const internal::Signature& signature = signature_map.at(name);
950 for (const auto& output_key : signature.output_names) {
951 const auto& tensor_info = signature_def.outputs().at(output_key);
952
953 VLOG(1) << "Importing Signature Output: output_key = " << output_key
954 << ", tensor_info = " << tensor_info.DebugString();
955
956 TF_RET_CHECK(tensor_info.encoding_case() == tensorflow::TensorInfo::kName)
957 << "Only dense tensor is supported, but got encoding case "
958 << tensor_info.encoding_case();
959
960 joined_signature.output_nodes.push_back(tensor_info.name());
961 }
962 }
963
964 return joined_signature;
965 }
966
967 } // namespace
968
969 // TODO(b/216379787): Reuse `GraphExecutor::LoadClientGraph()`.
970 StatusOr<std::reference_wrapper<const SavedModelImpl::LoadingResult>>
LoadJoinedSignature(const JoinedSignature & joined_signature)971 SavedModelImpl::LoadJoinedSignature(const JoinedSignature& joined_signature) {
972 // Step 1: Import the combined subgraph from proto to an MLIR module.
973 mlir::MLIRContext context;
974 ASSIGN_OR_RETURN_IN_IMPORT(
975 auto module, ImportSubgraph(&context, joined_signature.input_nodes,
976 joined_signature.output_nodes,
977 joined_signature.target_nodes));
978
979 // Step 2: Compile the MLIR module from TF dialect to TFRT dialect (in BEF).
980 auto loading_result = std::make_unique<LoadingResult>();
981 loading_result->name = joined_signature.name;
982 loading_result->resource_context = CreateResourceContext(
983 runtime(), tpu_model_resource_.get(),
984 options_.graph_execution_options.compile_options.tpu_target);
985
986 RETURN_IF_ERROR_IN_COMPILE(tensorflow::ConvertTfMlirToBef(
987 options_.graph_execution_options.compile_options, module.get(),
988 &loading_result->bef));
989
990 // Step 3: Initialize runtime states using special BEF functions.
991 ASSIGN_OR_RETURN_IN_INIT(
992 loading_result->bef_file,
993 tfrt::CreateBefFileFromBefBuffer(
994 *options_.graph_execution_options.runtime, loading_result->bef));
995 RETURN_IF_ERROR_IN_INIT(RunInitializers(
996 /*initializers_and_signatures=*/{},
997 options_.graph_execution_options.model_metadata,
998 loading_result->bef_file.get(), *options_.graph_execution_options.runtime,
999 loading_result->resource_context.get(), *fallback_state_));
1000
1001 // Store loading_result in cache.
1002 const auto* loading_result_ptr = loading_result.get();
1003 loading_result_cache_[joined_signature.name] = std::move(loading_result);
1004 return {*loading_result_ptr};
1005 }
1006
1007 StatusOr<std::reference_wrapper<const SavedModelImpl::LoadingResult>>
GetOrCreateLoadingResult(absl::Span<const std::string> names)1008 SavedModelImpl::GetOrCreateLoadingResult(absl::Span<const std::string> names) {
1009 const auto joined_name = absl::StrJoin(names, kSignatureJoiningDelimiter);
1010 tensorflow::mutex_lock l(loading_result_cache_mu_);
1011 const auto iter = loading_result_cache_.find(joined_name);
1012 if (iter != loading_result_cache_.end()) return {*iter->second};
1013
1014 TF_ASSIGN_OR_RETURN(
1015 const auto joined_signature,
1016 JoinSignatures(names, signatures_, meta_graph_def_.signature_def()));
1017
1018 return LoadJoinedSignature(joined_signature);
1019 }
1020
1021 } // namespace tfrt_stub
1022 } // namespace tensorflow
1023