xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tfrt/graph_executor/graph_executor.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tfrt/graph_executor/graph_executor.h"
16 
17 #include <algorithm>
18 #include <array>
19 #include <cstdint>
20 #include <functional>
21 #include <memory>
22 #include <numeric>
23 #include <optional>
24 #include <string>
25 #include <utility>
26 #include <vector>
27 
28 #include "learning/brain/experimental/tfrt/native_lowering/saved_model/saved_model_translate.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/string_view.h"
31 #include "absl/time/clock.h"
32 #include "absl/time/time.h"
33 #include "absl/types/span.h"
34 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
35 #include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_request_context.h"
36 #include "tensorflow/compiler/mlir/tfrt/translate/import_model.h"
37 #include "tensorflow/core/framework/tensor.h"
38 #include "tensorflow/core/lib/gtl/cleanup.h"
39 #include "tensorflow/core/platform/errors.h"
40 #include "tensorflow/core/platform/status.h"
41 #include "tensorflow/core/platform/statusor.h"
42 #include "tensorflow/core/platform/threadpool_interface.h"
43 #include "tensorflow/core/platform/types.h"
44 #include "tensorflow/core/profiler/lib/connected_traceme.h"
45 #include "tensorflow/core/profiler/lib/traceme_encode.h"
46 #include "tensorflow/core/protobuf/config.pb.h"
47 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.h"
48 #include "tensorflow/core/tfrt/fallback/fallback_state.h"
49 #include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h"
50 #include "tensorflow/core/tfrt/runtime/runtime.h"
51 #include "tensorflow/core/tfrt/runtime/work_queue_interface.h"
52 #include "tensorflow/core/tfrt/tpu/tpu_resources.h"
53 #include "tensorflow/core/tfrt/utils/error_util.h"
54 #include "tensorflow/core/tfrt/utils/fallback_tensor.h"
55 #include "tensorflow/core/tfrt/utils/utils.h"
56 #include "tfrt/core_runtime/core_runtime.h"  // from @tf_runtime
57 #include "tfrt/host_context/async_dispatch.h"  // from @tf_runtime
58 #include "tfrt/host_context/async_value.h"  // from @tf_runtime
59 #include "tfrt/host_context/async_value_ref.h"  // from @tf_runtime
60 #include "tfrt/host_context/chain.h"  // from @tf_runtime
61 #include "tfrt/host_context/concurrent_work_queue.h"  // from @tf_runtime
62 #include "tfrt/host_context/execution_context.h"  // from @tf_runtime
63 #include "tfrt/host_context/function.h"  // from @tf_runtime
64 #include "tfrt/host_context/host_context.h"  // from @tf_runtime
65 #include "tfrt/host_context/request_deadline_tracker.h"  // from @tf_runtime
66 #include "tfrt/host_context/resource_context.h"  // from @tf_runtime
67 #include "tfrt/support/forward_decls.h"  // from @tf_runtime
68 #include "tfrt/support/ref_count.h"  // from @tf_runtime
69 #include "tfrt/support/string_util.h"  // from @tf_runtime
70 
71 namespace tensorflow {
72 namespace tfrt_stub {
73 namespace {
74 
75 constexpr char kDeadlineExceededMessage[] = "Deadline exceeded.";
76 constexpr char kTensorNameJoiningDelimiter[] = "-";
77 constexpr char kArgumentTypeJoiningDelimiter[] = "^";
78 
79 }  // namespace
80 
SetUpRequestContext(const GraphExecutionRunOptions & run_options,const SessionMetadata & model_metadata,tfrt::HostContext * host,tensorflow::tfrt_stub::WorkQueueInterface * work_queue,tfrt::ResourceContext * resource_context,const tensorflow::tfrt_stub::FallbackState & fallback_state)81 StatusOr<std::unique_ptr<RequestInfo>> SetUpRequestContext(
82     const GraphExecutionRunOptions& run_options,
83     const SessionMetadata& model_metadata, tfrt::HostContext* host,
84     tensorflow::tfrt_stub::WorkQueueInterface* work_queue,
85     tfrt::ResourceContext* resource_context,
86     const tensorflow::tfrt_stub::FallbackState& fallback_state) {
87   DCHECK(host);
88   DCHECK(work_queue);
89   // Create request context and prepare deadline tracker.
90   // TODO(tfrt-devs): Consider using an ID unique within each model to reduce
91   // contention.
92   int64_t request_id = work_queue->id();
93   if (request_id == 0) request_id = tfrt::GetUniqueInt();
94   tfrt::RequestContextBuilder request_context_builder(
95       host, resource_context, request_id, run_options.enable_cost_measurement);
96 
97   // TODO(b/198671794): `intra_op_threadpool` should be passed through Run()
98   // directly.
99   tensorflow::thread::ThreadPoolInterface* intra_op_threadpool = nullptr;
100 
101   // TODO(b/198671794): The per-request queue should be passed through Run()
102   // directly.
103   TF_ASSIGN_OR_RETURN(auto request_queue,
104                       work_queue->InitializeRequest(&request_context_builder,
105                                                     &intra_op_threadpool));
106 
107   auto request_info = std::make_unique<RequestInfo>();
108 
109   // If a per-request queue is not provided, use the original queue in the
110   // tensorflow::Executor::Args::Runner.
111   auto* inter_op_queue = request_queue ? request_queue.get() : work_queue;
112   request_info->runner = [inter_op_queue](std::function<void()> f) {
113     inter_op_queue->AddTask(std::move(f));
114   };
115 
116   request_info->request_queue = std::move(request_queue);
117 
118   TF_RETURN_IF_ERROR(tensorflow::tfd::SetUpKernelFallbackCompatRequestContext(
119       &request_context_builder, &fallback_state.device_manager(),
120       &fallback_state.process_function_library_runtime(), intra_op_threadpool,
121       model_metadata, &request_info->runner));
122 
123   TF_RETURN_IF_ERROR(
124       tensorflow::SetUpTfJitRtRequestContext(&request_context_builder));
125   tfrt::RequestOptions request_options;
126   request_options.priority = run_options.priority;
127   request_context_builder.set_request_options(request_options);
128 
129   auto expected_req_ctx = std::move(request_context_builder).build();
130   if (!expected_req_ctx) {
131     return tensorflow::errors::Internal(
132         tfrt::StrCat(expected_req_ctx.takeError()));
133   }
134 
135   request_info->tfrt_request_context = std::move(expected_req_ctx.get());
136 
137   return request_info;
138 }
139 
GraphExecutionRunOnFunction(const GraphExecutionOptions & options,const GraphExecutionRunOptions & run_options,absl::string_view signature_name,const tfrt::Function & func,absl::Span<const tensorflow::Tensor> inputs,absl::Span<const tensorflow::Tensor> captures,std::vector<tensorflow::Tensor> * outputs,tfrt::ResourceContext * resource_context,const Runtime & runtime,const FallbackState & fallback_state,tfrt::RequestDeadlineTracker & req_deadline_tracker)140 tensorflow::Status GraphExecutionRunOnFunction(
141     const GraphExecutionOptions& options,
142     const GraphExecutionRunOptions& run_options,
143     absl::string_view signature_name, const tfrt::Function& func,
144     absl::Span<const tensorflow::Tensor> inputs,
145     absl::Span<const tensorflow::Tensor> captures,
146     std::vector<tensorflow::Tensor>* outputs,
147     tfrt::ResourceContext* resource_context, const Runtime& runtime,
148     const FallbackState& fallback_state,
149     tfrt::RequestDeadlineTracker& req_deadline_tracker) {
150   auto* host = runtime.core_runtime()->GetHostContext();
151 
152   TF_ASSIGN_OR_RETURN(
153       auto request_info,
154       SetUpRequestContext(run_options, options.model_metadata, host,
155                           run_options.work_queue ? run_options.work_queue
156                                                  : runtime.work_queue(),
157                           resource_context, fallback_state));
158 
159   tensorflow::profiler::TraceMeProducer traceme(
160       // To TraceMeConsumers in RunHandlerThreadPool::WorkerLoop.
161       [request_id = request_info->tfrt_request_context->id(), signature_name,
162        &options] {
163         return tensorflow::profiler::TraceMeEncode(
164             "TfrtModelRun",
165             {{"_r", 1},
166              {"id", request_id},
167              {"signature", signature_name},
168              {"model_id", absl::StrCat(options.model_metadata.name(), ":",
169                                        options.model_metadata.version())}});
170       },
171       tensorflow::profiler::ContextType::kTfrtExecutor,
172       request_info->tfrt_request_context->id());
173 
174   // Only configure timer when the deadline is set.
175   if (run_options.deadline.has_value()) {
176     auto deadline = run_options.deadline.value();
177     if (absl::ToChronoTime(absl::Now()) > deadline) {
178       return tensorflow::errors::DeadlineExceeded(kDeadlineExceededMessage);
179     }
180     req_deadline_tracker.CancelRequestOnDeadline(
181         deadline, request_info->tfrt_request_context);
182   }
183 
184   tfrt::ExecutionContext exec_ctx{request_info->tfrt_request_context};
185   if (run_options.work_queue) {
186     // TODO(b/198671794): Avoid creating `request_queue` when the `work_queue`
187     // in `run_options` is specified.
188     exec_ctx.set_work_queue(run_options.work_queue);
189   } else if (request_info->request_queue) {
190     exec_ctx.set_work_queue(request_info->request_queue.get());
191   } else {
192     exec_ctx.set_work_queue(runtime.work_queue());
193   }
194 
195   llvm::SmallVector<tfrt::AsyncValue*, 4> arguments;
196   auto cleanup = tensorflow::gtl::MakeCleanup([&]() {
197     for (auto* argument : arguments) argument->DropRef();
198   });
199 
200   // The first argument is a chain for side-effects. Since SavedModel::Run()
201   // only returns when side-effects are visible, we can use a ready chain here.
202   arguments.push_back(tfrt::GetReadyChain().release());
203 
204   for (const auto& input : inputs) {
205     arguments.push_back(
206         tfrt::MakeAvailableAsyncValueRef<FallbackTensor>(input).release());
207   }
208 
209   DCHECK(captures.empty()) << "signature should have no captures, which is "
210                               "guaranteed by the compiler";
211 
212   if (arguments.size() != func.argument_types().size())
213     return tensorflow::errors::Internal("incorrect number of inputs.");
214 
215   llvm::SmallVector<tfrt::RCReference<tfrt::AsyncValue>, 4> chain_and_results;
216   chain_and_results.resize(func.result_types().size());
217 
218   // Hand over the execution to thread pool.
219   std::array<tfrt::RCReference<tfrt::AsyncValue>, 1> executed = {
220       EnqueueWork(exec_ctx, [&]() -> tfrt::Chain {
221         func.Execute(exec_ctx, arguments, chain_and_results);
222         return {};
223       })};
224 
225   // Wait for the function execution before checking chain and results.
226   exec_ctx.work_queue().Await(executed);
227 
228   // Wait for all results including the side-effect chain. This ensures that all
229   // side-effects are visible when SavedModel::Run() returns.
230   exec_ctx.work_queue().Await(chain_and_results);
231 
232   DCHECK(!chain_and_results.empty());
233 
234   tfrt::RCReference<tfrt::AsyncValue>& chain = chain_and_results[0];
235   auto results = llvm::drop_begin(chain_and_results, 1);
236 
237   tensorflow::StatusGroup status_group;
238 
239   if (chain->IsError()) {
240     status_group.Update(CreateTfErrorStatus(chain->GetError()));
241   }
242 
243   for (tfrt::RCReference<tfrt::AsyncValue>& result : results) {
244     DCHECK(result->IsAvailable());
245 
246     if (result->IsError()) {
247       status_group.Update(CreateTfErrorStatus(result->GetError()));
248       outputs->push_back(tensorflow::Tensor());
249       continue;
250     }
251 
252     // The result must be a host tensor. This is guaranteed as the compiler
253     // will insert necessary device transfer operations in the graph.
254     DCHECK(result->IsType<FallbackTensor>());
255     const auto& host_tensor = result->get<FallbackTensor>().tensor();
256     // Make a copy of tensor here as the different result AsyncValues might
257     // point to the same underlying tensor.
258     outputs->push_back(host_tensor);
259   }
260 
261   // TODO(b/171926578): Explicitly clear the context data. Remove it after the
262   // b/171926578 is fixed.
263   exec_ctx.request_ctx()->ClearData();
264 
265   // Check if error is due to cancellation.
266   // TODO(tfrt-devs): report cancellation reason from runtime.
267   if (request_info->tfrt_request_context->IsCancelled()) {
268     // Currently a request can only be cancelled by an expired timer.
269     return tensorflow::errors::DeadlineExceeded(kDeadlineExceededMessage);
270   }
271 
272   return status_group.as_summary_status();
273 }
274 
CreateResourceContext(const tensorflow::tfrt_stub::Runtime & runtime,tfrt::tpu::TpuModelResource * tpu_model_resource,tensorflow::TfrtTpuInfraTarget tpu_target)275 std::unique_ptr<tfrt::ResourceContext> CreateResourceContext(
276     const tensorflow::tfrt_stub::Runtime& runtime,
277     tfrt::tpu::TpuModelResource* tpu_model_resource,
278     tensorflow::TfrtTpuInfraTarget tpu_target) {
279   auto resource_context = std::make_unique<tfrt::ResourceContext>();
280   runtime.CreateRuntimeResources(resource_context.get());
281 
282   // TODO(b/178227859): We should make TPU resource init code pluggable, as
283   // opposed to linking it in. We can do this by adding a callback with
284   // `Runtime::AddCreateRuntimeResourceFn`.
285   if (tpu_target == tensorflow::TfrtTpuInfraTarget::kTpurt) {
286     AddTpuResources(resource_context.get(), tpu_model_resource);
287   }
288   return resource_context;
289 }
290 
Create(Options options,const FallbackState & fallback_state,tfrt::tpu::TpuModelResource * tpu_model_resource,tensorflow::GraphDef graph_def)291 StatusOr<std::unique_ptr<GraphExecutor>> GraphExecutor::Create(
292     Options options, const FallbackState& fallback_state,
293     tfrt::tpu::TpuModelResource* tpu_model_resource,
294     tensorflow::GraphDef graph_def) {
295   if (options.runtime == nullptr) {
296     return errors::InvalidArgument("options.runtime must be non-null ");
297   }
298 
299   TfrtGraphExecutionState::Options graph_execution_state_options;
300   graph_execution_state_options.run_placer_grappler_on_functions =
301       options.run_placer_grappler_on_functions;
302   graph_execution_state_options.enable_tfrt_gpu = options.enable_tfrt_gpu;
303 
304   TF_ASSIGN_OR_RETURN(
305       auto graph_execution_state,
306       TfrtGraphExecutionState::Create(graph_execution_state_options,
307                                       std::move(graph_def), fallback_state));
308   return std::make_unique<GraphExecutor>(std::move(options), fallback_state,
309                                          tpu_model_resource,
310                                          std::move(graph_execution_state));
311 }
312 
313 namespace {
314 
315 // Sort the strings in `names` and store the results in `sorted_names`. In
316 // addition, the original index in `names` for the item `sorted_names[i]` is
317 // stored in `original_indices[i]`.
CreateSortedNamesAndOriginalIndices(absl::Span<const std::string> names,std::vector<std::string> & sorted_names,std::vector<int> & original_indices)318 void CreateSortedNamesAndOriginalIndices(absl::Span<const std::string> names,
319                                          std::vector<std::string>& sorted_names,
320                                          std::vector<int>& original_indices) {
321   DCHECK(sorted_names.empty());
322   DCHECK(original_indices.empty());
323 
324   // Generate indices.
325   original_indices.resize(names.size());
326   std::iota(original_indices.begin(), original_indices.end(), 0);
327 
328   // Sort indices by comparing the corresponding names.
329   std::sort(original_indices.begin(), original_indices.end(),
330             [&](int x, int y) { return names[x] < names[y]; });
331 
332   // Use sorted indices to generate sorted names.
333   sorted_names.reserve(names.size());
334   for (int original_index : original_indices) {
335     DCHECK_LT(original_index, names.size());
336     sorted_names.push_back(names[original_index]);
337   }
338 }
339 
340 }  // namespace
341 
Run(const RunOptions & run_options,absl::Span<const std::pair<std::string,tensorflow::Tensor>> inputs,absl::Span<const std::string> output_tensor_names,absl::Span<const std::string> target_tensor_names,std::vector<tensorflow::Tensor> * outputs)342 tensorflow::Status GraphExecutor::Run(
343     const RunOptions& run_options,
344     absl::Span<const std::pair<std::string, tensorflow::Tensor>> inputs,
345     absl::Span<const std::string> output_tensor_names,
346     absl::Span<const std::string> target_tensor_names,
347     std::vector<tensorflow::Tensor>* outputs) {
348   // TODO(b/192498110): Validate input type.
349 
350   // Sort the input/output names to have a stable order, so that the
351   // `joined_name`, which is used as the cache key, will be the same as long as
352   // the same set of inputs/outputs are specified.
353   std::vector<std::string> input_names;
354   input_names.reserve(inputs.size());
355   for (const auto& p : inputs) input_names.push_back(p.first);
356   std::vector<std::string> sorted_input_names;
357   std::vector<int> input_original_indices;
358   CreateSortedNamesAndOriginalIndices(input_names, sorted_input_names,
359                                       input_original_indices);
360   // We also need to create sorted input dtypes as they are needed for the
361   // compilation.
362   std::vector<tensorflow::DataType> sorted_input_dtypes;
363   sorted_input_dtypes.reserve(inputs.size());
364   for (int original_index : input_original_indices) {
365     sorted_input_dtypes.push_back(inputs.at(original_index).second.dtype());
366   }
367 
368   std::vector<std::string> sorted_output_names;
369   std::vector<int> output_original_indices;
370   CreateSortedNamesAndOriginalIndices(output_tensor_names, sorted_output_names,
371                                       output_original_indices);
372 
373   // For target node names, we only need to sort them. The original indices are
374   // not needed.
375   std::vector<std::string> sorted_target_node_names(target_tensor_names.begin(),
376                                                     target_tensor_names.end());
377   std::sort(sorted_target_node_names.begin(), sorted_target_node_names.end());
378 
379   // Load the client graph.
380   TF_ASSIGN_OR_RETURN(
381       const LoadedClientGraph& loaded_client_graph,
382       GetOrCreateLoadedClientGraph(
383           sorted_input_names, sorted_input_dtypes, sorted_output_names,
384           sorted_target_node_names, run_options.work_queue));
385 
386   const auto* func = loaded_client_graph.bef_file->GetFunction(
387       tensorflow::kImportModelDefaultGraphFuncName);
388   DCHECK(func);
389 
390   // Create the actual arguments to the compiled function, which are sorted
391   // according to the input tensor names.
392   std::vector<tensorflow::Tensor> flat_inputs;
393   flat_inputs.reserve(inputs.size());
394   for (int original_index : input_original_indices) {
395     flat_inputs.push_back(inputs.at(original_index).second);
396   }
397 
398   std::vector<tensorflow::Tensor> flat_outputs;
399   TF_RETURN_IF_ERROR(GraphExecutionRunOnFunction(
400       options_, run_options, loaded_client_graph.name, *func, flat_inputs,
401       /*captures=*/{}, &flat_outputs,
402       loaded_client_graph.resource_context.get(), runtime(), fallback_state_,
403       req_deadline_tracker_));
404 
405   // Create the outputs from the actual function results, which are sorted
406   // according to the output tensor names.
407   auto flat_output_iter = flat_outputs.begin();
408   outputs->resize(flat_outputs.size());
409   for (int original_index : output_original_indices) {
410     (*outputs)[original_index] = std::move(*flat_output_iter);
411     ++flat_output_iter;
412   }
413 
414   return OkStatus();
415 }
416 
Extend(const GraphDef & graph)417 tensorflow::Status GraphExecutor::Extend(const GraphDef& graph) {
418   return graph_execution_state_->Extend(graph);
419 }
420 
421 StatusOr<std::unique_ptr<GraphExecutor::LoadedClientGraph>>
ImportAndCompileClientGraph(const GraphExecutor::ClientGraph & client_graph)422 GraphExecutor::ImportAndCompileClientGraph(
423     const GraphExecutor::ClientGraph& client_graph) {
424   auto loaded_client_graph = std::make_unique<LoadedClientGraph>();
425   loaded_client_graph->name = client_graph.name;
426   loaded_client_graph->resource_context = CreateResourceContext(
427       runtime(), tpu_model_resource_, options_.compile_options.tpu_target);
428 
429   // Step 1 of loading: Import the client graph from proto to an MLIR module.
430   auto import_start_time = absl::Now();
431   mlir::MLIRContext context;
432   ASSIGN_OR_RETURN_IN_IMPORT(
433       auto module, ImportClientGraphToMlirModule(client_graph, &context));
434   auto import_duration = absl::Now() - import_start_time;
435   LOG(INFO) << "TFRT finished importing client graph (" << &client_graph
436             << "). Took " << absl::ToInt64Milliseconds(import_duration)
437             << " ms. Client graph name: " << client_graph.name;
438 
439   // Step 2 of loading: Compile the MLIR module from TF dialect to TFRT dialect
440   // (in BEF).
441   // TODO(b/229261464): Unify the sync and async lowering passes so we do not
442   // need this branch.
443   auto compile_start_time = absl::Now();
444   if (options_.compile_options.compile_to_sync_tfrt_dialect) {
445     ASSIGN_OR_RETURN_IN_COMPILE(
446         loaded_client_graph->bef,
447         tfrt::CompileTfMlirModuleToSyncBef(module.get()));
448   } else {
449     ASSIGN_OR_RETURN_IN_COMPILE(loaded_client_graph->bef,
450                                 CompileMlirModuleToBef(module.get()));
451   }
452   ASSIGN_OR_RETURN_IN_COMPILE(
453       loaded_client_graph->bef_file,
454       tfrt::CreateBefFileFromBefBuffer(runtime(), loaded_client_graph->bef));
455   auto compile_duration = absl::Now() - compile_start_time;
456   LOG(INFO) << "TFRT finished compiling client graph (" << &client_graph
457             << "). Took " << absl::ToInt64Milliseconds(compile_duration)
458             << " ms. Client graph name: " << client_graph.name;
459 
460   return loaded_client_graph;
461 }
462 
463 StatusOr<std::unique_ptr<GraphExecutor::LoadedClientGraph>>
LoadClientGraph(const GraphExecutor::ClientGraph & client_graph,tensorflow::tfrt_stub::WorkQueueInterface * work_queue)464 GraphExecutor::LoadClientGraph(
465     const GraphExecutor::ClientGraph& client_graph,
466     tensorflow::tfrt_stub::WorkQueueInterface* work_queue) {
467   LOG(INFO) << "TFRT loading client graph (" << &client_graph << ") "
468             << client_graph.name;
469   TF_ASSIGN_OR_RETURN(auto loaded_client_graph,
470                       ImportAndCompileClientGraph(client_graph));
471 
472   // Step 3 of loading: Initialize runtime states using special BEF functions.
473   auto init_start_time = absl::Now();
474   RETURN_IF_ERROR_IN_INIT(InitBef(loaded_client_graph->bef_file.get(),
475                                   loaded_client_graph->resource_context.get(),
476                                   work_queue));
477   auto init_duration = absl::Now() - init_start_time;
478   LOG(INFO) << "TFRT finished initializing client graph (" << &client_graph
479             << "). Took " << absl::ToInt64Milliseconds(init_duration)
480             << " ms. Client graph name: " << client_graph.name;
481 
482   return loaded_client_graph;
483 }
484 
485 tensorflow::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
ImportClientGraphToMlirModule(const GraphExecutor::ClientGraph & client_graph,mlir::MLIRContext * context) const486 GraphExecutor::ImportClientGraphToMlirModule(
487     const GraphExecutor::ClientGraph& client_graph,
488     mlir::MLIRContext* context) const {
489   tensorflow::GraphImportConfig graph_import_config;
490   graph_import_config.prune_unused_nodes = true;
491   graph_import_config.enable_shape_inference = false;
492   graph_import_config.inputs = client_graph.input_nodes;
493   graph_import_config.outputs = client_graph.output_nodes;
494   graph_import_config.control_outputs = client_graph.target_nodes;
495 
496   // Optimize the graph.
497   TF_ASSIGN_OR_RETURN(
498       auto optimized_graph,
499       graph_execution_state_->CreateOptimizedGraph(graph_import_config));
500 
501   LOG(INFO) << "TFRT import client graph (" << &client_graph
502             << "): Functionalization took "
503             << absl::ToInt64Milliseconds(
504                    optimized_graph.functionalization_duration)
505             << " ms. Client graph name: " << client_graph.name;
506   LOG(INFO) << "TFRT import client graph (" << &client_graph
507             << "): Grappler took "
508             << absl::ToInt64Milliseconds(optimized_graph.grappler_duration)
509             << " ms. Client graph name: " << client_graph.name;
510 
511   // Convert the optimized graph to an MLIR module.
512   return tensorflow::ConvertGraphToMlir(
513       *optimized_graph.graph, /*debug_info=*/{},
514       optimized_graph.graph->flib_def(), graph_import_config, context);
515 }
516 
CompileMlirModuleToBef(mlir::ModuleOp module) const517 StatusOr<tfrt::BefBuffer> GraphExecutor::CompileMlirModuleToBef(
518     mlir::ModuleOp module) const {
519   tfrt::BefBuffer bef;
520   TF_RETURN_IF_ERROR(
521       tensorflow::ConvertTfMlirToBef(options_.compile_options, module, &bef));
522   return bef;
523 }
524 
InitBef(tfrt::BEFFile * bef_file,tfrt::ResourceContext * resource_context,tensorflow::tfrt_stub::WorkQueueInterface * work_queue)525 tensorflow::Status GraphExecutor::InitBef(
526     tfrt::BEFFile* bef_file, tfrt::ResourceContext* resource_context,
527     tensorflow::tfrt_stub::WorkQueueInterface* work_queue) {
528   auto* host = runtime().core_runtime()->GetHostContext();
529   TF_ASSIGN_OR_RETURN(
530       auto request_info,
531       SetUpRequestContext(/*run_options=*/{}, /*model_metadata=*/{}, host,
532                           work_queue ? work_queue : runtime().work_queue(),
533                           resource_context, fallback_state_));
534 
535   tfrt::ExecutionContext exec_ctx(request_info->tfrt_request_context);
536 
537   // Run "_tfrt_fallback_init" first to initialize fallback-specific states. It
538   // is the special function created by compiler, which calls a sequence of
539   // tfrt_fallback_async.createop to create all fallback ops used in this BEF.
540   TF_RETURN_IF_ERROR(
541       RunRuntimeInitializer(exec_ctx, bef_file, "_tfrt_fallback_init"));
542 
543   // After we initialized all the resources in the original graph, we can run
544   // the "_tfrt_resource_init" function to set these resources in runtime
545   // states, so that later it can be efficiently retrieved without any locking.
546   TF_RETURN_IF_ERROR(
547       RunRuntimeInitializer(exec_ctx, bef_file, "_tfrt_resource_init"));
548 
549   return OkStatus();
550 }
551 
552 StatusOr<std::reference_wrapper<const GraphExecutor::LoadedClientGraph>>
GetOrCreateLoadedClientGraph(absl::Span<const std::string> input_tensor_names,absl::Span<const tensorflow::DataType> input_tensor_dtypes,absl::Span<const std::string> output_tensor_names,absl::Span<const std::string> target_tensor_names,tensorflow::tfrt_stub::WorkQueueInterface * work_queue)553 GraphExecutor::GetOrCreateLoadedClientGraph(
554     absl::Span<const std::string> input_tensor_names,
555     absl::Span<const tensorflow::DataType> input_tensor_dtypes,
556     absl::Span<const std::string> output_tensor_names,
557     absl::Span<const std::string> target_tensor_names,
558     tensorflow::tfrt_stub::WorkQueueInterface* work_queue) {
559   // The format of the joined name is illustrated as in the following example:
560   // input1-input2^output1-output2^target1-target2
561   const auto joined_name = absl::StrCat(
562       absl::StrJoin(input_tensor_names, kTensorNameJoiningDelimiter),
563       kArgumentTypeJoiningDelimiter,
564       absl::StrJoin(output_tensor_names, kTensorNameJoiningDelimiter),
565       kArgumentTypeJoiningDelimiter,
566       absl::StrJoin(target_tensor_names, kTensorNameJoiningDelimiter));
567 
568   tensorflow::mutex_lock l(loaded_client_graphs_mu_);
569 
570   // Cache hit; return immediately.
571   const auto iter = loaded_client_graphs_.find(joined_name);
572   if (iter != loaded_client_graphs_.end()) return {*iter->second};
573 
574   // Cache miss; populate a `ClientGraph` and load it.
575   tensorflow::GraphImportConfig::InputArrays input_nodes;
576   DCHECK_EQ(input_tensor_names.size(), input_tensor_dtypes.size());
577   for (int i = 0; i < input_tensor_names.size(); ++i) {
578     const auto& input_name = input_tensor_names[i];
579     auto input_dtype = input_tensor_dtypes[i];
580 
581     tensorflow::ArrayInfo array_info;
582     array_info.imported_dtype = input_dtype;
583     array_info.shape.set_unknown_rank(true);
584     input_nodes[input_name] = array_info;
585   }
586   ClientGraph client_graph{
587       joined_name,
588       std::move(input_nodes),
589       {output_tensor_names.begin(), output_tensor_names.end()},
590       {target_tensor_names.begin(), target_tensor_names.end()}};
591   TF_ASSIGN_OR_RETURN(auto loaded_client_graph,
592                       LoadClientGraph(client_graph, work_queue));
593 
594   // Store the new loaded client graph in cache and return.
595   const auto* loaded_client_graph_ptr = loaded_client_graph.get();
596   loaded_client_graphs_[joined_name] = std::move(loaded_client_graph);
597   return {*loaded_client_graph_ptr};
598 }
599 
600 }  // namespace tfrt_stub
601 }  // namespace tensorflow
602