1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tfrt/graph_executor/graph_executor.h"
16
17 #include <algorithm>
18 #include <array>
19 #include <cstdint>
20 #include <functional>
21 #include <memory>
22 #include <numeric>
23 #include <optional>
24 #include <string>
25 #include <utility>
26 #include <vector>
27
28 #include "learning/brain/experimental/tfrt/native_lowering/saved_model/saved_model_translate.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/string_view.h"
31 #include "absl/time/clock.h"
32 #include "absl/time/time.h"
33 #include "absl/types/span.h"
34 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
35 #include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_request_context.h"
36 #include "tensorflow/compiler/mlir/tfrt/translate/import_model.h"
37 #include "tensorflow/core/framework/tensor.h"
38 #include "tensorflow/core/lib/gtl/cleanup.h"
39 #include "tensorflow/core/platform/errors.h"
40 #include "tensorflow/core/platform/status.h"
41 #include "tensorflow/core/platform/statusor.h"
42 #include "tensorflow/core/platform/threadpool_interface.h"
43 #include "tensorflow/core/platform/types.h"
44 #include "tensorflow/core/profiler/lib/connected_traceme.h"
45 #include "tensorflow/core/profiler/lib/traceme_encode.h"
46 #include "tensorflow/core/protobuf/config.pb.h"
47 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.h"
48 #include "tensorflow/core/tfrt/fallback/fallback_state.h"
49 #include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h"
50 #include "tensorflow/core/tfrt/runtime/runtime.h"
51 #include "tensorflow/core/tfrt/runtime/work_queue_interface.h"
52 #include "tensorflow/core/tfrt/tpu/tpu_resources.h"
53 #include "tensorflow/core/tfrt/utils/error_util.h"
54 #include "tensorflow/core/tfrt/utils/fallback_tensor.h"
55 #include "tensorflow/core/tfrt/utils/utils.h"
56 #include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime
57 #include "tfrt/host_context/async_dispatch.h" // from @tf_runtime
58 #include "tfrt/host_context/async_value.h" // from @tf_runtime
59 #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime
60 #include "tfrt/host_context/chain.h" // from @tf_runtime
61 #include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime
62 #include "tfrt/host_context/execution_context.h" // from @tf_runtime
63 #include "tfrt/host_context/function.h" // from @tf_runtime
64 #include "tfrt/host_context/host_context.h" // from @tf_runtime
65 #include "tfrt/host_context/request_deadline_tracker.h" // from @tf_runtime
66 #include "tfrt/host_context/resource_context.h" // from @tf_runtime
67 #include "tfrt/support/forward_decls.h" // from @tf_runtime
68 #include "tfrt/support/ref_count.h" // from @tf_runtime
69 #include "tfrt/support/string_util.h" // from @tf_runtime
70
71 namespace tensorflow {
72 namespace tfrt_stub {
73 namespace {
74
75 constexpr char kDeadlineExceededMessage[] = "Deadline exceeded.";
76 constexpr char kTensorNameJoiningDelimiter[] = "-";
77 constexpr char kArgumentTypeJoiningDelimiter[] = "^";
78
79 } // namespace
80
SetUpRequestContext(const GraphExecutionRunOptions & run_options,const SessionMetadata & model_metadata,tfrt::HostContext * host,tensorflow::tfrt_stub::WorkQueueInterface * work_queue,tfrt::ResourceContext * resource_context,const tensorflow::tfrt_stub::FallbackState & fallback_state)81 StatusOr<std::unique_ptr<RequestInfo>> SetUpRequestContext(
82 const GraphExecutionRunOptions& run_options,
83 const SessionMetadata& model_metadata, tfrt::HostContext* host,
84 tensorflow::tfrt_stub::WorkQueueInterface* work_queue,
85 tfrt::ResourceContext* resource_context,
86 const tensorflow::tfrt_stub::FallbackState& fallback_state) {
87 DCHECK(host);
88 DCHECK(work_queue);
89 // Create request context and prepare deadline tracker.
90 // TODO(tfrt-devs): Consider using an ID unique within each model to reduce
91 // contention.
92 int64_t request_id = work_queue->id();
93 if (request_id == 0) request_id = tfrt::GetUniqueInt();
94 tfrt::RequestContextBuilder request_context_builder(
95 host, resource_context, request_id, run_options.enable_cost_measurement);
96
97 // TODO(b/198671794): `intra_op_threadpool` should be passed through Run()
98 // directly.
99 tensorflow::thread::ThreadPoolInterface* intra_op_threadpool = nullptr;
100
101 // TODO(b/198671794): The per-request queue should be passed through Run()
102 // directly.
103 TF_ASSIGN_OR_RETURN(auto request_queue,
104 work_queue->InitializeRequest(&request_context_builder,
105 &intra_op_threadpool));
106
107 auto request_info = std::make_unique<RequestInfo>();
108
109 // If a per-request queue is not provided, use the original queue in the
110 // tensorflow::Executor::Args::Runner.
111 auto* inter_op_queue = request_queue ? request_queue.get() : work_queue;
112 request_info->runner = [inter_op_queue](std::function<void()> f) {
113 inter_op_queue->AddTask(std::move(f));
114 };
115
116 request_info->request_queue = std::move(request_queue);
117
118 TF_RETURN_IF_ERROR(tensorflow::tfd::SetUpKernelFallbackCompatRequestContext(
119 &request_context_builder, &fallback_state.device_manager(),
120 &fallback_state.process_function_library_runtime(), intra_op_threadpool,
121 model_metadata, &request_info->runner));
122
123 TF_RETURN_IF_ERROR(
124 tensorflow::SetUpTfJitRtRequestContext(&request_context_builder));
125 tfrt::RequestOptions request_options;
126 request_options.priority = run_options.priority;
127 request_context_builder.set_request_options(request_options);
128
129 auto expected_req_ctx = std::move(request_context_builder).build();
130 if (!expected_req_ctx) {
131 return tensorflow::errors::Internal(
132 tfrt::StrCat(expected_req_ctx.takeError()));
133 }
134
135 request_info->tfrt_request_context = std::move(expected_req_ctx.get());
136
137 return request_info;
138 }
139
GraphExecutionRunOnFunction(const GraphExecutionOptions & options,const GraphExecutionRunOptions & run_options,absl::string_view signature_name,const tfrt::Function & func,absl::Span<const tensorflow::Tensor> inputs,absl::Span<const tensorflow::Tensor> captures,std::vector<tensorflow::Tensor> * outputs,tfrt::ResourceContext * resource_context,const Runtime & runtime,const FallbackState & fallback_state,tfrt::RequestDeadlineTracker & req_deadline_tracker)140 tensorflow::Status GraphExecutionRunOnFunction(
141 const GraphExecutionOptions& options,
142 const GraphExecutionRunOptions& run_options,
143 absl::string_view signature_name, const tfrt::Function& func,
144 absl::Span<const tensorflow::Tensor> inputs,
145 absl::Span<const tensorflow::Tensor> captures,
146 std::vector<tensorflow::Tensor>* outputs,
147 tfrt::ResourceContext* resource_context, const Runtime& runtime,
148 const FallbackState& fallback_state,
149 tfrt::RequestDeadlineTracker& req_deadline_tracker) {
150 auto* host = runtime.core_runtime()->GetHostContext();
151
152 TF_ASSIGN_OR_RETURN(
153 auto request_info,
154 SetUpRequestContext(run_options, options.model_metadata, host,
155 run_options.work_queue ? run_options.work_queue
156 : runtime.work_queue(),
157 resource_context, fallback_state));
158
159 tensorflow::profiler::TraceMeProducer traceme(
160 // To TraceMeConsumers in RunHandlerThreadPool::WorkerLoop.
161 [request_id = request_info->tfrt_request_context->id(), signature_name,
162 &options] {
163 return tensorflow::profiler::TraceMeEncode(
164 "TfrtModelRun",
165 {{"_r", 1},
166 {"id", request_id},
167 {"signature", signature_name},
168 {"model_id", absl::StrCat(options.model_metadata.name(), ":",
169 options.model_metadata.version())}});
170 },
171 tensorflow::profiler::ContextType::kTfrtExecutor,
172 request_info->tfrt_request_context->id());
173
174 // Only configure timer when the deadline is set.
175 if (run_options.deadline.has_value()) {
176 auto deadline = run_options.deadline.value();
177 if (absl::ToChronoTime(absl::Now()) > deadline) {
178 return tensorflow::errors::DeadlineExceeded(kDeadlineExceededMessage);
179 }
180 req_deadline_tracker.CancelRequestOnDeadline(
181 deadline, request_info->tfrt_request_context);
182 }
183
184 tfrt::ExecutionContext exec_ctx{request_info->tfrt_request_context};
185 if (run_options.work_queue) {
186 // TODO(b/198671794): Avoid creating `request_queue` when the `work_queue`
187 // in `run_options` is specified.
188 exec_ctx.set_work_queue(run_options.work_queue);
189 } else if (request_info->request_queue) {
190 exec_ctx.set_work_queue(request_info->request_queue.get());
191 } else {
192 exec_ctx.set_work_queue(runtime.work_queue());
193 }
194
195 llvm::SmallVector<tfrt::AsyncValue*, 4> arguments;
196 auto cleanup = tensorflow::gtl::MakeCleanup([&]() {
197 for (auto* argument : arguments) argument->DropRef();
198 });
199
200 // The first argument is a chain for side-effects. Since SavedModel::Run()
201 // only returns when side-effects are visible, we can use a ready chain here.
202 arguments.push_back(tfrt::GetReadyChain().release());
203
204 for (const auto& input : inputs) {
205 arguments.push_back(
206 tfrt::MakeAvailableAsyncValueRef<FallbackTensor>(input).release());
207 }
208
209 DCHECK(captures.empty()) << "signature should have no captures, which is "
210 "guaranteed by the compiler";
211
212 if (arguments.size() != func.argument_types().size())
213 return tensorflow::errors::Internal("incorrect number of inputs.");
214
215 llvm::SmallVector<tfrt::RCReference<tfrt::AsyncValue>, 4> chain_and_results;
216 chain_and_results.resize(func.result_types().size());
217
218 // Hand over the execution to thread pool.
219 std::array<tfrt::RCReference<tfrt::AsyncValue>, 1> executed = {
220 EnqueueWork(exec_ctx, [&]() -> tfrt::Chain {
221 func.Execute(exec_ctx, arguments, chain_and_results);
222 return {};
223 })};
224
225 // Wait for the function execution before checking chain and results.
226 exec_ctx.work_queue().Await(executed);
227
228 // Wait for all results including the side-effect chain. This ensures that all
229 // side-effects are visible when SavedModel::Run() returns.
230 exec_ctx.work_queue().Await(chain_and_results);
231
232 DCHECK(!chain_and_results.empty());
233
234 tfrt::RCReference<tfrt::AsyncValue>& chain = chain_and_results[0];
235 auto results = llvm::drop_begin(chain_and_results, 1);
236
237 tensorflow::StatusGroup status_group;
238
239 if (chain->IsError()) {
240 status_group.Update(CreateTfErrorStatus(chain->GetError()));
241 }
242
243 for (tfrt::RCReference<tfrt::AsyncValue>& result : results) {
244 DCHECK(result->IsAvailable());
245
246 if (result->IsError()) {
247 status_group.Update(CreateTfErrorStatus(result->GetError()));
248 outputs->push_back(tensorflow::Tensor());
249 continue;
250 }
251
252 // The result must be a host tensor. This is guaranteed as the compiler
253 // will insert necessary device transfer operations in the graph.
254 DCHECK(result->IsType<FallbackTensor>());
255 const auto& host_tensor = result->get<FallbackTensor>().tensor();
256 // Make a copy of tensor here as the different result AsyncValues might
257 // point to the same underlying tensor.
258 outputs->push_back(host_tensor);
259 }
260
261 // TODO(b/171926578): Explicitly clear the context data. Remove it after the
262 // b/171926578 is fixed.
263 exec_ctx.request_ctx()->ClearData();
264
265 // Check if error is due to cancellation.
266 // TODO(tfrt-devs): report cancellation reason from runtime.
267 if (request_info->tfrt_request_context->IsCancelled()) {
268 // Currently a request can only be cancelled by an expired timer.
269 return tensorflow::errors::DeadlineExceeded(kDeadlineExceededMessage);
270 }
271
272 return status_group.as_summary_status();
273 }
274
CreateResourceContext(const tensorflow::tfrt_stub::Runtime & runtime,tfrt::tpu::TpuModelResource * tpu_model_resource,tensorflow::TfrtTpuInfraTarget tpu_target)275 std::unique_ptr<tfrt::ResourceContext> CreateResourceContext(
276 const tensorflow::tfrt_stub::Runtime& runtime,
277 tfrt::tpu::TpuModelResource* tpu_model_resource,
278 tensorflow::TfrtTpuInfraTarget tpu_target) {
279 auto resource_context = std::make_unique<tfrt::ResourceContext>();
280 runtime.CreateRuntimeResources(resource_context.get());
281
282 // TODO(b/178227859): We should make TPU resource init code pluggable, as
283 // opposed to linking it in. We can do this by adding a callback with
284 // `Runtime::AddCreateRuntimeResourceFn`.
285 if (tpu_target == tensorflow::TfrtTpuInfraTarget::kTpurt) {
286 AddTpuResources(resource_context.get(), tpu_model_resource);
287 }
288 return resource_context;
289 }
290
Create(Options options,const FallbackState & fallback_state,tfrt::tpu::TpuModelResource * tpu_model_resource,tensorflow::GraphDef graph_def)291 StatusOr<std::unique_ptr<GraphExecutor>> GraphExecutor::Create(
292 Options options, const FallbackState& fallback_state,
293 tfrt::tpu::TpuModelResource* tpu_model_resource,
294 tensorflow::GraphDef graph_def) {
295 if (options.runtime == nullptr) {
296 return errors::InvalidArgument("options.runtime must be non-null ");
297 }
298
299 TfrtGraphExecutionState::Options graph_execution_state_options;
300 graph_execution_state_options.run_placer_grappler_on_functions =
301 options.run_placer_grappler_on_functions;
302 graph_execution_state_options.enable_tfrt_gpu = options.enable_tfrt_gpu;
303
304 TF_ASSIGN_OR_RETURN(
305 auto graph_execution_state,
306 TfrtGraphExecutionState::Create(graph_execution_state_options,
307 std::move(graph_def), fallback_state));
308 return std::make_unique<GraphExecutor>(std::move(options), fallback_state,
309 tpu_model_resource,
310 std::move(graph_execution_state));
311 }
312
313 namespace {
314
315 // Sort the strings in `names` and store the results in `sorted_names`. In
316 // addition, the original index in `names` for the item `sorted_names[i]` is
317 // stored in `original_indices[i]`.
CreateSortedNamesAndOriginalIndices(absl::Span<const std::string> names,std::vector<std::string> & sorted_names,std::vector<int> & original_indices)318 void CreateSortedNamesAndOriginalIndices(absl::Span<const std::string> names,
319 std::vector<std::string>& sorted_names,
320 std::vector<int>& original_indices) {
321 DCHECK(sorted_names.empty());
322 DCHECK(original_indices.empty());
323
324 // Generate indices.
325 original_indices.resize(names.size());
326 std::iota(original_indices.begin(), original_indices.end(), 0);
327
328 // Sort indices by comparing the corresponding names.
329 std::sort(original_indices.begin(), original_indices.end(),
330 [&](int x, int y) { return names[x] < names[y]; });
331
332 // Use sorted indices to generate sorted names.
333 sorted_names.reserve(names.size());
334 for (int original_index : original_indices) {
335 DCHECK_LT(original_index, names.size());
336 sorted_names.push_back(names[original_index]);
337 }
338 }
339
340 } // namespace
341
Run(const RunOptions & run_options,absl::Span<const std::pair<std::string,tensorflow::Tensor>> inputs,absl::Span<const std::string> output_tensor_names,absl::Span<const std::string> target_tensor_names,std::vector<tensorflow::Tensor> * outputs)342 tensorflow::Status GraphExecutor::Run(
343 const RunOptions& run_options,
344 absl::Span<const std::pair<std::string, tensorflow::Tensor>> inputs,
345 absl::Span<const std::string> output_tensor_names,
346 absl::Span<const std::string> target_tensor_names,
347 std::vector<tensorflow::Tensor>* outputs) {
348 // TODO(b/192498110): Validate input type.
349
350 // Sort the input/output names to have a stable order, so that the
351 // `joined_name`, which is used as the cache key, will be the same as long as
352 // the same set of inputs/outputs are specified.
353 std::vector<std::string> input_names;
354 input_names.reserve(inputs.size());
355 for (const auto& p : inputs) input_names.push_back(p.first);
356 std::vector<std::string> sorted_input_names;
357 std::vector<int> input_original_indices;
358 CreateSortedNamesAndOriginalIndices(input_names, sorted_input_names,
359 input_original_indices);
360 // We also need to create sorted input dtypes as they are needed for the
361 // compilation.
362 std::vector<tensorflow::DataType> sorted_input_dtypes;
363 sorted_input_dtypes.reserve(inputs.size());
364 for (int original_index : input_original_indices) {
365 sorted_input_dtypes.push_back(inputs.at(original_index).second.dtype());
366 }
367
368 std::vector<std::string> sorted_output_names;
369 std::vector<int> output_original_indices;
370 CreateSortedNamesAndOriginalIndices(output_tensor_names, sorted_output_names,
371 output_original_indices);
372
373 // For target node names, we only need to sort them. The original indices are
374 // not needed.
375 std::vector<std::string> sorted_target_node_names(target_tensor_names.begin(),
376 target_tensor_names.end());
377 std::sort(sorted_target_node_names.begin(), sorted_target_node_names.end());
378
379 // Load the client graph.
380 TF_ASSIGN_OR_RETURN(
381 const LoadedClientGraph& loaded_client_graph,
382 GetOrCreateLoadedClientGraph(
383 sorted_input_names, sorted_input_dtypes, sorted_output_names,
384 sorted_target_node_names, run_options.work_queue));
385
386 const auto* func = loaded_client_graph.bef_file->GetFunction(
387 tensorflow::kImportModelDefaultGraphFuncName);
388 DCHECK(func);
389
390 // Create the actual arguments to the compiled function, which are sorted
391 // according to the input tensor names.
392 std::vector<tensorflow::Tensor> flat_inputs;
393 flat_inputs.reserve(inputs.size());
394 for (int original_index : input_original_indices) {
395 flat_inputs.push_back(inputs.at(original_index).second);
396 }
397
398 std::vector<tensorflow::Tensor> flat_outputs;
399 TF_RETURN_IF_ERROR(GraphExecutionRunOnFunction(
400 options_, run_options, loaded_client_graph.name, *func, flat_inputs,
401 /*captures=*/{}, &flat_outputs,
402 loaded_client_graph.resource_context.get(), runtime(), fallback_state_,
403 req_deadline_tracker_));
404
405 // Create the outputs from the actual function results, which are sorted
406 // according to the output tensor names.
407 auto flat_output_iter = flat_outputs.begin();
408 outputs->resize(flat_outputs.size());
409 for (int original_index : output_original_indices) {
410 (*outputs)[original_index] = std::move(*flat_output_iter);
411 ++flat_output_iter;
412 }
413
414 return OkStatus();
415 }
416
Extend(const GraphDef & graph)417 tensorflow::Status GraphExecutor::Extend(const GraphDef& graph) {
418 return graph_execution_state_->Extend(graph);
419 }
420
421 StatusOr<std::unique_ptr<GraphExecutor::LoadedClientGraph>>
ImportAndCompileClientGraph(const GraphExecutor::ClientGraph & client_graph)422 GraphExecutor::ImportAndCompileClientGraph(
423 const GraphExecutor::ClientGraph& client_graph) {
424 auto loaded_client_graph = std::make_unique<LoadedClientGraph>();
425 loaded_client_graph->name = client_graph.name;
426 loaded_client_graph->resource_context = CreateResourceContext(
427 runtime(), tpu_model_resource_, options_.compile_options.tpu_target);
428
429 // Step 1 of loading: Import the client graph from proto to an MLIR module.
430 auto import_start_time = absl::Now();
431 mlir::MLIRContext context;
432 ASSIGN_OR_RETURN_IN_IMPORT(
433 auto module, ImportClientGraphToMlirModule(client_graph, &context));
434 auto import_duration = absl::Now() - import_start_time;
435 LOG(INFO) << "TFRT finished importing client graph (" << &client_graph
436 << "). Took " << absl::ToInt64Milliseconds(import_duration)
437 << " ms. Client graph name: " << client_graph.name;
438
439 // Step 2 of loading: Compile the MLIR module from TF dialect to TFRT dialect
440 // (in BEF).
441 // TODO(b/229261464): Unify the sync and async lowering passes so we do not
442 // need this branch.
443 auto compile_start_time = absl::Now();
444 if (options_.compile_options.compile_to_sync_tfrt_dialect) {
445 ASSIGN_OR_RETURN_IN_COMPILE(
446 loaded_client_graph->bef,
447 tfrt::CompileTfMlirModuleToSyncBef(module.get()));
448 } else {
449 ASSIGN_OR_RETURN_IN_COMPILE(loaded_client_graph->bef,
450 CompileMlirModuleToBef(module.get()));
451 }
452 ASSIGN_OR_RETURN_IN_COMPILE(
453 loaded_client_graph->bef_file,
454 tfrt::CreateBefFileFromBefBuffer(runtime(), loaded_client_graph->bef));
455 auto compile_duration = absl::Now() - compile_start_time;
456 LOG(INFO) << "TFRT finished compiling client graph (" << &client_graph
457 << "). Took " << absl::ToInt64Milliseconds(compile_duration)
458 << " ms. Client graph name: " << client_graph.name;
459
460 return loaded_client_graph;
461 }
462
463 StatusOr<std::unique_ptr<GraphExecutor::LoadedClientGraph>>
LoadClientGraph(const GraphExecutor::ClientGraph & client_graph,tensorflow::tfrt_stub::WorkQueueInterface * work_queue)464 GraphExecutor::LoadClientGraph(
465 const GraphExecutor::ClientGraph& client_graph,
466 tensorflow::tfrt_stub::WorkQueueInterface* work_queue) {
467 LOG(INFO) << "TFRT loading client graph (" << &client_graph << ") "
468 << client_graph.name;
469 TF_ASSIGN_OR_RETURN(auto loaded_client_graph,
470 ImportAndCompileClientGraph(client_graph));
471
472 // Step 3 of loading: Initialize runtime states using special BEF functions.
473 auto init_start_time = absl::Now();
474 RETURN_IF_ERROR_IN_INIT(InitBef(loaded_client_graph->bef_file.get(),
475 loaded_client_graph->resource_context.get(),
476 work_queue));
477 auto init_duration = absl::Now() - init_start_time;
478 LOG(INFO) << "TFRT finished initializing client graph (" << &client_graph
479 << "). Took " << absl::ToInt64Milliseconds(init_duration)
480 << " ms. Client graph name: " << client_graph.name;
481
482 return loaded_client_graph;
483 }
484
485 tensorflow::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
ImportClientGraphToMlirModule(const GraphExecutor::ClientGraph & client_graph,mlir::MLIRContext * context) const486 GraphExecutor::ImportClientGraphToMlirModule(
487 const GraphExecutor::ClientGraph& client_graph,
488 mlir::MLIRContext* context) const {
489 tensorflow::GraphImportConfig graph_import_config;
490 graph_import_config.prune_unused_nodes = true;
491 graph_import_config.enable_shape_inference = false;
492 graph_import_config.inputs = client_graph.input_nodes;
493 graph_import_config.outputs = client_graph.output_nodes;
494 graph_import_config.control_outputs = client_graph.target_nodes;
495
496 // Optimize the graph.
497 TF_ASSIGN_OR_RETURN(
498 auto optimized_graph,
499 graph_execution_state_->CreateOptimizedGraph(graph_import_config));
500
501 LOG(INFO) << "TFRT import client graph (" << &client_graph
502 << "): Functionalization took "
503 << absl::ToInt64Milliseconds(
504 optimized_graph.functionalization_duration)
505 << " ms. Client graph name: " << client_graph.name;
506 LOG(INFO) << "TFRT import client graph (" << &client_graph
507 << "): Grappler took "
508 << absl::ToInt64Milliseconds(optimized_graph.grappler_duration)
509 << " ms. Client graph name: " << client_graph.name;
510
511 // Convert the optimized graph to an MLIR module.
512 return tensorflow::ConvertGraphToMlir(
513 *optimized_graph.graph, /*debug_info=*/{},
514 optimized_graph.graph->flib_def(), graph_import_config, context);
515 }
516
CompileMlirModuleToBef(mlir::ModuleOp module) const517 StatusOr<tfrt::BefBuffer> GraphExecutor::CompileMlirModuleToBef(
518 mlir::ModuleOp module) const {
519 tfrt::BefBuffer bef;
520 TF_RETURN_IF_ERROR(
521 tensorflow::ConvertTfMlirToBef(options_.compile_options, module, &bef));
522 return bef;
523 }
524
InitBef(tfrt::BEFFile * bef_file,tfrt::ResourceContext * resource_context,tensorflow::tfrt_stub::WorkQueueInterface * work_queue)525 tensorflow::Status GraphExecutor::InitBef(
526 tfrt::BEFFile* bef_file, tfrt::ResourceContext* resource_context,
527 tensorflow::tfrt_stub::WorkQueueInterface* work_queue) {
528 auto* host = runtime().core_runtime()->GetHostContext();
529 TF_ASSIGN_OR_RETURN(
530 auto request_info,
531 SetUpRequestContext(/*run_options=*/{}, /*model_metadata=*/{}, host,
532 work_queue ? work_queue : runtime().work_queue(),
533 resource_context, fallback_state_));
534
535 tfrt::ExecutionContext exec_ctx(request_info->tfrt_request_context);
536
537 // Run "_tfrt_fallback_init" first to initialize fallback-specific states. It
538 // is the special function created by compiler, which calls a sequence of
539 // tfrt_fallback_async.createop to create all fallback ops used in this BEF.
540 TF_RETURN_IF_ERROR(
541 RunRuntimeInitializer(exec_ctx, bef_file, "_tfrt_fallback_init"));
542
543 // After we initialized all the resources in the original graph, we can run
544 // the "_tfrt_resource_init" function to set these resources in runtime
545 // states, so that later it can be efficiently retrieved without any locking.
546 TF_RETURN_IF_ERROR(
547 RunRuntimeInitializer(exec_ctx, bef_file, "_tfrt_resource_init"));
548
549 return OkStatus();
550 }
551
552 StatusOr<std::reference_wrapper<const GraphExecutor::LoadedClientGraph>>
GetOrCreateLoadedClientGraph(absl::Span<const std::string> input_tensor_names,absl::Span<const tensorflow::DataType> input_tensor_dtypes,absl::Span<const std::string> output_tensor_names,absl::Span<const std::string> target_tensor_names,tensorflow::tfrt_stub::WorkQueueInterface * work_queue)553 GraphExecutor::GetOrCreateLoadedClientGraph(
554 absl::Span<const std::string> input_tensor_names,
555 absl::Span<const tensorflow::DataType> input_tensor_dtypes,
556 absl::Span<const std::string> output_tensor_names,
557 absl::Span<const std::string> target_tensor_names,
558 tensorflow::tfrt_stub::WorkQueueInterface* work_queue) {
559 // The format of the joined name is illustrated as in the following example:
560 // input1-input2^output1-output2^target1-target2
561 const auto joined_name = absl::StrCat(
562 absl::StrJoin(input_tensor_names, kTensorNameJoiningDelimiter),
563 kArgumentTypeJoiningDelimiter,
564 absl::StrJoin(output_tensor_names, kTensorNameJoiningDelimiter),
565 kArgumentTypeJoiningDelimiter,
566 absl::StrJoin(target_tensor_names, kTensorNameJoiningDelimiter));
567
568 tensorflow::mutex_lock l(loaded_client_graphs_mu_);
569
570 // Cache hit; return immediately.
571 const auto iter = loaded_client_graphs_.find(joined_name);
572 if (iter != loaded_client_graphs_.end()) return {*iter->second};
573
574 // Cache miss; populate a `ClientGraph` and load it.
575 tensorflow::GraphImportConfig::InputArrays input_nodes;
576 DCHECK_EQ(input_tensor_names.size(), input_tensor_dtypes.size());
577 for (int i = 0; i < input_tensor_names.size(); ++i) {
578 const auto& input_name = input_tensor_names[i];
579 auto input_dtype = input_tensor_dtypes[i];
580
581 tensorflow::ArrayInfo array_info;
582 array_info.imported_dtype = input_dtype;
583 array_info.shape.set_unknown_rank(true);
584 input_nodes[input_name] = array_info;
585 }
586 ClientGraph client_graph{
587 joined_name,
588 std::move(input_nodes),
589 {output_tensor_names.begin(), output_tensor_names.end()},
590 {target_tensor_names.begin(), target_tensor_names.end()}};
591 TF_ASSIGN_OR_RETURN(auto loaded_client_graph,
592 LoadClientGraph(client_graph, work_queue));
593
594 // Store the new loaded client graph in cache and return.
595 const auto* loaded_client_graph_ptr = loaded_client_graph.get();
596 loaded_client_graphs_[joined_name] = std::move(loaded_client_graph);
597 return {*loaded_client_graph_ptr};
598 }
599
600 } // namespace tfrt_stub
601 } // namespace tensorflow
602