xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/outfeed_receiver_py.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/outfeed_receiver_py.h"
17 
18 #include <memory>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/synchronization/mutex.h"
22 #include "pybind11/functional.h"
23 #include "pybind11/pybind11.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
26 #include "tensorflow/compiler/xla/python/outfeed_receiver.h"
27 #include "tensorflow/compiler/xla/python/py_client.h"
28 #include "tensorflow/compiler/xla/python/types.h"
29 
30 namespace xla {
31 
32 namespace py = pybind11;
33 
34 namespace {
35 
36 // A wrapper for OutfeedReceiver for use from Python, useful for ensuring
37 // that the GIL is released before destroying the OutfeedReceiver.
38 class OutfeedReceiverForPython {
39  public:
40   // A callback to Python takes: consumer id, received literal.
41   using CallbackToPython =
42       std::function<void(ClientAndPtr<PjRtDevice>, uint32_t, pybind11::object)>;
43 
OutfeedReceiverForPython(CallbackToPython callback_python,std::vector<std::shared_ptr<PyClient>> clients,ssize_t max_callback_queue_size_bytes)44   OutfeedReceiverForPython(CallbackToPython callback_python,
45                            std::vector<std::shared_ptr<PyClient>> clients,
46                            ssize_t max_callback_queue_size_bytes)
47       : callback_python_(std::move(callback_python)),
48         clients_(std::move(clients)) {
49     OutfeedReceiver::Callback callback =
50         [this](PjRtDevice* device, uint32_t consumer_id,
51                std::shared_ptr<Literal> literal) {
52           this->Callback(device, consumer_id, std::move(literal));
53         };
54     std::vector<PjRtClient*> client_ptrs(clients_.size());
55     absl::c_transform(clients_, client_ptrs.begin(),
56                       [](const std::shared_ptr<PyClient>& client) {
57                         return client->pjrt_client();
58                       });
59     outfeed_receiver_ = std::make_unique<OutfeedReceiver>(
60         callback, client_ptrs, max_callback_queue_size_bytes);
61   }
62   OutfeedReceiverForPython(const OutfeedReceiverForPython&) = delete;
63   OutfeedReceiverForPython& operator=(const OutfeedReceiverForPython&) = delete;
64 
~OutfeedReceiverForPython()65   ~OutfeedReceiverForPython() {
66     // This destructor is called from the Python GC. Release it for the duration
67     // of the destruction, including the destruction of the OutfeedReceiver,
68     // when we may actually have to wait for threads to end. During this time
69     // we do not callback to Python (sometimes we get an exception
70     // "std::runtime_error: scoped_acquire::dec_ref(): thread state must
71     // be current!"").
72     {
73       absl::MutexLock lock(&mu_);
74       outfeed_receiver_shutting_down_ = true;
75     }
76     py::gil_scoped_release gil_release;
77     outfeed_receiver_ = nullptr;  // Shutdown the outfeed receiver.
78   }
79 
Start()80   void Start() { outfeed_receiver_->Start(); }
81 
AddOutfeed(XlaBuilder * builder,XlaOp token,uint32_t consumer_id,std::vector<XlaOp> arrays)82   StatusOr<XlaOp> AddOutfeed(XlaBuilder* builder, XlaOp token,
83                              uint32_t consumer_id, std::vector<XlaOp> arrays) {
84     return outfeed_receiver_->AddOutfeedToBuilder(builder, token, consumer_id,
85                                                   arrays);
86   }
87 
Callback(PjRtDevice * device,uint32_t consumer_id,std::shared_ptr<Literal> literal)88   void Callback(PjRtDevice* device, uint32_t consumer_id,
89                 std::shared_ptr<Literal> literal) {
90     {
91       absl::MutexLock lock(&mu_);
92       if (outfeed_receiver_shutting_down_) {
93         VLOG(2) << "Ignoring unsafe callback to Python during shutdown";
94         return;
95       }
96     }
97     // We expect the number of clients to be small, so an O(n) search is fine.
98     auto it = absl::c_find_if(
99         clients_, [device](const std::shared_ptr<PyClient>& client) {
100           return client->pjrt_client() == device->client();
101         });
102     CHECK(it != clients_.end());
103     py::gil_scoped_acquire gil_acquire;  // Need GIL also for LiteralToPython
104     py::object literal_python =
105         LiteralToPython(std::move(literal)).ValueOrDie();
106     // The callback_ should handle all exceptions in user-code. If we get
107     // an exception here, it is a bug in the callback and we should stop.
108     callback_python_(WrapWithClient<PjRtDevice>(*it, device), consumer_id,
109                      std::move(literal_python));
110   }
111 
112  private:
113   CallbackToPython callback_python_;
114   absl::Mutex mu_;
115   bool outfeed_receiver_shutting_down_ ABSL_GUARDED_BY(mu_) = false;
116   std::vector<std::shared_ptr<PyClient>> clients_;
117   std::unique_ptr<OutfeedReceiver> outfeed_receiver_;
118 };
119 
120 }  // namespace
121 
BuildOutfeedReceiverSubmodule(py::module * m)122 void BuildOutfeedReceiverSubmodule(py::module* m) {
123   py::module outfeed_receiver =
124       m->def_submodule("outfeed_receiver", "Outfeed receiver");
125   outfeed_receiver.def(
126       "start",
127       [](OutfeedReceiverForPython::CallbackToPython callback_to_python,
128          std::vector<std::shared_ptr<PyClient>> clients,
129          ssize_t max_callback_queue_size_bytes)
130           -> std::unique_ptr<OutfeedReceiverForPython> {
131         auto server = std::make_unique<OutfeedReceiverForPython>(
132             callback_to_python, clients, max_callback_queue_size_bytes);
133         server->Start();
134         return server;
135       },
136       py::arg("callback_to_python"), py::arg("backends"),
137       py::arg("max_queue_size_bytes") = 256 * 1024 * 1024,
138       R"(Starts a multithreaded outfeed receiver.
139 
140       There is one thread for each of the specified devices. When Python
141       drops the last reference to the returned object, the receiver is shut
142       down. The destructor will block until all data is received from
143       devices.
144 
145       Args:
146         * callback_to_python: a Python callback to call, with <consumer_id>
147           and the data received.
148         * backends: the list of backends to listen on.
149         * max_queue_size_bytes: an optional integer to bound the maximum size
150             of arrays in the callback queue. When this limit is reached the
151             device listener pauses.
152       )",
153       py::call_guard<py::gil_scoped_release>());
154 
155   py::class_<OutfeedReceiverForPython> outfeed_receiver_class(
156       outfeed_receiver, "OutfeedReceiverForPython");
157 
158   outfeed_receiver_class.def(
159       "add_outfeed", &OutfeedReceiverForPython::AddOutfeed, py::arg("builder"),
160       py::arg("token"), py::arg("consumer_id"), py::arg("arrays"),
161       R"(Adds an outfeed into the given computation builder.
162 
163       Has the side-effect of registering the sent shape along with the consumer
164       ID. Returns error if the outfeed shape is not compatible with previously
165       used shape for the same consumer ID.)",
166       py::call_guard<py::gil_scoped_release>());
167 }
168 
169 }  // namespace xla
170