xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 // Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 
16 #include "absl/container/btree_map.h"
17 #include "absl/container/flat_hash_map.h"
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/strings/str_split.h"
20 #include "absl/synchronization/mutex.h"
21 #include "tensorflow/compiler/xla/pjrt/semaphore.h"
22 #include "tensorflow/compiler/xla/pjrt/worker_thread.h"
23 #include "tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.h"
24 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
25 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow/core/platform/errors.h"
28 
29 namespace tpu_driver {
30 namespace {
31 
32 #define CHECK_EXISTS_OR_RETURN(container, target_op_id, operation_id)  \
33   {                                                                    \
34     auto p = CheckHandleExists(container, target_op_id, operation_id); \
35     if (p != nullptr) return p;                                        \
36   }
37 
38 using xla::OkStatus;
39 using xla::Status;
40 using xla::WorkerThread;
41 
42 const char kPodTpuDriverPrefix[] = "grpc+pod://";
43 
44 class PodTpuDriver;
45 
46 class PodEvent : public Event {
47  public:
PodEvent(PodTpuDriver * driver,int64_t operation_id)48   explicit PodEvent(PodTpuDriver* driver, int64_t operation_id)
49       : driver_(driver), operation_id_(operation_id) {}
operation_id() const50   int64_t operation_id() const { return operation_id_; }
51 
52   xla::Status Await() override;
53 
54   std::optional<xla::Status> AwaitWithTimeout(absl::Duration duration) override;
55 
56   void AddCallback(std::function<void(Status)> callback) override;
57 
58  private:
59   PodTpuDriver* driver_;
60   const int64_t operation_id_;
61 };
62 
63 class ErrorEvent : public PodEvent {
64  public:
ErrorEvent(PodTpuDriver * driver,int64_t operation_id,Status status)65   explicit ErrorEvent(PodTpuDriver* driver, int64_t operation_id, Status status)
66       : PodEvent(driver, operation_id) {
67     status_ = status;
68   }
69 
Await()70   xla::Status Await() override { return status_; }
AwaitWithTimeout(absl::Duration duration)71   std::optional<xla::Status> AwaitWithTimeout(
72       absl::Duration duration) override {
73     return status_;
74   }
AddCallback(std::function<void (Status)> callback)75   void AddCallback(std::function<void(Status)> callback) override {
76     callback(status_);
77   }
78 
79  private:
80   Status status_;
81 };
82 
83 class CombinedEvent : public PodEvent {
84  public:
CombinedEvent(PodTpuDriver * driver,int64_t operation_id,std::vector<std::shared_ptr<Event>> events)85   explicit CombinedEvent(PodTpuDriver* driver, int64_t operation_id,
86                          std::vector<std::shared_ptr<Event>> events)
87       : PodEvent(driver, operation_id), events_(events) {
88     for (auto& event : events_) {
89       event->AddCallback([this](Status s) { IncrementAndCheckComplete(s); });
90     }
91   }
92 
Await()93   xla::Status Await() override {
94     for (auto& event : events_) {
95       TF_RETURN_IF_ERROR(event->Await());
96     }
97     return OkStatus();
98   }
99 
AwaitWithTimeout(absl::Duration duration)100   std::optional<xla::Status> AwaitWithTimeout(
101       absl::Duration duration) override {
102     for (auto& event : events_) {
103       auto start_time = absl::Now();
104       auto status = event->AwaitWithTimeout(duration);
105       duration -= absl::Now() - start_time;
106       if (status == std::nullopt) {
107         return std::nullopt;
108       } else {
109         TF_RETURN_IF_ERROR(status.value());
110       }
111     }
112     return OkStatus();
113   }
114 
AddCallback(std::function<void (Status)> callback)115   void AddCallback(std::function<void(Status)> callback)
116       ABSL_LOCKS_EXCLUDED(mu_) override {
117     bool all_events_completed = false;
118     {
119       absl::MutexLock l(&mu_);
120       all_events_completed = events_completed_ == events_.size();
121     }
122     if (all_events_completed) {
123       callback(event_status_);
124     } else {
125       absl::MutexLock l(&mu_);
126       callbacks_.push_back(std::move(callback));
127     }
128   }
129 
130  private:
IncrementAndCheckComplete(Status s)131   void IncrementAndCheckComplete(Status s) ABSL_LOCKS_EXCLUDED(mu_) {
132     std::vector<std::function<void(Status)>> callbacks;
133     {
134       absl::MutexLock l(&mu_);
135 
136       event_status_ = s;
137       events_completed_++;
138       if (events_completed_ == events_.size()) {
139         // Copy callbacks to a temporary to be invoked outside the mutex.
140         callbacks.assign(callbacks_.begin(), callbacks_.end());
141         callbacks_.clear();
142       } else {
143         return;
144       }
145     }
146 
147     for (const auto& callback : callbacks) {
148       callback(event_status_);
149     }
150   }
151 
152   absl::Mutex mu_;
153   std::vector<std::shared_ptr<Event>> events_;
154   std::vector<std::function<void(Status)>> callbacks_ ABSL_GUARDED_BY(mu_);
155   int64_t events_completed_ ABSL_GUARDED_BY(mu_) = 0;
156   Status event_status_;
157 };
158 
159 class PodBufferHandle : public BufferHandle {
160  public:
PodBufferHandle(PodTpuDriver * driver,int64_t operation_id,int64_t size_in_bytes,std::optional<xla::ShapeProto> shape,int64_t core_id)161   explicit PodBufferHandle(PodTpuDriver* driver, int64_t operation_id,
162                            int64_t size_in_bytes,
163                            std::optional<xla::ShapeProto> shape,
164                            int64_t core_id)
165       : driver_(driver),
166         operation_id_(operation_id),
167         size_in_bytes_(size_in_bytes),
168         shape_(shape),
169         event_(std::make_shared<PodEvent>(driver_, operation_id_)),
170         core_id_(core_id) {}
171 
OnReady()172   std::shared_ptr<Event> OnReady() override { return event_; }
size_in_bytes()173   int64_t size_in_bytes() override { return size_in_bytes_; }
shape()174   std::optional<xla::ShapeProto> shape() override { return shape_; }
175 
operation_id() const176   int64_t operation_id() const { return operation_id_; }
core_id() const177   int64_t core_id() const { return core_id_; }
178 
179  private:
180   PodTpuDriver* driver_;
181   const int64_t operation_id_;
182   const int64_t size_in_bytes_;
183   const std::optional<xla::ShapeProto> shape_;
184   std::shared_ptr<PodEvent> event_;
185   const int64_t core_id_;
186 };
187 
188 class PodCompiledProgramHandle : public CompiledProgramHandle {
189  public:
PodCompiledProgramHandle(PodTpuDriver * driver,int64_t operation_id)190   explicit PodCompiledProgramHandle(PodTpuDriver* driver, int64_t operation_id)
191       : driver_(driver),
192         operation_id_(operation_id),
193         event_(std::make_shared<PodEvent>(driver_, operation_id_)) {}
194 
OnReady()195   std::shared_ptr<Event> OnReady() override { return event_; }
196 
197   xla::Status program_shape(xla::ProgramShapeProto* program_shape) override;
198 
operation_id() const199   int64_t operation_id() const { return operation_id_; }
200 
201  private:
202   PodTpuDriver* driver_;
203   const int64_t operation_id_;
204   std::shared_ptr<PodEvent> event_;
205 };
206 
207 class PodLoadedProgramHandle : public LoadedProgramHandle {
208  public:
PodLoadedProgramHandle(PodTpuDriver * driver,int64_t operation_id,int64_t core_id)209   explicit PodLoadedProgramHandle(PodTpuDriver* driver, int64_t operation_id,
210                                   int64_t core_id)
211       : driver_(driver),
212         operation_id_(operation_id),
213         core_id_(core_id),
214         event_(std::make_shared<PodEvent>(driver_, operation_id_)) {}
215 
OnReady()216   std::shared_ptr<Event> OnReady() override { return event_; }
217 
operation_id() const218   int64_t operation_id() const { return operation_id_; }
core_id() const219   int64_t core_id() const { return core_id_; }
220 
221  private:
222   PodTpuDriver* driver_;
223   const int64_t operation_id_;
224   const int64_t core_id_;
225   std::shared_ptr<PodEvent> event_;
226 };
227 
228 struct EventInFlight {
EventInFlighttpu_driver::__anonf243d9070111::EventInFlight229   EventInFlight()
230       : underlying_event(nullptr),
231         create_fn(nullptr),
232         incomplete_deps(),
233         callbacks() {}
234 
235   std::shared_ptr<Event> underlying_event;
236   std::function<std::shared_ptr<Event>(void)> create_fn;
237 
238   absl::flat_hash_set<int64_t> incomplete_deps;
239   std::vector<std::function<void(Status)>> callbacks;
240 };
241 
242 class PodTpuDriver : public TpuDriver {
243  public:
PodTpuDriver(const TpuDriverConfig & config,std::shared_ptr<::grpc::ChannelCredentials> creds)244   explicit PodTpuDriver(const TpuDriverConfig& config,
245                         std::shared_ptr<::grpc::ChannelCredentials> creds)
246       : config_(config),
247         creds_(creds),
248         event_thread_(tensorflow::Env::Default(), "grpc_pod_event_thread") {
249     std::vector<std::string> workers = absl::StrSplit(
250         absl::StripPrefix(config.worker(), kPodTpuDriverPrefix), ',');
251 
252     int worker_count = 0;
253 
254     // Flag for environments where local core # == all cores in TPU system #,
255     // which means that we are connecting to separate TPU systems or we are in
256     // a test environment.
257     bool in_local_core_environment = false;
258 
259     for (const auto& worker : workers) {
260       TpuDriverConfig worker_config(config_);
261       *(worker_config.mutable_worker()) = absl::StrCat("grpc://", worker);
262       auto tpu_driver = CreateGrpcTpuDriver(worker_config, creds_).value();
263 
264       SystemInfo driver_info;
265       tpu_driver->QuerySystemInfo(&driver_info);
266 
267       if (driver_info.core_count() == driver_info.local_core_size()) {
268         drivers_.insert({worker_count, std::move(tpu_driver)});
269         in_local_core_environment = true;
270       } else {
271         drivers_.insert({driver_info.host_id(), std::move(tpu_driver)});
272       }
273 
274       worker_count++;
275     }
276 
277     absl::flat_hash_set<std::tuple<int, int, int>> processed_chips;
278 
279     for (int driver_num = 0; driver_num < workers.size(); ++driver_num) {
280       SystemInfo driver_info;
281       drivers_[driver_num]->QuerySystemInfo(&driver_info);
282 
283       for (const auto& tpu_chip : driver_info.tpu_chip()) {
284         std::tuple<int, int, int> coord{tpu_chip.chip_coord().x(),
285                                         tpu_chip.chip_coord().y(),
286                                         tpu_chip.chip_coord().z()};
287         // We only want to add chips that we have not seen before if we are in a
288         // TPU pod slice, or we are only seeing local cores (e.g. we are
289         // connected to individual TPUs or we are in a test environment).
290         if (!processed_chips.contains(coord) ||
291             driver_info.core_count() == driver_info.local_core_size()) {
292           *(pod_info_.add_tpu_chip()) = tpu_chip;
293           processed_chips.insert(coord);
294         }
295       }
296 
297       *(pod_info_.mutable_cpu()) = driver_info.cpu();
298     }
299 
300     // Process all the unique chips that we have seen.
301     int core_count = 0;
302     for (auto& tpu_chip : *pod_info_.mutable_tpu_chip()) {
303       for (auto& tpu_core : *tpu_chip.mutable_core()) {
304         int current_core = tpu_core.id();
305         if (in_local_core_environment) {
306           current_core = core_count;
307         }
308 
309         core_to_driver_.insert(
310             {current_core, drivers_[tpu_chip.host_id()].get()});
311         core_to_driver_id_.insert({current_core, tpu_chip.host_id()});
312         core_to_driver_core_.insert({current_core, tpu_core.id()});
313 
314         tpu_core.set_id(current_core);
315         tpu_core.set_core_on_host_index(current_core);
316         *(pod_info_.add_local_core()) = tpu_core;
317 
318         core_count++;
319       }
320 
321       // We are setting host_id to zero because we want this to look like one
322       // host with many cores from the perspective of tpu_client.cc.
323       tpu_chip.set_host_id(0);
324     }
325 
326     pod_info_.set_chip_count(pod_info_.tpu_chip_size());
327     pod_info_.set_core_count(pod_info_.local_core_size());
328 
329     // We want this to look like one host with many TPU chips/cores connected.
330     pod_info_.set_host_count(1);
331     pod_info_.set_host_id(0);
332   }
333 
~PodTpuDriver()334   ~PodTpuDriver() override {
335     // TODO(frankchn): Unload all handles, and wait for all events to finish.
336   }
337 
QuerySystemInfo(SystemInfo * system_info)338   void QuerySystemInfo(SystemInfo* system_info) override {
339     *system_info = pod_info_;
340   }
341 
Reset()342   xla::Status Reset() override {
343     for (auto& driver : drivers_) {
344       TF_RETURN_IF_ERROR(driver.second->Reset());
345     }
346     return OkStatus();
347   }
348 
Allocate(int32_t core_id,MemoryRegion region,int64_t num_bytes,absl::Span<Event * const> wait_for)349   std::unique_ptr<BufferHandle> Allocate(
350       int32_t core_id, MemoryRegion region, int64_t num_bytes,
351       absl::Span<Event* const> wait_for) override {
352     int64_t operation_id = GetOperationId();
353     auto deps = GetDependencyOperationIds(wait_for);
354 
355     ScheduleRequest(
356         operation_id,
357         [this, core_id, region, num_bytes, operation_id]()
358             ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
359               underlying_buffers_.insert(
360                   {operation_id,
361                    core_to_driver_[core_id]->Allocate(
362                        core_to_driver_core_[core_id], region, num_bytes, {})});
363               return underlying_buffers_[operation_id]->OnReady();
364             },
365         deps);
366 
367     return std::make_unique<PodBufferHandle>(this, operation_id, num_bytes,
368                                              std::nullopt, core_id);
369   }
370 
Allocate(int32_t core_id,MemoryRegion region,const xla::ShapeProto & shape,absl::Span<Event * const> wait_for)371   std::unique_ptr<BufferHandle> Allocate(
372       int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape,
373       absl::Span<Event* const> wait_for) override {
374     int64_t operation_id = GetOperationId();
375     auto deps = GetDependencyOperationIds(wait_for);
376 
377     ScheduleRequest(
378         operation_id,
379         [this, core_id, region, shape, operation_id]()
380             ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
381               underlying_buffers_.insert(
382                   {operation_id,
383                    core_to_driver_[core_id]->Allocate(
384                        core_to_driver_core_[core_id], region, shape, {})});
385               return underlying_buffers_[operation_id]->OnReady();
386             },
387         deps);
388 
389     return std::make_unique<PodBufferHandle>(
390         this, operation_id, ComputeBytesFromShape(shape), shape, core_id);
391   }
392 
AllocateTuple(int32_t core_id,MemoryRegion region,absl::Span<BufferHandle * const> children,absl::Span<Event * const> wait_for)393   std::unique_ptr<BufferHandle> AllocateTuple(
394       int32_t core_id, MemoryRegion region,
395       absl::Span<BufferHandle* const> children,
396       absl::Span<Event* const> wait_for) override {
397     int64_t operation_id = GetOperationId();
398     auto deps = GetDependencyOperationIds(wait_for);
399 
400     std::vector<int64_t> children_ids;
401     const size_t children_ids_size = children.size();
402     children_ids.reserve(children_ids_size);
403     for (size_t i = 0; i < children_ids_size; ++i) {
404       auto child_op_id =
405           static_cast<PodBufferHandle* const>(children[i])->operation_id();
406       deps.insert(child_op_id);
407       children_ids.push_back(child_op_id);
408     }
409 
410     ScheduleRequest(
411         operation_id,
412         [this, core_id, region, children_ids, operation_id]()
413             ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr<Event> {
414               std::vector<BufferHandle*> child_buffers;
415               child_buffers.reserve(children_ids.size());
416               for (size_t i = 0; i < children_ids.size(); ++i) {
417                 CHECK_EXISTS_OR_RETURN(underlying_buffers_, children_ids[i],
418                                        operation_id);
419                 child_buffers.push_back(
420                     underlying_buffers_[children_ids[i]].get());
421               }
422 
423               underlying_buffers_.insert(
424                   {operation_id, core_to_driver_[core_id]->AllocateTuple(
425                                      core_to_driver_core_[core_id], region,
426                                      child_buffers, {})});
427               return underlying_buffers_[operation_id]->OnReady();
428             },
429         deps);
430 
431     return std::make_unique<PodBufferHandle>(this, operation_id, 0,
432                                              std::nullopt, core_id);
433   }
434 
Deallocate(std::unique_ptr<BufferHandle> handle,absl::Span<Event * const> wait_for)435   std::shared_ptr<Event> Deallocate(
436       std::unique_ptr<BufferHandle> handle,
437       absl::Span<Event* const> wait_for) override {
438     int64_t operation_id = GetOperationId();
439     auto deps = GetDependencyOperationIds(wait_for);
440     deps.insert(static_cast<PodBufferHandle*>(handle.get())->operation_id());
441 
442     auto op_id = static_cast<PodBufferHandle*>(handle.get())->operation_id();
443     auto core_id = static_cast<PodBufferHandle*>(handle.get())->core_id();
444 
445     ScheduleRequest(
446         operation_id,
447         [this, operation_id, op_id, core_id]()
448             ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr<Event> {
449               CHECK_EXISTS_OR_RETURN(underlying_buffers_, op_id, operation_id);
450 
451               auto buf_iter = underlying_buffers_.find(op_id);
452               auto underlying_hn = std::move(buf_iter->second);
453               underlying_buffers_.erase(buf_iter);
454 
455               return core_to_driver_[core_id]->Deallocate(
456                   std::move(underlying_hn), {});
457             },
458         deps);
459 
460     return std::make_shared<PodEvent>(this, operation_id);
461   }
462 
TransferToDevice(const void * src,BufferHandle * dst,absl::Span<Event * const> wait_for)463   std::shared_ptr<Event> TransferToDevice(
464       const void* src, BufferHandle* dst,
465       absl::Span<Event* const> wait_for) override {
466     int64_t operation_id = GetOperationId();
467     auto deps = GetDependencyOperationIds(wait_for);
468     deps.insert(static_cast<PodBufferHandle*>(dst)->operation_id());
469 
470     auto op_id = static_cast<PodBufferHandle*>(dst)->operation_id();
471     auto core_id = static_cast<PodBufferHandle*>(dst)->core_id();
472 
473     ScheduleRequest(
474         operation_id,
475         [this, src, operation_id, op_id, core_id]()
476             ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr<Event> {
477               CHECK_EXISTS_OR_RETURN(underlying_buffers_, op_id, operation_id);
478 
479               auto buf_iter = underlying_buffers_.find(op_id);
480               return core_to_driver_[core_id]->TransferToDevice(
481                   src, buf_iter->second.get(), {});
482             },
483         deps);
484 
485     return std::make_shared<PodEvent>(this, operation_id);
486   }
487 
TransferFromDevice(const BufferHandle * src,void * dst,absl::Span<Event * const> wait_for)488   std::shared_ptr<Event> TransferFromDevice(
489       const BufferHandle* src, void* dst,
490       absl::Span<Event* const> wait_for) override {
491     int64_t operation_id = GetOperationId();
492     auto deps = GetDependencyOperationIds(wait_for);
493     deps.insert(static_cast<const PodBufferHandle*>(src)->operation_id());
494 
495     auto op_id = static_cast<const PodBufferHandle*>(src)->operation_id();
496     auto core_id = static_cast<const PodBufferHandle*>(src)->core_id();
497 
498     ScheduleRequest(
499         operation_id,
500         [this, dst, operation_id, op_id, core_id]()
501             ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr<Event> {
502               CHECK_EXISTS_OR_RETURN(underlying_buffers_, op_id, operation_id);
503               auto buf_iter = underlying_buffers_.find(op_id);
504               return core_to_driver_[core_id]->TransferFromDevice(
505                   buf_iter->second.get(), dst, {});
506             },
507         deps);
508 
509     return std::make_shared<PodEvent>(this, operation_id);
510   }
511 
TransferFromDeviceToDevice(const BufferHandle * src,BufferHandle * dst,absl::Span<Event * const> wait_for)512   std::shared_ptr<Event> TransferFromDeviceToDevice(
513       const BufferHandle* src, BufferHandle* dst,
514       absl::Span<Event* const> wait_for) override {
515     auto src_core_id = static_cast<const PodBufferHandle*>(src)->core_id();
516     auto dst_core_id = static_cast<PodBufferHandle*>(dst)->core_id();
517 
518     auto src_driver_id = core_to_driver_id_[src_core_id];
519     auto dst_driver_id = core_to_driver_id_[dst_core_id];
520 
521     if (src_driver_id == dst_driver_id) {
522       // They are in the same host, we can schedule it normally
523       int64_t operation_id = GetOperationId();
524       auto deps = GetDependencyOperationIds(wait_for);
525       deps.insert(static_cast<const PodBufferHandle*>(src)->operation_id());
526       deps.insert(static_cast<PodBufferHandle*>(dst)->operation_id());
527 
528       auto src_op_id = static_cast<const PodBufferHandle*>(src)->operation_id();
529       auto dst_op_id = static_cast<PodBufferHandle*>(dst)->operation_id();
530 
531       ScheduleRequest(
532           operation_id,
533           [this, operation_id, src_op_id, dst_op_id, dst_core_id]()
534               ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr<Event> {
535                 CHECK_EXISTS_OR_RETURN(underlying_buffers_, src_op_id,
536                                        operation_id);
537                 CHECK_EXISTS_OR_RETURN(underlying_buffers_, dst_op_id,
538                                        operation_id);
539 
540                 auto src_iter = underlying_buffers_.find(src_op_id);
541                 auto dst_iter = underlying_buffers_.find(dst_op_id);
542                 return core_to_driver_[dst_core_id]->TransferFromDeviceToDevice(
543                     src_iter->second.get(), dst_iter->second.get(), {});
544               },
545           deps);
546       return std::make_shared<PodEvent>(this, operation_id);
547     } else {
548       // src and dst are on different hosts, we have to bounce through us.
549       auto dst_size = dst->size_in_bytes();
550       char* host_buf = new char[dst_size];
551 
552       auto src_event = TransferFromDevice(src, host_buf, wait_for);
553       auto dst_event = TransferToDevice(host_buf, dst, {src_event.get()});
554       dst_event->AddCallback(
555           [src_event, host_buf](xla::Status status) { delete[] host_buf; });
556       return dst_event;
557     }
558   }
559 
CompileProgram(const xla::HloProto & source,int32_t num_replicas,absl::Span<Event * const> wait_for)560   std::unique_ptr<CompiledProgramHandle> CompileProgram(
561       const xla::HloProto& source, int32_t num_replicas,
562       absl::Span<Event* const> wait_for) override {
563     int64_t operation_id = GetOperationId();
564     auto deps = GetDependencyOperationIds(wait_for);
565 
566     ScheduleRequest(
567         operation_id,
568         [this, operation_id, source,
569          num_replicas]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
570           auto cph_iterator =
571               underlying_cph_
572                   .insert(
573                       {operation_id,
574                        std::vector<std::unique_ptr<CompiledProgramHandle>>()})
575                   .first;
576 
577           std::vector<std::shared_ptr<Event>> collected_events;
578           for (int i = 0; i < drivers_.size(); ++i) {
579             auto current_cph =
580                 drivers_[i]->CompileProgram(source, num_replicas, {});
581             cph_iterator->second.push_back(std::move(current_cph));
582             collected_events.push_back(cph_iterator->second[i]->OnReady());
583           }
584           return std::make_shared<CombinedEvent>(this, operation_id,
585                                                  collected_events);
586         },
587         deps);
588 
589     return std::make_unique<PodCompiledProgramHandle>(this, operation_id);
590   }
591 
LoadProgram(int32_t core_id,const CompiledProgramHandle * handle,absl::Span<Event * const> wait_for)592   std::unique_ptr<LoadedProgramHandle> LoadProgram(
593       int32_t core_id, const CompiledProgramHandle* handle,
594       absl::Span<Event* const> wait_for) override {
595     int64_t operation_id = GetOperationId();
596     auto deps = GetDependencyOperationIds(wait_for);
597     deps.insert(
598         static_cast<const PodCompiledProgramHandle*>(handle)->operation_id());
599     auto cph_op_id =
600         static_cast<const PodCompiledProgramHandle*>(handle)->operation_id();
601 
602     ScheduleRequest(
603         operation_id,
604         [this, operation_id, cph_op_id, core_id]()
605             ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr<Event> {
606               CHECK_EXISTS_OR_RETURN(underlying_cph_, cph_op_id, operation_id);
607               auto cph_iter = underlying_cph_.find(cph_op_id);
608 
609               underlying_lph_.insert(
610                   {operation_id,
611                    core_to_driver_[core_id]->LoadProgram(
612                        core_to_driver_core_[core_id],
613                        cph_iter->second[core_to_driver_id_[core_id]].get(),
614                        {})});
615 
616               return underlying_lph_[operation_id]->OnReady();
617             },
618         deps);
619 
620     return std::make_unique<PodLoadedProgramHandle>(this, operation_id,
621                                                     core_id);
622   }
623 
UnloadProgram(std::unique_ptr<LoadedProgramHandle> handle,absl::Span<Event * const> wait_for)624   std::shared_ptr<Event> UnloadProgram(
625       std::unique_ptr<LoadedProgramHandle> handle,
626       absl::Span<Event* const> wait_for) override {
627     int64_t operation_id = GetOperationId();
628     auto deps = GetDependencyOperationIds(wait_for);
629     deps.insert(
630         static_cast<PodLoadedProgramHandle*>(handle.get())->operation_id());
631     auto op_id =
632         static_cast<PodLoadedProgramHandle*>(handle.get())->operation_id();
633     auto core_id =
634         static_cast<PodLoadedProgramHandle*>(handle.get())->core_id();
635 
636     ScheduleRequest(
637         operation_id,
638         [this, operation_id, op_id, core_id]()
639             ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr<Event> {
640               CHECK_EXISTS_OR_RETURN(underlying_lph_, op_id, operation_id);
641               auto lph_iter = underlying_lph_.find(op_id);
642               auto event = core_to_driver_[core_id]->UnloadProgram(
643                   std::move(lph_iter->second), {});
644               underlying_lph_.erase(lph_iter);
645 
646               return event;
647             },
648         deps);
649 
650     return std::make_shared<PodEvent>(this, operation_id);
651   }
652 
ExecuteProgram(LoadedProgramHandle * program,absl::Span<BufferHandle * const> inputs,absl::Span<BufferHandle * const> outputs,const xla::DeviceAssignmentProto & device_assignment,absl::Span<Event * const> wait_for)653   std::shared_ptr<Event> ExecuteProgram(
654       LoadedProgramHandle* program, absl::Span<BufferHandle* const> inputs,
655       absl::Span<BufferHandle* const> outputs,
656       const xla::DeviceAssignmentProto& device_assignment,
657       absl::Span<Event* const> wait_for) override {
658     int64_t operation_id = GetOperationId();
659 
660     auto deps = GetDependencyOperationIds(wait_for);
661     deps.insert(static_cast<PodLoadedProgramHandle*>(program)->operation_id());
662 
663     auto op_id = static_cast<PodLoadedProgramHandle*>(program)->operation_id();
664     auto core_id = static_cast<PodLoadedProgramHandle*>(program)->core_id();
665 
666     std::vector<int64_t> input_op_ids;
667     std::vector<int64_t> output_op_ids;
668     input_op_ids.reserve(inputs.size());
669     output_op_ids.reserve(outputs.size());
670 
671     for (auto* input : inputs) {
672       auto input_dep =
673           static_cast<PodBufferHandle* const>(input)->operation_id();
674       input_op_ids.push_back(input_dep);
675       deps.insert(input_dep);
676     }
677     for (auto* output : outputs) {
678       auto output_dep =
679           static_cast<PodBufferHandle* const>(output)->operation_id();
680       output_op_ids.push_back(output_dep);
681       deps.insert(output_dep);
682     }
683 
684     ScheduleRequest(
685         operation_id,
686         [this, operation_id, core_id, op_id, input_op_ids, output_op_ids,
687          device_assignment]()
688             ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr<Event> {
689               std::vector<BufferHandle*> underlying_inputs;
690               std::vector<BufferHandle*> underlying_outputs;
691 
692               underlying_inputs.reserve(input_op_ids.size());
693               for (auto input_op_id : input_op_ids) {
694                 CHECK_EXISTS_OR_RETURN(underlying_buffers_, input_op_id,
695                                        operation_id);
696                 underlying_inputs.push_back(
697                     underlying_buffers_[input_op_id].get());
698               }
699               underlying_outputs.reserve(output_op_ids.size());
700               for (auto output_op_id : output_op_ids) {
701                 CHECK_EXISTS_OR_RETURN(underlying_buffers_, output_op_id,
702                                        operation_id);
703                 underlying_outputs.push_back(
704                     underlying_buffers_[output_op_id].get());
705               }
706 
707               CHECK_EXISTS_OR_RETURN(underlying_lph_, op_id, operation_id);
708               LoadedProgramHandle* handle = underlying_lph_[op_id].get();
709               return core_to_driver_[core_id]->ExecuteProgram(
710                   handle, underlying_inputs, underlying_outputs,
711                   device_assignment, {});
712             },
713         deps);
714 
715     return std::make_shared<PodEvent>(this, operation_id);
716   }
717 
GetLinearizer()718   std::unique_ptr<TpuLinearizer> GetLinearizer() override {
719     return drivers_[0]->GetLinearizer();
720   }
721 
722   // Helper methods for Event scheduling
723 
WaitForEvent(int64_t event_id,absl::Duration duration)724   std::optional<Status> WaitForEvent(int64_t event_id, absl::Duration duration)
725       ABSL_LOCKS_EXCLUDED(mu_) {
726     std::shared_ptr<Event> underlying_event;
727 
728     {
729       absl::MutexLock l(&mu_);
730       auto event = events_.find(event_id);
731 
732       if (event == events_.end()) {
733         auto event_status = abnormal_event_status_.find(event_id);
734         if (event_status == abnormal_event_status_.end()) {
735           return OkStatus();
736         } else {
737           return event_status->second;
738         }
739       }
740 
741       auto done = [this, event_id]() {
742         mu_.AssertHeld();
743         // The event was either completed and erased from the map or we have
744         // an underlying event available to us.
745         return events_.count(event_id) == 0 ||
746                (events_[event_id]->underlying_event != nullptr &&
747                 events_[event_id]->underlying_event.use_count() != 0);
748       };
749 
750       auto status = mu_.AwaitWithTimeout(absl::Condition(&done), duration);
751       if (!status) {
752         return std::nullopt;
753       }
754 
755       if (events_.count(event_id) > 0) {
756         underlying_event = events_[event_id]->underlying_event;
757       } else {
758         underlying_event = nullptr;
759       }
760     }
761 
762     // Wait for the underlying event without holding on to the event_lock_, or
763     // else incoming events will not be processed.
764     if (underlying_event != nullptr) {
765       return underlying_event->AwaitWithTimeout(duration);
766     } else {
767       absl::MutexLock l(&mu_);
768       auto event_status = abnormal_event_status_.find(event_id);
769       if (event_status == abnormal_event_status_.end()) {
770         return OkStatus();
771       } else {
772         return event_status->second;
773       }
774     }
775   }
776 
AddCallbackForEvent(int64_t event_id,std::function<void (Status)> fn)777   void AddCallbackForEvent(int64_t event_id, std::function<void(Status)> fn)
778       ABSL_LOCKS_EXCLUDED(mu_) {
779     absl::MutexLock l(&mu_);
780     auto event = events_.find(event_id);
781 
782     if (event == events_.end()) {
783       auto event_status = abnormal_event_status_.find(event_id);
784       if (event_status == abnormal_event_status_.end()) {
785         fn(OkStatus());
786       } else {
787         fn(event_status->second);
788       }
789     } else {
790       if (event->second->underlying_event != nullptr &&
791           event->second->underlying_event.use_count() != 0) {
792         event->second->underlying_event->AddCallback(fn);
793       } else {
794         event->second->callbacks.push_back(std::move(fn));
795       }
796     }
797   }
798 
GetCompiledProgramShape(int64_t op_id,xla::ProgramShapeProto * program_shape)799   xla::Status GetCompiledProgramShape(int64_t op_id,
800                                       xla::ProgramShapeProto* program_shape)
801       ABSL_LOCKS_EXCLUDED(mu_) {
802     absl::MutexLock l(&mu_);
803 
804     auto done = [this, op_id]() {
805       mu_.AssertHeld();
806       return underlying_cph_.contains(op_id);
807     };
808     mu_.Await(absl::Condition(&done));
809 
810     return underlying_cph_[op_id][0]->program_shape(program_shape);
811   }
812 
813  private:
814   const TpuDriverConfig& config_;
815   std::shared_ptr<::grpc::ChannelCredentials> creds_;
816 
817   absl::flat_hash_map<int32_t, std::unique_ptr<TpuDriver>> drivers_;
818   absl::flat_hash_map<int32_t, int32_t> core_to_driver_id_;
819   absl::flat_hash_map<int32_t, TpuDriver*> core_to_driver_;
820   absl::flat_hash_map<int32_t, int32_t> core_to_driver_core_;
821   SystemInfo pod_info_;
822 
823   absl::Mutex mu_;
824 
825   absl::flat_hash_map<int64_t, std::unique_ptr<BufferHandle>>
826       underlying_buffers_ ABSL_GUARDED_BY(mu_);
827   absl::flat_hash_map<int64_t,
828                       std::vector<std::unique_ptr<CompiledProgramHandle>>>
829       underlying_cph_ ABSL_GUARDED_BY(mu_);
830   absl::flat_hash_map<int64_t, std::unique_ptr<LoadedProgramHandle>>
831       underlying_lph_ ABSL_GUARDED_BY(mu_);
832 
833   absl::btree_map<int64_t, std::unique_ptr<EventInFlight>> events_
834       ABSL_GUARDED_BY(mu_);
835   absl::flat_hash_map<int64_t, Status> abnormal_event_status_
836       ABSL_GUARDED_BY(mu_);
837 
838   std::atomic<int64_t> operation_id_counter_{0};
839 
840   WorkerThread event_thread_;
841 
GetOperationId()842   int64_t GetOperationId() { return operation_id_counter_++; }
843 
GetDependencyOperationIds(absl::Span<Event * const> wait_for)844   absl::flat_hash_set<int64_t> GetDependencyOperationIds(
845       absl::Span<Event* const> wait_for) {
846     absl::flat_hash_set<int64_t> deps;
847     for (auto* event : wait_for) {
848       deps.insert(static_cast<PodEvent* const>(event)->operation_id());
849     }
850     return deps;
851   }
852 
853   // EventCompleted is executed on the event_thread_ worker thread. We want
854   // to propagate the fact that the event is completed to any subsequent events
855   // that might depend on this event.
EventCompleted(int64_t event_id,Status status)856   void EventCompleted(int64_t event_id, Status status)
857       ABSL_LOCKS_EXCLUDED(mu_) {
858     absl::MutexLock l(&mu_);
859 
860     absl::btree_map<int64_t, std::unique_ptr<EventInFlight>>::iterator
861         curr_event;
862     if (!status.ok()) abnormal_event_status_.insert({event_id, status});
863     curr_event = events_.find(event_id);
864 
865     DCHECK(curr_event->second->callbacks.empty());
866     DCHECK(curr_event->second->incomplete_deps.empty());
867 
868     for (auto& event : events_) {
869       event.second->incomplete_deps.erase(event_id);
870       // The if statement conditions on both
871       //  - all previous events have completed (incomplete_deps.empty())
872       //  - the op creating this event has not been called yet
873       //    (event.second.create_fn != nullptr)
874       // We call the create_fn that creates the event and adds any relevant
875       // callbacks to the actual event, before setting create_fn to nullptr
876       // to indicate that it has already been called
877       if (event.second->incomplete_deps.empty() &&
878           event.second->create_fn != nullptr) {
879         // We were the last unfilled dependency, all other dependencies are
880         // filled. We can now fire the create function.
881         event.second->underlying_event = event.second->create_fn();
882         for (auto& fn : event.second->callbacks) {
883           event.second->underlying_event->AddCallback(std::move(fn));
884         }
885         event.second->callbacks.clear();
886         event.second->create_fn = nullptr;
887       }
888     }
889 
890     // We erase the current event to signal that it has finished.
891     events_.erase(curr_event);
892   }
893 
ScheduleRequest(int64_t operation_id,std::function<std::shared_ptr<Event> (void)> fn,const absl::flat_hash_set<int64_t> & deps)894   void ScheduleRequest(int64_t operation_id,
895                        std::function<std::shared_ptr<Event>(void)> fn,
896                        const absl::flat_hash_set<int64_t>& deps)
897       ABSL_LOCKS_EXCLUDED(mu_) {
898     absl::MutexLock l(&mu_);
899     absl::btree_map<int64_t, std::unique_ptr<EventInFlight>>::iterator event;
900     absl::flat_hash_set<int64_t> incomplete_deps;
901 
902     event =
903         events_.insert({operation_id, std::make_unique<EventInFlight>()}).first;
904     for (const auto& dep : deps) {
905       if (events_.count(dep) > 0) incomplete_deps.insert(dep);
906     }
907 
908     if (incomplete_deps.empty()) {
909       // All dependencies have been fulfilled, we execute the request
910       // immediately and add a callback to inform our event fulfilled thread
911       // when it is done.
912       event->second->create_fn = nullptr;
913       event->second->underlying_event = fn();
914       event->second->underlying_event->AddCallback(
915           [this, operation_id](Status status) {
916             event_thread_.Schedule([this, operation_id, status]() {
917               EventCompleted(operation_id, status);
918             });
919           });
920     } else {
921       // There are some dependencies that are not yet fulfilled. We attach
922       // the request to the event, and will execute it in the EventFulfilled
923       // worker thread when all its dependencies are fulfilled.
924       event->second->create_fn = std::move(fn);
925       event->second->incomplete_deps = std::move(incomplete_deps);
926       event->second->callbacks.push_back([this, operation_id](Status status) {
927         event_thread_.Schedule([this, operation_id, status]() {
928           EventCompleted(operation_id, status);
929         });
930       });
931     }
932   }
933 
934   template <typename T>
CheckHandleExists(absl::flat_hash_map<int64_t,T> & container,int64_t target_op_id,int64_t operation_id)935   std::shared_ptr<Event> CheckHandleExists(
936       absl::flat_hash_map<int64_t, T>& container, int64_t target_op_id,
937       int64_t operation_id) {
938     if (container.count(target_op_id) == 0) {
939       return std::make_shared<ErrorEvent>(
940           this, operation_id,
941           tensorflow::errors::InvalidArgument("Handle ", target_op_id,
942                                               " does not exist."));
943     }
944     return nullptr;
945   }
946 };
947 
Await()948 xla::Status PodEvent::Await() {
949   return driver_->WaitForEvent(operation_id_, absl::InfiniteDuration()).value();
950 }
951 
AwaitWithTimeout(absl::Duration duration)952 std::optional<xla::Status> PodEvent::AwaitWithTimeout(absl::Duration duration) {
953   return driver_->WaitForEvent(operation_id_, duration);
954 }
955 
AddCallback(std::function<void (Status)> callback)956 void PodEvent::AddCallback(std::function<void(Status)> callback) {
957   driver_->AddCallbackForEvent(operation_id_, std::move(callback));
958 }
959 
CreatePodTpuDriver(const TpuDriverConfig & config,std::shared_ptr<::grpc::ChannelCredentials> creds)960 xla::StatusOr<std::unique_ptr<TpuDriver>> CreatePodTpuDriver(
961     const TpuDriverConfig& config,
962     std::shared_ptr<::grpc::ChannelCredentials> creds) {
963   return std::unique_ptr<TpuDriver>(new PodTpuDriver(config, creds));
964 }
965 
program_shape(xla::ProgramShapeProto * program_shape)966 xla::Status PodCompiledProgramHandle::program_shape(
967     xla::ProgramShapeProto* program_shape) {
968   return driver_->GetCompiledProgramShape(operation_id(), program_shape);
969 }
970 
971 }  // namespace
972 
973 REGISTER_TPU_DRIVER(kPodTpuDriverPrefix,
974                     [](const TpuDriverConfig& config)
__anonf243d9071202(const TpuDriverConfig& config) 975                         -> xla::StatusOr<std::unique_ptr<TpuDriver>> {
976                       return CreatePodTpuDriver(
977                           config,
978                           ::grpc::InsecureChannelCredentials());  // NOLINT
979                     });
980 
981 }  // namespace tpu_driver
982