1 // Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15
16 #include "absl/container/btree_map.h"
17 #include "absl/container/flat_hash_map.h"
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/strings/str_split.h"
20 #include "absl/synchronization/mutex.h"
21 #include "tensorflow/compiler/xla/pjrt/semaphore.h"
22 #include "tensorflow/compiler/xla/pjrt/worker_thread.h"
23 #include "tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.h"
24 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
25 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow/core/platform/errors.h"
28
29 namespace tpu_driver {
30 namespace {
31
32 #define CHECK_EXISTS_OR_RETURN(container, target_op_id, operation_id) \
33 { \
34 auto p = CheckHandleExists(container, target_op_id, operation_id); \
35 if (p != nullptr) return p; \
36 }
37
38 using xla::OkStatus;
39 using xla::Status;
40 using xla::WorkerThread;
41
42 const char kPodTpuDriverPrefix[] = "grpc+pod://";
43
44 class PodTpuDriver;
45
46 class PodEvent : public Event {
47 public:
PodEvent(PodTpuDriver * driver,int64_t operation_id)48 explicit PodEvent(PodTpuDriver* driver, int64_t operation_id)
49 : driver_(driver), operation_id_(operation_id) {}
operation_id() const50 int64_t operation_id() const { return operation_id_; }
51
52 xla::Status Await() override;
53
54 std::optional<xla::Status> AwaitWithTimeout(absl::Duration duration) override;
55
56 void AddCallback(std::function<void(Status)> callback) override;
57
58 private:
59 PodTpuDriver* driver_;
60 const int64_t operation_id_;
61 };
62
63 class ErrorEvent : public PodEvent {
64 public:
ErrorEvent(PodTpuDriver * driver,int64_t operation_id,Status status)65 explicit ErrorEvent(PodTpuDriver* driver, int64_t operation_id, Status status)
66 : PodEvent(driver, operation_id) {
67 status_ = status;
68 }
69
Await()70 xla::Status Await() override { return status_; }
AwaitWithTimeout(absl::Duration duration)71 std::optional<xla::Status> AwaitWithTimeout(
72 absl::Duration duration) override {
73 return status_;
74 }
AddCallback(std::function<void (Status)> callback)75 void AddCallback(std::function<void(Status)> callback) override {
76 callback(status_);
77 }
78
79 private:
80 Status status_;
81 };
82
83 class CombinedEvent : public PodEvent {
84 public:
CombinedEvent(PodTpuDriver * driver,int64_t operation_id,std::vector<std::shared_ptr<Event>> events)85 explicit CombinedEvent(PodTpuDriver* driver, int64_t operation_id,
86 std::vector<std::shared_ptr<Event>> events)
87 : PodEvent(driver, operation_id), events_(events) {
88 for (auto& event : events_) {
89 event->AddCallback([this](Status s) { IncrementAndCheckComplete(s); });
90 }
91 }
92
Await()93 xla::Status Await() override {
94 for (auto& event : events_) {
95 TF_RETURN_IF_ERROR(event->Await());
96 }
97 return OkStatus();
98 }
99
AwaitWithTimeout(absl::Duration duration)100 std::optional<xla::Status> AwaitWithTimeout(
101 absl::Duration duration) override {
102 for (auto& event : events_) {
103 auto start_time = absl::Now();
104 auto status = event->AwaitWithTimeout(duration);
105 duration -= absl::Now() - start_time;
106 if (status == std::nullopt) {
107 return std::nullopt;
108 } else {
109 TF_RETURN_IF_ERROR(status.value());
110 }
111 }
112 return OkStatus();
113 }
114
AddCallback(std::function<void (Status)> callback)115 void AddCallback(std::function<void(Status)> callback)
116 ABSL_LOCKS_EXCLUDED(mu_) override {
117 bool all_events_completed = false;
118 {
119 absl::MutexLock l(&mu_);
120 all_events_completed = events_completed_ == events_.size();
121 }
122 if (all_events_completed) {
123 callback(event_status_);
124 } else {
125 absl::MutexLock l(&mu_);
126 callbacks_.push_back(std::move(callback));
127 }
128 }
129
130 private:
IncrementAndCheckComplete(Status s)131 void IncrementAndCheckComplete(Status s) ABSL_LOCKS_EXCLUDED(mu_) {
132 std::vector<std::function<void(Status)>> callbacks;
133 {
134 absl::MutexLock l(&mu_);
135
136 event_status_ = s;
137 events_completed_++;
138 if (events_completed_ == events_.size()) {
139 // Copy callbacks to a temporary to be invoked outside the mutex.
140 callbacks.assign(callbacks_.begin(), callbacks_.end());
141 callbacks_.clear();
142 } else {
143 return;
144 }
145 }
146
147 for (const auto& callback : callbacks) {
148 callback(event_status_);
149 }
150 }
151
152 absl::Mutex mu_;
153 std::vector<std::shared_ptr<Event>> events_;
154 std::vector<std::function<void(Status)>> callbacks_ ABSL_GUARDED_BY(mu_);
155 int64_t events_completed_ ABSL_GUARDED_BY(mu_) = 0;
156 Status event_status_;
157 };
158
159 class PodBufferHandle : public BufferHandle {
160 public:
PodBufferHandle(PodTpuDriver * driver,int64_t operation_id,int64_t size_in_bytes,std::optional<xla::ShapeProto> shape,int64_t core_id)161 explicit PodBufferHandle(PodTpuDriver* driver, int64_t operation_id,
162 int64_t size_in_bytes,
163 std::optional<xla::ShapeProto> shape,
164 int64_t core_id)
165 : driver_(driver),
166 operation_id_(operation_id),
167 size_in_bytes_(size_in_bytes),
168 shape_(shape),
169 event_(std::make_shared<PodEvent>(driver_, operation_id_)),
170 core_id_(core_id) {}
171
OnReady()172 std::shared_ptr<Event> OnReady() override { return event_; }
size_in_bytes()173 int64_t size_in_bytes() override { return size_in_bytes_; }
shape()174 std::optional<xla::ShapeProto> shape() override { return shape_; }
175
operation_id() const176 int64_t operation_id() const { return operation_id_; }
core_id() const177 int64_t core_id() const { return core_id_; }
178
179 private:
180 PodTpuDriver* driver_;
181 const int64_t operation_id_;
182 const int64_t size_in_bytes_;
183 const std::optional<xla::ShapeProto> shape_;
184 std::shared_ptr<PodEvent> event_;
185 const int64_t core_id_;
186 };
187
188 class PodCompiledProgramHandle : public CompiledProgramHandle {
189 public:
PodCompiledProgramHandle(PodTpuDriver * driver,int64_t operation_id)190 explicit PodCompiledProgramHandle(PodTpuDriver* driver, int64_t operation_id)
191 : driver_(driver),
192 operation_id_(operation_id),
193 event_(std::make_shared<PodEvent>(driver_, operation_id_)) {}
194
OnReady()195 std::shared_ptr<Event> OnReady() override { return event_; }
196
197 xla::Status program_shape(xla::ProgramShapeProto* program_shape) override;
198
operation_id() const199 int64_t operation_id() const { return operation_id_; }
200
201 private:
202 PodTpuDriver* driver_;
203 const int64_t operation_id_;
204 std::shared_ptr<PodEvent> event_;
205 };
206
207 class PodLoadedProgramHandle : public LoadedProgramHandle {
208 public:
PodLoadedProgramHandle(PodTpuDriver * driver,int64_t operation_id,int64_t core_id)209 explicit PodLoadedProgramHandle(PodTpuDriver* driver, int64_t operation_id,
210 int64_t core_id)
211 : driver_(driver),
212 operation_id_(operation_id),
213 core_id_(core_id),
214 event_(std::make_shared<PodEvent>(driver_, operation_id_)) {}
215
OnReady()216 std::shared_ptr<Event> OnReady() override { return event_; }
217
operation_id() const218 int64_t operation_id() const { return operation_id_; }
core_id() const219 int64_t core_id() const { return core_id_; }
220
221 private:
222 PodTpuDriver* driver_;
223 const int64_t operation_id_;
224 const int64_t core_id_;
225 std::shared_ptr<PodEvent> event_;
226 };
227
228 struct EventInFlight {
EventInFlighttpu_driver::__anonf243d9070111::EventInFlight229 EventInFlight()
230 : underlying_event(nullptr),
231 create_fn(nullptr),
232 incomplete_deps(),
233 callbacks() {}
234
235 std::shared_ptr<Event> underlying_event;
236 std::function<std::shared_ptr<Event>(void)> create_fn;
237
238 absl::flat_hash_set<int64_t> incomplete_deps;
239 std::vector<std::function<void(Status)>> callbacks;
240 };
241
242 class PodTpuDriver : public TpuDriver {
243 public:
PodTpuDriver(const TpuDriverConfig & config,std::shared_ptr<::grpc::ChannelCredentials> creds)244 explicit PodTpuDriver(const TpuDriverConfig& config,
245 std::shared_ptr<::grpc::ChannelCredentials> creds)
246 : config_(config),
247 creds_(creds),
248 event_thread_(tensorflow::Env::Default(), "grpc_pod_event_thread") {
249 std::vector<std::string> workers = absl::StrSplit(
250 absl::StripPrefix(config.worker(), kPodTpuDriverPrefix), ',');
251
252 int worker_count = 0;
253
254 // Flag for environments where local core # == all cores in TPU system #,
255 // which means that we are connecting to separate TPU systems or we are in
256 // a test environment.
257 bool in_local_core_environment = false;
258
259 for (const auto& worker : workers) {
260 TpuDriverConfig worker_config(config_);
261 *(worker_config.mutable_worker()) = absl::StrCat("grpc://", worker);
262 auto tpu_driver = CreateGrpcTpuDriver(worker_config, creds_).value();
263
264 SystemInfo driver_info;
265 tpu_driver->QuerySystemInfo(&driver_info);
266
267 if (driver_info.core_count() == driver_info.local_core_size()) {
268 drivers_.insert({worker_count, std::move(tpu_driver)});
269 in_local_core_environment = true;
270 } else {
271 drivers_.insert({driver_info.host_id(), std::move(tpu_driver)});
272 }
273
274 worker_count++;
275 }
276
277 absl::flat_hash_set<std::tuple<int, int, int>> processed_chips;
278
279 for (int driver_num = 0; driver_num < workers.size(); ++driver_num) {
280 SystemInfo driver_info;
281 drivers_[driver_num]->QuerySystemInfo(&driver_info);
282
283 for (const auto& tpu_chip : driver_info.tpu_chip()) {
284 std::tuple<int, int, int> coord{tpu_chip.chip_coord().x(),
285 tpu_chip.chip_coord().y(),
286 tpu_chip.chip_coord().z()};
287 // We only want to add chips that we have not seen before if we are in a
288 // TPU pod slice, or we are only seeing local cores (e.g. we are
289 // connected to individual TPUs or we are in a test environment).
290 if (!processed_chips.contains(coord) ||
291 driver_info.core_count() == driver_info.local_core_size()) {
292 *(pod_info_.add_tpu_chip()) = tpu_chip;
293 processed_chips.insert(coord);
294 }
295 }
296
297 *(pod_info_.mutable_cpu()) = driver_info.cpu();
298 }
299
300 // Process all the unique chips that we have seen.
301 int core_count = 0;
302 for (auto& tpu_chip : *pod_info_.mutable_tpu_chip()) {
303 for (auto& tpu_core : *tpu_chip.mutable_core()) {
304 int current_core = tpu_core.id();
305 if (in_local_core_environment) {
306 current_core = core_count;
307 }
308
309 core_to_driver_.insert(
310 {current_core, drivers_[tpu_chip.host_id()].get()});
311 core_to_driver_id_.insert({current_core, tpu_chip.host_id()});
312 core_to_driver_core_.insert({current_core, tpu_core.id()});
313
314 tpu_core.set_id(current_core);
315 tpu_core.set_core_on_host_index(current_core);
316 *(pod_info_.add_local_core()) = tpu_core;
317
318 core_count++;
319 }
320
321 // We are setting host_id to zero because we want this to look like one
322 // host with many cores from the perspective of tpu_client.cc.
323 tpu_chip.set_host_id(0);
324 }
325
326 pod_info_.set_chip_count(pod_info_.tpu_chip_size());
327 pod_info_.set_core_count(pod_info_.local_core_size());
328
329 // We want this to look like one host with many TPU chips/cores connected.
330 pod_info_.set_host_count(1);
331 pod_info_.set_host_id(0);
332 }
333
~PodTpuDriver()334 ~PodTpuDriver() override {
335 // TODO(frankchn): Unload all handles, and wait for all events to finish.
336 }
337
QuerySystemInfo(SystemInfo * system_info)338 void QuerySystemInfo(SystemInfo* system_info) override {
339 *system_info = pod_info_;
340 }
341
Reset()342 xla::Status Reset() override {
343 for (auto& driver : drivers_) {
344 TF_RETURN_IF_ERROR(driver.second->Reset());
345 }
346 return OkStatus();
347 }
348
Allocate(int32_t core_id,MemoryRegion region,int64_t num_bytes,absl::Span<Event * const> wait_for)349 std::unique_ptr<BufferHandle> Allocate(
350 int32_t core_id, MemoryRegion region, int64_t num_bytes,
351 absl::Span<Event* const> wait_for) override {
352 int64_t operation_id = GetOperationId();
353 auto deps = GetDependencyOperationIds(wait_for);
354
355 ScheduleRequest(
356 operation_id,
357 [this, core_id, region, num_bytes, operation_id]()
358 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
359 underlying_buffers_.insert(
360 {operation_id,
361 core_to_driver_[core_id]->Allocate(
362 core_to_driver_core_[core_id], region, num_bytes, {})});
363 return underlying_buffers_[operation_id]->OnReady();
364 },
365 deps);
366
367 return std::make_unique<PodBufferHandle>(this, operation_id, num_bytes,
368 std::nullopt, core_id);
369 }
370
Allocate(int32_t core_id,MemoryRegion region,const xla::ShapeProto & shape,absl::Span<Event * const> wait_for)371 std::unique_ptr<BufferHandle> Allocate(
372 int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape,
373 absl::Span<Event* const> wait_for) override {
374 int64_t operation_id = GetOperationId();
375 auto deps = GetDependencyOperationIds(wait_for);
376
377 ScheduleRequest(
378 operation_id,
379 [this, core_id, region, shape, operation_id]()
380 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
381 underlying_buffers_.insert(
382 {operation_id,
383 core_to_driver_[core_id]->Allocate(
384 core_to_driver_core_[core_id], region, shape, {})});
385 return underlying_buffers_[operation_id]->OnReady();
386 },
387 deps);
388
389 return std::make_unique<PodBufferHandle>(
390 this, operation_id, ComputeBytesFromShape(shape), shape, core_id);
391 }
392
AllocateTuple(int32_t core_id,MemoryRegion region,absl::Span<BufferHandle * const> children,absl::Span<Event * const> wait_for)393 std::unique_ptr<BufferHandle> AllocateTuple(
394 int32_t core_id, MemoryRegion region,
395 absl::Span<BufferHandle* const> children,
396 absl::Span<Event* const> wait_for) override {
397 int64_t operation_id = GetOperationId();
398 auto deps = GetDependencyOperationIds(wait_for);
399
400 std::vector<int64_t> children_ids;
401 const size_t children_ids_size = children.size();
402 children_ids.reserve(children_ids_size);
403 for (size_t i = 0; i < children_ids_size; ++i) {
404 auto child_op_id =
405 static_cast<PodBufferHandle* const>(children[i])->operation_id();
406 deps.insert(child_op_id);
407 children_ids.push_back(child_op_id);
408 }
409
410 ScheduleRequest(
411 operation_id,
412 [this, core_id, region, children_ids, operation_id]()
413 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr<Event> {
414 std::vector<BufferHandle*> child_buffers;
415 child_buffers.reserve(children_ids.size());
416 for (size_t i = 0; i < children_ids.size(); ++i) {
417 CHECK_EXISTS_OR_RETURN(underlying_buffers_, children_ids[i],
418 operation_id);
419 child_buffers.push_back(
420 underlying_buffers_[children_ids[i]].get());
421 }
422
423 underlying_buffers_.insert(
424 {operation_id, core_to_driver_[core_id]->AllocateTuple(
425 core_to_driver_core_[core_id], region,
426 child_buffers, {})});
427 return underlying_buffers_[operation_id]->OnReady();
428 },
429 deps);
430
431 return std::make_unique<PodBufferHandle>(this, operation_id, 0,
432 std::nullopt, core_id);
433 }
434
Deallocate(std::unique_ptr<BufferHandle> handle,absl::Span<Event * const> wait_for)435 std::shared_ptr<Event> Deallocate(
436 std::unique_ptr<BufferHandle> handle,
437 absl::Span<Event* const> wait_for) override {
438 int64_t operation_id = GetOperationId();
439 auto deps = GetDependencyOperationIds(wait_for);
440 deps.insert(static_cast<PodBufferHandle*>(handle.get())->operation_id());
441
442 auto op_id = static_cast<PodBufferHandle*>(handle.get())->operation_id();
443 auto core_id = static_cast<PodBufferHandle*>(handle.get())->core_id();
444
445 ScheduleRequest(
446 operation_id,
447 [this, operation_id, op_id, core_id]()
448 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr<Event> {
449 CHECK_EXISTS_OR_RETURN(underlying_buffers_, op_id, operation_id);
450
451 auto buf_iter = underlying_buffers_.find(op_id);
452 auto underlying_hn = std::move(buf_iter->second);
453 underlying_buffers_.erase(buf_iter);
454
455 return core_to_driver_[core_id]->Deallocate(
456 std::move(underlying_hn), {});
457 },
458 deps);
459
460 return std::make_shared<PodEvent>(this, operation_id);
461 }
462
TransferToDevice(const void * src,BufferHandle * dst,absl::Span<Event * const> wait_for)463 std::shared_ptr<Event> TransferToDevice(
464 const void* src, BufferHandle* dst,
465 absl::Span<Event* const> wait_for) override {
466 int64_t operation_id = GetOperationId();
467 auto deps = GetDependencyOperationIds(wait_for);
468 deps.insert(static_cast<PodBufferHandle*>(dst)->operation_id());
469
470 auto op_id = static_cast<PodBufferHandle*>(dst)->operation_id();
471 auto core_id = static_cast<PodBufferHandle*>(dst)->core_id();
472
473 ScheduleRequest(
474 operation_id,
475 [this, src, operation_id, op_id, core_id]()
476 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr<Event> {
477 CHECK_EXISTS_OR_RETURN(underlying_buffers_, op_id, operation_id);
478
479 auto buf_iter = underlying_buffers_.find(op_id);
480 return core_to_driver_[core_id]->TransferToDevice(
481 src, buf_iter->second.get(), {});
482 },
483 deps);
484
485 return std::make_shared<PodEvent>(this, operation_id);
486 }
487
TransferFromDevice(const BufferHandle * src,void * dst,absl::Span<Event * const> wait_for)488 std::shared_ptr<Event> TransferFromDevice(
489 const BufferHandle* src, void* dst,
490 absl::Span<Event* const> wait_for) override {
491 int64_t operation_id = GetOperationId();
492 auto deps = GetDependencyOperationIds(wait_for);
493 deps.insert(static_cast<const PodBufferHandle*>(src)->operation_id());
494
495 auto op_id = static_cast<const PodBufferHandle*>(src)->operation_id();
496 auto core_id = static_cast<const PodBufferHandle*>(src)->core_id();
497
498 ScheduleRequest(
499 operation_id,
500 [this, dst, operation_id, op_id, core_id]()
501 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr<Event> {
502 CHECK_EXISTS_OR_RETURN(underlying_buffers_, op_id, operation_id);
503 auto buf_iter = underlying_buffers_.find(op_id);
504 return core_to_driver_[core_id]->TransferFromDevice(
505 buf_iter->second.get(), dst, {});
506 },
507 deps);
508
509 return std::make_shared<PodEvent>(this, operation_id);
510 }
511
TransferFromDeviceToDevice(const BufferHandle * src,BufferHandle * dst,absl::Span<Event * const> wait_for)512 std::shared_ptr<Event> TransferFromDeviceToDevice(
513 const BufferHandle* src, BufferHandle* dst,
514 absl::Span<Event* const> wait_for) override {
515 auto src_core_id = static_cast<const PodBufferHandle*>(src)->core_id();
516 auto dst_core_id = static_cast<PodBufferHandle*>(dst)->core_id();
517
518 auto src_driver_id = core_to_driver_id_[src_core_id];
519 auto dst_driver_id = core_to_driver_id_[dst_core_id];
520
521 if (src_driver_id == dst_driver_id) {
522 // They are in the same host, we can schedule it normally
523 int64_t operation_id = GetOperationId();
524 auto deps = GetDependencyOperationIds(wait_for);
525 deps.insert(static_cast<const PodBufferHandle*>(src)->operation_id());
526 deps.insert(static_cast<PodBufferHandle*>(dst)->operation_id());
527
528 auto src_op_id = static_cast<const PodBufferHandle*>(src)->operation_id();
529 auto dst_op_id = static_cast<PodBufferHandle*>(dst)->operation_id();
530
531 ScheduleRequest(
532 operation_id,
533 [this, operation_id, src_op_id, dst_op_id, dst_core_id]()
534 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr<Event> {
535 CHECK_EXISTS_OR_RETURN(underlying_buffers_, src_op_id,
536 operation_id);
537 CHECK_EXISTS_OR_RETURN(underlying_buffers_, dst_op_id,
538 operation_id);
539
540 auto src_iter = underlying_buffers_.find(src_op_id);
541 auto dst_iter = underlying_buffers_.find(dst_op_id);
542 return core_to_driver_[dst_core_id]->TransferFromDeviceToDevice(
543 src_iter->second.get(), dst_iter->second.get(), {});
544 },
545 deps);
546 return std::make_shared<PodEvent>(this, operation_id);
547 } else {
548 // src and dst are on different hosts, we have to bounce through us.
549 auto dst_size = dst->size_in_bytes();
550 char* host_buf = new char[dst_size];
551
552 auto src_event = TransferFromDevice(src, host_buf, wait_for);
553 auto dst_event = TransferToDevice(host_buf, dst, {src_event.get()});
554 dst_event->AddCallback(
555 [src_event, host_buf](xla::Status status) { delete[] host_buf; });
556 return dst_event;
557 }
558 }
559
CompileProgram(const xla::HloProto & source,int32_t num_replicas,absl::Span<Event * const> wait_for)560 std::unique_ptr<CompiledProgramHandle> CompileProgram(
561 const xla::HloProto& source, int32_t num_replicas,
562 absl::Span<Event* const> wait_for) override {
563 int64_t operation_id = GetOperationId();
564 auto deps = GetDependencyOperationIds(wait_for);
565
566 ScheduleRequest(
567 operation_id,
568 [this, operation_id, source,
569 num_replicas]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
570 auto cph_iterator =
571 underlying_cph_
572 .insert(
573 {operation_id,
574 std::vector<std::unique_ptr<CompiledProgramHandle>>()})
575 .first;
576
577 std::vector<std::shared_ptr<Event>> collected_events;
578 for (int i = 0; i < drivers_.size(); ++i) {
579 auto current_cph =
580 drivers_[i]->CompileProgram(source, num_replicas, {});
581 cph_iterator->second.push_back(std::move(current_cph));
582 collected_events.push_back(cph_iterator->second[i]->OnReady());
583 }
584 return std::make_shared<CombinedEvent>(this, operation_id,
585 collected_events);
586 },
587 deps);
588
589 return std::make_unique<PodCompiledProgramHandle>(this, operation_id);
590 }
591
LoadProgram(int32_t core_id,const CompiledProgramHandle * handle,absl::Span<Event * const> wait_for)592 std::unique_ptr<LoadedProgramHandle> LoadProgram(
593 int32_t core_id, const CompiledProgramHandle* handle,
594 absl::Span<Event* const> wait_for) override {
595 int64_t operation_id = GetOperationId();
596 auto deps = GetDependencyOperationIds(wait_for);
597 deps.insert(
598 static_cast<const PodCompiledProgramHandle*>(handle)->operation_id());
599 auto cph_op_id =
600 static_cast<const PodCompiledProgramHandle*>(handle)->operation_id();
601
602 ScheduleRequest(
603 operation_id,
604 [this, operation_id, cph_op_id, core_id]()
605 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr<Event> {
606 CHECK_EXISTS_OR_RETURN(underlying_cph_, cph_op_id, operation_id);
607 auto cph_iter = underlying_cph_.find(cph_op_id);
608
609 underlying_lph_.insert(
610 {operation_id,
611 core_to_driver_[core_id]->LoadProgram(
612 core_to_driver_core_[core_id],
613 cph_iter->second[core_to_driver_id_[core_id]].get(),
614 {})});
615
616 return underlying_lph_[operation_id]->OnReady();
617 },
618 deps);
619
620 return std::make_unique<PodLoadedProgramHandle>(this, operation_id,
621 core_id);
622 }
623
UnloadProgram(std::unique_ptr<LoadedProgramHandle> handle,absl::Span<Event * const> wait_for)624 std::shared_ptr<Event> UnloadProgram(
625 std::unique_ptr<LoadedProgramHandle> handle,
626 absl::Span<Event* const> wait_for) override {
627 int64_t operation_id = GetOperationId();
628 auto deps = GetDependencyOperationIds(wait_for);
629 deps.insert(
630 static_cast<PodLoadedProgramHandle*>(handle.get())->operation_id());
631 auto op_id =
632 static_cast<PodLoadedProgramHandle*>(handle.get())->operation_id();
633 auto core_id =
634 static_cast<PodLoadedProgramHandle*>(handle.get())->core_id();
635
636 ScheduleRequest(
637 operation_id,
638 [this, operation_id, op_id, core_id]()
639 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr<Event> {
640 CHECK_EXISTS_OR_RETURN(underlying_lph_, op_id, operation_id);
641 auto lph_iter = underlying_lph_.find(op_id);
642 auto event = core_to_driver_[core_id]->UnloadProgram(
643 std::move(lph_iter->second), {});
644 underlying_lph_.erase(lph_iter);
645
646 return event;
647 },
648 deps);
649
650 return std::make_shared<PodEvent>(this, operation_id);
651 }
652
ExecuteProgram(LoadedProgramHandle * program,absl::Span<BufferHandle * const> inputs,absl::Span<BufferHandle * const> outputs,const xla::DeviceAssignmentProto & device_assignment,absl::Span<Event * const> wait_for)653 std::shared_ptr<Event> ExecuteProgram(
654 LoadedProgramHandle* program, absl::Span<BufferHandle* const> inputs,
655 absl::Span<BufferHandle* const> outputs,
656 const xla::DeviceAssignmentProto& device_assignment,
657 absl::Span<Event* const> wait_for) override {
658 int64_t operation_id = GetOperationId();
659
660 auto deps = GetDependencyOperationIds(wait_for);
661 deps.insert(static_cast<PodLoadedProgramHandle*>(program)->operation_id());
662
663 auto op_id = static_cast<PodLoadedProgramHandle*>(program)->operation_id();
664 auto core_id = static_cast<PodLoadedProgramHandle*>(program)->core_id();
665
666 std::vector<int64_t> input_op_ids;
667 std::vector<int64_t> output_op_ids;
668 input_op_ids.reserve(inputs.size());
669 output_op_ids.reserve(outputs.size());
670
671 for (auto* input : inputs) {
672 auto input_dep =
673 static_cast<PodBufferHandle* const>(input)->operation_id();
674 input_op_ids.push_back(input_dep);
675 deps.insert(input_dep);
676 }
677 for (auto* output : outputs) {
678 auto output_dep =
679 static_cast<PodBufferHandle* const>(output)->operation_id();
680 output_op_ids.push_back(output_dep);
681 deps.insert(output_dep);
682 }
683
684 ScheduleRequest(
685 operation_id,
686 [this, operation_id, core_id, op_id, input_op_ids, output_op_ids,
687 device_assignment]()
688 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr<Event> {
689 std::vector<BufferHandle*> underlying_inputs;
690 std::vector<BufferHandle*> underlying_outputs;
691
692 underlying_inputs.reserve(input_op_ids.size());
693 for (auto input_op_id : input_op_ids) {
694 CHECK_EXISTS_OR_RETURN(underlying_buffers_, input_op_id,
695 operation_id);
696 underlying_inputs.push_back(
697 underlying_buffers_[input_op_id].get());
698 }
699 underlying_outputs.reserve(output_op_ids.size());
700 for (auto output_op_id : output_op_ids) {
701 CHECK_EXISTS_OR_RETURN(underlying_buffers_, output_op_id,
702 operation_id);
703 underlying_outputs.push_back(
704 underlying_buffers_[output_op_id].get());
705 }
706
707 CHECK_EXISTS_OR_RETURN(underlying_lph_, op_id, operation_id);
708 LoadedProgramHandle* handle = underlying_lph_[op_id].get();
709 return core_to_driver_[core_id]->ExecuteProgram(
710 handle, underlying_inputs, underlying_outputs,
711 device_assignment, {});
712 },
713 deps);
714
715 return std::make_shared<PodEvent>(this, operation_id);
716 }
717
GetLinearizer()718 std::unique_ptr<TpuLinearizer> GetLinearizer() override {
719 return drivers_[0]->GetLinearizer();
720 }
721
722 // Helper methods for Event scheduling
723
WaitForEvent(int64_t event_id,absl::Duration duration)724 std::optional<Status> WaitForEvent(int64_t event_id, absl::Duration duration)
725 ABSL_LOCKS_EXCLUDED(mu_) {
726 std::shared_ptr<Event> underlying_event;
727
728 {
729 absl::MutexLock l(&mu_);
730 auto event = events_.find(event_id);
731
732 if (event == events_.end()) {
733 auto event_status = abnormal_event_status_.find(event_id);
734 if (event_status == abnormal_event_status_.end()) {
735 return OkStatus();
736 } else {
737 return event_status->second;
738 }
739 }
740
741 auto done = [this, event_id]() {
742 mu_.AssertHeld();
743 // The event was either completed and erased from the map or we have
744 // an underlying event available to us.
745 return events_.count(event_id) == 0 ||
746 (events_[event_id]->underlying_event != nullptr &&
747 events_[event_id]->underlying_event.use_count() != 0);
748 };
749
750 auto status = mu_.AwaitWithTimeout(absl::Condition(&done), duration);
751 if (!status) {
752 return std::nullopt;
753 }
754
755 if (events_.count(event_id) > 0) {
756 underlying_event = events_[event_id]->underlying_event;
757 } else {
758 underlying_event = nullptr;
759 }
760 }
761
762 // Wait for the underlying event without holding on to the event_lock_, or
763 // else incoming events will not be processed.
764 if (underlying_event != nullptr) {
765 return underlying_event->AwaitWithTimeout(duration);
766 } else {
767 absl::MutexLock l(&mu_);
768 auto event_status = abnormal_event_status_.find(event_id);
769 if (event_status == abnormal_event_status_.end()) {
770 return OkStatus();
771 } else {
772 return event_status->second;
773 }
774 }
775 }
776
AddCallbackForEvent(int64_t event_id,std::function<void (Status)> fn)777 void AddCallbackForEvent(int64_t event_id, std::function<void(Status)> fn)
778 ABSL_LOCKS_EXCLUDED(mu_) {
779 absl::MutexLock l(&mu_);
780 auto event = events_.find(event_id);
781
782 if (event == events_.end()) {
783 auto event_status = abnormal_event_status_.find(event_id);
784 if (event_status == abnormal_event_status_.end()) {
785 fn(OkStatus());
786 } else {
787 fn(event_status->second);
788 }
789 } else {
790 if (event->second->underlying_event != nullptr &&
791 event->second->underlying_event.use_count() != 0) {
792 event->second->underlying_event->AddCallback(fn);
793 } else {
794 event->second->callbacks.push_back(std::move(fn));
795 }
796 }
797 }
798
GetCompiledProgramShape(int64_t op_id,xla::ProgramShapeProto * program_shape)799 xla::Status GetCompiledProgramShape(int64_t op_id,
800 xla::ProgramShapeProto* program_shape)
801 ABSL_LOCKS_EXCLUDED(mu_) {
802 absl::MutexLock l(&mu_);
803
804 auto done = [this, op_id]() {
805 mu_.AssertHeld();
806 return underlying_cph_.contains(op_id);
807 };
808 mu_.Await(absl::Condition(&done));
809
810 return underlying_cph_[op_id][0]->program_shape(program_shape);
811 }
812
813 private:
814 const TpuDriverConfig& config_;
815 std::shared_ptr<::grpc::ChannelCredentials> creds_;
816
817 absl::flat_hash_map<int32_t, std::unique_ptr<TpuDriver>> drivers_;
818 absl::flat_hash_map<int32_t, int32_t> core_to_driver_id_;
819 absl::flat_hash_map<int32_t, TpuDriver*> core_to_driver_;
820 absl::flat_hash_map<int32_t, int32_t> core_to_driver_core_;
821 SystemInfo pod_info_;
822
823 absl::Mutex mu_;
824
825 absl::flat_hash_map<int64_t, std::unique_ptr<BufferHandle>>
826 underlying_buffers_ ABSL_GUARDED_BY(mu_);
827 absl::flat_hash_map<int64_t,
828 std::vector<std::unique_ptr<CompiledProgramHandle>>>
829 underlying_cph_ ABSL_GUARDED_BY(mu_);
830 absl::flat_hash_map<int64_t, std::unique_ptr<LoadedProgramHandle>>
831 underlying_lph_ ABSL_GUARDED_BY(mu_);
832
833 absl::btree_map<int64_t, std::unique_ptr<EventInFlight>> events_
834 ABSL_GUARDED_BY(mu_);
835 absl::flat_hash_map<int64_t, Status> abnormal_event_status_
836 ABSL_GUARDED_BY(mu_);
837
838 std::atomic<int64_t> operation_id_counter_{0};
839
840 WorkerThread event_thread_;
841
GetOperationId()842 int64_t GetOperationId() { return operation_id_counter_++; }
843
GetDependencyOperationIds(absl::Span<Event * const> wait_for)844 absl::flat_hash_set<int64_t> GetDependencyOperationIds(
845 absl::Span<Event* const> wait_for) {
846 absl::flat_hash_set<int64_t> deps;
847 for (auto* event : wait_for) {
848 deps.insert(static_cast<PodEvent* const>(event)->operation_id());
849 }
850 return deps;
851 }
852
853 // EventCompleted is executed on the event_thread_ worker thread. We want
854 // to propagate the fact that the event is completed to any subsequent events
855 // that might depend on this event.
EventCompleted(int64_t event_id,Status status)856 void EventCompleted(int64_t event_id, Status status)
857 ABSL_LOCKS_EXCLUDED(mu_) {
858 absl::MutexLock l(&mu_);
859
860 absl::btree_map<int64_t, std::unique_ptr<EventInFlight>>::iterator
861 curr_event;
862 if (!status.ok()) abnormal_event_status_.insert({event_id, status});
863 curr_event = events_.find(event_id);
864
865 DCHECK(curr_event->second->callbacks.empty());
866 DCHECK(curr_event->second->incomplete_deps.empty());
867
868 for (auto& event : events_) {
869 event.second->incomplete_deps.erase(event_id);
870 // The if statement conditions on both
871 // - all previous events have completed (incomplete_deps.empty())
872 // - the op creating this event has not been called yet
873 // (event.second.create_fn != nullptr)
874 // We call the create_fn that creates the event and adds any relevant
875 // callbacks to the actual event, before setting create_fn to nullptr
876 // to indicate that it has already been called
877 if (event.second->incomplete_deps.empty() &&
878 event.second->create_fn != nullptr) {
879 // We were the last unfilled dependency, all other dependencies are
880 // filled. We can now fire the create function.
881 event.second->underlying_event = event.second->create_fn();
882 for (auto& fn : event.second->callbacks) {
883 event.second->underlying_event->AddCallback(std::move(fn));
884 }
885 event.second->callbacks.clear();
886 event.second->create_fn = nullptr;
887 }
888 }
889
890 // We erase the current event to signal that it has finished.
891 events_.erase(curr_event);
892 }
893
ScheduleRequest(int64_t operation_id,std::function<std::shared_ptr<Event> (void)> fn,const absl::flat_hash_set<int64_t> & deps)894 void ScheduleRequest(int64_t operation_id,
895 std::function<std::shared_ptr<Event>(void)> fn,
896 const absl::flat_hash_set<int64_t>& deps)
897 ABSL_LOCKS_EXCLUDED(mu_) {
898 absl::MutexLock l(&mu_);
899 absl::btree_map<int64_t, std::unique_ptr<EventInFlight>>::iterator event;
900 absl::flat_hash_set<int64_t> incomplete_deps;
901
902 event =
903 events_.insert({operation_id, std::make_unique<EventInFlight>()}).first;
904 for (const auto& dep : deps) {
905 if (events_.count(dep) > 0) incomplete_deps.insert(dep);
906 }
907
908 if (incomplete_deps.empty()) {
909 // All dependencies have been fulfilled, we execute the request
910 // immediately and add a callback to inform our event fulfilled thread
911 // when it is done.
912 event->second->create_fn = nullptr;
913 event->second->underlying_event = fn();
914 event->second->underlying_event->AddCallback(
915 [this, operation_id](Status status) {
916 event_thread_.Schedule([this, operation_id, status]() {
917 EventCompleted(operation_id, status);
918 });
919 });
920 } else {
921 // There are some dependencies that are not yet fulfilled. We attach
922 // the request to the event, and will execute it in the EventFulfilled
923 // worker thread when all its dependencies are fulfilled.
924 event->second->create_fn = std::move(fn);
925 event->second->incomplete_deps = std::move(incomplete_deps);
926 event->second->callbacks.push_back([this, operation_id](Status status) {
927 event_thread_.Schedule([this, operation_id, status]() {
928 EventCompleted(operation_id, status);
929 });
930 });
931 }
932 }
933
934 template <typename T>
CheckHandleExists(absl::flat_hash_map<int64_t,T> & container,int64_t target_op_id,int64_t operation_id)935 std::shared_ptr<Event> CheckHandleExists(
936 absl::flat_hash_map<int64_t, T>& container, int64_t target_op_id,
937 int64_t operation_id) {
938 if (container.count(target_op_id) == 0) {
939 return std::make_shared<ErrorEvent>(
940 this, operation_id,
941 tensorflow::errors::InvalidArgument("Handle ", target_op_id,
942 " does not exist."));
943 }
944 return nullptr;
945 }
946 };
947
Await()948 xla::Status PodEvent::Await() {
949 return driver_->WaitForEvent(operation_id_, absl::InfiniteDuration()).value();
950 }
951
AwaitWithTimeout(absl::Duration duration)952 std::optional<xla::Status> PodEvent::AwaitWithTimeout(absl::Duration duration) {
953 return driver_->WaitForEvent(operation_id_, duration);
954 }
955
AddCallback(std::function<void (Status)> callback)956 void PodEvent::AddCallback(std::function<void(Status)> callback) {
957 driver_->AddCallbackForEvent(operation_id_, std::move(callback));
958 }
959
CreatePodTpuDriver(const TpuDriverConfig & config,std::shared_ptr<::grpc::ChannelCredentials> creds)960 xla::StatusOr<std::unique_ptr<TpuDriver>> CreatePodTpuDriver(
961 const TpuDriverConfig& config,
962 std::shared_ptr<::grpc::ChannelCredentials> creds) {
963 return std::unique_ptr<TpuDriver>(new PodTpuDriver(config, creds));
964 }
965
program_shape(xla::ProgramShapeProto * program_shape)966 xla::Status PodCompiledProgramHandle::program_shape(
967 xla::ProgramShapeProto* program_shape) {
968 return driver_->GetCompiledProgramShape(operation_id(), program_shape);
969 }
970
971 } // namespace
972
973 REGISTER_TPU_DRIVER(kPodTpuDriverPrefix,
974 [](const TpuDriverConfig& config)
__anonf243d9071202(const TpuDriverConfig& config) 975 -> xla::StatusOr<std::unique_ptr<TpuDriver>> {
976 return CreatePodTpuDriver(
977 config,
978 ::grpc::InsecureChannelCredentials()); // NOLINT
979 });
980
981 } // namespace tpu_driver
982