1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/outfeed_receiver_py.h"
17
18 #include <memory>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/synchronization/mutex.h"
22 #include "pybind11/functional.h"
23 #include "pybind11/pybind11.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
26 #include "tensorflow/compiler/xla/python/outfeed_receiver.h"
27 #include "tensorflow/compiler/xla/python/py_client.h"
28 #include "tensorflow/compiler/xla/python/types.h"
29
30 namespace xla {
31
32 namespace py = pybind11;
33
34 namespace {
35
36 // A wrapper for OutfeedReceiver for use from Python, useful for ensuring
37 // that the GIL is released before destroying the OutfeedReceiver.
38 class OutfeedReceiverForPython {
39 public:
40 // A callback to Python takes: consumer id, received literal.
41 using CallbackToPython =
42 std::function<void(ClientAndPtr<PjRtDevice>, uint32_t, pybind11::object)>;
43
OutfeedReceiverForPython(CallbackToPython callback_python,std::vector<std::shared_ptr<PyClient>> clients,ssize_t max_callback_queue_size_bytes)44 OutfeedReceiverForPython(CallbackToPython callback_python,
45 std::vector<std::shared_ptr<PyClient>> clients,
46 ssize_t max_callback_queue_size_bytes)
47 : callback_python_(std::move(callback_python)),
48 clients_(std::move(clients)) {
49 OutfeedReceiver::Callback callback =
50 [this](PjRtDevice* device, uint32_t consumer_id,
51 std::shared_ptr<Literal> literal) {
52 this->Callback(device, consumer_id, std::move(literal));
53 };
54 std::vector<PjRtClient*> client_ptrs(clients_.size());
55 absl::c_transform(clients_, client_ptrs.begin(),
56 [](const std::shared_ptr<PyClient>& client) {
57 return client->pjrt_client();
58 });
59 outfeed_receiver_ = std::make_unique<OutfeedReceiver>(
60 callback, client_ptrs, max_callback_queue_size_bytes);
61 }
62 OutfeedReceiverForPython(const OutfeedReceiverForPython&) = delete;
63 OutfeedReceiverForPython& operator=(const OutfeedReceiverForPython&) = delete;
64
~OutfeedReceiverForPython()65 ~OutfeedReceiverForPython() {
66 // This destructor is called from the Python GC. Release it for the duration
67 // of the destruction, including the destruction of the OutfeedReceiver,
68 // when we may actually have to wait for threads to end. During this time
69 // we do not callback to Python (sometimes we get an exception
70 // "std::runtime_error: scoped_acquire::dec_ref(): thread state must
71 // be current!"").
72 {
73 absl::MutexLock lock(&mu_);
74 outfeed_receiver_shutting_down_ = true;
75 }
76 py::gil_scoped_release gil_release;
77 outfeed_receiver_ = nullptr; // Shutdown the outfeed receiver.
78 }
79
Start()80 void Start() { outfeed_receiver_->Start(); }
81
AddOutfeed(XlaBuilder * builder,XlaOp token,uint32_t consumer_id,std::vector<XlaOp> arrays)82 StatusOr<XlaOp> AddOutfeed(XlaBuilder* builder, XlaOp token,
83 uint32_t consumer_id, std::vector<XlaOp> arrays) {
84 return outfeed_receiver_->AddOutfeedToBuilder(builder, token, consumer_id,
85 arrays);
86 }
87
Callback(PjRtDevice * device,uint32_t consumer_id,std::shared_ptr<Literal> literal)88 void Callback(PjRtDevice* device, uint32_t consumer_id,
89 std::shared_ptr<Literal> literal) {
90 {
91 absl::MutexLock lock(&mu_);
92 if (outfeed_receiver_shutting_down_) {
93 VLOG(2) << "Ignoring unsafe callback to Python during shutdown";
94 return;
95 }
96 }
97 // We expect the number of clients to be small, so an O(n) search is fine.
98 auto it = absl::c_find_if(
99 clients_, [device](const std::shared_ptr<PyClient>& client) {
100 return client->pjrt_client() == device->client();
101 });
102 CHECK(it != clients_.end());
103 py::gil_scoped_acquire gil_acquire; // Need GIL also for LiteralToPython
104 py::object literal_python =
105 LiteralToPython(std::move(literal)).ValueOrDie();
106 // The callback_ should handle all exceptions in user-code. If we get
107 // an exception here, it is a bug in the callback and we should stop.
108 callback_python_(WrapWithClient<PjRtDevice>(*it, device), consumer_id,
109 std::move(literal_python));
110 }
111
112 private:
113 CallbackToPython callback_python_;
114 absl::Mutex mu_;
115 bool outfeed_receiver_shutting_down_ ABSL_GUARDED_BY(mu_) = false;
116 std::vector<std::shared_ptr<PyClient>> clients_;
117 std::unique_ptr<OutfeedReceiver> outfeed_receiver_;
118 };
119
120 } // namespace
121
BuildOutfeedReceiverSubmodule(py::module * m)122 void BuildOutfeedReceiverSubmodule(py::module* m) {
123 py::module outfeed_receiver =
124 m->def_submodule("outfeed_receiver", "Outfeed receiver");
125 outfeed_receiver.def(
126 "start",
127 [](OutfeedReceiverForPython::CallbackToPython callback_to_python,
128 std::vector<std::shared_ptr<PyClient>> clients,
129 ssize_t max_callback_queue_size_bytes)
130 -> std::unique_ptr<OutfeedReceiverForPython> {
131 auto server = std::make_unique<OutfeedReceiverForPython>(
132 callback_to_python, clients, max_callback_queue_size_bytes);
133 server->Start();
134 return server;
135 },
136 py::arg("callback_to_python"), py::arg("backends"),
137 py::arg("max_queue_size_bytes") = 256 * 1024 * 1024,
138 R"(Starts a multithreaded outfeed receiver.
139
140 There is one thread for each of the specified devices. When Python
141 drops the last reference to the returned object, the receiver is shut
142 down. The destructor will block until all data is received from
143 devices.
144
145 Args:
146 * callback_to_python: a Python callback to call, with <consumer_id>
147 and the data received.
148 * backends: the list of backends to listen on.
149 * max_queue_size_bytes: an optional integer to bound the maximum size
150 of arrays in the callback queue. When this limit is reached the
151 device listener pauses.
152 )",
153 py::call_guard<py::gil_scoped_release>());
154
155 py::class_<OutfeedReceiverForPython> outfeed_receiver_class(
156 outfeed_receiver, "OutfeedReceiverForPython");
157
158 outfeed_receiver_class.def(
159 "add_outfeed", &OutfeedReceiverForPython::AddOutfeed, py::arg("builder"),
160 py::arg("token"), py::arg("consumer_id"), py::arg("arrays"),
161 R"(Adds an outfeed into the given computation builder.
162
163 Has the side-effect of registering the sent shape along with the consumer
164 ID. Returns error if the outfeed shape is not compatible with previously
165 used shape for the same consumer ID.)",
166 py::call_guard<py::gil_scoped_release>());
167 }
168
169 } // namespace xla
170