xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/stream_executor/tpu/tpu_executor.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_
17 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_
18 
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/synchronization/mutex.h"
21 #include "tensorflow/compiler/xla/stream_executor/device_memory.h"
22 #include "tensorflow/compiler/xla/stream_executor/device_options.h"
23 #include "tensorflow/compiler/xla/stream_executor/event.h"
24 #include "tensorflow/compiler/xla/stream_executor/lib/status.h"
25 #include "tensorflow/compiler/xla/stream_executor/lib/statusor.h"
26 #include "tensorflow/compiler/xla/stream_executor/stream.h"
27 #include "tensorflow/compiler/xla/stream_executor/stream_executor.h"
28 #include "tensorflow/compiler/xla/stream_executor/stream_executor_internal.h"
29 #include "tensorflow/compiler/xla/stream_executor/temporary_device_memory.h"
30 #include "tensorflow/compiler/xla/stream_executor/timer.h"
31 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_executor_c_api.h"
32 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_executor_interface.h"
33 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_platform.h"
34 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_platform_interface.h"
35 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_stream.h"
36 #include "tensorflow/core/platform/casts.h"
37 #include "tensorflow/core/platform/types.h"
38 
39 namespace tensorflow {
40 namespace tpu {
41 
42 class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface {
43  public:
44   using Status = ::stream_executor::port::Status;
45   template <typename T>
46   using StatusOr = ::stream_executor::port::StatusOr<T>;
47   using StatusCallback = std::function<void(const Status&)>;
48   using Stream = ::stream_executor::Stream;
49   using Event = ::stream_executor::Event;
50   using Timer = ::stream_executor::Timer;
51   using DeviceMemoryBase = ::stream_executor::DeviceMemoryBase;
52   using StreamInterface = ::stream_executor::internal::StreamInterface;
53   using StreamExecutorInterface =
54       ::stream_executor::internal::StreamExecutorInterface;
55 
56   using TimerMap =
57       absl::flat_hash_map<stream_executor::internal::TimerInterface*,
58                           SE_Timer*>;
59 
TpuExecutor(::tensorflow::tpu::TpuPlatformInterface * platform,SE_StreamExecutor * executor)60   explicit TpuExecutor(::tensorflow::tpu::TpuPlatformInterface* platform,
61                        SE_StreamExecutor* executor)
62       : platform_(platform), executor_(executor) {}
63 
64   ~TpuExecutor() override;
65 
66   Status Init(int device_ordinal,
67               ::stream_executor::DeviceOptions device_options) override;
68 
69   DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override;
70 
71   Status AllocateEvent(Event* event) override;
72 
73   bool AllocateStream(Stream* stream) override;
74 
75   bool AllocateTimer(Timer* timer) override;
76 
77   Status BlockHostUntilDone(::stream_executor::Stream* stream) override;
78 
79   Status BlockUntilDoneOrFailed();
80 
81   StatusOr<std::unique_ptr<::stream_executor::DeviceDescription>>
82   CreateDeviceDescription() const override;
83 
84   bool CreateStreamDependency(Stream* dependent, Stream* other) override;
85 
86   void DeallocateStream(Stream* stream) override;
87 
88   void Deallocate(const DeviceMemoryBase& memory);
89 
90   void Deallocate(DeviceMemoryBase* memory) override;
91 
92   Status DeallocateEvent(Event* event) override;
93 
94   void DeallocateTimer(Timer* timer) override;
95 
96   bool DeviceMemoryUsage(int64_t* free, int64_t* total) const override;
97 
98   void DequeueOutfeed(int32_t outfeed_queue_index, absl::Span<uint8> bytes,
99                       StatusCallback done);
100 
101   Status EnqueueInfeed(int32_t infeed_queue_index,
102                        absl::Span<const uint8> bytes);
103 
104   std::optional<stream_executor::AllocatorStats> GetAllocatorStats() override;
105 
106   tpu::TpuCoreLocationExternal GetCoreLocationExternal() const override;
107 
108   Status GetStatus(Stream* stream) override;
109 
110   std::unique_ptr<::stream_executor::internal::StreamInterface>
111   GetStreamImplementation() override;
112 
113   std::unique_ptr<::stream_executor::internal::TimerInterface>
114   GetTimerImplementation() override;
115 
116   std::unique_ptr<::stream_executor::internal::EventInterface>
117   CreateEventImplementation() override;
118 
119   bool HostCallback(Stream* stream, std::function<Status()> callback) override;
120 
121   bool Memcpy(Stream* stream, void* host_dst,
122               const ::stream_executor::DeviceMemoryBase& device_src,
123               uint64_t size) override;
124 
125   bool Memcpy(Stream* stream, ::stream_executor::DeviceMemoryBase* device_dst,
126               const void* host_src, uint64_t size) override;
127 
128   bool MemcpyDeviceToDevice(Stream* stream,
129                             ::stream_executor::DeviceMemoryBase* gpu_dst,
130                             const ::stream_executor::DeviceMemoryBase& host_src,
131                             uint64_t size) override;
132 
133   void SyncAndForgetFailedStreams();
134   bool SynchronizeAllActivity() override;
135 
136   Status SynchronousMemcpy(::stream_executor::DeviceMemoryBase* device_dst,
137                            const void* host_src, uint64_t size) override;
138   Status SynchronousMemcpy(
139       void* host_dst, const ::stream_executor::DeviceMemoryBase& device_src,
140       uint64_t size) override;
141   Status SynchronousMemcpyDeviceToDevice(
142       ::stream_executor::DeviceMemoryBase* device_dst,
143       const ::stream_executor::DeviceMemoryBase& device_src,
144       uint64_t size) override;
145 
146   int PlatformDeviceCount() override;
147 
148   Event::Status PollForEventStatus(Event* event) override;
149   Status RecordEvent(Stream* stream, ::stream_executor::Event* event) override;
150   Status WaitForEvent(Stream* stream, ::stream_executor::Event* event) override;
151 
152   bool StartTimer(Stream* stream, ::stream_executor::Timer* timer) override;
153   bool StopTimer(Stream* stream, ::stream_executor::Timer* timer) override;
154 
155   Status WaitForInfeedReady(int32_t infeed_queue_index);
156 
157   Status WaitForOutfeedReady(int32_t outfeed_queue_index);
158 
159   Status UnloadAllPrograms() override;
160 
161   Status EnqueueCompactionOnStreamForHbm(Stream* compaction_stream) override;
162 
platform()163   const ::tensorflow::tpu::TpuPlatformInterface& platform() const override {
164     return *platform_;
165   }
166 
platform()167   ::tensorflow::tpu::TpuPlatformInterface& platform() override {
168     return *platform_;
169   }
170 
171   // TODO(henrytan): convert this to override once the base interface is changed
172   // to TpuExecutorInterface.
173   StatusOr<std::unique_ptr<
174       tensorflow::tpu::TpuExecutorInterface::TemporaryDeviceMemory>>
CreateTemporaryDeviceMemory(int64_t memory_space,int64_t byte_offset,int64_t size)175   CreateTemporaryDeviceMemory(int64_t memory_space, int64_t byte_offset,
176                               int64_t size) override {
177     LOG(FATAL) << "Unimplemented.";
178   }
179 
180   // -- Unimplemented (stubbed out) methods.
181   std::unique_ptr<stream_executor::internal::KernelInterface>
CreateKernelImplementation()182   CreateKernelImplementation() override {
183     LOG(FATAL) << "Not yet implemented";
184   }
185 
GetSubBuffer(DeviceMemoryBase * parent,uint64_t offset,uint64_t size)186   void* GetSubBuffer(DeviceMemoryBase* parent, uint64_t offset,
187                      uint64_t size) override {
188     LOG(FATAL) << "not yet implemented";
189   }
MemZero(Stream * stream,DeviceMemoryBase * location,uint64_t size)190   Status MemZero(Stream* stream, DeviceMemoryBase* location,
191                  uint64_t size) override {
192     LOG(FATAL) << "not yet implemented";
193   }
Memset32(Stream * stream,DeviceMemoryBase * location,uint32 pattern,uint64_t size)194   Status Memset32(Stream* stream, DeviceMemoryBase* location, uint32 pattern,
195                   uint64_t size) override {
196     LOG(FATAL) << "not yet implemented";
197   }
EnablePeerAccessTo(StreamExecutorInterface * other)198   Status EnablePeerAccessTo(StreamExecutorInterface* other) override {
199     LOG(FATAL) << "not yet implemented";
200   }
CanEnablePeerAccessTo(StreamExecutorInterface * other)201   bool CanEnablePeerAccessTo(StreamExecutorInterface* other) override {
202     LOG(FATAL) << "not yet implemented";
203   }
204 
HostMemoryAllocate(uint64_t size)205   void* HostMemoryAllocate(uint64_t size) override {
206     LOG(FATAL) << "not yet implemented";
207   }
HostMemoryDeallocate(void * mem)208   void HostMemoryDeallocate(void* mem) override {
209     LOG(FATAL) << "not yet implemented";
210   }
HostMemoryRegister(void * mem,uint64_t size)211   bool HostMemoryRegister(void* mem, uint64_t size) override {
212     LOG(FATAL) << "not yet implemented";
213   }
HostMemoryUnregister(void * mem)214   bool HostMemoryUnregister(void* mem) override {
215     LOG(FATAL) << "not yet implemented";
216   }
SynchronousMemZero(DeviceMemoryBase * location,uint64_t size)217   Status SynchronousMemZero(DeviceMemoryBase* location,
218                             uint64_t size) override {
219     LOG(FATAL) << "not yet implemented";
220   }
SynchronousMemSet(DeviceMemoryBase * location,int value,uint64_t size)221   Status SynchronousMemSet(DeviceMemoryBase* location, int value,
222                            uint64_t size) override {
223     LOG(FATAL) << "not yet implemented";
224   }
225 
se_executor()226   SE_StreamExecutor* se_executor() { return executor_; }
227 
228  private:
tpu_platform()229   TpuPlatform& tpu_platform() {
230     return *(tensorflow::down_cast<TpuPlatform*>(platform_));
231   }
232 
stream_map()233   TpuPlatform::StreamMap& stream_map() {
234     return *(tpu_platform().stream_map());
235   }
236 
get_stream(StreamInterface * ptr)237   SE_Stream* get_stream(StreamInterface* ptr) {
238     absl::MutexLock m(&tpu_platform().mutex());
239     return stream_map()[ptr];
240   }
241 
242   TimerMap timer_map_;
243   tensorflow::tpu::TpuPlatformInterface* platform_;
244   SE_StreamExecutor* executor_;
245 };
246 
247 }  // namespace tpu
248 }  // namespace tensorflow
249 
250 #endif  // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_
251