xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/renamed_device.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
18 
19 #include "tensorflow/core/common_runtime/device.h"
20 #include "tensorflow/core/lib/core/threadpool_interface.h"
21 #include "tensorflow/core/util/device_name_utils.h"
22 
23 namespace tensorflow {
24 
25 // Wraps a device with a new name, delegating work to the wrapped device.
26 //
27 // This class is used to wrap local devices when using clusterspec propagation
28 // where the name of a particular device may change in the context of a given
29 // session.
30 class RenamedDevice : public Device {
31  public:
32   static std::unique_ptr<Device> NewRenamedDevice(
33       const string& new_base, Device* underlying, bool owns_underlying,
34       bool isolate_session_state,
35       thread::ThreadPoolInterface* underlying_threadpool = nullptr);
36 
37   ~RenamedDevice() override;
38 
UnderlyingDevice()39   const DeviceBase* UnderlyingDevice() const override {
40     return underlying_device_->UnderlyingDevice();
41   }
UnderlyingDevice()42   DeviceBase* UnderlyingDevice() override {
43     return underlying_device_->UnderlyingDevice();
44   }
45 
tensorflow_cpu_worker_threads()46   const CpuWorkerThreads* tensorflow_cpu_worker_threads() const override {
47     if (underlying_threadpool_) {
48       return Device::tensorflow_cpu_worker_threads();
49     }
50     return underlying_device_->tensorflow_cpu_worker_threads();
51   }
52 
tensorflow_accelerator_device_info()53   const DeviceBase::AcceleratorDeviceInfo* tensorflow_accelerator_device_info()
54       const override {
55     return underlying_device_->tensorflow_accelerator_device_info();
56   }
57 
GetAllocator(AllocatorAttributes attr)58   Allocator* GetAllocator(AllocatorAttributes attr) override {
59     return underlying_device_->GetAllocator(attr);
60   }
61 
GetScopedAllocator(AllocatorAttributes attr,int64_t step_id)62   Allocator* GetScopedAllocator(AllocatorAttributes attr,
63                                 int64_t step_id) override {
64     return underlying_device_->GetScopedAllocator(attr, step_id);
65   }
66 
GetScopedAllocatorMgr()67   ScopedAllocatorMgr* GetScopedAllocatorMgr() const override {
68     return underlying_device_->GetScopedAllocatorMgr();
69   }
70 
eigen_cpu_device()71   const Eigen::ThreadPoolDevice* eigen_cpu_device() override {
72     // Use the underlying threadpool only if the underlying device supports
73     // eigen_cpu_device.
74     if (underlying_threadpool_ && underlying_device_->has_eigen_cpu_device()) {
75       return Device::eigen_cpu_device();
76     }
77     return underlying_device_->eigen_cpu_device();
78   }
79 
tensorflow_device_thread_pool()80   thread::ThreadPool* tensorflow_device_thread_pool() override {
81     // Use the underlying threadpool instead of tensorflow_device_thread_pool
82     // of the underlying device only if tensorflow_device_thread_pool is defined
83     // for the underlying device.
84     if (underlying_threadpool_ &&
85         underlying_device_->tensorflow_device_thread_pool() != nullptr) {
86       return Device::tensorflow_device_thread_pool();
87     }
88     return underlying_device_->tensorflow_device_thread_pool();
89   }
90 
has_eigen_cpu_device()91   bool has_eigen_cpu_device() const override {
92     return underlying_device_->has_eigen_cpu_device();
93   }
94 
95 
MakeGpuDevice()96   PerOpGpuDevice* MakeGpuDevice() override {
97     return underlying_device_->MakeGpuDevice();
98   }
99 
ReinitializeGpuDevice(OpKernelContext * context,PerOpGpuDevice * device,DeviceContext * dc,Allocator * allocator)100   Status ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
101                                DeviceContext* dc,
102                                Allocator* allocator) override {
103     return underlying_device_->ReinitializeGpuDevice(context, device, dc,
104                                                      allocator);
105   }
106 
MakeTensorFromProto(const TensorProto & tensor_proto,const AllocatorAttributes alloc_attrs,Tensor * tensor)107   Status MakeTensorFromProto(const TensorProto& tensor_proto,
108                              const AllocatorAttributes alloc_attrs,
109                              Tensor* tensor) override {
110     return underlying_device_->MakeTensorFromProto(tensor_proto, alloc_attrs,
111                                                    tensor);
112   }
113 
CopyTensorInSameDevice(const Tensor * input_tensor,Tensor * output_tensor,const DeviceContext * device_context,StatusCallback done)114   void CopyTensorInSameDevice(const Tensor* input_tensor, Tensor* output_tensor,
115                               const DeviceContext* device_context,
116                               StatusCallback done) override {
117     underlying_device_->CopyTensorInSameDevice(input_tensor, output_tensor,
118                                                device_context, std::move(done));
119   }
120 
121   // Below are virtual methods defined on Device
122 
Compute(OpKernel * op_kernel,OpKernelContext * context)123   void Compute(OpKernel* op_kernel, OpKernelContext* context) override {
124     underlying_device_->Compute(op_kernel, context);
125   }
126 
ComputeAsync(AsyncOpKernel * op_kernel,OpKernelContext * context,AsyncOpKernel::DoneCallback done)127   void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
128                     AsyncOpKernel::DoneCallback done) override {
129     underlying_device_->ComputeAsync(op_kernel, context, std::move(done));
130   }
131 
Sync()132   Status Sync() override { return underlying_device_->Sync(); }
133 
MaybeRewriteGraph(std::unique_ptr<Graph> * graph)134   Status MaybeRewriteGraph(std::unique_ptr<Graph>* graph) override {
135     return underlying_device_->MaybeRewriteGraph(graph);
136   }
137 
TryGetDeviceContext(DeviceContext ** out_context)138   Status TryGetDeviceContext(DeviceContext** out_context) override {
139     return underlying_device_->TryGetDeviceContext(out_context);
140   }
141 
142   // Returns the resource manager associated w/ this device.
resource_manager()143   ResourceMgr* resource_manager() override {
144     if (isolate_session_state_) {
145       return Device::resource_manager();
146     } else {
147       return underlying_device_->resource_manager();
148     }
149   }
150 
IsLocal()151   bool IsLocal() const override { return underlying_device_->IsLocal(); }
152 
IsRemoteCallAllowed()153   bool IsRemoteCallAllowed() const override {
154     return underlying_device_->IsRemoteCallAllowed();
155   }
156 
157  private:
158   RenamedDevice(Device* underlying, const DeviceAttributes& attributes,
159                 bool owns_underlying, bool isolate_session_state,
160                 thread::ThreadPoolInterface* underlying_threadpool);
161   Device* const underlying_device_;
162   const bool owns_underlying_device_;
163   const bool isolate_session_state_;
164 
165   std::unique_ptr<thread::ThreadPool> underlying_threadpool_;
166   // eigen_worker_threads_ is stored here so that we can pass the pointer
167   // of eigen_worker_threads_.workers to the parent class.
168   DeviceBase::CpuWorkerThreads eigen_worker_threads_;
169 };
170 
171 }  // namespace tensorflow
172 
173 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
174