1 // Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include <atomic>
16 #include <functional>
17 #include <optional>
18
19 #include "absl/base/internal/sysinfo.h"
20 #include "absl/strings/str_split.h"
21 #include "absl/strings/string_view.h"
22 #include "tensorflow/compiler/xla/python/tpu_driver/platform/external/compat.h"
23 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
24 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
25 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_service.grpc.pb.h"
26 #include "tensorflow/core/platform/file_system.h"
27 #include "tensorflow/core/platform/threadpool.h"
28
29 /*
30 * The ReplayDriver wraps a concrete TpuDriver implementation and records the
31 * stream of operations to a log file. This log can be later replayed and
32 * analyzed for debugging.
33 */
34
35 namespace tpu_driver {
36 namespace {
37
38 static std::atomic<int64_t> id_counter(0);
39
40 using xla::Status;
41
42 class RecordingTpuDriver;
43
44 class RecordingEvent : public Event {
45 public:
RecordingEvent(std::shared_ptr<Event> event)46 explicit RecordingEvent(std::shared_ptr<Event> event)
47 : shared_event_(std::move(event)), id_(id_counter++) {}
48
RecordingEvent(std::shared_ptr<Event> event,int64_t id)49 explicit RecordingEvent(std::shared_ptr<Event> event, int64_t id)
50 : shared_event_(event), id_(id) {}
51
~RecordingEvent()52 ~RecordingEvent() override {}
53
Await()54 xla::Status Await() override { return shared_event_->Await(); }
55
AwaitWithTimeout(absl::Duration duration)56 std::optional<xla::Status> AwaitWithTimeout(
57 absl::Duration duration) override {
58 return shared_event_->AwaitWithTimeout(duration);
59 }
60
AddCallback(std::function<void (xla::Status)> callback)61 void AddCallback(std::function<void(xla::Status)> callback) override {
62 return shared_event_->AddCallback(callback);
63 }
64
65 private:
66 std::shared_ptr<Event> shared_event_;
67
68 int64_t id_;
69 friend class RecordingTpuDriver;
70 };
71
72 class RecordingBufferHandle : public BufferHandle {
73 public:
RecordingBufferHandle(std::unique_ptr<BufferHandle> handle)74 explicit RecordingBufferHandle(std::unique_ptr<BufferHandle> handle)
75 : handle_(std::move(handle)),
76 id_(id_counter++),
77 event_(std::make_shared<RecordingEvent>(handle_->OnReady(), id_)) {}
OnReady()78 std::shared_ptr<Event> OnReady() override { return event_; }
size_in_bytes()79 int64_t size_in_bytes() override { return handle_->size_in_bytes(); }
shape()80 std::optional<xla::ShapeProto> shape() override { return handle_->shape(); }
81
82 private:
83 std::unique_ptr<BufferHandle> handle_;
84 int64_t id_;
85 std::shared_ptr<RecordingEvent> event_;
86 friend class RecordingTpuDriver;
87 };
88
89 class RecordingCompiledProgramHandle : public CompiledProgramHandle {
90 public:
RecordingCompiledProgramHandle(std::unique_ptr<CompiledProgramHandle> handle)91 explicit RecordingCompiledProgramHandle(
92 std::unique_ptr<CompiledProgramHandle> handle)
93 : handle_(std::move(handle)),
94 id_(id_counter++),
95 event_(std::make_shared<RecordingEvent>(handle_->OnReady(), id_)) {}
OnReady()96 std::shared_ptr<Event> OnReady() override { return event_; }
size_in_bytes()97 int64_t size_in_bytes() override { return handle_->size_in_bytes(); }
program_shape(xla::ProgramShapeProto * program_shape)98 xla::Status program_shape(xla::ProgramShapeProto* program_shape) override {
99 return handle_->program_shape(program_shape);
100 }
101
102 private:
103 std::unique_ptr<CompiledProgramHandle> handle_;
104 int64_t id_;
105 std::shared_ptr<RecordingEvent> event_;
106 friend class RecordingTpuDriver;
107 };
108
109 class RecordingLoadedProgramHandle : public LoadedProgramHandle {
110 public:
RecordingLoadedProgramHandle(std::unique_ptr<LoadedProgramHandle> handle)111 explicit RecordingLoadedProgramHandle(
112 std::unique_ptr<LoadedProgramHandle> handle)
113 : handle_(std::move(handle)),
114 id_(id_counter++),
115 event_(std::make_shared<RecordingEvent>(handle_->OnReady(), id_)) {}
OnReady()116 std::shared_ptr<Event> OnReady() override { return event_; }
size_in_bytes()117 int64_t size_in_bytes() override { return handle_->size_in_bytes(); }
118
119 private:
120 std::unique_ptr<LoadedProgramHandle> handle_;
121 int64_t id_;
122 std::shared_ptr<RecordingEvent> event_;
123 friend class RecordingTpuDriver;
124 };
125
126 class RecordingTpuDriver : public TpuDriver {
127 public:
RecordingTpuDriver(std::unique_ptr<TpuDriver> driver,const std::string recording_path,const bool flush)128 explicit RecordingTpuDriver(std::unique_ptr<TpuDriver> driver,
129 const std::string recording_path,
130 const bool flush)
131 : driver_(std::move(driver)),
132 recording_path_(recording_path),
133 flush_(flush) {
134 auto file_status = tensorflow::Env::Default()->NewAppendableFile(
135 recording_path_, &log_file_);
136 if (!file_status.ok()) {
137 LOG(FATAL) << "Unable to open " << recording_path_
138 << " for appending. Error: " << file_status.ToString();
139 }
140 }
~RecordingTpuDriver()141 ~RecordingTpuDriver() override {
142 {
143 log_file_->Flush().IgnoreError();
144 log_file_->Close().IgnoreError();
145 log_file_ = nullptr;
146 }
147 }
148
QuerySystemInfo(SystemInfo * system_info)149 void QuerySystemInfo(SystemInfo* system_info) override {
150 // TODO(frankchn): Should we even save this event, since it is out-of-band.
151 driver_->QuerySystemInfo(system_info);
152 }
153
Reset()154 Status Reset() override { return driver_->Reset(); }
155
Allocate(int32_t core_id,MemoryRegion region,int64_t num_bytes,absl::Span<Event * const> wait_for)156 std::unique_ptr<BufferHandle> Allocate(
157 int32_t core_id, MemoryRegion region, int64_t num_bytes,
158 absl::Span<Event* const> wait_for) override {
159 auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
160
161 auto thread_id = GetCurrentThreadId();
162 auto handle =
163 driver_->Allocate(core_id, region, num_bytes, unwrapped_wait_for);
164 auto recording_handle =
165 std::make_unique<RecordingBufferHandle>(std::move(handle));
166 auto handle_id = recording_handle->id_;
167
168 {
169 StreamRequest::Entry r;
170 r.mutable_alloc()->set_core_id(core_id);
171 r.mutable_alloc()->set_region(region);
172 r.mutable_alloc()->set_num_bytes(num_bytes);
173
174 PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id);
175 }
176
177 return recording_handle;
178 }
179
Allocate(int32_t core_id,MemoryRegion region,const xla::ShapeProto & shape,absl::Span<Event * const> wait_for)180 std::unique_ptr<BufferHandle> Allocate(
181 int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape,
182 absl::Span<Event* const> wait_for) override {
183 auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
184
185 auto thread_id = GetCurrentThreadId();
186 auto handle = driver_->Allocate(core_id, region, shape, unwrapped_wait_for);
187 auto recording_handle =
188 std::make_unique<RecordingBufferHandle>(std::move(handle));
189 auto handle_id = recording_handle->id_;
190
191 {
192 StreamRequest::Entry r;
193 r.mutable_alloc()->set_core_id(core_id);
194 r.mutable_alloc()->set_region(region);
195 *(r.mutable_alloc()->mutable_shape()) = shape;
196
197 PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id);
198 }
199
200 return recording_handle;
201 }
202
AllocateTuple(int32_t core_id,MemoryRegion region,absl::Span<BufferHandle * const> children,absl::Span<Event * const> wait_for)203 std::unique_ptr<BufferHandle> AllocateTuple(
204 int32_t core_id, MemoryRegion region,
205 absl::Span<BufferHandle* const> children,
206 absl::Span<Event* const> wait_for) override {
207 auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
208
209 std::vector<BufferHandle*> unwrapped_children;
210 std::vector<int64_t> child_ids;
211 const auto children_size = children.size();
212 unwrapped_children.reserve(children_size);
213 child_ids.reserve(children_size);
214 for (auto child : children) {
215 BufferHandle* unwrapped_child =
216 static_cast<const RecordingBufferHandle*>(child)->handle_.get();
217 unwrapped_children.push_back(unwrapped_child);
218 child_ids.push_back(
219 static_cast<const RecordingBufferHandle*>(child)->id_);
220 }
221
222 auto thread_id = GetCurrentThreadId();
223 auto handle = driver_->AllocateTuple(core_id, region, unwrapped_children,
224 unwrapped_wait_for);
225 auto recording_handle =
226 std::make_unique<RecordingBufferHandle>(std::move(handle));
227 auto handle_id = recording_handle->id_;
228
229 {
230 StreamRequest::Entry r;
231 r.mutable_alloc_tuple()->set_core_id(core_id);
232 r.mutable_alloc_tuple()->set_region(region);
233
234 for (auto child : child_ids) {
235 r.mutable_alloc_tuple()->add_children(child);
236 }
237
238 PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id);
239 }
240
241 return recording_handle;
242 }
243
Deallocate(std::unique_ptr<BufferHandle> handle,absl::Span<Event * const> wait_for)244 std::shared_ptr<Event> Deallocate(
245 std::unique_ptr<BufferHandle> handle,
246 absl::Span<Event* const> wait_for) override {
247 auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
248
249 auto thread_id = GetCurrentThreadId();
250 auto recording_handle = static_cast<RecordingBufferHandle*>(handle.get());
251 int64_t recording_handle_id = recording_handle->id_;
252 auto event = driver_->Deallocate(std::move(recording_handle->handle_),
253 unwrapped_wait_for);
254 auto recording_event = std::make_shared<RecordingEvent>(std::move(event));
255 int64_t event_id = recording_event->id_;
256
257 {
258 StreamRequest::Entry r;
259 r.mutable_dealloc()->set_handle(recording_handle_id);
260 PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
261 }
262
263 return recording_event;
264 }
265
TransferToDevice(const void * src,BufferHandle * dst,absl::Span<Event * const> wait_for)266 std::shared_ptr<Event> TransferToDevice(
267 const void* src, BufferHandle* dst,
268 absl::Span<Event* const> wait_for) override {
269 int64_t num_bytes = dst->size_in_bytes();
270 auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
271
272 auto thread_id = GetCurrentThreadId();
273 auto recording_handle = static_cast<RecordingBufferHandle*>(dst);
274 int64_t recording_handle_id = recording_handle->id_;
275 auto recording_event =
276 std::make_shared<RecordingEvent>(driver_->TransferToDevice(
277 src, static_cast<RecordingBufferHandle*>(dst)->handle_.get(),
278 unwrapped_wait_for));
279 int64_t event_id = recording_event->id_;
280
281 {
282 StreamRequest::Entry r;
283 r.mutable_transfer_to()->set_target_handle(recording_handle_id);
284 if (num_bytes > 0) {
285 r.mutable_transfer_to()->mutable_data()->assign(
286 static_cast<const char*>(src), num_bytes);
287 } else {
288 *r.mutable_transfer_to()->mutable_data() = "";
289 }
290 PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
291 }
292
293 return recording_event;
294 }
295
TransferFromDevice(const BufferHandle * src,void * dst,absl::Span<Event * const> wait_for)296 std::shared_ptr<Event> TransferFromDevice(
297 const BufferHandle* src, void* dst,
298 absl::Span<Event* const> wait_for) override {
299 auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
300
301 auto thread_id = GetCurrentThreadId();
302 auto src_handle_id = static_cast<const RecordingBufferHandle*>(src)->id_;
303 auto recording_event =
304 std::make_shared<RecordingEvent>(driver_->TransferFromDevice(
305 static_cast<const RecordingBufferHandle*>(src)->handle_.get(), dst,
306 unwrapped_wait_for));
307 auto event_id = recording_event->id_;
308
309 {
310 StreamRequest::Entry r;
311 r.mutable_transfer_from()->set_source_handle(src_handle_id);
312 PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
313 }
314
315 return recording_event;
316 }
317
TransferFromDeviceToDevice(const BufferHandle * src,BufferHandle * dst,absl::Span<Event * const> wait_for)318 std::shared_ptr<Event> TransferFromDeviceToDevice(
319 const BufferHandle* src, BufferHandle* dst,
320 absl::Span<Event* const> wait_for) override {
321 auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
322
323 auto thread_id = GetCurrentThreadId();
324 auto src_handle_id = static_cast<const RecordingBufferHandle*>(src)->id_;
325 auto dst_handle_id = static_cast<const RecordingBufferHandle*>(dst)->id_;
326 auto recording_event =
327 std::make_shared<RecordingEvent>(driver_->TransferFromDeviceToDevice(
328 static_cast<const RecordingBufferHandle*>(src)->handle_.get(),
329 static_cast<const RecordingBufferHandle*>(dst)->handle_.get(),
330 unwrapped_wait_for));
331 auto event_id = recording_event->id_;
332
333 {
334 StreamRequest::Entry r;
335 r.mutable_transfer_from_to()->set_source_handle(src_handle_id);
336 r.mutable_transfer_from_to()->set_target_handle(dst_handle_id);
337 PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
338 }
339
340 return recording_event;
341 }
342
CompileProgram(const xla::HloProto & source,int32_t num_replicas,absl::Span<Event * const> wait_for)343 std::unique_ptr<CompiledProgramHandle> CompileProgram(
344 const xla::HloProto& source, int32_t num_replicas,
345 absl::Span<Event* const> wait_for) override {
346 auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
347
348 auto thread_id = GetCurrentThreadId();
349 auto recording_handle = std::make_unique<RecordingCompiledProgramHandle>(
350 driver_->CompileProgram(source, num_replicas, unwrapped_wait_for));
351 auto handle_id = recording_handle->id_;
352
353 {
354 StreamRequest::Entry r;
355 *r.mutable_compile()->mutable_hlo_program() = source;
356 r.mutable_compile()->set_num_replicas(num_replicas);
357 PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id);
358 }
359
360 return recording_handle;
361 }
362
LoadProgram(int32_t core_id,const CompiledProgramHandle * handle,absl::Span<Event * const> wait_for)363 std::unique_ptr<LoadedProgramHandle> LoadProgram(
364 int32_t core_id, const CompiledProgramHandle* handle,
365 absl::Span<Event* const> wait_for) override {
366 auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
367
368 auto thread_id = GetCurrentThreadId();
369 auto compiled_handle_id =
370 static_cast<const RecordingCompiledProgramHandle*>(handle)->id_;
371 auto recording_handle =
372 std::make_unique<RecordingLoadedProgramHandle>(driver_->LoadProgram(
373 core_id,
374 static_cast<const RecordingCompiledProgramHandle*>(handle)
375 ->handle_.get(),
376 unwrapped_wait_for));
377 auto handle_id = recording_handle->id_;
378 {
379 StreamRequest::Entry r;
380 r.mutable_load()->set_core_id(core_id);
381 r.mutable_load()->set_compiled_program_handle(compiled_handle_id);
382 PopulateAndSaveEntry(&r, wait_for, handle_id, thread_id);
383 }
384
385 return recording_handle;
386 }
387
UnloadProgram(std::unique_ptr<LoadedProgramHandle> handle,absl::Span<Event * const> wait_for)388 std::shared_ptr<Event> UnloadProgram(
389 std::unique_ptr<LoadedProgramHandle> handle,
390 absl::Span<Event* const> wait_for) override {
391 auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
392
393 auto thread_id = GetCurrentThreadId();
394 auto loaded_handle_id =
395 static_cast<RecordingLoadedProgramHandle*>(handle.get())->id_;
396 auto recording_event =
397 std::make_shared<RecordingEvent>(driver_->UnloadProgram(
398 std::move(static_cast<RecordingLoadedProgramHandle*>(handle.get())
399 ->handle_),
400 unwrapped_wait_for));
401 auto event_id = recording_event->id_;
402
403 {
404 StreamRequest::Entry r;
405 r.mutable_unload()->set_loaded_program_handle(loaded_handle_id);
406 PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
407 }
408
409 return recording_event;
410 }
411
ExecuteProgram(LoadedProgramHandle * program,absl::Span<BufferHandle * const> inputs,absl::Span<BufferHandle * const> outputs,const xla::DeviceAssignmentProto & device_assignment,absl::Span<Event * const> wait_for)412 std::shared_ptr<Event> ExecuteProgram(
413 LoadedProgramHandle* program, absl::Span<BufferHandle* const> inputs,
414 absl::Span<BufferHandle* const> outputs,
415 const xla::DeviceAssignmentProto& device_assignment,
416 absl::Span<Event* const> wait_for) override {
417 auto unwrapped_wait_for = UnwrapWaitFor(wait_for);
418
419 auto thread_id = GetCurrentThreadId();
420 auto program_handle_id =
421 static_cast<RecordingLoadedProgramHandle*>(program)->id_;
422
423 std::vector<BufferHandle*> unwrapped_inputs;
424 std::vector<int64_t> input_ids;
425 const auto inputs_size = inputs.size();
426 unwrapped_inputs.reserve(inputs_size);
427 input_ids.reserve(inputs_size);
428 for (auto input : inputs) {
429 BufferHandle* unwrapped_input =
430 static_cast<const RecordingBufferHandle*>(input)->handle_.get();
431 unwrapped_inputs.push_back(unwrapped_input);
432 input_ids.push_back(
433 static_cast<const RecordingBufferHandle*>(input)->id_);
434 }
435
436 std::vector<BufferHandle*> unwrapped_outputs;
437 std::vector<int64_t> output_ids;
438 const auto output_size = outputs.size();
439 unwrapped_outputs.reserve(output_size);
440 output_ids.reserve(output_size);
441 for (auto output : outputs) {
442 BufferHandle* unwrapped_output =
443 static_cast<const RecordingBufferHandle*>(output)->handle_.get();
444 unwrapped_outputs.push_back(unwrapped_output);
445 output_ids.push_back(
446 static_cast<const RecordingBufferHandle*>(output)->id_);
447 }
448
449 auto recording_event =
450 std::make_shared<RecordingEvent>(driver_->ExecuteProgram(
451 static_cast<RecordingLoadedProgramHandle*>(program)->handle_.get(),
452 unwrapped_inputs, unwrapped_outputs, device_assignment,
453 unwrapped_wait_for));
454 auto event_id = recording_event->id_;
455
456 {
457 StreamRequest::Entry r;
458 r.mutable_execute()->set_loaded_program_handle(program_handle_id);
459 for (auto input_id : input_ids) {
460 r.mutable_execute()->add_input_handle(input_id);
461 }
462 for (auto output_id : output_ids) {
463 r.mutable_execute()->add_output_handle(output_id);
464 }
465 *r.mutable_execute()->mutable_device_assignment() = device_assignment;
466
467 PopulateAndSaveEntry(&r, wait_for, event_id, thread_id);
468 }
469
470 return recording_event;
471 }
472
GetLinearizer()473 std::unique_ptr<TpuLinearizer> GetLinearizer() override {
474 return driver_->GetLinearizer();
475 }
476
477 private:
478 std::unique_ptr<TpuDriver> driver_;
479 const std::string recording_path_;
480 const bool flush_;
481
482 std::unique_ptr<tensorflow::WritableFile> log_file_;
483
PopulateAndSaveEntry(StreamRequest::Entry * r,absl::Span<Event * const> wait_for,int64_t handle_id,int64_t thread_id)484 void PopulateAndSaveEntry(StreamRequest::Entry* r,
485 absl::Span<Event* const> wait_for,
486 int64_t handle_id, int64_t thread_id) {
487 for (auto event : wait_for) {
488 auto recording_event = static_cast<const RecordingEvent*>(event);
489 r->add_wait_for_id(recording_event->id_);
490 }
491 r->set_operation_id(handle_id);
492 r->set_thread_id(thread_id);
493
494 uint64_t data_size = r->ByteSizeLong();
495 std::vector<char> buffer;
496 buffer.resize(sizeof(data_size) + data_size);
497 memcpy(buffer.data(), &data_size, sizeof(data_size));
498 r->SerializeToArray(buffer.data() + sizeof(data_size), data_size);
499
500 {
501 if (log_file_ == nullptr) {
502 LOG(WARNING) << "The TPU driver has been shut down before all logging "
503 "has been written.";
504 return;
505 }
506
507 absl::string_view buffer_sp(buffer.data(), buffer.size());
508 auto data_status = log_file_->Append(buffer_sp);
509 if (!data_status.ok()) {
510 LOG(WARNING) << "Unable to write data to log file. File possibly "
511 "corrupt. Error: "
512 << data_status.ToString();
513 }
514
515 if (flush_) {
516 auto flush_status = log_file_->Flush();
517 if (!flush_status.ok()) {
518 LOG(WARNING) << "Unable to flush data to log file. File possibly "
519 "corrupt. Error: "
520 << flush_status.ToString();
521 }
522
523 auto sync_status = log_file_->Sync();
524 if (!sync_status.ok()) {
525 LOG(WARNING) << "Unable to sync log file. File possibly "
526 "corrupt. Error: "
527 << sync_status.ToString();
528 }
529 }
530 }
531 }
532
UnwrapWaitFor(absl::Span<Event * const> wait_for)533 std::vector<Event*> UnwrapWaitFor(absl::Span<Event* const> wait_for) {
534 std::vector<Event*> unwrapped_events;
535 for (auto event : wait_for) {
536 Event* unwrapped_event =
537 static_cast<RecordingEvent*>(event)->shared_event_.get();
538 unwrapped_events.push_back(unwrapped_event);
539 }
540 return unwrapped_events;
541 }
542
GetCurrentThreadId()543 int64_t GetCurrentThreadId() { return absl::base_internal::GetTID(); }
544 };
545
RegisterRecordingTpuDriver(const TpuDriverConfig & config)546 xla::StatusOr<std::unique_ptr<TpuDriver>> RegisterRecordingTpuDriver(
547 const TpuDriverConfig& config) {
548 std::vector<std::string> configs = absl::StrSplit(config.worker(), '|');
549
550 std::string file;
551 std::string worker;
552 bool flush = false;
553
554 for (const auto& config : configs) {
555 std::vector<std::string> kv =
556 absl::StrSplit(config, absl::MaxSplits('=', 1));
557 if (kv[0] == "file") {
558 file = kv[1];
559 }
560 if (kv[0] == "worker") {
561 worker = kv[1];
562 }
563 if (kv[0] == "flush") {
564 if (kv[1] == "true" || kv[1] == "1") {
565 flush = true;
566 }
567 }
568 }
569
570 TpuDriverConfig worker_config;
571 worker_config.set_worker(worker);
572
573 auto driver_status = TpuDriverRegistry::Open(worker_config);
574 if (!driver_status.ok()) return driver_status.status();
575 return std::unique_ptr<TpuDriver>(
576 new RecordingTpuDriver(std::move(driver_status).value(), file, flush));
577 }
578
579 // To record a sequence of operations, set the worker configuration string to
580 // record://|file=<filename>|worker=grpc://1.2.3.4:8470 (for GRPC).
581 REGISTER_TPU_DRIVER("record://", RegisterRecordingTpuDriver);
582
583 } // namespace
584 } // namespace tpu_driver
585