1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_executor.h"
17
18 #include "absl/cleanup/cleanup.h"
19 #include "tensorflow/c/tf_status.h"
20 #include "tensorflow/compiler/xla/stream_executor/tpu/status_helper.h"
21 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_event.h"
22 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_stream.h"
23 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_timer.h"
24 #include "tensorflow/core/tpu/tpu_api.h"
25
26 using stream_executor::DeviceMemoryBase;
27
28 namespace tensorflow {
29 namespace tpu {
30
31 namespace {
32 using ::stream_executor::port::Status;
33 } // namespace
34
~TpuExecutor()35 TpuExecutor::~TpuExecutor() {
36 tpu::ExecutorApiFn()->TpuExecutor_FreeFn(executor_);
37 }
38
Init(int device_ordinal,::stream_executor::DeviceOptions device_options)39 Status TpuExecutor::Init(int device_ordinal,
40 ::stream_executor::DeviceOptions device_options) {
41 StatusHelper status;
42 SE_DeviceOptions* options =
43 tpu::ExecutorApiFn()->TpuExecutor_NewDeviceOptionsFn(
44 device_options.flags());
45 tpu::ExecutorApiFn()->TpuExecutor_InitFn(executor_, device_ordinal, options,
46 status.c_status);
47 tpu::ExecutorApiFn()->TpuExecutor_FreeDeviceOptionsFn(options);
48 return status.status();
49 }
50
PlatformDeviceCount()51 int TpuExecutor::PlatformDeviceCount() {
52 return tpu::ExecutorApiFn()->TpuExecutor_PlatformDeviceCountFn(executor_);
53 }
54
SyncAndForgetFailedStreams()55 void TpuExecutor::SyncAndForgetFailedStreams() {
56 tpu::ExecutorApiFn()->TpuExecutor_SyncAndForgetFailedStreamsFn(executor_);
57 }
58
SynchronizeAllActivity()59 bool TpuExecutor::SynchronizeAllActivity() {
60 return tpu::ExecutorApiFn()->TpuExecutor_SynchronizeAllActivityFn(executor_);
61 }
62
BlockHostUntilDone(Stream * stream)63 Status TpuExecutor::BlockHostUntilDone(Stream* stream) {
64 StatusHelper status;
65 tpu::ExecutorApiFn()->TpuExecutor_BlockHostUntilDoneFn(
66 executor_, get_stream(stream->implementation()), status.c_status);
67 return status.status();
68 }
69
BlockUntilDoneOrFailed()70 Status TpuExecutor::BlockUntilDoneOrFailed() {
71 StatusHelper status;
72 tpu::ExecutorApiFn()->TpuExecutor_BlockUntilDoneOrFailedFn(executor_,
73 status.c_status);
74 return status.status();
75 }
76
GetStatus(Stream * stream)77 Status TpuExecutor::GetStatus(Stream* stream) {
78 StatusHelper status;
79 tpu::ExecutorApiFn()->TpuExecutor_GetStatusFn(
80 executor_, get_stream(stream->implementation()), status.c_status);
81 return status.status();
82 }
83
GetCoreLocationExternal() const84 tpu::TpuCoreLocationExternal TpuExecutor::GetCoreLocationExternal() const {
85 return tpu::TpuCoreLocationExternal(
86 tpu::ExecutorApiFn()->TpuExecutor_GetCoreLocationFn(executor_));
87 }
88
AllocateStream(Stream * stream)89 bool TpuExecutor::AllocateStream(Stream* stream) {
90 return tpu::ExecutorApiFn()->TpuExecutor_AllocateStreamFn(
91 executor_, get_stream(stream->implementation()));
92 }
93
DeallocateStream(Stream * stream)94 void TpuExecutor::DeallocateStream(Stream* stream) {
95 tpu::ExecutorApiFn()->TpuExecutor_DeallocateStreamFn(
96 executor_, get_stream(stream->implementation()));
97 tpu_platform().mutex().Lock();
98 stream_map().erase(stream->implementation());
99 tpu_platform().mutex().Unlock();
100 }
101
CreateStreamDependency(Stream * dependent,Stream * other)102 bool TpuExecutor::CreateStreamDependency(Stream* dependent, Stream* other) {
103 return tpu::ExecutorApiFn()->TpuExecutor_CreateStreamDependencyFn(
104 executor_, get_stream(dependent->implementation()),
105 get_stream(other->implementation()));
106 }
107
AllocateEvent(Event * event)108 Status TpuExecutor::AllocateEvent(Event* event) { return OkStatus(); }
109
DeallocateEvent(Event * event)110 Status TpuExecutor::DeallocateEvent(Event* event) {
111 tpu_platform().EraseEvent(event->implementation());
112 return OkStatus();
113 }
114
115 // AllocateTimer/DeallocateTimer have no specialization.
AllocateTimer(Timer * timer)116 bool TpuExecutor::AllocateTimer(Timer* timer) { return true; }
117
DeallocateTimer(Timer * timer)118 void TpuExecutor::DeallocateTimer(Timer* timer) {}
119
StartTimer(Stream * stream,::stream_executor::Timer * timer)120 bool TpuExecutor::StartTimer(Stream* stream, ::stream_executor::Timer* timer) {
121 return tpu::ExecutorApiFn()->TpuExecutor_StartTimerFn(
122 executor_, get_stream(stream->implementation()),
123 timer_map_.at(timer->implementation()));
124 }
125
StopTimer(Stream * stream,::stream_executor::Timer * timer)126 bool TpuExecutor::StopTimer(Stream* stream, ::stream_executor::Timer* timer) {
127 return tpu::ExecutorApiFn()->TpuExecutor_StopTimerFn(
128 executor_, get_stream(stream->implementation()),
129 timer_map_.at(timer->implementation()));
130 }
131
PollForEventStatus(stream_executor::Event * event)132 stream_executor::Event::Status TpuExecutor::PollForEventStatus(
133 stream_executor::Event* event) {
134 auto se_event = tpu_platform().LookupEvent(event->implementation());
135 return stream_executor::Event::Status(
136 tpu::ExecutorApiFn()->TpuExecutor_PollForEventStatusFn(executor_,
137 se_event));
138 }
139
RecordEvent(Stream * stream,::stream_executor::Event * event)140 Status TpuExecutor::RecordEvent(Stream* stream,
141 ::stream_executor::Event* event) {
142 StatusHelper status;
143 auto se_event = tpu_platform().LookupEvent(event->implementation());
144 tpu::ExecutorApiFn()->TpuExecutor_RecordEventFn(
145 executor_, get_stream(stream->implementation()), se_event,
146 status.c_status);
147 return status.status();
148 }
149
WaitForEvent(Stream * stream,::stream_executor::Event * event)150 Status TpuExecutor::WaitForEvent(Stream* stream,
151 ::stream_executor::Event* event) {
152 StatusHelper status;
153 auto se_event = tpu_platform().LookupEvent(event->implementation());
154 tpu::ExecutorApiFn()->TpuExecutor_WaitForEventFn(
155 executor_, get_stream(stream->implementation()), se_event,
156 status.c_status);
157 return status.status();
158 }
159
160 // Implementations for Timer, Stream, Event
161 // We need to map these implementations to internal equivalents -- thus we
162 // allocate the internal Timer, Stream and Event operations here, and map
163 // the implementations to the internal values. The "wrapper" interfaces are
164 // responsible for deallocating the internal value when they are destroyed.
165
166 // Called by Timer::Timer
167 std::unique_ptr<::stream_executor::internal::TimerInterface>
GetTimerImplementation()168 TpuExecutor::GetTimerImplementation() {
169 SE_Timer* tpu_timer = tpu::ExecutorApiFn()->TpuTimer_NewFn(executor_);
170 auto ptr = std::make_unique<TpuTimer>(tpu_timer);
171 timer_map_[ptr.get()] = tpu_timer;
172 return ptr;
173 }
174
175 // Called by Stream::Stream
176 std::unique_ptr<::stream_executor::internal::StreamInterface>
GetStreamImplementation()177 TpuExecutor::GetStreamImplementation() {
178 SE_Stream* tpu_stream = tpu::ExecutorApiFn()->TpuStream_NewFn(executor_);
179 auto ptr = std::make_unique<tpu::TpuStream>(tpu_stream);
180 tpu_platform().mutex().Lock();
181 stream_map()[ptr.get()] = tpu_stream;
182 tpu_platform().mutex().Unlock();
183 return ptr;
184 }
185
186 // Called by Event::Event
187 std::unique_ptr<::stream_executor::internal::EventInterface>
CreateEventImplementation()188 TpuExecutor::CreateEventImplementation() {
189 SE_Event* tpu_event = tpu::ExecutorApiFn()->TpuEvent_NewFn(executor_);
190 auto ptr = std::make_unique<TpuEvent>(tpu_event);
191 tpu_platform().InsertEvent(ptr.get(), tpu_event);
192 return ptr;
193 }
194
Allocate(uint64_t size,int64_t memory_space)195 DeviceMemoryBase TpuExecutor::Allocate(uint64_t size, int64_t memory_space) {
196 SE_DeviceMemoryBase se_base = tpu::ExecutorApiFn()->TpuExecutor_AllocateFn(
197 executor_, size, memory_space);
198 return ApiConverter::FromC(se_base);
199 }
200
Deallocate(const DeviceMemoryBase & memory)201 void TpuExecutor::Deallocate(const DeviceMemoryBase& memory) {
202 SE_DeviceMemoryBase se_base = ApiConverter::ToC(memory);
203 tpu::ExecutorApiFn()->TpuExecutor_DeallocateFn(executor_, &se_base);
204 }
205
Deallocate(DeviceMemoryBase * memory)206 void TpuExecutor::Deallocate(DeviceMemoryBase* memory) {
207 SE_DeviceMemoryBase se_base = ApiConverter::ToC(*memory);
208 tpu::ExecutorApiFn()->TpuExecutor_DeallocateFn(executor_, &se_base);
209 }
210
DeviceMemoryUsage(int64_t * free,int64_t * total) const211 bool TpuExecutor::DeviceMemoryUsage(int64_t* free, int64_t* total) const {
212 int64_t _free;
213 int64_t _total;
214 if (tpu::ExecutorApiFn()->TpuExecutor_DeviceMemoryUsageFn(executor_, &_free,
215 &_total)) {
216 *free = _free;
217 *total = _total;
218 return true;
219 }
220 return false;
221 }
222
223 std::optional<stream_executor::AllocatorStats>
GetAllocatorStats()224 TpuExecutor::GetAllocatorStats() {
225 SE_AllocatorStats c_stats;
226 if (tpu::ExecutorApiFn()->TpuExecutor_GetAllocatorStatsFn(executor_,
227 &c_stats)) {
228 ::stream_executor::AllocatorStats stats;
229 stats.num_allocs = c_stats.num_allocs;
230 stats.bytes_in_use = c_stats.bytes_in_use;
231 stats.peak_bytes_in_use = c_stats.peak_bytes_in_use;
232 stats.largest_alloc_size = c_stats.largest_alloc_size;
233 if (c_stats.has_bytes_limit) {
234 stats.bytes_limit = c_stats.bytes_limit;
235 }
236 stats.bytes_reserved = c_stats.bytes_reserved;
237 stats.peak_bytes_reserved = c_stats.peak_bytes_reserved;
238 if (c_stats.has_bytes_reservable_limit) {
239 stats.bytes_reservable_limit = c_stats.bytes_reservable_limit;
240 }
241 stats.largest_free_block_bytes = c_stats.largest_free_block_bytes;
242 return stats;
243 }
244 return {};
245 }
246
WaitForInfeedReady(int32_t infeed_queue_index)247 Status TpuExecutor::WaitForInfeedReady(int32_t infeed_queue_index) {
248 StatusHelper status;
249 tpu::ExecutorApiFn()->TpuExecutor_WaitForInfeedReadyFn(
250 executor_, infeed_queue_index, status.c_status);
251 return status.status();
252 }
253
WaitForOutfeedReady(int32_t outfeed_queue_index)254 Status TpuExecutor::WaitForOutfeedReady(int32_t outfeed_queue_index) {
255 StatusHelper status;
256 tpu::ExecutorApiFn()->TpuExecutor_WaitForOutfeedReadyFn(
257 executor_, outfeed_queue_index, status.c_status);
258 return status.status();
259 }
260
DequeueOutfeed(int32_t outfeed_queue_index,absl::Span<uint8> bytes,StatusCallback done)261 void TpuExecutor::DequeueOutfeed(int32_t outfeed_queue_index,
262 absl::Span<uint8> bytes, StatusCallback done) {
263 StatusHelper status;
264 tpu::ExecutorApiFn()->TpuExecutor_DequeueOutfeedFn(
265 executor_, outfeed_queue_index, bytes.data(), bytes.size(),
266 status.c_status);
267 done(status.status());
268 }
269
EnqueueInfeed(int32_t infeed_queue_index,absl::Span<const uint8> bytes)270 Status TpuExecutor::EnqueueInfeed(int32_t infeed_queue_index,
271 absl::Span<const uint8> bytes) {
272 StatusHelper status;
273 tpu::ExecutorApiFn()->TpuExecutor_EnqueueInfeedFn(
274 executor_, infeed_queue_index, bytes.data(), bytes.size(),
275 status.c_status);
276 return status.status();
277 }
278
Memcpy(Stream * stream,void * host_dst,const::stream_executor::DeviceMemoryBase & device_src,uint64_t size)279 bool TpuExecutor::Memcpy(Stream* stream, void* host_dst,
280 const ::stream_executor::DeviceMemoryBase& device_src,
281 uint64_t size) {
282 SE_DeviceMemoryBase se_base = ApiConverter::ToC(device_src);
283 return tpu::ExecutorApiFn()->TpuExecutor_MemcpyToHostFn(
284 executor_, get_stream(stream->implementation()), host_dst, &se_base,
285 size);
286 }
287
Memcpy(Stream * stream,::stream_executor::DeviceMemoryBase * device_dst,const void * host_src,uint64_t size)288 bool TpuExecutor::Memcpy(Stream* stream,
289 ::stream_executor::DeviceMemoryBase* device_dst,
290 const void* host_src, uint64_t size) {
291 SE_DeviceMemoryBase se_base = ApiConverter::ToC(*device_dst);
292 return tpu::ExecutorApiFn()->TpuExecutor_MemcpyFromHostFn(
293 executor_, get_stream(stream->implementation()), &se_base, host_src,
294 size);
295 }
296
SynchronousMemcpy(::stream_executor::DeviceMemoryBase * device_dst,const void * host_src,uint64_t size)297 Status TpuExecutor::SynchronousMemcpy(
298 ::stream_executor::DeviceMemoryBase* device_dst, const void* host_src,
299 uint64_t size) {
300 StatusHelper status;
301 SE_DeviceMemoryBase se_base = ApiConverter::ToC(*device_dst);
302 tpu::ExecutorApiFn()->TpuExecutor_SynchronousMemcpyFromHostFn(
303 executor_, &se_base, host_src, size, status.c_status);
304 return status.status();
305 }
306
SynchronousMemcpy(void * host_dst,const::stream_executor::DeviceMemoryBase & device_src,uint64_t size)307 Status TpuExecutor::SynchronousMemcpy(
308 void* host_dst, const ::stream_executor::DeviceMemoryBase& device_src,
309 uint64_t size) {
310 StatusHelper status;
311 SE_DeviceMemoryBase se_base = ApiConverter::ToC(device_src);
312 tpu::ExecutorApiFn()->TpuExecutor_SynchronousMemcpyToHostFn(
313 executor_, host_dst, &se_base, size, status.c_status);
314 return status.status();
315 }
316
SynchronousMemcpyDeviceToDevice(::stream_executor::DeviceMemoryBase * device_dst,const::stream_executor::DeviceMemoryBase & device_src,uint64_t size)317 Status TpuExecutor::SynchronousMemcpyDeviceToDevice(
318 ::stream_executor::DeviceMemoryBase* device_dst,
319 const ::stream_executor::DeviceMemoryBase& device_src, uint64_t size) {
320 return ::stream_executor::port::UnimplementedError(
321 "This operation not supported on TPU");
322 }
323
MemcpyDeviceToDevice(Stream * stream,::stream_executor::DeviceMemoryBase * gpu_dst,const::stream_executor::DeviceMemoryBase & host_src,uint64_t size)324 bool TpuExecutor::MemcpyDeviceToDevice(
325 Stream* stream, ::stream_executor::DeviceMemoryBase* gpu_dst,
326 const ::stream_executor::DeviceMemoryBase& host_src, uint64_t size) {
327 LOG(FATAL) << __func__ << " not supported on TpuExecutor";
328 }
329
UnloadAllPrograms()330 Status TpuExecutor::UnloadAllPrograms() {
331 StatusHelper status;
332 tpu::ExecutorApiFn()->TpuExecutor_UnloadAllProgramsFn(executor_,
333 status.c_status);
334 return status.status();
335 }
336
EnqueueCompactionOnStreamForHbm(Stream * compaction_stream)337 Status TpuExecutor::EnqueueCompactionOnStreamForHbm(Stream* compaction_stream) {
338 StatusHelper status;
339 tpu::ExecutorApiFn()->TpuExecutor_EnqueueCompactionOnStreamForHbmFn(
340 executor_, get_stream(compaction_stream->implementation()),
341 status.c_status);
342 return status.status();
343 }
344
345 struct HostCallbackContext {
346 std::function<Status()> callback;
347 };
348
HostCallbackTrampoline(void * ctx)349 TF_Status* HostCallbackTrampoline(void* ctx) {
350 HostCallbackContext* host_ctx = reinterpret_cast<HostCallbackContext*>(ctx);
351 Status status = host_ctx->callback();
352 TF_Status* c_status = tpu::ExecutorApiFn()->TpuStatus_CreateFn(
353 status.code(), status.error_message().c_str());
354 delete host_ctx;
355 return c_status;
356 }
357
HostCallback(Stream * stream,std::function<Status ()> callback)358 bool TpuExecutor::HostCallback(Stream* stream,
359 std::function<Status()> callback) {
360 HostCallbackContext* ctx = new HostCallbackContext{callback};
361 return tpu::ExecutorApiFn()->TpuExecutor_HostCallbackFn(
362 executor_, get_stream(stream->implementation()), &HostCallbackTrampoline,
363 ctx);
364 }
365
366 TpuExecutor::StatusOr<std::unique_ptr<::stream_executor::DeviceDescription>>
CreateDeviceDescription() const367 TpuExecutor::CreateDeviceDescription() const {
368 StatusHelper status;
369 SE_DeviceDescription* description =
370 tpu::ExecutorApiFn()->TpuDeviceDescription_NewFn();
371 absl::Cleanup cleanup = [description]() {
372 tpu::ExecutorApiFn()->TpuDeviceDescription_FreeFn(description);
373 };
374 tpu::ExecutorApiFn()->TpuExecutor_CreateDeviceDescriptionFn(
375 executor_, description, status.c_status);
376 if (status.status().ok()) {
377 stream_executor::internal::DeviceDescriptionBuilder builder;
378 CHECK_NE(description->device_vendor, nullptr);
379 builder.set_device_vendor(description->device_vendor);
380 builder.set_name(description->name);
381 builder.set_clock_rate_ghz(description->clock_rate_ghz);
382 builder.set_core_count(description->core_count);
383 builder.set_ecc_enabled(description->ecc_enabled);
384 builder.set_device_memory_size(description->device_memory_size);
385 builder.set_platform_version(description->platform_version);
386 return builder.Build();
387 }
388 return status.status();
389 }
390
391 } // namespace tpu
392 } // namespace tensorflow
393