1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/runtime_fallback/runtime/kernel_utils.h"
16
17 #include "tensorflow/core/common_runtime/device_mgr.h"
18 #include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
19 #include "tensorflow/core/platform/status.h"
20 #include "tensorflow/core/public/session_options.h"
21 #include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime
22 #include "tfrt/support/error_util.h" // from @tf_runtime
23
24 namespace tensorflow {
25 namespace tfd {
26
InitEagerContext(DynamicDeviceMgr * device_mgr,const SessionOptions & session_opts,ContextDevicePlacementPolicy default_device_placement_policy,bool is_async)27 tfrt::Expected<OwnedEagerContext> InitEagerContext(
28 DynamicDeviceMgr* device_mgr, const SessionOptions& session_opts,
29 ContextDevicePlacementPolicy default_device_placement_policy,
30 bool is_async) {
31 // Copied from TFE_NewContext.
32 std::vector<std::unique_ptr<tensorflow::Device>> devices;
33 tensorflow::Status status = tensorflow::DeviceFactory::AddDevices(
34 session_opts, "/job:localhost/replica:0/task:0", &devices);
35 if (!status.ok()) {
36 return tfrt::MakeStringError(tfrt::StrCat(status.error_message()));
37 }
38
39 if (device_mgr != nullptr) {
40 Status s = device_mgr->AddDevices(std::move(devices));
41 DCHECK(s.ok()) << "Failed to initialize device manager.";
42 auto r = new tensorflow::IntraProcessRendezvous(device_mgr);
43
44 OwnedEagerContext owned_eager_context{new tensorflow::EagerContext(
45 session_opts, default_device_placement_policy, is_async, device_mgr,
46 /*device_mgr_owned=*/false, r)};
47
48 #if !defined(IS_MOBILE_PLATFORM)
49 owned_eager_context->SetDistributedManager(
50 std::make_unique<tensorflow::EagerContextDistributedManager>(
51 owned_eager_context.get()));
52 #endif
53
54 return std::move(owned_eager_context);
55 }
56
57 auto owned_device_mgr =
58 std::make_unique<tensorflow::StaticDeviceMgr>(std::move(devices));
59 auto r = new tensorflow::IntraProcessRendezvous(owned_device_mgr.get());
60
61 OwnedEagerContext owned_eager_context{new tensorflow::EagerContext(
62 session_opts, default_device_placement_policy, is_async,
63 owned_device_mgr.release(), /*device_mgr_owned=*/true, r)};
64
65 #if !defined(IS_MOBILE_PLATFORM)
66 owned_eager_context->SetDistributedManager(
67 std::make_unique<tensorflow::EagerContextDistributedManager>(
68 owned_eager_context.get()));
69 #endif
70
71 return std::move(owned_eager_context);
72 }
73
InitEagerContext()74 tfrt::Expected<OwnedEagerContext> InitEagerContext() {
75 tensorflow::SessionOptions session_opts;
76 return InitEagerContext(
77 /*device_mgr=*/nullptr, session_opts,
78 tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
79 /*is_async=*/false);
80 }
81
GetEagerContext(tfrt::ExecutionContext exec_ctx)82 tfrt::Expected<EagerContext*> GetEagerContext(tfrt::ExecutionContext exec_ctx) {
83 tfrt::ResourceContext* resource_context = exec_ctx.resource_context();
84 tensorflow::tfd::EagerContextResource* eager_context_resource =
85 resource_context
86 ->GetOrCreateResource<tensorflow::tfd::EagerContextResource>(
87 tensorflow::tfd::kEagerContextResourceName);
88 return eager_context_resource->GetTFEagerContext();
89 }
90
GetFallbackOp(tfrt::string_view op_name,tfrt::HostContext * host)91 tfrt::Expected<tfrt::CoreRuntimeOp> GetFallbackOp(tfrt::string_view op_name,
92 tfrt::HostContext* host) {
93 auto* runtime = tfrt::CoreRuntime::GetFromHostContext(host);
94 assert(runtime);
95 // TODO(b/161993570): Cleanup this magic string constant.
96 constexpr tfrt::string_view kRuntimeFallbackOpHandlerName = "tf";
97
98 tfrt::OpHandler* op_handler = nullptr;
99 // TODO(b/165334630): Cleanup GPU macros.
100 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
101 op_handler = runtime->GetOpHandler(kRuntimeFallbackOpHandlerName);
102 #else
103 constexpr tfrt::string_view kKernelFallbackOpHandlerName = "tfkernel";
104 op_handler = runtime->GetOpHandler(kKernelFallbackOpHandlerName);
105 if (op_handler == nullptr) {
106 op_handler = runtime->GetOpHandler(kRuntimeFallbackOpHandlerName);
107 }
108 #endif
109 assert(op_handler && "fallback op_handler not found");
110
111 return runtime->MakeOp(op_name, op_handler);
112 }
113
114 } // namespace tfd
115 } // namespace tensorflow
116