1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/outfeed_receiver.h"
17
18 #include <sys/types.h>
19
20 #include <memory>
21 #include <queue>
22 #include <sstream>
23
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/xla/client/sharding_builder.h"
27 #include "tensorflow/compiler/xla/client/xla_builder.h"
28 #include "tensorflow/compiler/xla/client/xla_computation.h"
29 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
30 #include "tensorflow/compiler/xla/util.h"
31 #include "tensorflow/core/profiler/lib/traceme.h"
32
33 // Implementation notes:
34 //
35 // Startup:
36 // -------
37 //
38 // The startup is initiated by a call from Python to StartOutfeedReceiver. For
39 // each local device there is one thread for listening for outfeeds from the
40 // device, one queue of received outfeeds, and one thread for invoking the
41 // Python callbacks.
42 //
43 // Framing protocol
44 // ----------------
45 //
46 // The outfeed mechanism has a single channel and the receiver must know
47 // exactly the shape and number of outfeed operations issued by the compiled
48 // code. This makes it hard to use outfeed in conditionals and loops and
49 // especially when outfeeding different-shaped data.
50 //
51 // To address this, when we compile the code we capture the shape of the
52 // data being outfed, and we generate a consumer ID (uint32_t) that is unique
53 // across the lifetime of the program to: the Python callable to callback to,
54 // the shape of the arguments, the keyword arguments to pass to the callable.
55 // Each outfeed payload is preceeded by a header (of shape u32[2]) with a
56 // special first value and the consumer ID. We maintain a registry of shapes
57 // by consumer ID. When receiving we lookup the shape by consumer ID, and then
58 // we read the payload.
59 //
60 // Back pressure:
61 // --------------
62 //
63 // We maintain a sum of the bytes from all the data waiting in the callback
64 // queues. The listening threads will wait for the sum to drop below a
65 // configurable threshold, default 256Mb. While the listening thread is waiting,
66 // on CPU and GPU the next outfeed operation from the device will block. On
67 // TPU there is a buffer, but eventually the TPU will also block.
68 //
69 // Shutdown:
70 // ---------
71 //
72 // The shutdown is initiated automatically when the last reference to the
73 // outfeed receiver object is dropped, and the Python garbage collector invokes
74 // the destructor.
75 //
76 // The shutdown sequence is implemented as follows:
77 // * we enqueue on all devices a computation that outfeeds a special header
78 // with customer ID kOutfeedCidShutdown.
79 // * when each listening threads gets the shutdown header, it decrements
80 // a counter of listening threads, and it
81 // enqueues a special shutdown callback.
82 // * when each callback thread gets the shutdown callback marker, it terminates.
83 // * the shutdown code waits until all threads terminate.
84 //
85 // Since we currently keep the shape registry in the OutfeedReceiver, it is
86 // not safe to replace the OutfeedReceiver instance during the lifetime of
87 // the JAX program, or else previously cached jitted computations may refer
88 // to previously cached shapes. This can be solved, but for now we disallow
89 // replacing the OutfeedReceiver, and do not provide a Shutdown API to the
90 // Python program.
91
92 namespace xla {
93
94 // The header contains:
95 // 0. kOutfeedHeaderStart
96 // 1. consumer id
97 int constexpr kOutfeedHeaderWords = 2;
98 uint32_t constexpr kOutfeedHeaderStart = 271828;
99 // Special consumer IDs, without outfeed payload.
100 uint32_t constexpr kOutfeedCidShutdown = 0;
101
102 // Encapsulates data received from a device outfeed.
103 class OutfeedData {
104 public:
OutfeedData(PjRtDevice * device,uint32_t consumer_id,Shape shape)105 OutfeedData(PjRtDevice* device, uint32_t consumer_id, Shape shape)
106 : device_(device),
107 consumer_id_(consumer_id),
108 shape_(shape),
109 literal_(nullptr),
110 literal_size_bytes_(0) {}
111
device()112 PjRtDevice* device() { return device_; }
consumer_id() const113 uint32_t consumer_id() const { return consumer_id_; }
shape() const114 Shape shape() const { return shape_; }
literal()115 std::unique_ptr<Literal> literal() {
116 CHECK(literal_);
117 return std::move(literal_);
118 }
119
120 void SetLiteral(std::unique_ptr<Literal> literal);
121
literal_size_bytes() const122 ssize_t literal_size_bytes() const { return literal_size_bytes_; }
123
124 std::string DebugString() const;
125
126 private:
127 PjRtDevice* device_;
128 uint32_t consumer_id_;
129 Shape shape_;
130 std::unique_ptr<Literal> literal_;
131 ssize_t literal_size_bytes_;
132 };
133
SetLiteral(std::unique_ptr<Literal> literal)134 void OutfeedData::SetLiteral(std::unique_ptr<Literal> literal) {
135 literal_ = std::move(literal);
136 shape_ = literal_->shape();
137 int total_size_bytes = 0;
138 ShapeUtil::ForEachSubshape(
139 shape_, [&](const Shape& literal_subshape, const ShapeIndex& index) {
140 if (!literal_subshape.IsTuple()) {
141 total_size_bytes += ShapeUtil::ByteSizeOf(literal_subshape, 8);
142 }
143 });
144 literal_size_bytes_ = total_size_bytes;
145 }
146
DebugString() const147 std::string OutfeedData::DebugString() const {
148 return absl::StrFormat("dev=%s; cons=%d; shape=%s", device_->DebugString(),
149 consumer_id_, shape_.ToString());
150 }
151
152 class OutfeedReceiverImpl {
153 public:
154 OutfeedReceiverImpl(OutfeedReceiver::Callback callback,
155 absl::Span<PjRtClient* const> clients,
156 ssize_t max_callback_queue_size_bytes);
157
158 OutfeedReceiverImpl(const OutfeedReceiverImpl&) = delete;
159 OutfeedReceiverImpl& operator=(const OutfeedReceiverImpl&) = delete;
160
161 // Blocks until all data has been received from devices and all data
162 // in the queue has been passed to Python.
163 ~OutfeedReceiverImpl();
164
165 void Start();
166
167 StatusOr<XlaOp> AddOutfeedToBuilder(XlaBuilder* builder, XlaOp token,
168 uint32_t consumer_id,
169 std::vector<XlaOp> arrays);
170
171 private:
CallbackQueueHasSpace()172 bool CallbackQueueHasSpace() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
173 return callback_queue_size_bytes_ < max_callback_queue_size_bytes_;
174 }
175
ShutdownDone()176 bool ShutdownDone() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
177 return (num_working_callback_threads_ == 0 && num_listening_threads_ == 0);
178 }
179
180 void CallbackThreadLoop(int device_idx);
181 void DeviceListenerThreadLoop(int device_idx);
182
183 // Enqueues to a device an outfeed operation with a shutdown consumer ID.
184 Status SendShutdownOutfeedHeader(int device_idx);
185
186 // Receives a raw Literal from a device outfeed.
187 StatusOr<std::unique_ptr<Literal>> ReceiveRawFromOutfeed(PjRtDevice* device,
188 const Shape& shape);
189
190 // Enqueues received data in the callbaback queue.
191 void EnqueueReceivedData(uint32_t device_idx,
192 std::unique_ptr<OutfeedData> received)
193 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
194
195 // Shuts down the threads. See implementation notes at top of file.
196 // It is not safe to restart an OutfeedReceiver after shutting down one.
197 void Shutdown();
198
199 OutfeedReceiver::Callback callback_;
200 // The devices on which we are listening.
201 std::vector<PjRtDevice*> devices_;
202 // Maximum bytes capacity of the ensemble of callback queues.
203 uint64_t max_callback_queue_size_bytes_;
204
205 absl::Mutex mu_;
206 // Registered shapes by consumer id.
207 // The shape registry must be alive as long as the program exists.
208 // Right now we tell the user to never restart after Shutdown.
209 absl::flat_hash_map<uint32_t, Shape> shape_registry_ ABSL_GUARDED_BY(mu_);
210 // How many bytes of Literal are in the ensemble of callback queues.
211 uint64_t callback_queue_size_bytes_ ABSL_GUARDED_BY(mu_);
212 // Threads listening.
213 int num_listening_threads_ ABSL_GUARDED_BY(mu_);
214 bool shutdown_started_ ABSL_GUARDED_BY(mu_);
215
216 // How many callback threads are still working. Used for shutdown.
217 int num_working_callback_threads_ ABSL_GUARDED_BY(mu_);
218
219 std::vector<std::queue<std::unique_ptr<OutfeedData>>> callback_queues_
220 ABSL_GUARDED_BY(mu_);
221 // The threadpool must come last to ensure the queue exists
222 // when the pool destructor is called.
223 std::unique_ptr<tensorflow::thread::ThreadPool> threads_;
224 };
225
OutfeedReceiverImpl(OutfeedReceiver::Callback callback,absl::Span<PjRtClient * const> clients,ssize_t max_callback_queue_size_bytes)226 OutfeedReceiverImpl::OutfeedReceiverImpl(
227 OutfeedReceiver::Callback callback, absl::Span<PjRtClient* const> clients,
228 ssize_t max_callback_queue_size_bytes) {
229 callback_ = callback;
230 max_callback_queue_size_bytes_ = max_callback_queue_size_bytes;
231 for (const auto& client : clients) {
232 for (auto device : client->addressable_devices()) {
233 devices_.push_back(device);
234 }
235 }
236 CHECK_GT(devices_.size(), 0);
237 callback_queues_ =
238 std::vector<std::queue<std::unique_ptr<OutfeedData>>>(devices_.size());
239
240 callback_queue_size_bytes_ = 0;
241 num_listening_threads_ = 0;
242 num_working_callback_threads_ = 0;
243 shutdown_started_ = false;
244 }
245
Start()246 void OutfeedReceiverImpl::Start() {
247 {
248 absl::MutexLock lock(&mu_);
249 CHECK(!shutdown_started_);
250 }
251
252 int num_threads = 2 * devices_.size();
253 threads_ = std::make_unique<tensorflow::thread::ThreadPool>(
254 tensorflow::Env::Default(), "outfeed_receiver", num_threads);
255 for (int device_idx = 0; device_idx < devices_.size(); ++device_idx) {
256 threads_->Schedule(
257 [this, device_idx]() { DeviceListenerThreadLoop(device_idx); });
258 threads_->Schedule(
259 [this, device_idx]() { CallbackThreadLoop(device_idx); });
260 }
261 }
262
Shutdown()263 void OutfeedReceiverImpl::Shutdown() {
264 VLOG(2) << "Shutdown start";
265 {
266 absl::MutexLock lock(&mu_);
267 CHECK(!shutdown_started_);
268 shutdown_started_ = true;
269 }
270 for (int device_idx = 0; device_idx < devices_.size(); ++device_idx) {
271 CHECK(SendShutdownOutfeedHeader(device_idx).ok());
272 }
273 VLOG(2) << "Shutdown waiting for listening and callback threads to stop";
274 absl::MutexLock lock(&mu_);
275 mu_.Await(absl::Condition(this, &OutfeedReceiverImpl::ShutdownDone));
276 VLOG(2) << "Shutdown done";
277 }
278
~OutfeedReceiverImpl()279 OutfeedReceiverImpl::~OutfeedReceiverImpl() {
280 VLOG(2) << "~OutfeedReceiverImpl";
281 Shutdown();
282 }
283
DeviceListenerThreadLoop(int device_idx)284 void OutfeedReceiverImpl::DeviceListenerThreadLoop(int device_idx) {
285 {
286 absl::MutexLock lock(&mu_);
287 ++num_listening_threads_;
288 }
289 PjRtDevice* device = devices_[device_idx];
290 while (true) {
291 Shape header_shape = ShapeUtil::MakeShape(U32, {kOutfeedHeaderWords});
292 std::unique_ptr<Literal> header =
293 ReceiveRawFromOutfeed(device, header_shape).ValueOrDie();
294 absl::Span<uint32_t> header_data = header->data<uint32_t>();
295 CHECK_EQ(header_data.size(), kOutfeedHeaderWords);
296 CHECK_EQ(header_data[0], kOutfeedHeaderStart);
297 uint32_t consumer_id = header_data[1];
298 Shape shape;
299 {
300 absl::MutexLock lock(&mu_);
301 auto registered_shape = shape_registry_.find(consumer_id);
302 if (registered_shape == shape_registry_.end()) {
303 LOG(FATAL)
304 << "[" << device->DebugString()
305 << "] Cannot find registered shape for consumer ID " << consumer_id
306 << ". Perhaps the code was compiled with a different instance "
307 << "of OutfeedReceiver.";
308 }
309 shape = registered_shape->second;
310 }
311 auto received = std::make_unique<OutfeedData>(device, consumer_id, shape);
312 VLOG(2) << "Listener received header " << received->DebugString();
313 if (consumer_id == kOutfeedCidShutdown) {
314 VLOG(2) << "[" << device->DebugString()
315 << "] Listener received shutdown header";
316 absl::MutexLock lock(&mu_);
317 --num_listening_threads_;
318 VLOG(2) << "[" << device->DebugString() << "] Enqueue shutdown callback";
319 EnqueueReceivedData(device_idx, std::move(received));
320 return;
321 }
322 std::unique_ptr<Literal> data =
323 ReceiveRawFromOutfeed(device, shape).ValueOrDie();
324 received->SetLiteral(std::move(data));
325 absl::MutexLock lock(&mu_);
326 EnqueueReceivedData(device_idx, std::move(received));
327 }
328 }
329
EnqueueReceivedData(uint32_t device_idx,std::unique_ptr<OutfeedData> received)330 void OutfeedReceiverImpl::EnqueueReceivedData(
331 uint32_t device_idx, std::unique_ptr<OutfeedData> received)
332 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
333 mu_.Await(absl::Condition(this, &OutfeedReceiverImpl::CallbackQueueHasSpace));
334 ssize_t literal_size_bytes = received->literal_size_bytes();
335 callback_queue_size_bytes_ += literal_size_bytes;
336 VLOG(2) << "Listener enqueues data " << received->DebugString() << " of size "
337 << literal_size_bytes << " bytes; "
338 << (1 + callback_queues_[device_idx].size())
339 << " callbacks in queue of total size " << callback_queue_size_bytes_
340 << " bytes.\n";
341 callback_queues_[device_idx].push(std::move(received));
342 }
343
ReceiveRawFromOutfeed(PjRtDevice * device,const Shape & shape)344 StatusOr<std::unique_ptr<Literal>> OutfeedReceiverImpl::ReceiveRawFromOutfeed(
345 PjRtDevice* device, const Shape& shape) {
346 auto literal = std::make_unique<Literal>(shape);
347 TF_RETURN_IF_ERROR(device->TransferFromOutfeed(literal.get()));
348 return literal;
349 }
350
CallbackThreadLoop(int device_idx)351 void OutfeedReceiverImpl::CallbackThreadLoop(int device_idx) {
352 const PjRtDevice* device = devices_[device_idx];
353 {
354 absl::MutexLock lock(&mu_);
355 num_working_callback_threads_++;
356 }
357 while (true) {
358 std::unique_ptr<OutfeedData> received;
359 {
360 absl::MutexLock lock(&mu_);
361 mu_.Await(absl::Condition(
362 +[](std::queue<std::unique_ptr<OutfeedData>>* queue) {
363 return !queue->empty();
364 },
365 &callback_queues_[device_idx]));
366 received = std::move(callback_queues_[device_idx].front());
367 callback_queues_[device_idx].pop();
368 callback_queue_size_bytes_ -= received->literal_size_bytes();
369 VLOG(2) << "[" << device->DebugString() << "] Dequeued callback for "
370 << received->DebugString() << "; "
371 << callback_queues_[device_idx].size()
372 << " callbacks in queue of total size "
373 << callback_queue_size_bytes_ << " bytes.\n";
374 }
375 if (received->consumer_id() == kOutfeedCidShutdown) {
376 VLOG(2) << "[" << device->DebugString()
377 << "] Callback loop received shutdown signal";
378 {
379 absl::MutexLock lock(&mu_);
380 CHECK(callback_queues_[device_idx].empty());
381 --num_working_callback_threads_;
382 }
383 VLOG(2) << "[" << device->DebugString() << "] Callback loop done";
384 return;
385 }
386 {
387 tensorflow::profiler::TraceMe traceme("OutfeedReceiver::Callback");
388 callback_(received->device(), received->consumer_id(),
389 received->literal());
390 }
391 }
392 }
393
SendShutdownOutfeedHeader(int device_idx)394 Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) {
395 const PjRtDevice* device = devices_[device_idx];
396 constexpr int consumer_id = kOutfeedCidShutdown;
397 VLOG(2) << "[" << device->DebugString()
398 << "] SendSpecialHeader cons=" << consumer_id;
399 XlaBuilder builder(
400 absl::StrFormat("special_outfeed_header_%d_%d", consumer_id, device_idx));
401 XlaOp send =
402 AddOutfeedToBuilder(&builder, CreateToken(&builder), consumer_id, {})
403 .ValueOrDie();
404 XlaComputation computation = builder.Build(send).ValueOrDie();
405
406 CompileOptions compile_options;
407 compile_options.executable_build_options.set_num_replicas(1);
408 compile_options.executable_build_options.set_num_partitions(1);
409 DeviceAssignment device_assignment(1, 1);
410 device_assignment(0, 0) = device->id();
411 compile_options.executable_build_options.set_device_assignment(
412 device_assignment);
413
414 TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtLoadedExecutable> executable,
415 devices_[device_idx]->client()->Compile(
416 computation, std::move(compile_options)));
417 ExecuteOptions execute_options;
418 TF_ASSIGN_OR_RETURN(
419 std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers,
420 executable->Execute({{}}, execute_options));
421 return OkStatus();
422 }
423
AddOutfeedToBuilder(XlaBuilder * builder,XlaOp token,uint32_t consumer_id,std::vector<XlaOp> arrays)424 StatusOr<XlaOp> OutfeedReceiverImpl::AddOutfeedToBuilder(
425 XlaBuilder* builder, XlaOp token, uint32_t consumer_id,
426 std::vector<XlaOp> arrays) {
427 XlaOp data = Tuple(builder, std::move(arrays));
428 Shape shape_with_layout = builder->GetShape(data).ValueOrDie();
429 ShapeUtil::ForEachMutableSubshape(
430 &shape_with_layout, [](Shape* subshape, const ShapeIndex&) {
431 if (!subshape->has_layout()) {
432 LayoutUtil::SetToDefaultLayout(subshape);
433 }
434 });
435 VLOG(2) << "RegisterShape cons=" << consumer_id
436 << "; shape=" << shape_with_layout.ToString();
437 {
438 absl::MutexLock lock(&mu_);
439 auto found = shape_registry_.find(consumer_id);
440 if (found != shape_registry_.end()) {
441 if (!ShapeUtil::Equal(shape_with_layout, found->second)) {
442 return InvalidArgument(
443 "Shape %s does not match previous shape %s used "
444 "for consumer id %d",
445 shape_with_layout.DebugString(), found->second.DebugString(),
446 consumer_id);
447 }
448 } else {
449 shape_registry_.insert({consumer_id, shape_with_layout});
450 }
451 }
452
453 std::vector<uint32_t> header{kOutfeedHeaderStart, consumer_id};
454 XlaOp header_op = ConstantR1<uint32_t>(builder, header);
455 // We assign the outfeed to the first device. This must match the sharding
456 // for the paired infeed.
457 builder->SetSharding(sharding_builder::AssignDevice(0));
458 token = OutfeedWithToken(
459 header_op, token, ShapeUtil::MakeShape(U32, {kOutfeedHeaderWords}), "");
460 if (consumer_id != kOutfeedCidShutdown) {
461 token = OutfeedWithToken(data, token, shape_with_layout, "");
462 }
463 builder->ClearSharding();
464 return token;
465 }
466
OutfeedReceiver(Callback callback,absl::Span<PjRtClient * const> clients,ssize_t max_callback_queue_size_bytes)467 OutfeedReceiver::OutfeedReceiver(Callback callback,
468 absl::Span<PjRtClient* const> clients,
469 ssize_t max_callback_queue_size_bytes) {
470 p_impl_ = std::make_unique<OutfeedReceiverImpl>(
471 callback, clients, max_callback_queue_size_bytes);
472 }
473
~OutfeedReceiver()474 OutfeedReceiver::~OutfeedReceiver() {}
475
Start()476 void OutfeedReceiver::Start() { p_impl_->Start(); }
477
AddOutfeedToBuilder(XlaBuilder * builder,XlaOp token,uint32_t consumer_id,std::vector<XlaOp> arrays)478 StatusOr<XlaOp> OutfeedReceiver::AddOutfeedToBuilder(
479 XlaBuilder* builder, XlaOp token, uint32_t consumer_id,
480 std::vector<XlaOp> arrays) {
481 if (consumer_id == kOutfeedCidShutdown) {
482 return InvalidArgument("Consumer ID cannot be a reserved value: %d",
483 consumer_id);
484 }
485 return p_impl_->AddOutfeedToBuilder(builder, token, consumer_id, arrays);
486 }
487
488 } // namespace xla
489