1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_
18
19 #include <memory>
20 #include <optional>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "pybind11/pybind11.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/compiler/xla/types.h"
30
31 namespace xla {
32
33 class PyBuffer;
34 class PyClient;
35 class PyExecutable;
36
37 // Custom holder types.
38 //
39 // We must keep the PyClient object alive as long as any of the runtime
40 // objects are alive. Since we don't have a lot of control over Python
41 // destructor ordering, we keep the PyClient object as a std::shared_ptr<>,
42 // and ensure that each Python runtime object holds a reference to the
43 // PyClient. An alternative design would be to keep a single global
44 // singleton PyClient, although this seems less flexible, especially for
45 // writing tests.
46 //
47 // To maintain PyClient references, we define pybind11 holder classes that
48 // are custom smart pointers that also keep a reference to a PyClient.
49 // pybind11 has a `keep_alive` feature that has a similar goal, but it doesn't
50 // seem sufficiently flexible to describe ownership relationships in cases where
51 // the ownership doesn't pertain to a direct argument or return value of a
52 // function. Another alternative to the holder classes would be to create proxy
53 // objects that contain both a reference and a runtime class; holder classes
54 // seem less tedious to define.
55
56 // A pair of a PyClient reference and an unowned pointer to T.
57 template <typename T>
58 struct ClientAndPtr {
59 ClientAndPtr() = default;
60 // pybind11 requires that we define a constructor that takes a raw pointer,
61 // but it should be unreachable.
ClientAndPtrClientAndPtr62 explicit ClientAndPtr(T*) {
63 LOG(FATAL) << "ClientAndPtr should constructed via WrapWithClient.";
64 }
65
66 ClientAndPtr(const ClientAndPtr&) = default;
67 ClientAndPtr(ClientAndPtr&&) = default;
68 ClientAndPtr& operator=(const ClientAndPtr&) = default;
69 ClientAndPtr& operator=(ClientAndPtr&&) = default;
70
71 std::shared_ptr<PyClient> client;
72 T* contents;
73
getClientAndPtr74 T* get() const { return contents; }
75 T* operator->() const { return contents; }
76 T& operator*() const { return *contents; }
77 };
78
79 // By defining a templated helper function, we can use return type deduction
80 // and avoid specifying types at the caller.
81 template <typename T>
WrapWithClient(std::shared_ptr<PyClient> client,T * contents)82 ClientAndPtr<T> WrapWithClient(std::shared_ptr<PyClient> client, T* contents) {
83 ClientAndPtr<T> result;
84 result.client = std::move(client);
85 result.contents = contents;
86 return result;
87 }
88
89 // Python wrapper around PjRtClient.
90 // We use a wrapper class to add Python-specific functionality.
91 class PyClient : public std::enable_shared_from_this<PyClient> {
92 public:
93 explicit PyClient(std::unique_ptr<PjRtClient> pjrt_client);
94 explicit PyClient(std::shared_ptr<PjRtClient> pjrt_client);
95 virtual ~PyClient();
96
pjrt_client()97 PjRtClient* pjrt_client() const { return pjrt_client_.get(); }
shared_pjrt_client()98 std::shared_ptr<PjRtClient> shared_pjrt_client() { return pjrt_client_; }
99
platform_name()100 absl::string_view platform_name() const {
101 return pjrt_client_->platform_name();
102 }
platform_version()103 absl::string_view platform_version() const {
104 return pjrt_client_->platform_version();
105 }
runtime_type()106 absl::string_view runtime_type() const {
107 return PjRtRuntimeTypeString(pjrt_client_->runtime_type());
108 }
addressable_device_count()109 int addressable_device_count() const {
110 return pjrt_client_->addressable_device_count();
111 }
device_count()112 int device_count() const { return pjrt_client_->device_count(); }
process_index()113 int process_index() const { return pjrt_client_->process_index(); }
114
115 std::vector<ClientAndPtr<PjRtDevice>> Devices();
116 std::vector<ClientAndPtr<PjRtDevice>> LocalDevices();
117
118 // Returns a vector of live PyBuffer objects. PyBuffer objects may share
119 // PjRtBuffers, so there may be duplicates of the same underlying device
120 // buffer.
121 std::vector<pybind11::object> LiveBuffers();
122 std::vector<pybind11::object> LiveBuffersOnDevice(PjRtDevice* device);
123
124 // Returns a vector of live PyExecutable objects.
125 // note: must return std::shared_ptr instead of raw ptrs
126 // https://pybind11.readthedocs.io/en/stable/advanced/smart_ptrs.html#std-shared-ptr
127 std::vector<std::shared_ptr<PyExecutable>> LiveExecutables();
128
129 // TODO(zhangqiaorjc): Remove when we have transparent defragmentation.
130 Status Defragment();
131
132 StatusOr<std::vector<std::vector<ClientAndPtr<PjRtDevice>>>>
133 GetDefaultDeviceAssignment(int num_replicas, int num_partitions);
134
135 // TODO(skye): delete after all callers can handle 2D output
136 StatusOr<std::vector<ClientAndPtr<PjRtDevice>>> GetDefaultDeviceAssignment1D(
137 int num_replicas);
138
CreateChannelHandle()139 StatusOr<ChannelHandle> CreateChannelHandle() { return ChannelHandle(); }
CreateDeviceToHostChannelHandle()140 StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() {
141 return pjrt_client_->CreateDeviceToHostChannelHandle();
142 }
CreateHostToDeviceChannelHandle()143 StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() {
144 return pjrt_client_->CreateHostToDeviceChannelHandle();
145 }
146
147 StatusOr<std::vector<std::pair<pybind11::bytes, pybind11::object>>>
148 MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes,
149 PjRtDevice* device);
150
151 StatusOr<pybind11::object> BufferFromPyval(
152 pybind11::handle argument, PjRtDevice* device, bool force_copy,
153 PjRtClient::HostBufferSemantics host_buffer_semantics);
154
155 StatusOr<std::shared_ptr<PyExecutable>> Compile(
156 const XlaComputation& computation, CompileOptions options,
157 std::vector<pybind11::capsule> host_callbacks);
158 StatusOr<std::shared_ptr<PyExecutable>> CompileMlir(
159 std::string mlir_module, CompileOptions options,
160 std::vector<pybind11::capsule> host_callbacks);
161
162 StatusOr<pybind11::bytes> SerializeExecutable(
163 const PyExecutable& executable) const;
164 StatusOr<std::shared_ptr<PyExecutable>> DeserializeExecutable(
165 const std::string& serialized, CompileOptions options,
166 std::vector<pybind11::capsule> host_callbacks);
167
168 // TODO(skyewm): remove when jax stop providing hlo_module
DeserializeExecutable(const std::string & serialized,std::shared_ptr<HloModule> hlo_module,CompileOptions options,std::vector<pybind11::capsule> host_callbacks)169 StatusOr<std::shared_ptr<PyExecutable>> DeserializeExecutable(
170 const std::string& serialized, std::shared_ptr<HloModule> hlo_module,
171 CompileOptions options, std::vector<pybind11::capsule> host_callbacks) {
172 return DeserializeExecutable(serialized, options,
173 std::move(host_callbacks));
174 }
175
176 StatusOr<pybind11::bytes> HeapProfile();
177
178 // `GetEmitPythonCallbackDescriptor` takes in an input Python callable that
179 // takes in arguments of shapes `operand_shapes` and returns values of shapes
180 // `result_shapes`. It returns a pair of a `uint64_t` descriptor and a Python
181 // object whose reference will keep the Python callback alive. The descriptor
182 // should be passed into a 'xla_cpu_python_callback' CustomCall as its first
183 // argument. Typically the callback may be kept alive by attaching the
184 // keep-alive object to the executable built from this computation.
185 //
186 // The callable receives as arguments NumPy arrays for arguments with array
187 // types, and None for Token argument. The callable must return a tuple of
188 // either arrays or None values.
189 //
190 // This is a method of PyClient since different platforms may implement this
191 // functionality in different ways.
192 StatusOr<std::pair<uint64_t, pybind11::object>>
193 GetEmitPythonCallbackDescriptor(pybind11::function callable,
194 absl::Span<Shape const> operand_shapes,
195 absl::Span<Shape const> result_shapes);
196 // Deprecated; please switch to emitting an MHLO `CustomCallOp` directly.
197 StatusOr<XlaOp> EmitPythonCallbackFromDescriptor(
198 XlaBuilder& builder, uint64_t descriptor,
199 absl::Span<XlaOp const> operands, absl::Span<Shape const> result_shapes,
200 std::optional<std::vector<Shape>> operand_layouts, bool has_side_effect);
201 // Deprecated; please switch to using `GetEmitPythonCallbackDescriptor`
202 // and then emitting a `CustomCall` op instead.
203 StatusOr<std::pair<XlaOp, pybind11::object>> EmitPythonCallback(
204 pybind11::function callable, XlaBuilder& builder,
205 absl::Span<XlaOp const> operands, absl::Span<Shape const> result_shapes,
206 std::optional<std::vector<Shape>> operand_layouts, bool has_side_effect);
207
208 // `MakePythonCallbackUsingHostSendAndRecv` takes in an input Python callable
209 // that takes in arguments of shapes `operand_shapes` and returns results of
210 // shapes `result_shapes`. The arguments correspond to Send ops in the HLO
211 // program through `send_channel_ids` and the results correspond to Recv ops
212 // through `recv_channel_ids`. It returns the host callback as an opaque
213 // object whose reference will keep the Python callback alive. The host
214 // callback can be passed to PyExecutable::Execute() so that the corresponding
215 // Send/Recv ops can trigger the execution of this host callback.
216 StatusOr<pybind11::object> MakePythonCallbackUsingHostSendAndRecv(
217 pybind11::function callable, absl::Span<Shape const> operand_shapes,
218 absl::Span<Shape const> result_shapes,
219 absl::Span<uint16_t const> send_channel_ids,
220 absl::Span<uint16_t const> recv_channel_ids);
221
222 private:
223 friend class PyBuffer;
224 friend class PyExecutable;
225
226 std::shared_ptr<PjRtClient> pjrt_client_;
227
228 // Pointers to intrusive doubly-linked lists of buffers and executables, used
229 // to iterate over all known objects when heap profiling. The list structure
230 // is protected by the GIL.
231
232 // buffers_ is a per-device list, indexed by device->id().
233 std::vector<PyBuffer*> buffers_;
234 PyExecutable* executables_ = nullptr;
235 };
236
237 } // namespace xla
238
239 PYBIND11_DECLARE_HOLDER_TYPE(T, xla::ClientAndPtr<T>);
240
241 #endif // TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_
242