xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/py_executable.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/py_executable.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "absl/algorithm/container.h"
22 #include "tensorflow/compiler/xla/pjrt/host_callback.h"
23 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
24 #include "tensorflow/core/platform/fingerprint.h"
25 
26 namespace xla {
27 
28 namespace py = pybind11;
29 
Await()30 Status PyToken::Await() {
31   CHECK(future_.IsValid());
32   py::gil_scoped_release gil_release;
33   return future_.Await();
34 }
35 
Await()36 Status PyShardedToken::Await() {
37   py::gil_scoped_release gil_release;
38   Status status = OkStatus();
39   for (auto& future : futures_) {
40     auto s = future.Await();
41     if (!s.ok()) status = std::move(s);
42   }
43   return status;
44 }
45 
PyExecutable(std::shared_ptr<PyClient> client,std::unique_ptr<PjRtLoadedExecutable> executable,std::shared_ptr<Traceback> traceback,std::optional<std::string> fingerprint,std::vector<pybind11::capsule> host_callbacks)46 PyExecutable::PyExecutable(std::shared_ptr<PyClient> client,
47                            std::unique_ptr<PjRtLoadedExecutable> executable,
48                            std::shared_ptr<Traceback> traceback,
49                            std::optional<std::string> fingerprint,
50                            std::vector<pybind11::capsule> host_callbacks)
51     : client_(std::move(client)),
52       executable_(std::move(executable)),
53       traceback_(std::move(traceback)),
54       fingerprint_(std::move(fingerprint)),
55       host_callbacks_(std::move(host_callbacks)) {
56   CHECK(PyGILState_Check());
57   next_ = client_->executables_;
58   client_->executables_ = this;
59   prev_ = nullptr;
60   if (next_) {
61     next_->prev_ = this;
62   }
63   options_.untuple_result = true;
64   if (fingerprint_) {
65     options_.launch_id = tensorflow::Fingerprint32(*fingerprint_);
66     VLOG(1) << "Fingerprint for executable " << executable_->name() << ": "
67             << *fingerprint_;
68   }
69 }
70 
~PyExecutable()71 PyExecutable::~PyExecutable() {
72   CHECK(PyGILState_Check());
73   if (client_->executables_ == this) {
74     client_->executables_ = next_;
75   }
76   if (prev_) {
77     prev_->next_ = next_;
78   }
79   if (next_) {
80     next_->prev_ = prev_;
81   }
82 }
83 
AddressableDevices() const84 std::vector<ClientAndPtr<PjRtDevice>> PyExecutable::AddressableDevices() const {
85   std::vector<ClientAndPtr<PjRtDevice>> devices;
86   devices.reserve(executable_->addressable_devices().size());
87   for (PjRtDevice* device : executable_->addressable_devices()) {
88     devices.push_back(WrapWithClient(client_, device));
89   }
90   return devices;
91 }
92 
93 StatusOr<std::pair<std::vector<PyBuffer::object>, PyToken>>
ExecuteInternal(absl::Span<PyBuffer::object const> args,PjRtDevice * device,std::optional<std::vector<PjRtFuture<Status>>> & returned_futures)94 PyExecutable::ExecuteInternal(
95     absl::Span<PyBuffer::object const> args, PjRtDevice* device,
96     std::optional<std::vector<PjRtFuture<Status>>>& returned_futures) {
97   std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers;
98   {
99     auto options = options_;
100     std::shared_ptr<HostCallbackStates> host_callback_states;
101 
102     if (!host_callbacks_.empty()) {
103       auto* host_memory_for_device_manager =
104           client()->pjrt_client()->GetPjRtHostMemoryForDeviceManager();
105       if (host_memory_for_device_manager == nullptr) {
106         return InternalError("Host callback not supported for runtime type: %s",
107                              client()->runtime_type());
108       }
109 
110       returned_futures.emplace();
111 
112       host_callback_states = std::make_shared<HostCallbackStates>();
113       auto& contexts = host_callback_states->contexts.emplace_back();
114       auto& send_callbacks =
115           host_callback_states->send_callbacks.emplace_back();
116       auto& recv_callbacks =
117           host_callback_states->recv_callbacks.emplace_back();
118 
119       for (const py::capsule& host_callback : host_callbacks_) {
120         contexts.push_back(CreateHostCallbackStateAndAppendSendRecvCallbacks(
121             *host_callback.get_pointer<HostCallback>(),
122             host_memory_for_device_manager, send_callbacks, recv_callbacks));
123       }
124       options.send_callbacks = host_callback_states->send_callbacks;
125       options.recv_callbacks = host_callback_states->recv_callbacks;
126     }
127 
128     py::gil_scoped_release gil_release;
129     std::vector<PjRtBuffer*> arg_buffers(args.size());
130     absl::c_transform(
131         args, arg_buffers.begin(),
132         [](const PyBuffer::object& buf) { return buf.buf()->buffer(); });
133     if (device) {
134       std::optional<PjRtFuture<Status>> future;
135       output_buffers.resize(1);
136       TF_ASSIGN_OR_RETURN(
137           output_buffers[0],
138           executable_->ExecutePortable(arg_buffers, device, options, future,
139                                        returned_futures.has_value()));
140       if (future) {
141         returned_futures->emplace_back(std::move(*future));
142       }
143     } else {
144       TF_ASSIGN_OR_RETURN(
145           output_buffers,
146           executable_->Execute({arg_buffers}, options, returned_futures));
147     }
148 
149     if (!host_callbacks_.empty()) {
150       // For host callbacks to work, `returned_futures` must not be nullopt.
151       returned_futures->at(0).OnReady([host_callback_states](Status) mutable {
152         host_callback_states.reset();
153       });
154     }
155   }
156   auto traceback = Traceback::Get();
157   std::vector<PyBuffer::object> outputs;
158   outputs.reserve(output_buffers[0].size());
159   for (auto& buffer : output_buffers[0]) {
160     outputs.push_back(PyBuffer::Make(client_, std::move(buffer), traceback));
161   }
162 
163   // TODO(b/240696624): Although the PjRt interface require `returned_futures`
164   // to be resized correctly if it is not nullopt, some implementation does not
165   // implement this. So we have to check whether returned_futures is empty.
166   // Remove this check once the implementation is fixed.
167   if (!returned_futures.has_value()) {
168     return std::pair<std::vector<PyBuffer::object>, PyToken>(
169         std::move(outputs), PyToken::ReadyPyToken());
170   }
171   return std::pair<std::vector<PyBuffer::object>, PyToken>(
172       std::move(outputs), PyToken(std::move(returned_futures->at(0))));
173 }
174 
175 StatusOr<std::pair<std::vector<PyBuffer::object>, PyToken>>
ExecuteWithToken(absl::Span<PyBuffer::object const> args,PjRtDevice * device)176 PyExecutable::ExecuteWithToken(absl::Span<PyBuffer::object const> args,
177                                PjRtDevice* device) {
178   std::optional<std::vector<PjRtFuture<Status>>> returned_futures;
179   if (executable_->IsReturnedFutureSupported()) returned_futures.emplace();
180   return ExecuteInternal(args, device, returned_futures);
181 }
182 
Execute(absl::Span<PyBuffer::object const> args,PjRtDevice * device)183 StatusOr<std::vector<PyBuffer::object>> PyExecutable::Execute(
184     absl::Span<PyBuffer::object const> args, PjRtDevice* device) {
185   std::optional<std::vector<PjRtFuture<Status>>> returned_futures;
186   TF_ASSIGN_OR_RETURN(auto outputs_and_token,
187                       ExecuteInternal(args, device, returned_futures));
188   return std::move(outputs_and_token.first);
189 }
190 
191 StatusOr<std::pair<std::vector<std::vector<PyBuffer::object>>, PyShardedToken>>
ExecuteShardedOnLocalDevicesInternal(absl::Span<const std::vector<PyBuffer::object>> args,std::optional<std::vector<PjRtFuture<Status>>> & returned_futures)192 PyExecutable::ExecuteShardedOnLocalDevicesInternal(
193     absl::Span<const std::vector<PyBuffer::object>> args,
194     std::optional<std::vector<PjRtFuture<Status>>>& returned_futures) {
195   std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers;
196   int num_computations = executable_->addressable_devices().size();
197   {
198     auto options = options_;
199     std::shared_ptr<HostCallbackStates> host_callback_states;
200     if (!host_callbacks_.empty()) {
201       auto* host_memory_for_device_manager =
202           client()->pjrt_client()->GetPjRtHostMemoryForDeviceManager();
203       if (host_memory_for_device_manager == nullptr) {
204         return InternalError("Host callback not supported for runtime type: %s",
205                              client()->runtime_type());
206       }
207       returned_futures.emplace();
208 
209       host_callback_states = std::make_shared<HostCallbackStates>();
210 
211       for (int i = 0; i < num_computations; ++i) {
212         auto& contexts = host_callback_states->contexts.emplace_back();
213         auto& send_callbacks =
214             host_callback_states->send_callbacks.emplace_back();
215         auto& recv_callbacks =
216             host_callback_states->recv_callbacks.emplace_back();
217 
218         for (const py::capsule& host_callback : host_callbacks_) {
219           contexts.push_back(CreateHostCallbackStateAndAppendSendRecvCallbacks(
220               *host_callback.get_pointer<HostCallback>(),
221               host_memory_for_device_manager, send_callbacks, recv_callbacks));
222         }
223       }
224       options.send_callbacks = host_callback_states->send_callbacks;
225       options.recv_callbacks = host_callback_states->recv_callbacks;
226     }
227 
228     py::gil_scoped_release gil_release;
229     for (const auto& arg : args) {
230       if (arg.size() != num_computations) {
231         return xla::InvalidArgument(
232             "Expected args to execute_sharded_on_local_devices to have %d "
233             "shards, got: [%s]",
234             num_computations,
235             absl::StrJoin(
236                 args, ", ",
237                 [](std::string* out, const std::vector<PyBuffer::object>& arg) {
238                   out->append(std::to_string(arg.size()));
239                 }));
240       }
241     }
242     std::vector<std::vector<PjRtBuffer*>> arg_buffers(num_computations);
243     const int num_args = args.size();
244     for (int computation = 0; computation < num_computations; ++computation) {
245       arg_buffers[computation].resize(num_args);
246       absl::c_transform(args, arg_buffers[computation].begin(),
247                         [&](const std::vector<PyBuffer::object>& arg) {
248                           return arg[computation].buf()->buffer();
249                         });
250     }
251     TF_ASSIGN_OR_RETURN(
252         output_buffers,
253         executable_->Execute(arg_buffers, options, returned_futures));
254 
255     if (!host_callbacks_.empty()) {
256       // For host callbacks to work, `returned_futures` must not be nullopt.
257       for (int i = 0; i < num_computations; ++i) {
258         returned_futures.value().at(i).OnReady(
259             [host_callback_states](Status) mutable {
260               host_callback_states.reset();
261             });
262       }
263     }
264   }
265   auto traceback = Traceback::Get();
266   int num_output_buffers = output_buffers[0].size();
267   std::vector<std::vector<PyBuffer::object>> outputs;
268   outputs.resize(num_output_buffers);
269   for (int buffer_id = 0; buffer_id < num_output_buffers; ++buffer_id) {
270     outputs[buffer_id].reserve(num_computations);
271     for (int computation = 0; computation < num_computations; ++computation) {
272       outputs[buffer_id].push_back(PyBuffer::Make(
273           client_, std::move(output_buffers[computation][buffer_id]),
274           traceback));
275     }
276   }
277 
278   // TODO(b/240696624): Although the PjRt interface require `returned_futures`
279   // to be resized correctly if it is not nullopt, some implementation does not
280   // implement this. So we have to check whether returned_futures is empty.
281   // Remove this check once the implementation is fixed.
282   if (!returned_futures.has_value()) {
283     return std::pair<std::vector<std::vector<PyBuffer::object>>,
284                      PyShardedToken>(std::move(outputs), PyShardedToken());
285   }
286 
287   PyShardedToken py_sharded_token(std::move(*returned_futures));
288 
289   return std::pair<std::vector<std::vector<PyBuffer::object>>, PyShardedToken>(
290       std::move(outputs), std::move(py_sharded_token));
291 }
292 
293 StatusOr<std::vector<std::vector<PyBuffer::object>>>
ExecuteShardedOnLocalDevices(absl::Span<const std::vector<PyBuffer::object>> args)294 PyExecutable::ExecuteShardedOnLocalDevices(
295     absl::Span<const std::vector<PyBuffer::object>> args) {
296   std::optional<std::vector<PjRtFuture<Status>>> returned_futures;
297   TF_ASSIGN_OR_RETURN(
298       auto outputs_and_tokens,
299       ExecuteShardedOnLocalDevicesInternal(args, returned_futures));
300   return std::move(outputs_and_tokens.first);
301 }
302 
303 StatusOr<std::pair<std::vector<std::vector<PyBuffer::object>>, PyShardedToken>>
ExecuteShardedOnLocalDevicesWithTokens(absl::Span<const std::vector<PyBuffer::object>> args)304 PyExecutable::ExecuteShardedOnLocalDevicesWithTokens(
305     absl::Span<const std::vector<PyBuffer::object>> args) {
306   std::optional<std::vector<PjRtFuture<Status>>> returned_futures;
307   if (executable_->IsReturnedFutureSupported()) returned_futures.emplace();
308   return ExecuteShardedOnLocalDevicesInternal(args, returned_futures);
309 }
310 
HloModules() const311 StatusOr<std::vector<std::shared_ptr<HloModule>>> PyExecutable::HloModules()
312     const {
313   return executable_->GetHloModules();
314 }
315 
KeepAlive(py::object obj)316 void PyExecutable::KeepAlive(py::object obj) {
317   keepalives_.push_back(std::move(obj));
318 }
319 
320 }  // namespace xla
321