xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/outfeed_receiver.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/outfeed_receiver.h"
17 
18 #include <sys/types.h>
19 
20 #include <memory>
21 #include <queue>
22 #include <sstream>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/xla/client/sharding_builder.h"
27 #include "tensorflow/compiler/xla/client/xla_builder.h"
28 #include "tensorflow/compiler/xla/client/xla_computation.h"
29 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
30 #include "tensorflow/compiler/xla/util.h"
31 #include "tensorflow/core/profiler/lib/traceme.h"
32 
33 // Implementation notes:
34 //
35 // Startup:
36 // -------
37 //
38 // The startup is initiated by a call from Python to StartOutfeedReceiver. For
39 // each local device there is one thread for listening for outfeeds from the
40 // device, one queue of received outfeeds, and one thread for invoking the
41 // Python callbacks.
42 //
43 // Framing protocol
44 // ----------------
45 //
46 // The outfeed mechanism has a single channel and the receiver must know
47 // exactly the shape and number of outfeed operations issued by the compiled
48 // code. This makes it hard to use outfeed in conditionals and loops and
49 // especially when outfeeding different-shaped data.
50 //
51 // To address this, when we compile the code we capture the shape of the
52 // data being outfed, and we generate a consumer ID (uint32_t) that is unique
53 // across the lifetime of the program to: the Python callable to callback to,
54 // the shape of the arguments, the keyword arguments to pass to the callable.
55 // Each outfeed payload is preceeded by a header (of shape u32[2]) with a
56 // special first value and the consumer ID. We maintain a registry of shapes
57 // by consumer ID. When receiving we lookup the shape by consumer ID, and then
58 // we read the payload.
59 //
60 // Back pressure:
61 // --------------
62 //
63 // We maintain a sum of the bytes from all the data waiting in the callback
64 // queues. The listening threads will wait for the sum to drop below a
65 // configurable threshold, default 256Mb. While the listening thread is waiting,
66 // on CPU and GPU the next outfeed operation from the device will block. On
67 // TPU there is a buffer, but eventually the TPU will also block.
68 //
69 // Shutdown:
70 // ---------
71 //
72 // The shutdown is initiated automatically when the last reference to the
73 // outfeed receiver object is dropped, and the Python garbage collector invokes
74 // the destructor.
75 //
76 // The shutdown sequence is implemented as follows:
77 // * we enqueue on all devices a computation that outfeeds a special header
78 //   with customer ID kOutfeedCidShutdown.
79 // * when each listening threads gets the shutdown header, it decrements
80 //   a counter of listening threads, and it
81 //   enqueues a special shutdown callback.
82 // * when each callback thread gets the shutdown callback marker, it terminates.
83 // * the shutdown code waits until all threads terminate.
84 //
85 // Since we currently keep the shape registry in the OutfeedReceiver, it is
86 // not safe to replace the OutfeedReceiver instance during the lifetime of
87 // the JAX program, or else previously cached jitted computations may refer
88 // to previously cached shapes. This can be solved, but for now we disallow
89 // replacing the OutfeedReceiver, and do not provide a Shutdown API to the
90 // Python program.
91 
92 namespace xla {
93 
94 // The header contains:
95 // 0. kOutfeedHeaderStart
96 // 1. consumer id
97 int constexpr kOutfeedHeaderWords = 2;
98 uint32_t constexpr kOutfeedHeaderStart = 271828;
99 // Special consumer IDs, without outfeed payload.
100 uint32_t constexpr kOutfeedCidShutdown = 0;
101 
102 // Encapsulates data received from a device outfeed.
103 class OutfeedData {
104  public:
OutfeedData(PjRtDevice * device,uint32_t consumer_id,Shape shape)105   OutfeedData(PjRtDevice* device, uint32_t consumer_id, Shape shape)
106       : device_(device),
107         consumer_id_(consumer_id),
108         shape_(shape),
109         literal_(nullptr),
110         literal_size_bytes_(0) {}
111 
device()112   PjRtDevice* device() { return device_; }
consumer_id() const113   uint32_t consumer_id() const { return consumer_id_; }
shape() const114   Shape shape() const { return shape_; }
literal()115   std::unique_ptr<Literal> literal() {
116     CHECK(literal_);
117     return std::move(literal_);
118   }
119 
120   void SetLiteral(std::unique_ptr<Literal> literal);
121 
literal_size_bytes() const122   ssize_t literal_size_bytes() const { return literal_size_bytes_; }
123 
124   std::string DebugString() const;
125 
126  private:
127   PjRtDevice* device_;
128   uint32_t consumer_id_;
129   Shape shape_;
130   std::unique_ptr<Literal> literal_;
131   ssize_t literal_size_bytes_;
132 };
133 
SetLiteral(std::unique_ptr<Literal> literal)134 void OutfeedData::SetLiteral(std::unique_ptr<Literal> literal) {
135   literal_ = std::move(literal);
136   shape_ = literal_->shape();
137   int total_size_bytes = 0;
138   ShapeUtil::ForEachSubshape(
139       shape_, [&](const Shape& literal_subshape, const ShapeIndex& index) {
140         if (!literal_subshape.IsTuple()) {
141           total_size_bytes += ShapeUtil::ByteSizeOf(literal_subshape, 8);
142         }
143       });
144   literal_size_bytes_ = total_size_bytes;
145 }
146 
DebugString() const147 std::string OutfeedData::DebugString() const {
148   return absl::StrFormat("dev=%s; cons=%d; shape=%s", device_->DebugString(),
149                          consumer_id_, shape_.ToString());
150 }
151 
152 class OutfeedReceiverImpl {
153  public:
154   OutfeedReceiverImpl(OutfeedReceiver::Callback callback,
155                       absl::Span<PjRtClient* const> clients,
156                       ssize_t max_callback_queue_size_bytes);
157 
158   OutfeedReceiverImpl(const OutfeedReceiverImpl&) = delete;
159   OutfeedReceiverImpl& operator=(const OutfeedReceiverImpl&) = delete;
160 
161   // Blocks until all data has been received from devices and all data
162   // in the queue has been passed to Python.
163   ~OutfeedReceiverImpl();
164 
165   void Start();
166 
167   StatusOr<XlaOp> AddOutfeedToBuilder(XlaBuilder* builder, XlaOp token,
168                                       uint32_t consumer_id,
169                                       std::vector<XlaOp> arrays);
170 
171  private:
CallbackQueueHasSpace()172   bool CallbackQueueHasSpace() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
173     return callback_queue_size_bytes_ < max_callback_queue_size_bytes_;
174   }
175 
ShutdownDone()176   bool ShutdownDone() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
177     return (num_working_callback_threads_ == 0 && num_listening_threads_ == 0);
178   }
179 
180   void CallbackThreadLoop(int device_idx);
181   void DeviceListenerThreadLoop(int device_idx);
182 
183   // Enqueues to a device an outfeed operation with a shutdown consumer ID.
184   Status SendShutdownOutfeedHeader(int device_idx);
185 
186   // Receives a raw Literal from a device outfeed.
187   StatusOr<std::unique_ptr<Literal>> ReceiveRawFromOutfeed(PjRtDevice* device,
188                                                            const Shape& shape);
189 
190   // Enqueues received data in the callbaback queue.
191   void EnqueueReceivedData(uint32_t device_idx,
192                            std::unique_ptr<OutfeedData> received)
193       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
194 
195   // Shuts down the threads. See implementation notes at top of file.
196   // It is not safe to restart an OutfeedReceiver after shutting down one.
197   void Shutdown();
198 
199   OutfeedReceiver::Callback callback_;
200   // The devices on which we are listening.
201   std::vector<PjRtDevice*> devices_;
202   // Maximum bytes capacity of the ensemble of callback queues.
203   uint64_t max_callback_queue_size_bytes_;
204 
205   absl::Mutex mu_;
206   // Registered shapes by consumer id.
207   // The shape registry must be alive as long as the program exists.
208   // Right now we tell the user to never restart after Shutdown.
209   absl::flat_hash_map<uint32_t, Shape> shape_registry_ ABSL_GUARDED_BY(mu_);
210   // How many bytes of Literal are in the ensemble of callback queues.
211   uint64_t callback_queue_size_bytes_ ABSL_GUARDED_BY(mu_);
212   // Threads listening.
213   int num_listening_threads_ ABSL_GUARDED_BY(mu_);
214   bool shutdown_started_ ABSL_GUARDED_BY(mu_);
215 
216   // How many callback threads are still working. Used for shutdown.
217   int num_working_callback_threads_ ABSL_GUARDED_BY(mu_);
218 
219   std::vector<std::queue<std::unique_ptr<OutfeedData>>> callback_queues_
220       ABSL_GUARDED_BY(mu_);
221   // The threadpool must come last to ensure the queue exists
222   // when the pool destructor is called.
223   std::unique_ptr<tensorflow::thread::ThreadPool> threads_;
224 };
225 
OutfeedReceiverImpl(OutfeedReceiver::Callback callback,absl::Span<PjRtClient * const> clients,ssize_t max_callback_queue_size_bytes)226 OutfeedReceiverImpl::OutfeedReceiverImpl(
227     OutfeedReceiver::Callback callback, absl::Span<PjRtClient* const> clients,
228     ssize_t max_callback_queue_size_bytes) {
229   callback_ = callback;
230   max_callback_queue_size_bytes_ = max_callback_queue_size_bytes;
231   for (const auto& client : clients) {
232     for (auto device : client->addressable_devices()) {
233       devices_.push_back(device);
234     }
235   }
236   CHECK_GT(devices_.size(), 0);
237   callback_queues_ =
238       std::vector<std::queue<std::unique_ptr<OutfeedData>>>(devices_.size());
239 
240   callback_queue_size_bytes_ = 0;
241   num_listening_threads_ = 0;
242   num_working_callback_threads_ = 0;
243   shutdown_started_ = false;
244 }
245 
Start()246 void OutfeedReceiverImpl::Start() {
247   {
248     absl::MutexLock lock(&mu_);
249     CHECK(!shutdown_started_);
250   }
251 
252   int num_threads = 2 * devices_.size();
253   threads_ = std::make_unique<tensorflow::thread::ThreadPool>(
254       tensorflow::Env::Default(), "outfeed_receiver", num_threads);
255   for (int device_idx = 0; device_idx < devices_.size(); ++device_idx) {
256     threads_->Schedule(
257         [this, device_idx]() { DeviceListenerThreadLoop(device_idx); });
258     threads_->Schedule(
259         [this, device_idx]() { CallbackThreadLoop(device_idx); });
260   }
261 }
262 
Shutdown()263 void OutfeedReceiverImpl::Shutdown() {
264   VLOG(2) << "Shutdown start";
265   {
266     absl::MutexLock lock(&mu_);
267     CHECK(!shutdown_started_);
268     shutdown_started_ = true;
269   }
270   for (int device_idx = 0; device_idx < devices_.size(); ++device_idx) {
271     CHECK(SendShutdownOutfeedHeader(device_idx).ok());
272   }
273   VLOG(2) << "Shutdown waiting for listening and callback threads to stop";
274   absl::MutexLock lock(&mu_);
275   mu_.Await(absl::Condition(this, &OutfeedReceiverImpl::ShutdownDone));
276   VLOG(2) << "Shutdown done";
277 }
278 
~OutfeedReceiverImpl()279 OutfeedReceiverImpl::~OutfeedReceiverImpl() {
280   VLOG(2) << "~OutfeedReceiverImpl";
281   Shutdown();
282 }
283 
DeviceListenerThreadLoop(int device_idx)284 void OutfeedReceiverImpl::DeviceListenerThreadLoop(int device_idx) {
285   {
286     absl::MutexLock lock(&mu_);
287     ++num_listening_threads_;
288   }
289   PjRtDevice* device = devices_[device_idx];
290   while (true) {
291     Shape header_shape = ShapeUtil::MakeShape(U32, {kOutfeedHeaderWords});
292     std::unique_ptr<Literal> header =
293         ReceiveRawFromOutfeed(device, header_shape).ValueOrDie();
294     absl::Span<uint32_t> header_data = header->data<uint32_t>();
295     CHECK_EQ(header_data.size(), kOutfeedHeaderWords);
296     CHECK_EQ(header_data[0], kOutfeedHeaderStart);
297     uint32_t consumer_id = header_data[1];
298     Shape shape;
299     {
300       absl::MutexLock lock(&mu_);
301       auto registered_shape = shape_registry_.find(consumer_id);
302       if (registered_shape == shape_registry_.end()) {
303         LOG(FATAL)
304             << "[" << device->DebugString()
305             << "] Cannot find registered shape for consumer ID " << consumer_id
306             << ". Perhaps the code was compiled with a different instance "
307             << "of OutfeedReceiver.";
308       }
309       shape = registered_shape->second;
310     }
311     auto received = std::make_unique<OutfeedData>(device, consumer_id, shape);
312     VLOG(2) << "Listener received header " << received->DebugString();
313     if (consumer_id == kOutfeedCidShutdown) {
314       VLOG(2) << "[" << device->DebugString()
315               << "] Listener received shutdown header";
316       absl::MutexLock lock(&mu_);
317       --num_listening_threads_;
318       VLOG(2) << "[" << device->DebugString() << "] Enqueue shutdown callback";
319       EnqueueReceivedData(device_idx, std::move(received));
320       return;
321     }
322     std::unique_ptr<Literal> data =
323         ReceiveRawFromOutfeed(device, shape).ValueOrDie();
324     received->SetLiteral(std::move(data));
325     absl::MutexLock lock(&mu_);
326     EnqueueReceivedData(device_idx, std::move(received));
327   }
328 }
329 
EnqueueReceivedData(uint32_t device_idx,std::unique_ptr<OutfeedData> received)330 void OutfeedReceiverImpl::EnqueueReceivedData(
331     uint32_t device_idx, std::unique_ptr<OutfeedData> received)
332     ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
333   mu_.Await(absl::Condition(this, &OutfeedReceiverImpl::CallbackQueueHasSpace));
334   ssize_t literal_size_bytes = received->literal_size_bytes();
335   callback_queue_size_bytes_ += literal_size_bytes;
336   VLOG(2) << "Listener enqueues data " << received->DebugString() << " of size "
337           << literal_size_bytes << " bytes; "
338           << (1 + callback_queues_[device_idx].size())
339           << " callbacks in queue of total size " << callback_queue_size_bytes_
340           << " bytes.\n";
341   callback_queues_[device_idx].push(std::move(received));
342 }
343 
ReceiveRawFromOutfeed(PjRtDevice * device,const Shape & shape)344 StatusOr<std::unique_ptr<Literal>> OutfeedReceiverImpl::ReceiveRawFromOutfeed(
345     PjRtDevice* device, const Shape& shape) {
346   auto literal = std::make_unique<Literal>(shape);
347   TF_RETURN_IF_ERROR(device->TransferFromOutfeed(literal.get()));
348   return literal;
349 }
350 
CallbackThreadLoop(int device_idx)351 void OutfeedReceiverImpl::CallbackThreadLoop(int device_idx) {
352   const PjRtDevice* device = devices_[device_idx];
353   {
354     absl::MutexLock lock(&mu_);
355     num_working_callback_threads_++;
356   }
357   while (true) {
358     std::unique_ptr<OutfeedData> received;
359     {
360       absl::MutexLock lock(&mu_);
361       mu_.Await(absl::Condition(
362           +[](std::queue<std::unique_ptr<OutfeedData>>* queue) {
363             return !queue->empty();
364           },
365           &callback_queues_[device_idx]));
366       received = std::move(callback_queues_[device_idx].front());
367       callback_queues_[device_idx].pop();
368       callback_queue_size_bytes_ -= received->literal_size_bytes();
369       VLOG(2) << "[" << device->DebugString() << "] Dequeued callback for "
370               << received->DebugString() << "; "
371               << callback_queues_[device_idx].size()
372               << " callbacks in queue of total size "
373               << callback_queue_size_bytes_ << " bytes.\n";
374     }
375     if (received->consumer_id() == kOutfeedCidShutdown) {
376       VLOG(2) << "[" << device->DebugString()
377               << "] Callback loop received shutdown signal";
378       {
379         absl::MutexLock lock(&mu_);
380         CHECK(callback_queues_[device_idx].empty());
381         --num_working_callback_threads_;
382       }
383       VLOG(2) << "[" << device->DebugString() << "] Callback loop done";
384       return;
385     }
386     {
387       tensorflow::profiler::TraceMe traceme("OutfeedReceiver::Callback");
388       callback_(received->device(), received->consumer_id(),
389                 received->literal());
390     }
391   }
392 }
393 
SendShutdownOutfeedHeader(int device_idx)394 Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) {
395   const PjRtDevice* device = devices_[device_idx];
396   constexpr int consumer_id = kOutfeedCidShutdown;
397   VLOG(2) << "[" << device->DebugString()
398           << "] SendSpecialHeader cons=" << consumer_id;
399   XlaBuilder builder(
400       absl::StrFormat("special_outfeed_header_%d_%d", consumer_id, device_idx));
401   XlaOp send =
402       AddOutfeedToBuilder(&builder, CreateToken(&builder), consumer_id, {})
403           .ValueOrDie();
404   XlaComputation computation = builder.Build(send).ValueOrDie();
405 
406   CompileOptions compile_options;
407   compile_options.executable_build_options.set_num_replicas(1);
408   compile_options.executable_build_options.set_num_partitions(1);
409   DeviceAssignment device_assignment(1, 1);
410   device_assignment(0, 0) = device->id();
411   compile_options.executable_build_options.set_device_assignment(
412       device_assignment);
413 
414   TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtLoadedExecutable> executable,
415                       devices_[device_idx]->client()->Compile(
416                           computation, std::move(compile_options)));
417   ExecuteOptions execute_options;
418   TF_ASSIGN_OR_RETURN(
419       std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers,
420       executable->Execute({{}}, execute_options));
421   return OkStatus();
422 }
423 
AddOutfeedToBuilder(XlaBuilder * builder,XlaOp token,uint32_t consumer_id,std::vector<XlaOp> arrays)424 StatusOr<XlaOp> OutfeedReceiverImpl::AddOutfeedToBuilder(
425     XlaBuilder* builder, XlaOp token, uint32_t consumer_id,
426     std::vector<XlaOp> arrays) {
427   XlaOp data = Tuple(builder, std::move(arrays));
428   Shape shape_with_layout = builder->GetShape(data).ValueOrDie();
429   ShapeUtil::ForEachMutableSubshape(
430       &shape_with_layout, [](Shape* subshape, const ShapeIndex&) {
431         if (!subshape->has_layout()) {
432           LayoutUtil::SetToDefaultLayout(subshape);
433         }
434       });
435   VLOG(2) << "RegisterShape cons=" << consumer_id
436           << "; shape=" << shape_with_layout.ToString();
437   {
438     absl::MutexLock lock(&mu_);
439     auto found = shape_registry_.find(consumer_id);
440     if (found != shape_registry_.end()) {
441       if (!ShapeUtil::Equal(shape_with_layout, found->second)) {
442         return InvalidArgument(
443             "Shape %s does not match previous shape %s used "
444             "for consumer id %d",
445             shape_with_layout.DebugString(), found->second.DebugString(),
446             consumer_id);
447       }
448     } else {
449       shape_registry_.insert({consumer_id, shape_with_layout});
450     }
451   }
452 
453   std::vector<uint32_t> header{kOutfeedHeaderStart, consumer_id};
454   XlaOp header_op = ConstantR1<uint32_t>(builder, header);
455   // We assign the outfeed to the first device. This must match the sharding
456   // for the paired infeed.
457   builder->SetSharding(sharding_builder::AssignDevice(0));
458   token = OutfeedWithToken(
459       header_op, token, ShapeUtil::MakeShape(U32, {kOutfeedHeaderWords}), "");
460   if (consumer_id != kOutfeedCidShutdown) {
461     token = OutfeedWithToken(data, token, shape_with_layout, "");
462   }
463   builder->ClearSharding();
464   return token;
465 }
466 
OutfeedReceiver(Callback callback,absl::Span<PjRtClient * const> clients,ssize_t max_callback_queue_size_bytes)467 OutfeedReceiver::OutfeedReceiver(Callback callback,
468                                  absl::Span<PjRtClient* const> clients,
469                                  ssize_t max_callback_queue_size_bytes) {
470   p_impl_ = std::make_unique<OutfeedReceiverImpl>(
471       callback, clients, max_callback_queue_size_bytes);
472 }
473 
~OutfeedReceiver()474 OutfeedReceiver::~OutfeedReceiver() {}
475 
Start()476 void OutfeedReceiver::Start() { p_impl_->Start(); }
477 
AddOutfeedToBuilder(XlaBuilder * builder,XlaOp token,uint32_t consumer_id,std::vector<XlaOp> arrays)478 StatusOr<XlaOp> OutfeedReceiver::AddOutfeedToBuilder(
479     XlaBuilder* builder, XlaOp token, uint32_t consumer_id,
480     std::vector<XlaOp> arrays) {
481   if (consumer_id == kOutfeedCidShutdown) {
482     return InvalidArgument("Consumer ID cannot be a reserved value: %d",
483                            consumer_id);
484   }
485   return p_impl_->AddOutfeedToBuilder(builder, token, consumer_id, arrays);
486 }
487 
488 }  // namespace xla
489