1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ 18 19 #include "tensorflow/core/common_runtime/device.h" 20 #include "tensorflow/core/lib/core/threadpool_interface.h" 21 #include "tensorflow/core/util/device_name_utils.h" 22 23 namespace tensorflow { 24 25 // Wraps a device with a new name, delegating work to the wrapped device. 26 // 27 // This class is used to wrap local devices when using clusterspec propagation 28 // where the name of a particular device may change in the context of a given 29 // session. 30 class RenamedDevice : public Device { 31 public: 32 static std::unique_ptr<Device> NewRenamedDevice( 33 const string& new_base, Device* underlying, bool owns_underlying, 34 bool isolate_session_state, 35 thread::ThreadPoolInterface* underlying_threadpool = nullptr); 36 37 ~RenamedDevice() override; 38 UnderlyingDevice()39 const DeviceBase* UnderlyingDevice() const override { 40 return underlying_device_->UnderlyingDevice(); 41 } UnderlyingDevice()42 DeviceBase* UnderlyingDevice() override { 43 return underlying_device_->UnderlyingDevice(); 44 } 45 tensorflow_cpu_worker_threads()46 const CpuWorkerThreads* tensorflow_cpu_worker_threads() const override { 47 if (underlying_threadpool_) { 48 return Device::tensorflow_cpu_worker_threads(); 49 } 50 return underlying_device_->tensorflow_cpu_worker_threads(); 51 } 52 tensorflow_accelerator_device_info()53 const DeviceBase::AcceleratorDeviceInfo* tensorflow_accelerator_device_info() 54 const override { 55 return underlying_device_->tensorflow_accelerator_device_info(); 56 } 57 GetAllocator(AllocatorAttributes attr)58 Allocator* GetAllocator(AllocatorAttributes attr) override { 59 return underlying_device_->GetAllocator(attr); 60 } 61 GetScopedAllocator(AllocatorAttributes attr,int64_t step_id)62 Allocator* GetScopedAllocator(AllocatorAttributes attr, 63 int64_t step_id) override { 64 return underlying_device_->GetScopedAllocator(attr, step_id); 65 } 66 GetScopedAllocatorMgr()67 ScopedAllocatorMgr* GetScopedAllocatorMgr() const override { 68 return underlying_device_->GetScopedAllocatorMgr(); 69 } 70 eigen_cpu_device()71 const Eigen::ThreadPoolDevice* eigen_cpu_device() override { 72 // Use the underlying threadpool only if the underlying device supports 73 // eigen_cpu_device. 74 if (underlying_threadpool_ && underlying_device_->has_eigen_cpu_device()) { 75 return Device::eigen_cpu_device(); 76 } 77 return underlying_device_->eigen_cpu_device(); 78 } 79 tensorflow_device_thread_pool()80 thread::ThreadPool* tensorflow_device_thread_pool() override { 81 // Use the underlying threadpool instead of tensorflow_device_thread_pool 82 // of the underlying device only if tensorflow_device_thread_pool is defined 83 // for the underlying device. 84 if (underlying_threadpool_ && 85 underlying_device_->tensorflow_device_thread_pool() != nullptr) { 86 return Device::tensorflow_device_thread_pool(); 87 } 88 return underlying_device_->tensorflow_device_thread_pool(); 89 } 90 has_eigen_cpu_device()91 bool has_eigen_cpu_device() const override { 92 return underlying_device_->has_eigen_cpu_device(); 93 } 94 95 MakeGpuDevice()96 PerOpGpuDevice* MakeGpuDevice() override { 97 return underlying_device_->MakeGpuDevice(); 98 } 99 ReinitializeGpuDevice(OpKernelContext * context,PerOpGpuDevice * device,DeviceContext * dc,Allocator * allocator)100 Status ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device, 101 DeviceContext* dc, 102 Allocator* allocator) override { 103 return underlying_device_->ReinitializeGpuDevice(context, device, dc, 104 allocator); 105 } 106 MakeTensorFromProto(const TensorProto & tensor_proto,const AllocatorAttributes alloc_attrs,Tensor * tensor)107 Status MakeTensorFromProto(const TensorProto& tensor_proto, 108 const AllocatorAttributes alloc_attrs, 109 Tensor* tensor) override { 110 return underlying_device_->MakeTensorFromProto(tensor_proto, alloc_attrs, 111 tensor); 112 } 113 CopyTensorInSameDevice(const Tensor * input_tensor,Tensor * output_tensor,const DeviceContext * device_context,StatusCallback done)114 void CopyTensorInSameDevice(const Tensor* input_tensor, Tensor* output_tensor, 115 const DeviceContext* device_context, 116 StatusCallback done) override { 117 underlying_device_->CopyTensorInSameDevice(input_tensor, output_tensor, 118 device_context, std::move(done)); 119 } 120 121 // Below are virtual methods defined on Device 122 Compute(OpKernel * op_kernel,OpKernelContext * context)123 void Compute(OpKernel* op_kernel, OpKernelContext* context) override { 124 underlying_device_->Compute(op_kernel, context); 125 } 126 ComputeAsync(AsyncOpKernel * op_kernel,OpKernelContext * context,AsyncOpKernel::DoneCallback done)127 void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, 128 AsyncOpKernel::DoneCallback done) override { 129 underlying_device_->ComputeAsync(op_kernel, context, std::move(done)); 130 } 131 Sync()132 Status Sync() override { return underlying_device_->Sync(); } 133 MaybeRewriteGraph(std::unique_ptr<Graph> * graph)134 Status MaybeRewriteGraph(std::unique_ptr<Graph>* graph) override { 135 return underlying_device_->MaybeRewriteGraph(graph); 136 } 137 TryGetDeviceContext(DeviceContext ** out_context)138 Status TryGetDeviceContext(DeviceContext** out_context) override { 139 return underlying_device_->TryGetDeviceContext(out_context); 140 } 141 142 // Returns the resource manager associated w/ this device. resource_manager()143 ResourceMgr* resource_manager() override { 144 if (isolate_session_state_) { 145 return Device::resource_manager(); 146 } else { 147 return underlying_device_->resource_manager(); 148 } 149 } 150 IsLocal()151 bool IsLocal() const override { return underlying_device_->IsLocal(); } 152 IsRemoteCallAllowed()153 bool IsRemoteCallAllowed() const override { 154 return underlying_device_->IsRemoteCallAllowed(); 155 } 156 157 private: 158 RenamedDevice(Device* underlying, const DeviceAttributes& attributes, 159 bool owns_underlying, bool isolate_session_state, 160 thread::ThreadPoolInterface* underlying_threadpool); 161 Device* const underlying_device_; 162 const bool owns_underlying_device_; 163 const bool isolate_session_state_; 164 165 std::unique_ptr<thread::ThreadPool> underlying_threadpool_; 166 // eigen_worker_threads_ is stored here so that we can pass the pointer 167 // of eigen_worker_threads_.workers to the parent class. 168 DeviceBase::CpuWorkerThreads eigen_worker_threads_; 169 }; 170 171 } // namespace tensorflow 172 173 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ 174