xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/py_executable.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PY_EXECUTABLE_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_PY_EXECUTABLE_H_
18 
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
26 #include "tensorflow/compiler/xla/python/py_buffer.h"
27 #include "tensorflow/compiler/xla/python/py_client.h"
28 #include "tensorflow/compiler/xla/python/traceback.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/compiler/xla/types.h"
31 
32 namespace xla {
33 
34 class PyToken {
35  public:
36   PyToken() = default;
PyToken(PjRtFuture<Status> future)37   explicit PyToken(PjRtFuture<Status> future) : future_(std::move(future)) {}
38 
ReadyPyToken()39   static PyToken ReadyPyToken() {
40     return PyToken(PjRtFuture<Status>(OkStatus()));
41   }
42 
43   Status Await();
44 
45  private:
46   PjRtFuture<Status> future_;
47 };
48 
49 // PyShardedToken contains a PyToken for each device's execution.
50 class PyShardedToken {
51  public:
52   // Default construction creates a always-ready token.
53   PyShardedToken() = default;
PyShardedToken(std::vector<PjRtFuture<Status>> futures)54   explicit PyShardedToken(std::vector<PjRtFuture<Status>> futures)
55       : futures_(std::move(futures)) {}
56 
GetPyToken(int device_id)57   PyToken GetPyToken(int device_id) const {
58     if (futures_.empty()) return PyToken::ReadyPyToken();
59     return PyToken(futures_.at(device_id));
60   }
61 
62   Status Await();
63 
64  private:
65   std::vector<PjRtFuture<Status>> futures_;
66 };
67 
68 // Python wrapper around PjRtExecutable. We use a wrapper class:
69 // a) to keep the PyClient alive via a std::shared_ptr<>
70 // b) to add Python-specific functionality.
71 class PyExecutable : public std::enable_shared_from_this<PyExecutable> {
72  public:
73   PyExecutable(std::shared_ptr<PyClient> client,
74                std::unique_ptr<PjRtLoadedExecutable> executable,
75                std::shared_ptr<Traceback> traceback,
76                std::optional<std::string> fingerprint,
77                std::vector<pybind11::capsule> host_callbacks);
78   ~PyExecutable();
79 
client()80   std::shared_ptr<PyClient> client() const { return client_; }
executable()81   std::shared_ptr<PjRtLoadedExecutable> executable() const {
82     return executable_;
83   }
84 
85   absl::Span<const PjRtLoadedExecutable::LogicalDeviceIds>
addressable_device_logical_ids()86   addressable_device_logical_ids() const {
87     return executable_->addressable_device_logical_ids();
88   }
89 
90   std::vector<ClientAndPtr<PjRtDevice>> AddressableDevices() const;
91 
SizeOfGeneratedCodeInBytes()92   int64_t SizeOfGeneratedCodeInBytes() const {
93     return executable_->SizeOfGeneratedCodeInBytes();
94   }
95 
GetCompiledMemoryStats()96   StatusOr<CompiledMemoryStats> GetCompiledMemoryStats() const {
97     return executable_->GetCompiledMemoryStats();
98   }
99 
Delete()100   void Delete() { return executable_->Delete(); }
101 
is_deleted()102   bool is_deleted() { return executable_->IsDeleted(); }
103 
104   StatusOr<std::vector<PyBuffer::object>> Execute(
105       absl::Span<PyBuffer::object const> args, PjRtDevice* device);
106 
107   StatusOr<std::pair<std::vector<PyBuffer::object>, PyToken>> ExecuteWithToken(
108       absl::Span<PyBuffer::object const> args, PjRtDevice* device);
109 
110   // Takes args indexed by argid then deviceid, transposes them, and passes to
111   // PjRtExecutable::Execute. The result is similarly transposed back into the
112   // argid,deviceid format.
113   // args is [num_args x num_devices].
114   StatusOr<std::vector<std::vector<PyBuffer::object>>>
115   ExecuteShardedOnLocalDevices(
116       absl::Span<const std::vector<PyBuffer::object>> args);
117 
118   StatusOr<
119       std::pair<std::vector<std::vector<PyBuffer::object>>, PyShardedToken>>
120   ExecuteShardedOnLocalDevicesWithTokens(
121       absl::Span<const std::vector<PyBuffer::object>> args);
122 
123   StatusOr<std::vector<std::shared_ptr<HloModule>>> HloModules() const;
124 
traceback()125   Traceback* traceback() { return traceback_.get(); }
126 
pjrt_executable()127   const PjRtLoadedExecutable& pjrt_executable() const { return *executable_; }
128 
mutable_pjrt_executable()129   PjRtLoadedExecutable* mutable_pjrt_executable() const {
130     return executable_.get();
131   }
options()132   const ExecuteOptions& options() const { return options_; }
fingerprint()133   const std::optional<std::string>& fingerprint() const { return fingerprint_; }
134 
135   // Keep `obj` alive as long as PyExecutable.
136   void KeepAlive(pybind11::object obj);
137 
138  private:
139   StatusOr<std::pair<std::vector<PyBuffer::object>, PyToken>> ExecuteInternal(
140       absl::Span<PyBuffer::object const> args, PjRtDevice* device,
141       std::optional<std::vector<PjRtFuture<Status>>>& returned_futures);
142   StatusOr<
143       std::pair<std::vector<std::vector<PyBuffer::object>>, PyShardedToken>>
144   ExecuteShardedOnLocalDevicesInternal(
145       absl::Span<const std::vector<PyBuffer::object>> args,
146       std::optional<std::vector<PjRtFuture<Status>>>& returned_futures);
147 
148   friend class PyClient;
149 
150   std::shared_ptr<PyClient> client_;
151   std::shared_ptr<PjRtLoadedExecutable> executable_;
152   std::shared_ptr<Traceback> traceback_;
153 
154   // Identical executables (i.e. representing the same program) will have the
155   // same fingerprint. nullopt on platforms or executables where fingerprints
156   // aren't implemented.
157   std::optional<std::string> fingerprint_;
158 
159   // The python callbacks implemented using send/recv support.
160   std::vector<pybind11::capsule> host_callbacks_;
161 
162   // The options to pass to `executable_.Execute`.
163   ExecuteOptions options_;
164 
165   // Python objects to keep alive as requested by user.
166   std::vector<pybind11::object> keepalives_;
167 
168   // Doubly-linked list of all executables known to the client. Protected by the
169   // GIL.
170   PyExecutable* next_;
171   PyExecutable* prev_;
172 };
173 
174 }  // namespace xla
175 
176 #endif  // TENSORFLOW_COMPILER_XLA_PYTHON_PY_EXECUTABLE_H_
177