xref: /aosp_15_r20/external/tensorflow/tensorflow/core/runtime_fallback/runtime/kernel_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/runtime_fallback/runtime/kernel_utils.h"
16 
17 #include "tensorflow/core/common_runtime/device_mgr.h"
18 #include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
19 #include "tensorflow/core/platform/status.h"
20 #include "tensorflow/core/public/session_options.h"
21 #include "tfrt/core_runtime/core_runtime.h"  // from @tf_runtime
22 #include "tfrt/support/error_util.h"  // from @tf_runtime
23 
24 namespace tensorflow {
25 namespace tfd {
26 
InitEagerContext(DynamicDeviceMgr * device_mgr,const SessionOptions & session_opts,ContextDevicePlacementPolicy default_device_placement_policy,bool is_async)27 tfrt::Expected<OwnedEagerContext> InitEagerContext(
28     DynamicDeviceMgr* device_mgr, const SessionOptions& session_opts,
29     ContextDevicePlacementPolicy default_device_placement_policy,
30     bool is_async) {
31   // Copied from TFE_NewContext.
32   std::vector<std::unique_ptr<tensorflow::Device>> devices;
33   tensorflow::Status status = tensorflow::DeviceFactory::AddDevices(
34       session_opts, "/job:localhost/replica:0/task:0", &devices);
35   if (!status.ok()) {
36     return tfrt::MakeStringError(tfrt::StrCat(status.error_message()));
37   }
38 
39   if (device_mgr != nullptr) {
40     Status s = device_mgr->AddDevices(std::move(devices));
41     DCHECK(s.ok()) << "Failed to initialize device manager.";
42     auto r = new tensorflow::IntraProcessRendezvous(device_mgr);
43 
44     OwnedEagerContext owned_eager_context{new tensorflow::EagerContext(
45         session_opts, default_device_placement_policy, is_async, device_mgr,
46         /*device_mgr_owned=*/false, r)};
47 
48 #if !defined(IS_MOBILE_PLATFORM)
49     owned_eager_context->SetDistributedManager(
50         std::make_unique<tensorflow::EagerContextDistributedManager>(
51             owned_eager_context.get()));
52 #endif
53 
54     return std::move(owned_eager_context);
55   }
56 
57   auto owned_device_mgr =
58       std::make_unique<tensorflow::StaticDeviceMgr>(std::move(devices));
59   auto r = new tensorflow::IntraProcessRendezvous(owned_device_mgr.get());
60 
61   OwnedEagerContext owned_eager_context{new tensorflow::EagerContext(
62       session_opts, default_device_placement_policy, is_async,
63       owned_device_mgr.release(), /*device_mgr_owned=*/true, r)};
64 
65 #if !defined(IS_MOBILE_PLATFORM)
66   owned_eager_context->SetDistributedManager(
67       std::make_unique<tensorflow::EagerContextDistributedManager>(
68           owned_eager_context.get()));
69 #endif
70 
71   return std::move(owned_eager_context);
72 }
73 
InitEagerContext()74 tfrt::Expected<OwnedEagerContext> InitEagerContext() {
75   tensorflow::SessionOptions session_opts;
76   return InitEagerContext(
77       /*device_mgr=*/nullptr, session_opts,
78       tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
79       /*is_async=*/false);
80 }
81 
GetEagerContext(tfrt::ExecutionContext exec_ctx)82 tfrt::Expected<EagerContext*> GetEagerContext(tfrt::ExecutionContext exec_ctx) {
83   tfrt::ResourceContext* resource_context = exec_ctx.resource_context();
84   tensorflow::tfd::EagerContextResource* eager_context_resource =
85       resource_context
86           ->GetOrCreateResource<tensorflow::tfd::EagerContextResource>(
87               tensorflow::tfd::kEagerContextResourceName);
88   return eager_context_resource->GetTFEagerContext();
89 }
90 
GetFallbackOp(tfrt::string_view op_name,tfrt::HostContext * host)91 tfrt::Expected<tfrt::CoreRuntimeOp> GetFallbackOp(tfrt::string_view op_name,
92                                                   tfrt::HostContext* host) {
93   auto* runtime = tfrt::CoreRuntime::GetFromHostContext(host);
94   assert(runtime);
95   // TODO(b/161993570): Cleanup this magic string constant.
96   constexpr tfrt::string_view kRuntimeFallbackOpHandlerName = "tf";
97 
98   tfrt::OpHandler* op_handler = nullptr;
99   // TODO(b/165334630): Cleanup GPU macros.
100 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
101   op_handler = runtime->GetOpHandler(kRuntimeFallbackOpHandlerName);
102 #else
103   constexpr tfrt::string_view kKernelFallbackOpHandlerName = "tfkernel";
104   op_handler = runtime->GetOpHandler(kKernelFallbackOpHandlerName);
105   if (op_handler == nullptr) {
106     op_handler = runtime->GetOpHandler(kRuntimeFallbackOpHandlerName);
107   }
108 #endif
109   assert(op_handler && "fallback op_handler not found");
110 
111   return runtime->MakeOp(op_name, op_handler);
112 }
113 
114 }  // namespace tfd
115 }  // namespace tensorflow
116