xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/tpu_driver/recording_tpu_driver.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 // Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include <atomic>
16 #include <functional>
17 #include <optional>
18 
19 #include "absl/base/internal/sysinfo.h"
20 #include "absl/strings/str_split.h"
21 #include "absl/strings/string_view.h"
22 #include "tensorflow/compiler/xla/python/tpu_driver/platform/external/compat.h"
23 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
24 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
25 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_service.grpc.pb.h"
26 #include "tensorflow/core/platform/file_system.h"
27 #include "tensorflow/core/platform/threadpool.h"
28 
29 /*
30  * The ReplayDriver wraps a concrete TpuDriver implementation and records the
31  * stream of operations to a log file. This log can be later replayed and
32  * analyzed for debugging.
33  */
34 
35 namespace tpu_driver {
36 namespace {
37 
38 static std::atomic<int64_t> id_counter(0);
39 
40 using xla::Status;
41 
42 class RecordingTpuDriver;
43 
44 class RecordingEvent : public Event {
45  public:
RecordingEvent(std::shared_ptr<Event> event)46   explicit RecordingEvent(std::shared_ptr<Event> event)
47       : shared_event_(std::move(event)), id_(id_counter++) {}
48 
RecordingEvent(std::shared_ptr<Event> event,int64_t id)49   explicit RecordingEvent(std::shared_ptr<Event> event, int64_t id)
50       : shared_event_(event), id_(id) {}
51 
~RecordingEvent()52   ~RecordingEvent() override {}
53 
Await()54   xla::Status Await() override { return shared_event_->Await(); }
55 
AwaitWithTimeout(absl::Duration duration)56   std::optional<xla::Status> AwaitWithTimeout(
57       absl::Duration duration) override {
58     return shared_event_->AwaitWithTimeout(duration);
59   }
60 
AddCallback(std::function<void (xla::Status)> callback)61   void AddCallback(std::function<void(xla::Status)> callback) override {
62     return shared_event_->AddCallback(callback);
63   }
64 
65  private:
66   std::shared_ptr<Event> shared_event_;
67 
68   int64_t id_;
69   friend class RecordingTpuDriver;
70 };
71 
72 class RecordingBufferHandle : public BufferHandle {
73  public:
RecordingBufferHandle(std::unique_ptr<BufferHandle> handle)74   explicit RecordingBufferHandle(std::unique_ptr<BufferHandle> handle)
75       : handle_(std::move(handle)),
76         id_(id_counter++),
77         event_(std::make_shared<RecordingEvent>(handle_->OnReady(), id_)) {}
OnReady()78   std::shared_ptr<Event> OnReady() override { return event_; }
size_in_bytes()79   int64_t size_in_bytes() override { return handle_->size_in_bytes(); }
shape()80   std::optional<xla::ShapeProto> shape() override { return handle_->shape(); }
81 
82  private:
83   std::unique_ptr<BufferHandle> handle_;
84   int64_t id_;
85   std::shared_ptr<RecordingEvent> event_;
86   friend class RecordingTpuDriver;
87 };
88 
89 class RecordingCompiledProgramHandle : public CompiledProgramHandle {
90  public:
RecordingCompiledProgramHandle(std::unique_ptr<CompiledProgramHandle> handle)91   explicit RecordingCompiledProgramHandle(
92       std::unique_ptr<CompiledProgramHandle> handle)
93       : handle_(std::move(handle)),
94         id_(id_counter++),
95         event_(std::make_shared<RecordingEvent>(handle_->OnReady(), id_)) {}
OnReady()96   std::shared_ptr<Event> OnReady() override { return event_; }
size_in_bytes()97   int64_t size_in_bytes() override { return handle_->size_in_bytes(); }
program_shape(xla::ProgramShapeProto * program_shape)98   xla::Status program_shape(xla::ProgramShapeProto* program_shape) override {
99     return handle_->program_shape(program_shape);
100   }
101 
102  private:
103   std::unique_ptr<CompiledProgramHandle> handle_;
104   int64_t id_;
105   std::shared_ptr<RecordingEvent> event_;
106   friend class RecordingTpuDriver;
107 };
108 
109 class RecordingLoadedProgramHandle : public LoadedProgramHandle {
110  public:
RecordingLoadedProgramHandle(std::unique_ptr<LoadedProgramHandle> handle)111   explicit RecordingLoadedProgramHandle(
112       std::unique_ptr<LoadedProgramHandle> handle)
113       : handle_(std::move(handle)),
114         id_(id_counter++),
115         event_(std::make_shared<RecordingEvent>(handle_->OnReady(), id_)) {}
OnReady()116   std::shared_ptr<Event> OnReady() override { return event_; }
size_in_bytes()117   int64_t size_in_bytes() override { return handle_->size_in_bytes(); }
118 
119  private:
120   std::unique_ptr<LoadedProgramHandle> handle_;
121   int64_t id_;
122   std::shared_ptr<RecordingEvent> event_;
123   friend class RecordingTpuDriver;
124 };
125 
126 class RecordingTpuDriver : public TpuDriver {
127  public:
RecordingTpuDriver(std::unique_ptr<TpuDriver> driver,const std::string recording_path,const bool flush)128   explicit RecordingTpuDriver(std::unique_ptr<TpuDriver> driver,
129                               const std::string recording_path,
130                               const bool flush)
131       : driver_(std::move(driver)),
132         recording_path_(recording_path),
133         flush_(flush) {
134     auto file_status = tensorflow::Env::Default()->NewAppendableFile(
135         recording_path_, &log_file_);
136     if (!file_status.ok()) {
137       LOG(FATAL) << "Unable to open " << recording_path_
138                  << " for appending. Error: " << file_status.ToString();
139     }
140   }
~RecordingTpuDriver()141   ~RecordingTpuDriver() override {
142     {
143       log_file_->Flush().IgnoreError();
144       log_file_->Close().IgnoreError();
145       log_file_ = nullptr;
146     }
147   }
148 
QuerySystemInfo(SystemInfo * system_info)149   void QuerySystemInfo(SystemInfo* system_info) override {
150     // TODO(frankchn): Should we even save this event, since it is out-of-band.
151     driver_->QuerySystemInfo(system_info);
152   }
153 
Reset()154   Status Reset() override { return driver_->Reset(); }
155 
Allocate(int32_t core_id,MemoryRegion region,int64_t num_bytes,absl::Span<Event * const> wait_for)156   std::unique_ptr<BufferHandle> Allocate(
157       int32_t core_id, MemoryRegion region, int64_t num_bytes,
158       absl::Span<Event* const> wait_for) override {
159     auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
160 
161     auto thread_id = GetCurrentThreadId();
162     auto handle =
163         driver_->Allocate(core_id, region, num_bytes, unwrapped_wait_for);
164     auto recording_handle =
165         std::make_unique<RecordingBufferHandle>(std::move(handle));
166     auto handle_id = recording_handle->id_;
167 
168     {
169       StreamRequest::Entry r;
170       r.mutable_alloc()->set_core_id(core_id);
171       r.mutable_alloc()->set_region(region);
172       r.mutable_alloc()->set_num_bytes(num_bytes);
173 
174       PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id);
175     }
176 
177     return recording_handle;
178   }
179 
Allocate(int32_t core_id,MemoryRegion region,const xla::ShapeProto & shape,absl::Span<Event * const> wait_for)180   std::unique_ptr<BufferHandle> Allocate(
181       int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape,
182       absl::Span<Event* const> wait_for) override {
183     auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
184 
185     auto thread_id = GetCurrentThreadId();
186     auto handle = driver_->Allocate(core_id, region, shape, unwrapped_wait_for);
187     auto recording_handle =
188         std::make_unique<RecordingBufferHandle>(std::move(handle));
189     auto handle_id = recording_handle->id_;
190 
191     {
192       StreamRequest::Entry r;
193       r.mutable_alloc()->set_core_id(core_id);
194       r.mutable_alloc()->set_region(region);
195       *(r.mutable_alloc()->mutable_shape()) = shape;
196 
197       PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id);
198     }
199 
200     return recording_handle;
201   }
202 
AllocateTuple(int32_t core_id,MemoryRegion region,absl::Span<BufferHandle * const> children,absl::Span<Event * const> wait_for)203   std::unique_ptr<BufferHandle> AllocateTuple(
204       int32_t core_id, MemoryRegion region,
205       absl::Span<BufferHandle* const> children,
206       absl::Span<Event* const> wait_for) override {
207     auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
208 
209     std::vector<BufferHandle*> unwrapped_children;
210     std::vector<int64_t> child_ids;
211     const auto children_size = children.size();
212     unwrapped_children.reserve(children_size);
213     child_ids.reserve(children_size);
214     for (auto child : children) {
215       BufferHandle* unwrapped_child =
216           static_cast<const RecordingBufferHandle*>(child)->handle_.get();
217       unwrapped_children.push_back(unwrapped_child);
218       child_ids.push_back(
219           static_cast<const RecordingBufferHandle*>(child)->id_);
220     }
221 
222     auto thread_id = GetCurrentThreadId();
223     auto handle = driver_->AllocateTuple(core_id, region, unwrapped_children,
224                                          unwrapped_wait_for);
225     auto recording_handle =
226         std::make_unique<RecordingBufferHandle>(std::move(handle));
227     auto handle_id = recording_handle->id_;
228 
229     {
230       StreamRequest::Entry r;
231       r.mutable_alloc_tuple()->set_core_id(core_id);
232       r.mutable_alloc_tuple()->set_region(region);
233 
234       for (auto child : child_ids) {
235         r.mutable_alloc_tuple()->add_children(child);
236       }
237 
238       PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id);
239     }
240 
241     return recording_handle;
242   }
243 
Deallocate(std::unique_ptr<BufferHandle> handle,absl::Span<Event * const> wait_for)244   std::shared_ptr<Event> Deallocate(
245       std::unique_ptr<BufferHandle> handle,
246       absl::Span<Event* const> wait_for) override {
247     auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
248 
249     auto thread_id = GetCurrentThreadId();
250     auto recording_handle = static_cast<RecordingBufferHandle*>(handle.get());
251     int64_t recording_handle_id = recording_handle->id_;
252     auto event = driver_->Deallocate(std::move(recording_handle->handle_),
253                                      unwrapped_wait_for);
254     auto recording_event = std::make_shared<RecordingEvent>(std::move(event));
255     int64_t event_id = recording_event->id_;
256 
257     {
258       StreamRequest::Entry r;
259       r.mutable_dealloc()->set_handle(recording_handle_id);
260       PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
261     }
262 
263     return recording_event;
264   }
265 
TransferToDevice(const void * src,BufferHandle * dst,absl::Span<Event * const> wait_for)266   std::shared_ptr<Event> TransferToDevice(
267       const void* src, BufferHandle* dst,
268       absl::Span<Event* const> wait_for) override {
269     int64_t num_bytes = dst->size_in_bytes();
270     auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
271 
272     auto thread_id = GetCurrentThreadId();
273     auto recording_handle = static_cast<RecordingBufferHandle*>(dst);
274     int64_t recording_handle_id = recording_handle->id_;
275     auto recording_event =
276         std::make_shared<RecordingEvent>(driver_->TransferToDevice(
277             src, static_cast<RecordingBufferHandle*>(dst)->handle_.get(),
278             unwrapped_wait_for));
279     int64_t event_id = recording_event->id_;
280 
281     {
282       StreamRequest::Entry r;
283       r.mutable_transfer_to()->set_target_handle(recording_handle_id);
284       if (num_bytes > 0) {
285         r.mutable_transfer_to()->mutable_data()->assign(
286             static_cast<const char*>(src), num_bytes);
287       } else {
288         *r.mutable_transfer_to()->mutable_data() = "";
289       }
290       PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
291     }
292 
293     return recording_event;
294   }
295 
TransferFromDevice(const BufferHandle * src,void * dst,absl::Span<Event * const> wait_for)296   std::shared_ptr<Event> TransferFromDevice(
297       const BufferHandle* src, void* dst,
298       absl::Span<Event* const> wait_for) override {
299     auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
300 
301     auto thread_id = GetCurrentThreadId();
302     auto src_handle_id = static_cast<const RecordingBufferHandle*>(src)->id_;
303     auto recording_event =
304         std::make_shared<RecordingEvent>(driver_->TransferFromDevice(
305             static_cast<const RecordingBufferHandle*>(src)->handle_.get(), dst,
306             unwrapped_wait_for));
307     auto event_id = recording_event->id_;
308 
309     {
310       StreamRequest::Entry r;
311       r.mutable_transfer_from()->set_source_handle(src_handle_id);
312       PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
313     }
314 
315     return recording_event;
316   }
317 
TransferFromDeviceToDevice(const BufferHandle * src,BufferHandle * dst,absl::Span<Event * const> wait_for)318   std::shared_ptr<Event> TransferFromDeviceToDevice(
319       const BufferHandle* src, BufferHandle* dst,
320       absl::Span<Event* const> wait_for) override {
321     auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
322 
323     auto thread_id = GetCurrentThreadId();
324     auto src_handle_id = static_cast<const RecordingBufferHandle*>(src)->id_;
325     auto dst_handle_id = static_cast<const RecordingBufferHandle*>(dst)->id_;
326     auto recording_event =
327         std::make_shared<RecordingEvent>(driver_->TransferFromDeviceToDevice(
328             static_cast<const RecordingBufferHandle*>(src)->handle_.get(),
329             static_cast<const RecordingBufferHandle*>(dst)->handle_.get(),
330             unwrapped_wait_for));
331     auto event_id = recording_event->id_;
332 
333     {
334       StreamRequest::Entry r;
335       r.mutable_transfer_from_to()->set_source_handle(src_handle_id);
336       r.mutable_transfer_from_to()->set_target_handle(dst_handle_id);
337       PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
338     }
339 
340     return recording_event;
341   }
342 
CompileProgram(const xla::HloProto & source,int32_t num_replicas,absl::Span<Event * const> wait_for)343   std::unique_ptr<CompiledProgramHandle> CompileProgram(
344       const xla::HloProto& source, int32_t num_replicas,
345       absl::Span<Event* const> wait_for) override {
346     auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
347 
348     auto thread_id = GetCurrentThreadId();
349     auto recording_handle = std::make_unique<RecordingCompiledProgramHandle>(
350         driver_->CompileProgram(source, num_replicas, unwrapped_wait_for));
351     auto handle_id = recording_handle->id_;
352 
353     {
354       StreamRequest::Entry r;
355       *r.mutable_compile()->mutable_hlo_program() = source;
356       r.mutable_compile()->set_num_replicas(num_replicas);
357       PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id);
358     }
359 
360     return recording_handle;
361   }
362 
LoadProgram(int32_t core_id,const CompiledProgramHandle * handle,absl::Span<Event * const> wait_for)363   std::unique_ptr<LoadedProgramHandle> LoadProgram(
364       int32_t core_id, const CompiledProgramHandle* handle,
365       absl::Span<Event* const> wait_for) override {
366     auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
367 
368     auto thread_id = GetCurrentThreadId();
369     auto compiled_handle_id =
370         static_cast<const RecordingCompiledProgramHandle*>(handle)->id_;
371     auto recording_handle =
372         std::make_unique<RecordingLoadedProgramHandle>(driver_->LoadProgram(
373             core_id,
374             static_cast<const RecordingCompiledProgramHandle*>(handle)
375                 ->handle_.get(),
376             unwrapped_wait_for));
377     auto handle_id = recording_handle->id_;
378     {
379       StreamRequest::Entry r;
380       r.mutable_load()->set_core_id(core_id);
381       r.mutable_load()->set_compiled_program_handle(compiled_handle_id);
382       PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id);
383     }
384 
385     return recording_handle;
386   }
387 
UnloadProgram(std::unique_ptr<LoadedProgramHandle> handle,absl::Span<Event * const> wait_for)388   std::shared_ptr<Event> UnloadProgram(
389       std::unique_ptr<LoadedProgramHandle> handle,
390       absl::Span<Event* const> wait_for) override {
391     auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
392 
393     auto thread_id = GetCurrentThreadId();
394     auto loaded_handle_id =
395         static_cast<RecordingLoadedProgramHandle*>(handle.get())->id_;
396     auto recording_event =
397         std::make_shared<RecordingEvent>(driver_->UnloadProgram(
398             std::move(static_cast<RecordingLoadedProgramHandle*>(handle.get())
399                           ->handle_),
400             unwrapped_wait_for));
401     auto event_id = recording_event->id_;
402 
403     {
404       StreamRequest::Entry r;
405       r.mutable_unload()->set_loaded_program_handle(loaded_handle_id);
406       PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
407     }
408 
409     return recording_event;
410   }
411 
ExecuteProgram(LoadedProgramHandle * program,absl::Span<BufferHandle * const> inputs,absl::Span<BufferHandle * const> outputs,const xla::DeviceAssignmentProto & device_assignment,absl::Span<Event * const> wait_for)412   std::shared_ptr<Event> ExecuteProgram(
413       LoadedProgramHandle* program, absl::Span<BufferHandle* const> inputs,
414       absl::Span<BufferHandle* const> outputs,
415       const xla::DeviceAssignmentProto& device_assignment,
416       absl::Span<Event* const> wait_for) override {
417     auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
418 
419     auto thread_id = GetCurrentThreadId();
420     auto program_handle_id =
421         static_cast<RecordingLoadedProgramHandle*>(program)->id_;
422 
423     std::vector<BufferHandle*> unwrapped_inputs;
424     std::vector<int64_t> input_ids;
425     const auto inputs_size = inputs.size();
426     unwrapped_inputs.reserve(inputs_size);
427     input_ids.reserve(inputs_size);
428     for (auto input : inputs) {
429       BufferHandle* unwrapped_input =
430           static_cast<const RecordingBufferHandle*>(input)->handle_.get();
431       unwrapped_inputs.push_back(unwrapped_input);
432       input_ids.push_back(
433           static_cast<const RecordingBufferHandle*>(input)->id_);
434     }
435 
436     std::vector<BufferHandle*> unwrapped_outputs;
437     std::vector<int64_t> output_ids;
438     const auto output_size = outputs.size();
439     unwrapped_outputs.reserve(output_size);
440     output_ids.reserve(output_size);
441     for (auto output : outputs) {
442       BufferHandle* unwrapped_output =
443           static_cast<const RecordingBufferHandle*>(output)->handle_.get();
444       unwrapped_outputs.push_back(unwrapped_output);
445       output_ids.push_back(
446           static_cast<const RecordingBufferHandle*>(output)->id_);
447     }
448 
449     auto recording_event =
450         std::make_shared<RecordingEvent>(driver_->ExecuteProgram(
451             static_cast<RecordingLoadedProgramHandle*>(program)->handle_.get(),
452             unwrapped_inputs, unwrapped_outputs, device_assignment,
453             unwrapped_wait_for));
454     auto event_id = recording_event->id_;
455 
456     {
457       StreamRequest::Entry r;
458       r.mutable_execute()->set_loaded_program_handle(program_handle_id);
459       for (auto input_id : input_ids) {
460         r.mutable_execute()->add_input_handle(input_id);
461       }
462       for (auto output_id : output_ids) {
463         r.mutable_execute()->add_output_handle(output_id);
464       }
465       *r.mutable_execute()->mutable_device_assignment() = device_assignment;
466 
467       PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
468     }
469 
470     return recording_event;
471   }
472 
GetLinearizer()473   std::unique_ptr<TpuLinearizer> GetLinearizer() override {
474     return driver_->GetLinearizer();
475   }
476 
477  private:
478   std::unique_ptr<TpuDriver> driver_;
479   const std::string recording_path_;
480   const bool flush_;
481 
482   std::unique_ptr<tensorflow::WritableFile> log_file_;
483 
PopulateAndSaveEntry(StreamRequest::Entry * r,absl::Span<Event * const> wait_for,int64_t handle_id,int64_t thread_id)484   void PopulateAndSaveEntry(StreamRequest::Entry* r,
485                             absl::Span<Event* const> wait_for,
486                             int64_t handle_id, int64_t thread_id) {
487     for (auto event : wait_for) {
488       auto recording_event = static_cast<const RecordingEvent*>(event);
489       r->add_wait_for_id(recording_event->id_);
490     }
491     r->set_operation_id(handle_id);
492     r->set_thread_id(thread_id);
493 
494     uint64_t data_size = r->ByteSizeLong();
495     std::vector<char> buffer;
496     buffer.resize(sizeof(data_size) + data_size);
497     memcpy(buffer.data(), &data_size, sizeof(data_size));
498     r->SerializeToArray(buffer.data() + sizeof(data_size), data_size);
499 
500     {
501       if (log_file_ == nullptr) {
502         LOG(WARNING) << "The TPU driver has been shut down before all logging "
503                         "has been written.";
504         return;
505       }
506 
507       absl::string_view buffer_sp(buffer.data(), buffer.size());
508       auto data_status = log_file_->Append(buffer_sp);
509       if (!data_status.ok()) {
510         LOG(WARNING) << "Unable to write data to log file. File possibly "
511                         "corrupt. Error: "
512                      << data_status.ToString();
513       }
514 
515       if (flush_) {
516         auto flush_status = log_file_->Flush();
517         if (!flush_status.ok()) {
518           LOG(WARNING) << "Unable to flush data to log file. File possibly "
519                           "corrupt. Error: "
520                        << flush_status.ToString();
521         }
522 
523         auto sync_status = log_file_->Sync();
524         if (!sync_status.ok()) {
525           LOG(WARNING) << "Unable to sync log file. File possibly "
526                           "corrupt. Error: "
527                        << sync_status.ToString();
528         }
529       }
530     }
531   }
532 
UnwrapWaitFor(absl::Span<Event * const> wait_for)533   std::vector<Event*> UnwrapWaitFor(absl::Span<Event* const> wait_for) {
534     std::vector<Event*> unwrapped_events;
535     for (auto event : wait_for) {
536       Event* unwrapped_event =
537           static_cast<RecordingEvent*>(event)->shared_event_.get();
538       unwrapped_events.push_back(unwrapped_event);
539     }
540     return unwrapped_events;
541   }
542 
GetCurrentThreadId()543   int64_t GetCurrentThreadId() { return absl::base_internal::GetTID(); }
544 };
545 
RegisterRecordingTpuDriver(const TpuDriverConfig & config)546 xla::StatusOr<std::unique_ptr<TpuDriver>> RegisterRecordingTpuDriver(
547     const TpuDriverConfig& config) {
548   std::vector<std::string> configs = absl::StrSplit(config.worker(), '|');
549 
550   std::string file;
551   std::string worker;
552   bool flush = false;
553 
554   for (const auto& config : configs) {
555     std::vector<std::string> kv =
556         absl::StrSplit(config, absl::MaxSplits('=', 1));
557     if (kv[0] == "file") {
558       file = kv[1];
559     }
560     if (kv[0] == "worker") {
561       worker = kv[1];
562     }
563     if (kv[0] == "flush") {
564       if (kv[1] == "true" || kv[1] == "1") {
565         flush = true;
566       }
567     }
568   }
569 
570   TpuDriverConfig worker_config;
571   worker_config.set_worker(worker);
572 
573   auto driver_status = TpuDriverRegistry::Open(worker_config);
574   if (!driver_status.ok()) return driver_status.status();
575   return std::unique_ptr<TpuDriver>(
576       new RecordingTpuDriver(std::move(driver_status).value(), file, flush));
577 }
578 
579 // To record a sequence of operations, set the worker configuration string to
580 // record://|file=<filename>|worker=grpc://1.2.3.4:8470 (for GRPC).
581 REGISTER_TPU_DRIVER("record://", RegisterRecordingTpuDriver);
582 
583 }  // namespace
584 }  // namespace tpu_driver
585