1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_ 17 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_ 18 19 #include "absl/container/flat_hash_map.h" 20 #include "absl/synchronization/mutex.h" 21 #include "tensorflow/compiler/xla/stream_executor/device_memory.h" 22 #include "tensorflow/compiler/xla/stream_executor/device_options.h" 23 #include "tensorflow/compiler/xla/stream_executor/event.h" 24 #include "tensorflow/compiler/xla/stream_executor/lib/status.h" 25 #include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" 26 #include "tensorflow/compiler/xla/stream_executor/stream.h" 27 #include "tensorflow/compiler/xla/stream_executor/stream_executor.h" 28 #include "tensorflow/compiler/xla/stream_executor/stream_executor_internal.h" 29 #include "tensorflow/compiler/xla/stream_executor/temporary_device_memory.h" 30 #include "tensorflow/compiler/xla/stream_executor/timer.h" 31 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_executor_c_api.h" 32 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_executor_interface.h" 33 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_platform.h" 34 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_platform_interface.h" 35 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_stream.h" 36 #include "tensorflow/core/platform/casts.h" 37 #include "tensorflow/core/platform/types.h" 38 39 namespace tensorflow { 40 namespace tpu { 41 42 class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface { 43 public: 44 using Status = ::stream_executor::port::Status; 45 template <typename T> 46 using StatusOr = ::stream_executor::port::StatusOr<T>; 47 using StatusCallback = std::function<void(const Status&)>; 48 using Stream = ::stream_executor::Stream; 49 using Event = ::stream_executor::Event; 50 using Timer = ::stream_executor::Timer; 51 using DeviceMemoryBase = ::stream_executor::DeviceMemoryBase; 52 using StreamInterface = ::stream_executor::internal::StreamInterface; 53 using StreamExecutorInterface = 54 ::stream_executor::internal::StreamExecutorInterface; 55 56 using TimerMap = 57 absl::flat_hash_map<stream_executor::internal::TimerInterface*, 58 SE_Timer*>; 59 TpuExecutor(::tensorflow::tpu::TpuPlatformInterface * platform,SE_StreamExecutor * executor)60 explicit TpuExecutor(::tensorflow::tpu::TpuPlatformInterface* platform, 61 SE_StreamExecutor* executor) 62 : platform_(platform), executor_(executor) {} 63 64 ~TpuExecutor() override; 65 66 Status Init(int device_ordinal, 67 ::stream_executor::DeviceOptions device_options) override; 68 69 DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override; 70 71 Status AllocateEvent(Event* event) override; 72 73 bool AllocateStream(Stream* stream) override; 74 75 bool AllocateTimer(Timer* timer) override; 76 77 Status BlockHostUntilDone(::stream_executor::Stream* stream) override; 78 79 Status BlockUntilDoneOrFailed(); 80 81 StatusOr<std::unique_ptr<::stream_executor::DeviceDescription>> 82 CreateDeviceDescription() const override; 83 84 bool CreateStreamDependency(Stream* dependent, Stream* other) override; 85 86 void DeallocateStream(Stream* stream) override; 87 88 void Deallocate(const DeviceMemoryBase& memory); 89 90 void Deallocate(DeviceMemoryBase* memory) override; 91 92 Status DeallocateEvent(Event* event) override; 93 94 void DeallocateTimer(Timer* timer) override; 95 96 bool DeviceMemoryUsage(int64_t* free, int64_t* total) const override; 97 98 void DequeueOutfeed(int32_t outfeed_queue_index, absl::Span<uint8> bytes, 99 StatusCallback done); 100 101 Status EnqueueInfeed(int32_t infeed_queue_index, 102 absl::Span<const uint8> bytes); 103 104 std::optional<stream_executor::AllocatorStats> GetAllocatorStats() override; 105 106 tpu::TpuCoreLocationExternal GetCoreLocationExternal() const override; 107 108 Status GetStatus(Stream* stream) override; 109 110 std::unique_ptr<::stream_executor::internal::StreamInterface> 111 GetStreamImplementation() override; 112 113 std::unique_ptr<::stream_executor::internal::TimerInterface> 114 GetTimerImplementation() override; 115 116 std::unique_ptr<::stream_executor::internal::EventInterface> 117 CreateEventImplementation() override; 118 119 bool HostCallback(Stream* stream, std::function<Status()> callback) override; 120 121 bool Memcpy(Stream* stream, void* host_dst, 122 const ::stream_executor::DeviceMemoryBase& device_src, 123 uint64_t size) override; 124 125 bool Memcpy(Stream* stream, ::stream_executor::DeviceMemoryBase* device_dst, 126 const void* host_src, uint64_t size) override; 127 128 bool MemcpyDeviceToDevice(Stream* stream, 129 ::stream_executor::DeviceMemoryBase* gpu_dst, 130 const ::stream_executor::DeviceMemoryBase& host_src, 131 uint64_t size) override; 132 133 void SyncAndForgetFailedStreams(); 134 bool SynchronizeAllActivity() override; 135 136 Status SynchronousMemcpy(::stream_executor::DeviceMemoryBase* device_dst, 137 const void* host_src, uint64_t size) override; 138 Status SynchronousMemcpy( 139 void* host_dst, const ::stream_executor::DeviceMemoryBase& device_src, 140 uint64_t size) override; 141 Status SynchronousMemcpyDeviceToDevice( 142 ::stream_executor::DeviceMemoryBase* device_dst, 143 const ::stream_executor::DeviceMemoryBase& device_src, 144 uint64_t size) override; 145 146 int PlatformDeviceCount() override; 147 148 Event::Status PollForEventStatus(Event* event) override; 149 Status RecordEvent(Stream* stream, ::stream_executor::Event* event) override; 150 Status WaitForEvent(Stream* stream, ::stream_executor::Event* event) override; 151 152 bool StartTimer(Stream* stream, ::stream_executor::Timer* timer) override; 153 bool StopTimer(Stream* stream, ::stream_executor::Timer* timer) override; 154 155 Status WaitForInfeedReady(int32_t infeed_queue_index); 156 157 Status WaitForOutfeedReady(int32_t outfeed_queue_index); 158 159 Status UnloadAllPrograms() override; 160 161 Status EnqueueCompactionOnStreamForHbm(Stream* compaction_stream) override; 162 platform()163 const ::tensorflow::tpu::TpuPlatformInterface& platform() const override { 164 return *platform_; 165 } 166 platform()167 ::tensorflow::tpu::TpuPlatformInterface& platform() override { 168 return *platform_; 169 } 170 171 // TODO(henrytan): convert this to override once the base interface is changed 172 // to TpuExecutorInterface. 173 StatusOr<std::unique_ptr< 174 tensorflow::tpu::TpuExecutorInterface::TemporaryDeviceMemory>> CreateTemporaryDeviceMemory(int64_t memory_space,int64_t byte_offset,int64_t size)175 CreateTemporaryDeviceMemory(int64_t memory_space, int64_t byte_offset, 176 int64_t size) override { 177 LOG(FATAL) << "Unimplemented."; 178 } 179 180 // -- Unimplemented (stubbed out) methods. 181 std::unique_ptr<stream_executor::internal::KernelInterface> CreateKernelImplementation()182 CreateKernelImplementation() override { 183 LOG(FATAL) << "Not yet implemented"; 184 } 185 GetSubBuffer(DeviceMemoryBase * parent,uint64_t offset,uint64_t size)186 void* GetSubBuffer(DeviceMemoryBase* parent, uint64_t offset, 187 uint64_t size) override { 188 LOG(FATAL) << "not yet implemented"; 189 } MemZero(Stream * stream,DeviceMemoryBase * location,uint64_t size)190 Status MemZero(Stream* stream, DeviceMemoryBase* location, 191 uint64_t size) override { 192 LOG(FATAL) << "not yet implemented"; 193 } Memset32(Stream * stream,DeviceMemoryBase * location,uint32 pattern,uint64_t size)194 Status Memset32(Stream* stream, DeviceMemoryBase* location, uint32 pattern, 195 uint64_t size) override { 196 LOG(FATAL) << "not yet implemented"; 197 } EnablePeerAccessTo(StreamExecutorInterface * other)198 Status EnablePeerAccessTo(StreamExecutorInterface* other) override { 199 LOG(FATAL) << "not yet implemented"; 200 } CanEnablePeerAccessTo(StreamExecutorInterface * other)201 bool CanEnablePeerAccessTo(StreamExecutorInterface* other) override { 202 LOG(FATAL) << "not yet implemented"; 203 } 204 HostMemoryAllocate(uint64_t size)205 void* HostMemoryAllocate(uint64_t size) override { 206 LOG(FATAL) << "not yet implemented"; 207 } HostMemoryDeallocate(void * mem)208 void HostMemoryDeallocate(void* mem) override { 209 LOG(FATAL) << "not yet implemented"; 210 } HostMemoryRegister(void * mem,uint64_t size)211 bool HostMemoryRegister(void* mem, uint64_t size) override { 212 LOG(FATAL) << "not yet implemented"; 213 } HostMemoryUnregister(void * mem)214 bool HostMemoryUnregister(void* mem) override { 215 LOG(FATAL) << "not yet implemented"; 216 } SynchronousMemZero(DeviceMemoryBase * location,uint64_t size)217 Status SynchronousMemZero(DeviceMemoryBase* location, 218 uint64_t size) override { 219 LOG(FATAL) << "not yet implemented"; 220 } SynchronousMemSet(DeviceMemoryBase * location,int value,uint64_t size)221 Status SynchronousMemSet(DeviceMemoryBase* location, int value, 222 uint64_t size) override { 223 LOG(FATAL) << "not yet implemented"; 224 } 225 se_executor()226 SE_StreamExecutor* se_executor() { return executor_; } 227 228 private: tpu_platform()229 TpuPlatform& tpu_platform() { 230 return *(tensorflow::down_cast<TpuPlatform*>(platform_)); 231 } 232 stream_map()233 TpuPlatform::StreamMap& stream_map() { 234 return *(tpu_platform().stream_map()); 235 } 236 get_stream(StreamInterface * ptr)237 SE_Stream* get_stream(StreamInterface* ptr) { 238 absl::MutexLock m(&tpu_platform().mutex()); 239 return stream_map()[ptr]; 240 } 241 242 TimerMap timer_map_; 243 tensorflow::tpu::TpuPlatformInterface* platform_; 244 SE_StreamExecutor* executor_; 245 }; 246 247 } // namespace tpu 248 } // namespace tensorflow 249 250 #endif // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_ 251