1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/py_executable.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "absl/algorithm/container.h"
22 #include "tensorflow/compiler/xla/pjrt/host_callback.h"
23 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
24 #include "tensorflow/core/platform/fingerprint.h"
25
26 namespace xla {
27
28 namespace py = pybind11;
29
Await()30 Status PyToken::Await() {
31 CHECK(future_.IsValid());
32 py::gil_scoped_release gil_release;
33 return future_.Await();
34 }
35
Await()36 Status PyShardedToken::Await() {
37 py::gil_scoped_release gil_release;
38 Status status = OkStatus();
39 for (auto& future : futures_) {
40 auto s = future.Await();
41 if (!s.ok()) status = std::move(s);
42 }
43 return status;
44 }
45
PyExecutable(std::shared_ptr<PyClient> client,std::unique_ptr<PjRtLoadedExecutable> executable,std::shared_ptr<Traceback> traceback,std::optional<std::string> fingerprint,std::vector<pybind11::capsule> host_callbacks)46 PyExecutable::PyExecutable(std::shared_ptr<PyClient> client,
47 std::unique_ptr<PjRtLoadedExecutable> executable,
48 std::shared_ptr<Traceback> traceback,
49 std::optional<std::string> fingerprint,
50 std::vector<pybind11::capsule> host_callbacks)
51 : client_(std::move(client)),
52 executable_(std::move(executable)),
53 traceback_(std::move(traceback)),
54 fingerprint_(std::move(fingerprint)),
55 host_callbacks_(std::move(host_callbacks)) {
56 CHECK(PyGILState_Check());
57 next_ = client_->executables_;
58 client_->executables_ = this;
59 prev_ = nullptr;
60 if (next_) {
61 next_->prev_ = this;
62 }
63 options_.untuple_result = true;
64 if (fingerprint_) {
65 options_.launch_id = tensorflow::Fingerprint32(*fingerprint_);
66 VLOG(1) << "Fingerprint for executable " << executable_->name() << ": "
67 << *fingerprint_;
68 }
69 }
70
~PyExecutable()71 PyExecutable::~PyExecutable() {
72 CHECK(PyGILState_Check());
73 if (client_->executables_ == this) {
74 client_->executables_ = next_;
75 }
76 if (prev_) {
77 prev_->next_ = next_;
78 }
79 if (next_) {
80 next_->prev_ = prev_;
81 }
82 }
83
AddressableDevices() const84 std::vector<ClientAndPtr<PjRtDevice>> PyExecutable::AddressableDevices() const {
85 std::vector<ClientAndPtr<PjRtDevice>> devices;
86 devices.reserve(executable_->addressable_devices().size());
87 for (PjRtDevice* device : executable_->addressable_devices()) {
88 devices.push_back(WrapWithClient(client_, device));
89 }
90 return devices;
91 }
92
93 StatusOr<std::pair<std::vector<PyBuffer::object>, PyToken>>
ExecuteInternal(absl::Span<PyBuffer::object const> args,PjRtDevice * device,std::optional<std::vector<PjRtFuture<Status>>> & returned_futures)94 PyExecutable::ExecuteInternal(
95 absl::Span<PyBuffer::object const> args, PjRtDevice* device,
96 std::optional<std::vector<PjRtFuture<Status>>>& returned_futures) {
97 std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers;
98 {
99 auto options = options_;
100 std::shared_ptr<HostCallbackStates> host_callback_states;
101
102 if (!host_callbacks_.empty()) {
103 auto* host_memory_for_device_manager =
104 client()->pjrt_client()->GetPjRtHostMemoryForDeviceManager();
105 if (host_memory_for_device_manager == nullptr) {
106 return InternalError("Host callback not supported for runtime type: %s",
107 client()->runtime_type());
108 }
109
110 returned_futures.emplace();
111
112 host_callback_states = std::make_shared<HostCallbackStates>();
113 auto& contexts = host_callback_states->contexts.emplace_back();
114 auto& send_callbacks =
115 host_callback_states->send_callbacks.emplace_back();
116 auto& recv_callbacks =
117 host_callback_states->recv_callbacks.emplace_back();
118
119 for (const py::capsule& host_callback : host_callbacks_) {
120 contexts.push_back(CreateHostCallbackStateAndAppendSendRecvCallbacks(
121 *host_callback.get_pointer<HostCallback>(),
122 host_memory_for_device_manager, send_callbacks, recv_callbacks));
123 }
124 options.send_callbacks = host_callback_states->send_callbacks;
125 options.recv_callbacks = host_callback_states->recv_callbacks;
126 }
127
128 py::gil_scoped_release gil_release;
129 std::vector<PjRtBuffer*> arg_buffers(args.size());
130 absl::c_transform(
131 args, arg_buffers.begin(),
132 [](const PyBuffer::object& buf) { return buf.buf()->buffer(); });
133 if (device) {
134 std::optional<PjRtFuture<Status>> future;
135 output_buffers.resize(1);
136 TF_ASSIGN_OR_RETURN(
137 output_buffers[0],
138 executable_->ExecutePortable(arg_buffers, device, options, future,
139 returned_futures.has_value()));
140 if (future) {
141 returned_futures->emplace_back(std::move(*future));
142 }
143 } else {
144 TF_ASSIGN_OR_RETURN(
145 output_buffers,
146 executable_->Execute({arg_buffers}, options, returned_futures));
147 }
148
149 if (!host_callbacks_.empty()) {
150 // For host callbacks to work, `returned_futures` must not be nullopt.
151 returned_futures->at(0).OnReady([host_callback_states](Status) mutable {
152 host_callback_states.reset();
153 });
154 }
155 }
156 auto traceback = Traceback::Get();
157 std::vector<PyBuffer::object> outputs;
158 outputs.reserve(output_buffers[0].size());
159 for (auto& buffer : output_buffers[0]) {
160 outputs.push_back(PyBuffer::Make(client_, std::move(buffer), traceback));
161 }
162
163 // TODO(b/240696624): Although the PjRt interface require `returned_futures`
164 // to be resized correctly if it is not nullopt, some implementation does not
165 // implement this. So we have to check whether returned_futures is empty.
166 // Remove this check once the implementation is fixed.
167 if (!returned_futures.has_value()) {
168 return std::pair<std::vector<PyBuffer::object>, PyToken>(
169 std::move(outputs), PyToken::ReadyPyToken());
170 }
171 return std::pair<std::vector<PyBuffer::object>, PyToken>(
172 std::move(outputs), PyToken(std::move(returned_futures->at(0))));
173 }
174
175 StatusOr<std::pair<std::vector<PyBuffer::object>, PyToken>>
ExecuteWithToken(absl::Span<PyBuffer::object const> args,PjRtDevice * device)176 PyExecutable::ExecuteWithToken(absl::Span<PyBuffer::object const> args,
177 PjRtDevice* device) {
178 std::optional<std::vector<PjRtFuture<Status>>> returned_futures;
179 if (executable_->IsReturnedFutureSupported()) returned_futures.emplace();
180 return ExecuteInternal(args, device, returned_futures);
181 }
182
Execute(absl::Span<PyBuffer::object const> args,PjRtDevice * device)183 StatusOr<std::vector<PyBuffer::object>> PyExecutable::Execute(
184 absl::Span<PyBuffer::object const> args, PjRtDevice* device) {
185 std::optional<std::vector<PjRtFuture<Status>>> returned_futures;
186 TF_ASSIGN_OR_RETURN(auto outputs_and_token,
187 ExecuteInternal(args, device, returned_futures));
188 return std::move(outputs_and_token.first);
189 }
190
191 StatusOr<std::pair<std::vector<std::vector<PyBuffer::object>>, PyShardedToken>>
ExecuteShardedOnLocalDevicesInternal(absl::Span<const std::vector<PyBuffer::object>> args,std::optional<std::vector<PjRtFuture<Status>>> & returned_futures)192 PyExecutable::ExecuteShardedOnLocalDevicesInternal(
193 absl::Span<const std::vector<PyBuffer::object>> args,
194 std::optional<std::vector<PjRtFuture<Status>>>& returned_futures) {
195 std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers;
196 int num_computations = executable_->addressable_devices().size();
197 {
198 auto options = options_;
199 std::shared_ptr<HostCallbackStates> host_callback_states;
200 if (!host_callbacks_.empty()) {
201 auto* host_memory_for_device_manager =
202 client()->pjrt_client()->GetPjRtHostMemoryForDeviceManager();
203 if (host_memory_for_device_manager == nullptr) {
204 return InternalError("Host callback not supported for runtime type: %s",
205 client()->runtime_type());
206 }
207 returned_futures.emplace();
208
209 host_callback_states = std::make_shared<HostCallbackStates>();
210
211 for (int i = 0; i < num_computations; ++i) {
212 auto& contexts = host_callback_states->contexts.emplace_back();
213 auto& send_callbacks =
214 host_callback_states->send_callbacks.emplace_back();
215 auto& recv_callbacks =
216 host_callback_states->recv_callbacks.emplace_back();
217
218 for (const py::capsule& host_callback : host_callbacks_) {
219 contexts.push_back(CreateHostCallbackStateAndAppendSendRecvCallbacks(
220 *host_callback.get_pointer<HostCallback>(),
221 host_memory_for_device_manager, send_callbacks, recv_callbacks));
222 }
223 }
224 options.send_callbacks = host_callback_states->send_callbacks;
225 options.recv_callbacks = host_callback_states->recv_callbacks;
226 }
227
228 py::gil_scoped_release gil_release;
229 for (const auto& arg : args) {
230 if (arg.size() != num_computations) {
231 return xla::InvalidArgument(
232 "Expected args to execute_sharded_on_local_devices to have %d "
233 "shards, got: [%s]",
234 num_computations,
235 absl::StrJoin(
236 args, ", ",
237 [](std::string* out, const std::vector<PyBuffer::object>& arg) {
238 out->append(std::to_string(arg.size()));
239 }));
240 }
241 }
242 std::vector<std::vector<PjRtBuffer*>> arg_buffers(num_computations);
243 const int num_args = args.size();
244 for (int computation = 0; computation < num_computations; ++computation) {
245 arg_buffers[computation].resize(num_args);
246 absl::c_transform(args, arg_buffers[computation].begin(),
247 [&](const std::vector<PyBuffer::object>& arg) {
248 return arg[computation].buf()->buffer();
249 });
250 }
251 TF_ASSIGN_OR_RETURN(
252 output_buffers,
253 executable_->Execute(arg_buffers, options, returned_futures));
254
255 if (!host_callbacks_.empty()) {
256 // For host callbacks to work, `returned_futures` must not be nullopt.
257 for (int i = 0; i < num_computations; ++i) {
258 returned_futures.value().at(i).OnReady(
259 [host_callback_states](Status) mutable {
260 host_callback_states.reset();
261 });
262 }
263 }
264 }
265 auto traceback = Traceback::Get();
266 int num_output_buffers = output_buffers[0].size();
267 std::vector<std::vector<PyBuffer::object>> outputs;
268 outputs.resize(num_output_buffers);
269 for (int buffer_id = 0; buffer_id < num_output_buffers; ++buffer_id) {
270 outputs[buffer_id].reserve(num_computations);
271 for (int computation = 0; computation < num_computations; ++computation) {
272 outputs[buffer_id].push_back(PyBuffer::Make(
273 client_, std::move(output_buffers[computation][buffer_id]),
274 traceback));
275 }
276 }
277
278 // TODO(b/240696624): Although the PjRt interface require `returned_futures`
279 // to be resized correctly if it is not nullopt, some implementation does not
280 // implement this. So we have to check whether returned_futures is empty.
281 // Remove this check once the implementation is fixed.
282 if (!returned_futures.has_value()) {
283 return std::pair<std::vector<std::vector<PyBuffer::object>>,
284 PyShardedToken>(std::move(outputs), PyShardedToken());
285 }
286
287 PyShardedToken py_sharded_token(std::move(*returned_futures));
288
289 return std::pair<std::vector<std::vector<PyBuffer::object>>, PyShardedToken>(
290 std::move(outputs), std::move(py_sharded_token));
291 }
292
293 StatusOr<std::vector<std::vector<PyBuffer::object>>>
ExecuteShardedOnLocalDevices(absl::Span<const std::vector<PyBuffer::object>> args)294 PyExecutable::ExecuteShardedOnLocalDevices(
295 absl::Span<const std::vector<PyBuffer::object>> args) {
296 std::optional<std::vector<PjRtFuture<Status>>> returned_futures;
297 TF_ASSIGN_OR_RETURN(
298 auto outputs_and_tokens,
299 ExecuteShardedOnLocalDevicesInternal(args, returned_futures));
300 return std::move(outputs_and_tokens.first);
301 }
302
303 StatusOr<std::pair<std::vector<std::vector<PyBuffer::object>>, PyShardedToken>>
ExecuteShardedOnLocalDevicesWithTokens(absl::Span<const std::vector<PyBuffer::object>> args)304 PyExecutable::ExecuteShardedOnLocalDevicesWithTokens(
305 absl::Span<const std::vector<PyBuffer::object>> args) {
306 std::optional<std::vector<PjRtFuture<Status>>> returned_futures;
307 if (executable_->IsReturnedFutureSupported()) returned_futures.emplace();
308 return ExecuteShardedOnLocalDevicesInternal(args, returned_futures);
309 }
310
HloModules() const311 StatusOr<std::vector<std::shared_ptr<HloModule>>> PyExecutable::HloModules()
312 const {
313 return executable_->GetHloModules();
314 }
315
KeepAlive(py::object obj)316 void PyExecutable::KeepAlive(py::object obj) {
317 keepalives_.push_back(std::move(obj));
318 }
319
320 } // namespace xla
321