1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/tfrt/eager/tfrt_context.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "tensorflow/core/common_runtime/process_util.h"
22 #include "tensorflow/core/framework/types.h"
23 #include "tensorflow/core/runtime_fallback/runtime/kernel_utils.h"
24 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_op_handler.h"
25 #include "tensorflow/core/tfrt/common/global_state.h"
26 #include "tensorflow/core/tfrt/eager/core_runtime/op_handler_registry.h"
27 #include "tensorflow/core/tpu/virtual_device.h"
28 #include "tensorflow/core/util/device_name_utils.h"
29 #include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime
30 #include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime
31 #include "tfrt/host_context/host_allocator.h" // from @tf_runtime
32
33 namespace tfrt {
34 namespace tf {
35
TfrtContext(const tensorflow::SessionOptions & opts,tensorflow::ContextDevicePlacementPolicy default_device_placement_policy,bool is_async)36 TfrtContext::TfrtContext(
37 const tensorflow::SessionOptions& opts,
38 tensorflow::ContextDevicePlacementPolicy default_device_placement_policy,
39 bool is_async) {
40 tensorflow::tfd::EagerContextResource* eager_context_resource =
41 resource_context_
42 .GetOrCreateResource<tensorflow::tfd::EagerContextResource>(
43 tensorflow::tfd::kEagerContextResourceName, opts,
44 default_device_placement_policy, is_async);
45 auto eager_context_expected = eager_context_resource->GetTFEagerContext();
46 DCHECK(eager_context_expected) << StrCat(eager_context_expected.takeError());
47 eager_context_ = eager_context_expected.get();
48
49 eager_ctx_thread_pool_ = std::make_unique<ThreadPoolInterfaceWrapper>(
50 eager_context_->GetThreadPool()->AsEigenThreadPool());
51
52 local_thread_pool_.reset(tensorflow::NewThreadPoolFromSessionOptions(opts));
53
54 local_thread_pool_wrapper_ = std::make_unique<ThreadPoolInterfaceWrapper>(
55 local_thread_pool_->AsEigenThreadPool());
56
57 tf_thread_pool_work_queue_ =
58 std::make_unique<tensorflow::tfrt_stub::TfThreadPoolWorkQueue>(
59 /*intra_op_threadpool=*/local_thread_pool_wrapper_.get(),
60 /*inter_op_threadpool=*/eager_ctx_thread_pool_.get());
61 LOG(INFO) << "Created work queue from TF thread pool. inter op thread pool "
62 << "# threads: " << eager_ctx_thread_pool_->NumThreads()
63 << " intra op thread pool # threads: "
64 << local_thread_pool_wrapper_->NumThreads();
65
66 // Default cpu device name is "/job:localhost/replica:0/task:0/device:CPU:0".
67 const std::string& host_cpu_name = eager_context_->HostCPU()->name();
68
69 auto diag_handler = [](const DecodedDiagnostic& diag) {
70 LOG(ERROR) << diag.message;
71 };
72
73 auto rt = CoreRuntime::Create(diag_handler, CreateMallocAllocator(),
74 CreateMultiThreadedWorkQueue(
75 /*num_threads=*/4,
76 /*num_blocking_threads=*/64),
77 host_cpu_name);
78 DCHECK(rt) << StrCat(rt.takeError());
79 corert_ = std::move(rt.get());
80 host_context_ = corert_->GetHostContext();
81
82 // Create multiple (currently virtual) CPU devices according to options.
83 // TODO(b/174877837): Support multiple physical cpu devices.
84 int requested_num_cpus = 1;
85 auto iter = opts.config.device_count().find("CPU");
86 if (iter != opts.config.device_count().end()) {
87 requested_num_cpus = iter->second;
88 }
89
90 std::string cpu_name_prefix{host_cpu_name};
91 cpu_name_prefix.pop_back(); // remove the `id` from host cpu device name.
92 for (int i = 1; i < requested_num_cpus; ++i) {
93 host_context_->GetDeviceManager()->MaybeAddDevice(TakeRef(
94 new CpuDevice(absl::StrCat(cpu_name_prefix, std::to_string(i)))));
95 }
96
97 // Specifically register RuntimeFallbackOpHandler.
98 auto runtime_fallback_op_handler =
99 tensorflow::tfd::CreateRuntimeFallbackOpHandler(corert_.get(), "");
100 DCHECK(runtime_fallback_op_handler)
101 << StrCat(runtime_fallback_op_handler.takeError());
102 fallback_op_handler_ = runtime_fallback_op_handler.get();
103 corert_->RegisterOpHandler("tf", fallback_op_handler_);
104
105 RegisterOpHandlers(corert_.get(), &resource_context_,
106 eager_context_->local_device_mgr());
107
108 // Set the global host context singleton.
109 tensorflow::tfrt_global::GlobalHostContext::Set(corert_->GetHostContext());
110 }
111
HostCPUParsedName() const112 const tensorflow::DeviceNameUtils::ParsedName& TfrtContext::HostCPUParsedName()
113 const {
114 return eager_context_->HostCPU()->parsed_name();
115 }
116
IsAsync() const117 bool TfrtContext::IsAsync() const { return eager_context_->Executor().Async(); }
118
~TfrtContext()119 TfrtContext::~TfrtContext() {}
120
121 } // namespace tf
122 } // namespace tfrt
123