1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/pmap_lib.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <stdexcept>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/hash/hash.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_format.h"
28 #include "absl/strings/str_join.h"
29 #include "absl/synchronization/notification.h"
30 #include "absl/types/span.h"
31 #include "absl/types/variant.h"
32 #include "pybind11/cast.h"
33 #include "pybind11/pybind11.h"
34 #include "pybind11/pytypes.h"
35 #include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil
36 #include "tensorflow/compiler/xla/python/exceptions.h"
37 #include "tensorflow/compiler/xla/python/jax_jit.h"
38 #include "tensorflow/compiler/xla/python/py_buffer.h"
39 #include "tensorflow/compiler/xla/python/py_executable.h"
40 #include "tensorflow/compiler/xla/python/py_values.h"
41 #include "tensorflow/compiler/xla/python/python_utils.h"
42 #include "tensorflow/compiler/xla/python/sharded_device_array.h"
43 #include "tensorflow/compiler/xla/python/types.h"
44 #include "tensorflow/compiler/xla/xla_data.pb.h"
45 #include "tensorflow/core/platform/logging.h"
46 #include "tensorflow/core/profiler/lib/traceme.h"
47
48 namespace jax {
49
50 namespace py = pybind11;
51
52 namespace {
53
54 // Specifies how to shard the inputs. Even though everything could be computed
55 // from `sharding_specs` and the argument shape, we cache derived computations
56 // for performance.
57 struct InputSpec {
InputSpecjax::__anond5510e320111::InputSpec58 InputSpec(ShardingSpec sharding_spec, py::object indices)
59 : sharding_spec(std::move(sharding_spec)), indices(std::move(indices)) {}
60 ShardingSpec sharding_spec;
61 py::object indices;
62 };
63
64 // An object containing the arguments to create ShardedDeviceArray from the
65 // output buffers.
66 struct ResultSpec {
67 public:
ResultSpecjax::__anond5510e320111::ResultSpec68 ResultSpec(py::object aval, ShardingSpec out_spec, py::object out_indices)
69 : out_aval(std::move(aval)),
70 weak_type(py::cast<bool>(out_aval.attr("weak_type"))),
71 out_spec(std::move(out_spec)),
72 out_indices(std::move(out_indices)) {}
73 py::object out_aval;
74 bool weak_type;
75 ShardingSpec out_spec;
76 py::object out_indices;
77 };
78
79 // The result of `ShardArg`.
80 struct ShardArgResult {
81 // Points to the on-device buffers. Not owned.
82 // Size `num_devices`.
83 std::vector<xla::PjRtBuffer*> per_device_buffers;
84
85 // The Python argument will be always be copied to `owning_sda`.
86 // If we need to copy data to a device, the newly created buffers will be
87 // added to `owned_buffers`.
88 std::vector<std::unique_ptr<xla::PjRtBuffer>> owned_buffers;
89 py::object owning_sda;
90 };
91
92 // Shars a single argument over devices.
93 //
94 // We currently only support fully in C++, C++ ShardedDeviceArray. For all
95 // other usages, we call a Python function returning C++ ShardedDeviceArray
96 // that will be casted back to the C++ objects.
97 //
98 // This function is not usable for JAX extensions that do not comply with the
99 // PjRt interfaces.
100 //
101 // Arguments:
102 // `arg`: The object to shard across `devices`. If a `ShardedDeviceArray`,
103 // a fast-path will be executed if it's already correctly sharded.
104 //
105 // Returns a failure Status when an unrecoverable error occurred, so we don't
106 // need to fallback to Python.
107 //
108 // Both `devices` and `sharding_spec` has the same length.
ShardArg(py::handle arg,absl::Span<xla::PjRtDevice * const> devices,const InputSpec & input_spec,py::handle py_devices,const py::function & python_fallback)109 xla::StatusOr<ShardArgResult> ShardArg(
110 py::handle arg, absl::Span<xla::PjRtDevice* const> devices,
111 const InputSpec& input_spec, py::handle py_devices,
112 const py::function& python_fallback) {
113 if (ShardedDeviceArray::IsShardedDeviceArray(arg)) {
114 ShardedDeviceArray* sda =
115 ShardedDeviceArray::AsShardedDeviceArrayUnchecked(arg);
116 const ShardingSpec& sharding_spec = input_spec.sharding_spec;
117 if (sharding_spec == sda->GetShardingSpec()) {
118 const int num_devices = devices.size();
119 TF_ASSIGN_OR_RETURN(absl::Span<xla::PjRtBuffer* const> sda_buffers,
120 sda->GetPjRtBuffers());
121 CHECK_EQ(sda_buffers.size(), num_devices);
122
123 ShardArgResult result;
124 result.owning_sda = py::reinterpret_borrow<py::object>(arg);
125 std::vector<xla::PjRtBuffer*>& per_device_buffers =
126 result.per_device_buffers;
127 per_device_buffers.reserve(num_devices);
128
129 for (int i = 0; i < num_devices; ++i) {
130 xla::PjRtBuffer* current_buffer = sda_buffers[i];
131 if (devices[i] == current_buffer->device()) { // Pointer equality.
132 per_device_buffers.push_back(current_buffer);
133 } else {
134 TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::PjRtBuffer> out,
135 current_buffer->CopyToDevice(devices[i]));
136 per_device_buffers.push_back(out.get());
137 result.owned_buffers.push_back(std::move(out));
138 }
139 }
140 return result;
141 }
142 }
143
144 // This fallback is better than nothing, but ideally we should be able to
145 // convert the argument in C++. At least, we can call the C++ DevicePut from
146 // Python.
147 auto per_device_pybuffers =
148 py::cast<py::list>(python_fallback(arg, py_devices, input_spec.indices));
149 ShardArgResult result;
150 result.owning_sda = py::reinterpret_borrow<py::object>(per_device_pybuffers);
151 std::vector<xla::PjRtBuffer*>& per_device_buffers = result.per_device_buffers;
152 if (!per_device_pybuffers.empty()) {
153 per_device_buffers.reserve(per_device_pybuffers.size());
154
155 // The JAX Python shard_arg function is expected to return JAX PyBuffer
156 // objects. If executing a JAX extension, it should have fallbacked to
157 // Python well before this point.
158 TF_RET_CHECK(xla::PyBuffer::IsPyBuffer(per_device_pybuffers[0]));
159 for (py::handle per_device_pybuffer : per_device_pybuffers) {
160 xla::PjRtBuffer* buf =
161 xla::PyBuffer::AsPyBuffer(per_device_pybuffer).ValueOrDie()->buffer();
162 per_device_buffers.push_back(buf);
163 }
164 }
165 return result;
166 }
167
168 struct PmapCacheEntry {
169 std::shared_ptr<xla::PyExecutable> executable;
170 // The value `backend.local_devices()`.
171 py::object py_devices; // To pass back to Python.
172 std::vector<xla::PjRtDevice*> devices;
173 std::vector<InputSpec> input_specs;
174 xla::PyTreeDef out_pytree_def;
175 // Objects necessary to build the out ShardedDeviceArray objects.
176 std::vector<ResultSpec> out_result_specs;
177
178 // Ensures a single thread performs the compilation for a given executable.
179 //
180 // The first thread (holding the GIL) will create the CacheEntry associated to
181 // a signature and if the object has been inserted already, other threads
182 // will wait for the notification.
183 absl::Notification compilation_complete;
184
185 bool fall_back_to_python = false;
186 };
187
188 } // namespace
189
190 // A `PmapFunction` is associated to a `jax.pmap(f)` and takes care of the
191 // bookkeeping of the different signatures used and the dispatch of calls to
192 // the correct underlying `PyExecutable`. This class is thread-safe.
193 class PmapFunction {
194 public:
PmapFunction(py::function fun,py::function cache_miss,std::vector<int> static_argnums,py::function python_shard_arg_fallback)195 PmapFunction(py::function fun, py::function cache_miss,
196 std::vector<int> static_argnums,
197 py::function python_shard_arg_fallback)
198 : fun_(std::move(fun)),
199 cache_miss_(std::move(cache_miss)),
200 static_argnums_(std::move(static_argnums)),
201 python_shard_arg_fallback_(std::move(python_shard_arg_fallback)) {
202 std::sort(static_argnums_.begin(), static_argnums_.end());
203
204 function_name_ = py::str(py::getattr(fun_, "__name__", fun));
205 }
206 PmapFunction(const PmapFunction&) = delete;
207 PmapFunction& operator=(const PmapFunction& other) = delete;
208 PmapFunction(PmapFunction&&) = default;
209 PmapFunction& operator=(PmapFunction&&) = default;
210
211 // This function will:
212 // (a) flatten the inputs using pytree
213 // (b) get buffer objects from the arguments
214 // (c) call the executable
215 // (d) construct `ShardedDeviceArray` objects from the outputs
216 // (e) reconstruct the `PyTree`.
217 xla::StatusOr<py::object> Call(py::args args, py::kwargs kwargs);
218
PythonSignature()219 py::object PythonSignature() {
220 static const auto* inspect = new py::module(py::module::import("inspect"));
221 return inspect->attr("signature")(fun_);
222 }
223
cache_size() const224 int cache_size() const { return executables_.size(); }
fun() const225 const py::function& fun() const { return fun_; }
cache_miss() const226 const py::function& cache_miss() const { return cache_miss_; }
function_name() const227 const std::string& function_name() const { return function_name_; }
python_shard_arg_fallback() const228 const py::function& python_shard_arg_fallback() const {
229 return python_shard_arg_fallback_;
230 }
static_argnums() const231 const std::vector<int>& static_argnums() const { return static_argnums_; }
232
233 // pybind11::object typed subclass for PmapFunction objects.
234 class pyobject : public py::object {
235 public:
236 PYBIND11_OBJECT(pyobject, // NOLINT
237 py::object, PmapFunction::IsPmapFunction);
238 pyobject() = default;
func() const239 PmapFunction* func() const {
240 return PmapFunction::AsPmapFunctionUnchecked(*this);
241 }
242 };
243 // Alias as ::object; outside the scope above we won't confuse pybind11's
244 // macros.
245 using object = pyobject;
246
247 py::handle AsPyHandle();
248 // Returns true if `h` is a PmapFunction.
249 static bool IsPmapFunction(py::handle handle);
250 // Converts `handle` to a PmapFunction*. Does not do any checking.
251 static PmapFunction* AsPmapFunctionUnchecked(py::handle handle);
252
253 // Helper function used by the tp_clear GC method.
ClearPythonReferences()254 void ClearPythonReferences() {
255 py::function fun, cache_miss, python_shard_arg_fallback;
256 // Swap values for nulls before they are destroyed. See the Python
257 // Py_CLEAR() documentation for a discussion of this topic.
258 std::swap(fun_, fun);
259 std::swap(cache_miss_, cache_miss);
260 std::swap(python_shard_arg_fallback_, python_shard_arg_fallback);
261 }
262
263 // Updates the signature of arguments for a pmapped function.
264 //
265 // It deals with the arguments signatures and also of the global and
266 // thread-local jit context.
UpdateArgsSignature(const py::args & args,const py::kwargs & kwargs,ParsedArgumentsAsBuffers & arguments)267 xla::Status UpdateArgsSignature(const py::args& args,
268 const py::kwargs& kwargs,
269 ParsedArgumentsAsBuffers& arguments) {
270 arguments.signature.function_name = function_name_;
271
272 // Get dynamic argument signatures.
273 JitState& global_state = jax::GetGlobalState();
274 JitState& tls = jax::GetLocalState();
275 const bool jax_enable_x64 = GetEnableX64();
276 arguments.signature.jax_enable_x64 = jax_enable_x64;
277 for (py::handle arg : arguments.flat_dynamic_args) {
278 auto signature_or_error = xla::PyArgSignatureOfValue(arg, jax_enable_x64);
279 if (!signature_or_error.ok()) {
280 VLOG(2) << "PyArgSignatureOfValue failed: "
281 << signature_or_error.status();
282 return signature_or_error.status();
283 }
284 arguments.signature.dynamic_arg_signatures.push_back(
285 std::move(signature_or_error).ValueOrDie());
286 }
287 try {
288 py::object pxla_module = py::module::import("jax").attr("config");
289 py::object sda = py::getattr(pxla_module, "_trace_context", py::none());
290 if (!sda.is_none()) {
291 arguments.signature.thread_local_extra_jit_context = sda();
292 }
293 } catch (const py::error_already_set& e) {
294 // Ignore; jax may not be present.
295 }
296 if (!arguments.signature.thread_local_extra_jit_context.has_value()) {
297 arguments.signature.thread_local_extra_jit_context =
298 tls.extra_jit_context;
299 arguments.signature.global_extra_jit_context =
300 global_state.extra_jit_context;
301 }
302 return xla::Status();
303 }
304
305 // Returns, for debugging purposes (e.g. finding why some call misses the
306 // cache and recompiles), the list of the string representations of the keys.
307 //
308 // The format can change at any time.
DebugCacheKeys() const309 std::string DebugCacheKeys() const {
310 std::vector<std::string> key_strings = {
311 absl::StrCat("The cache contains ", executables_.size(), " elements:")};
312 // We will be able to use auto& [key, _] when TF uses C++ 17.
313 for (auto& pair : executables_) {
314 key_strings.push_back(pair.first.DebugString());
315 }
316 return absl::StrJoin(key_strings, "\n\n");
317 }
318
319 private:
320 // Mutates `cache_entry` in place.
321 void PopulateCacheEntry(PmapCacheEntry& cache_entry,
322 const CallSignature& signature,
323 const py::tuple& out_and_fastpath_data);
324
325 bool always_fallback_to_python_ = false;
326
327 py::function fun_; // The Python function to pmap.
328 std::string function_name_;
329 // See JAX _cpp_pmap in api.py for documentation.
330 py::function cache_miss_;
331
332 // We need to know the static arguments to remove them from the arguments
333 // passed to the underlying PyExecutable. In sorted order.
334 std::vector<int> static_argnums_;
335 // We need a `unique_ptr` here to ensure value pointer stability.
336 absl::flat_hash_map<CallSignature, std::unique_ptr<PmapCacheEntry>>
337 executables_;
338
339 // The fallback function to use with `ShardArgs`.
340 // TODO(jblespiau): Add support for more types from C++.
341 py::function python_shard_arg_fallback_;
342 };
343
PopulateCacheEntry(PmapCacheEntry & cache_entry,const CallSignature & signature,const py::tuple & out_and_fastpath_data)344 void PmapFunction::PopulateCacheEntry(PmapCacheEntry& cache_entry,
345 const CallSignature& signature,
346 const py::tuple& out_and_fastpath_data) {
347 CHECK_EQ(out_and_fastpath_data.size(), 2);
348 if (out_and_fastpath_data[1].is_none()) {
349 cache_entry.fall_back_to_python = true;
350 return;
351 }
352
353 py::tuple pmap_data = py::cast<py::tuple>(out_and_fastpath_data[1]);
354 if (py::cast<int>(pmap_data.attr("version")) != 1) {
355 throw xla::XlaRuntimeError(absl::StrCat(
356 "The versions of jaxlib and Jax are incompatible (pmap cpp version 1 "
357 "expected, but got ",
358 py::cast<int>(pmap_data.attr("version")),
359 "Upgrade jaxlib and jax. Provided data was:",
360 py::cast<std::string>(py::str(py::repr(pmap_data)))));
361 }
362 // See api.py::_PmapFastpathData in the JAX code base for the expected
363 // namedtuple.
364 std::shared_ptr<xla::PyExecutable> executable;
365 try {
366 executable = py::cast<std::shared_ptr<xla::PyExecutable>>(
367 pmap_data.attr("xla_executable"));
368 } catch (const py::cast_error& e) {
369 // Backends that don't implement the C++ PjRt APIs
370 always_fallback_to_python_ = true;
371 return;
372 }
373 cache_entry.executable = std::move(executable);
374 const std::vector<xla::ClientAndPtr<xla::PjRtDevice>>& client_and_devices =
375 cache_entry.executable->AddressableDevices();
376 cache_entry.devices.reserve(client_and_devices.size());
377 for (auto& client_and_device : client_and_devices) {
378 cache_entry.devices.push_back(client_and_device.get());
379 }
380
381 // Inputs shard args details.
382 auto input_sharding_specs = py::cast<std::vector<ShardingSpec>>(
383 pmap_data.attr("input_sharding_specs"));
384 py::list input_indices = pmap_data.attr("input_indices");
385
386 cache_entry.py_devices = pmap_data.attr("input_devices");
387 auto input_devices =
388 py::cast<std::vector<xla::PjRtDevice*>>(pmap_data.attr("input_devices"));
389 CHECK_EQ(input_sharding_specs.size(), input_indices.size());
390 cache_entry.input_specs.reserve(input_sharding_specs.size());
391 for (int i = 0; i < input_sharding_specs.size(); ++i) {
392 cache_entry.input_specs.emplace_back(input_sharding_specs[i],
393 input_indices[i]);
394 }
395
396 // Outputs specs.
397 auto out_tree = py::cast<xla::PyTreeDef>(pmap_data.attr("out_pytree_def"));
398 cache_entry.out_pytree_def = std::move(out_tree);
399 py::list out_avals = pmap_data.attr("out_avals");
400 py::list out_indices = pmap_data.attr("out_indices");
401 auto out_sharding_specs =
402 py::cast<std::vector<ShardingSpec>>(pmap_data.attr("out_sharding_specs"));
403 CHECK_EQ(out_avals.size(), out_indices.size());
404 CHECK_EQ(out_indices.size(), out_sharding_specs.size());
405
406 cache_entry.out_result_specs.reserve(out_avals.size());
407 for (int i = 0; i < out_avals.size(); ++i) {
408 cache_entry.out_result_specs.emplace_back(
409 out_avals[i], std::move(out_sharding_specs[i]), out_indices[i]);
410 }
411 }
412
Call(py::args args,py::kwargs kwargs)413 xla::StatusOr<py::object> PmapFunction::Call(py::args args, py::kwargs kwargs) {
414 if (always_fallback_to_python_) {
415 return py::object(py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0]);
416 }
417
418 ParsedArgumentsAsBuffers arguments;
419 xla::Status status = ParseArguments(args, kwargs, static_argnums_,
420 /*static_argnames=*/{}, arguments);
421 if (!status.ok()) {
422 VLOG(2) << "ParseArguments failed: " << status;
423 return py::object(py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0]);
424 }
425
426 status = UpdateArgsSignature(args, kwargs, arguments);
427 if (!status.ok()) {
428 return py::object(py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0]);
429 }
430
431 // Retrieve/Maybe add the executable to the cache.
432 absl::flat_hash_map<CallSignature, std::unique_ptr<PmapCacheEntry>>::iterator
433 it;
434 bool inserted;
435 std::tie(it, inserted) = executables_.try_emplace(
436 arguments.signature, std::make_unique<PmapCacheEntry>());
437 PmapCacheEntry& cache_entry = *(it->second);
438
439 if (!cache_entry.compilation_complete.HasBeenNotified()) {
440 // In case of several threads attempting to compile the executable, only
441 // the one that inserted the item will perform the compilation.
442 if (inserted) {
443 py::object out_and_fastpath_data;
444 py::tuple out_tuple;
445 VLOG(2) << "Cache miss for " << arguments.signature.DebugString();
446 try {
447 // Calls Python and may release the GIL. May also throw if
448 // compilation/tracing fails.
449 out_and_fastpath_data = cache_miss_(*args, **kwargs);
450 out_tuple = py::cast<py::tuple>(out_and_fastpath_data);
451 PopulateCacheEntry(cache_entry, arguments.signature, out_tuple);
452 } catch (const std::exception& e) {
453 cache_entry.fall_back_to_python = true;
454 cache_entry.compilation_complete.Notify();
455 throw;
456 }
457 cache_entry.compilation_complete.Notify();
458
459 // We have already computed the result in the miss path so we can return
460 // it. We are even *required* to do so if there are donated arguments,
461 // because any donated buffers will now be invalid.
462 return py::object(out_tuple[0]);
463 } else {
464 // Release the GIL while we wait, making sure the compile thread can
465 // lock it.
466 py::gil_scoped_release release;
467 cache_entry.compilation_complete.WaitForNotification();
468 }
469 }
470 if (cache_entry.fall_back_to_python) {
471 return py::object(py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0]);
472 }
473
474 // 1. Parse arguments.
475 std::vector<xla::PjRtDevice*>& input_devices = cache_entry.devices;
476 const int num_computations =
477 cache_entry.executable->AddressableDevices().size();
478 std::vector<InputSpec>& input_specs = cache_entry.input_specs;
479 const int num_args = arguments.flat_dynamic_args.size();
480
481 // We need [num_computation, num_args] for the `Execute` call bellow,
482 std::vector<std::vector<xla::PjRtBuffer*>> num_computation_num_args_buffers(
483 num_computations);
484 for (int computation = 0; computation < num_computations; ++computation) {
485 num_computation_num_args_buffers[computation].resize(num_args);
486 }
487 for (int i = 0; i < num_args; ++i) {
488 TF_ASSIGN_OR_RETURN(
489 ShardArgResult sharded_arg,
490 ShardArg(arguments.flat_dynamic_args[i], input_devices, input_specs[i],
491 cache_entry.py_devices, python_shard_arg_fallback_));
492
493 std::vector<xla::PjRtBuffer*>& per_device_buffers =
494 sharded_arg.per_device_buffers;
495 for (int computation = 0; computation < num_computations; ++computation) {
496 num_computation_num_args_buffers[computation][i] =
497 per_device_buffers[computation];
498 }
499 for (auto& owned_buffer : sharded_arg.owned_buffers) {
500 arguments.keep_alive.push_back(std::move(owned_buffer));
501 }
502 if (sharded_arg.owning_sda) {
503 arguments.keep_alive_objects.push_back(std::move(sharded_arg.owning_sda));
504 }
505 }
506
507 // A vector of [num_devices, num_outputs].
508 std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>> output_buffers;
509 {
510 py::gil_scoped_release gil_release;
511 auto pjrt_executable = cache_entry.executable->mutable_pjrt_executable();
512 TF_ASSIGN_OR_RETURN(output_buffers, pjrt_executable->Execute(
513 num_computation_num_args_buffers,
514 cache_entry.executable->options()));
515 }
516
517 // TODO(jblespiau): We don't need to create the PyBuffer objects.
518 // Having a C++ `ShardedDeviceArray`, keeping internally the PjRtBuffer
519 // objects is sufficient, and we can lazily create the `PyBuffer` only if
520 // we access them from Python.
521 auto traceback = xla::Traceback::Get();
522 // TODO(jblespiau): Change the `client` function to return a reference.
523 std::shared_ptr<xla::PyClient> client = cache_entry.executable->client();
524
525 // Convert the PjRtBuffer objects to PyBuffer, and invert the order from
526 // [num_devices, num_args] to [num_args, num_devices].
527 const int num_outputs = output_buffers[0].size();
528 std::vector<std::vector<xla::PyBuffer::object>> outputs;
529 outputs.resize(num_outputs);
530 for (int output_id = 0; output_id < num_outputs; ++output_id) {
531 outputs[output_id].reserve(num_computations);
532 for (int computation = 0; computation < num_computations; ++computation) {
533 outputs[output_id].push_back(xla::PyBuffer::Make(
534 client, std::move(output_buffers[computation][output_id]),
535 traceback));
536 }
537 }
538
539 py::list outputs_as_python_objects;
540 const auto& output_specs = cache_entry.out_result_specs;
541
542 std::vector<py::object> flat_sharded_device_arrays;
543 flat_sharded_device_arrays.reserve(num_outputs);
544 for (int i = 0; i < num_outputs; ++i) {
545 const ResultSpec& result_spec = output_specs[i];
546 flat_sharded_device_arrays.push_back(ShardedDeviceArray::Make(
547 /*aval=*/result_spec.out_aval,
548 /*sharding_spec=*/result_spec.out_spec,
549 /*device_buffers=*/py::cast(std::move(outputs[i])),
550 /*indices=*/result_spec.out_indices,
551 /*weak_type=*/result_spec.weak_type));
552 }
553 py::object out =
554 cache_entry.out_pytree_def.Unflatten(flat_sharded_device_arrays);
555
556 // If there is a post-hook function, call it with the inputs and the outputs.
557 std::optional<py::object> post_hook = GetPostHook();
558 if (post_hook) {
559 (*post_hook)(this->AsPyHandle(), args, kwargs, out);
560 }
561
562 return out;
563 }
564
565 struct JaxPmapFunctionObject {
566 PyObject_HEAD;
567 PyObject* dict; // Dictionary for __dict__
568 PyObject* weakrefs; // Weak references; for use by the Python interpreter.
569 PmapFunction fun;
570 };
571
572 PyObject* JaxPmapFunction_Type = nullptr;
573
IsPmapFunction(py::handle handle)574 bool PmapFunction::IsPmapFunction(py::handle handle) {
575 return handle.get_type() == JaxPmapFunction_Type;
576 }
577
AsPmapFunctionUnchecked(py::handle handle)578 PmapFunction* PmapFunction::AsPmapFunctionUnchecked(py::handle handle) {
579 return &(reinterpret_cast<JaxPmapFunctionObject*>(handle.ptr())->fun);
580 }
581
AsPmapFunction(py::handle handle)582 xla::StatusOr<PmapFunction*> AsPmapFunction(py::handle handle) {
583 if (!PmapFunction::IsPmapFunction(handle)) {
584 return xla::InvalidArgument("Expected a PmapFunction");
585 }
586 return PmapFunction::AsPmapFunctionUnchecked(handle);
587 }
588
AsPyHandle()589 py::handle PmapFunction::AsPyHandle() {
590 return reinterpret_cast<PyObject*>(reinterpret_cast<char*>(this) -
591 offsetof(JaxPmapFunctionObject, fun));
592 }
593
594 namespace {
595
596 extern "C" {
597
JaxPmapFunction_tp_new(PyTypeObject * subtype,PyObject * args,PyObject * kwds)598 PyObject* JaxPmapFunction_tp_new(PyTypeObject* subtype, PyObject* args,
599 PyObject* kwds) {
600 JaxPmapFunctionObject* self =
601 reinterpret_cast<JaxPmapFunctionObject*>(subtype->tp_alloc(subtype, 0));
602 if (!self) return nullptr;
603 self->dict = nullptr;
604 self->weakrefs = nullptr;
605 return reinterpret_cast<PyObject*>(self);
606 }
607
JaxPmapFunction_tp_dealloc(PyObject * self)608 void JaxPmapFunction_tp_dealloc(PyObject* self) {
609 PyTypeObject* tp = Py_TYPE(self);
610 JaxPmapFunctionObject* o = reinterpret_cast<JaxPmapFunctionObject*>(self);
611 if (o->weakrefs) {
612 PyObject_ClearWeakRefs(self);
613 }
614 Py_CLEAR(o->dict);
615 o->fun.~PmapFunction();
616 tp->tp_free(self);
617 Py_DECREF(tp);
618 }
619
JaxPmapFunction_tp_traverse(PyObject * self,visitproc visit,void * arg)620 int JaxPmapFunction_tp_traverse(PyObject* self, visitproc visit, void* arg) {
621 JaxPmapFunctionObject* o = reinterpret_cast<JaxPmapFunctionObject*>(self);
622 Py_VISIT(o->dict);
623 Py_VISIT(o->fun.fun().ptr());
624 Py_VISIT(o->fun.cache_miss().ptr());
625 return 0;
626 }
627
JaxPmapFunction_tp_clear(PyObject * self)628 int JaxPmapFunction_tp_clear(PyObject* self) {
629 JaxPmapFunctionObject* o = reinterpret_cast<JaxPmapFunctionObject*>(self);
630 Py_CLEAR(o->dict);
631 o->fun.ClearPythonReferences();
632 return 0;
633 }
634
635 // Implements the Python descriptor protocol so PMAP-compiled functions can be
636 // used as bound methods. See:
637 // https://docs.python.org/3/howto/descriptor.html#functions-and-methods
JaxPmapFunction_tp_descr_get(PyObject * self,PyObject * obj,PyObject * type)638 PyObject* JaxPmapFunction_tp_descr_get(PyObject* self, PyObject* obj,
639 PyObject* type) {
640 if (obj == nullptr || obj == Py_None) {
641 Py_INCREF(self);
642 return self;
643 }
644 return PyMethod_New(self, obj);
645 }
646
647 // Support d = instance.__dict__.
JaxPmapFunction_get_dict(PyObject * self,void *)648 PyObject* JaxPmapFunction_get_dict(PyObject* self, void*) {
649 JaxPmapFunctionObject* o = reinterpret_cast<JaxPmapFunctionObject*>(self);
650 if (!o->dict) {
651 o->dict = PyDict_New();
652 }
653 Py_XINCREF(o->dict);
654 return o->dict;
655 }
656
JaxPmapFunction_set_dict(PyObject * self,PyObject * new_dict,void *)657 int JaxPmapFunction_set_dict(PyObject* self, PyObject* new_dict, void*) {
658 JaxPmapFunctionObject* o = reinterpret_cast<JaxPmapFunctionObject*>(self);
659 if (!PyDict_Check(new_dict)) {
660 PyErr_Format(PyExc_TypeError,
661 "__dict__ must be set to a dictionary, not a '%s'",
662 Py_TYPE(new_dict)->tp_name);
663 return -1;
664 }
665 Py_INCREF(new_dict);
666 Py_CLEAR(o->dict);
667 o->dict = new_dict;
668 return 0;
669 }
670
671 static PyGetSetDef JaxPmapFunction_tp_getset[] = {
672 // Having a __dict__ seems necessary to allow !functool.wraps to override
673 // __doc__.
674 {const_cast<char*>("__dict__"), JaxPmapFunction_get_dict,
675 JaxPmapFunction_set_dict, nullptr, nullptr},
676 {nullptr, nullptr, nullptr, nullptr, nullptr}};
677
JaxPmapFunction_tp_call(PyObject * self,PyObject * args,PyObject * kwargs)678 PyObject* JaxPmapFunction_tp_call(PyObject* self, PyObject* args,
679 PyObject* kwargs) {
680 JaxPmapFunctionObject* o = reinterpret_cast<JaxPmapFunctionObject*>(self);
681 tensorflow::profiler::TraceMe traceme([&] {
682 return absl::StrCat("JaxPmapFunction(", o->fun.function_name(), ")");
683 });
684 py::kwargs py_kwargs;
685 if (kwargs) {
686 py_kwargs = py::reinterpret_borrow<py::kwargs>(kwargs);
687 }
688 try {
689 xla::StatusOr<py::object> out = o->fun.Call(
690 py::reinterpret_borrow<py::args>(args), std::move(py_kwargs));
691 if (!out.ok()) {
692 PyErr_SetString(PyExc_ValueError, out.status().ToString().c_str());
693 return nullptr;
694 }
695 return out.ValueOrDie().release().ptr();
696 } catch (py::error_already_set& e) {
697 e.restore();
698 return nullptr;
699 } catch (py::cast_error& e) {
700 PyErr_SetString(PyExc_ValueError, e.what());
701 return nullptr;
702 } catch (std::invalid_argument& e) {
703 PyErr_SetString(PyExc_ValueError, e.what());
704 return nullptr;
705 }
706 }
707
InitializePmapFunction(JaxPmapFunctionObject * cfun,py::function fun,py::function cache_miss,std::vector<int> static_argnums,py::function python_shard_arg_fallback)708 void InitializePmapFunction(JaxPmapFunctionObject* cfun, py::function fun,
709 py::function cache_miss,
710 std::vector<int> static_argnums,
711 py::function python_shard_arg_fallback) {
712 new (&cfun->fun) PmapFunction(std::move(fun), std::move(cache_miss),
713 std::move(static_argnums),
714 std::move(python_shard_arg_fallback));
715 }
716
717 } // extern "C"
718
MakePmapFunction(py::function fun,py::function cache_miss,std::vector<int> static_argnums,py::function python_shard_arg_fallback)719 py::object MakePmapFunction(py::function fun, py::function cache_miss,
720 std::vector<int> static_argnums,
721 py::function python_shard_arg_fallback) {
722 py::object obj = py::reinterpret_steal<py::object>(JaxPmapFunction_tp_new(
723 reinterpret_cast<PyTypeObject*>(JaxPmapFunction_Type), nullptr, nullptr));
724 JaxPmapFunctionObject* buf =
725 reinterpret_cast<JaxPmapFunctionObject*>(obj.ptr());
726 InitializePmapFunction(buf, std::move(fun), std::move(cache_miss),
727 std::move(static_argnums),
728 std::move(python_shard_arg_fallback));
729 return obj;
730 }
731
732 // Version numbers for the pickled representations.
733 // Increment these if changing them.
734 const int kPmapFunctionPickleVersion = 1;
735
736 } // namespace
737
BuildPmapSubmodule(py::module & m)738 void BuildPmapSubmodule(py::module& m) {
739 py::module pmap_lib = m.def_submodule("pmap_lib", "Jax C++ pmap library");
740
741 py::class_<NoSharding> no_sharding(pmap_lib, "NoSharding");
742 no_sharding.def(py::init<>())
743 .def("__repr__",
744 [](const NoSharding& chuncked) { return "NoSharding()"; })
745 .def("__eq__",
746 [](const NoSharding& self, py::object obj) {
747 return py::isinstance<NoSharding>(obj);
748 })
749 .def("__hash__", [](const NoSharding& self) {
750 const size_t hash = absl::HashOf(self);
751 return py::int_(hash);
752 });
753
754 py::class_<Chunked> chunked(pmap_lib, "Chunked");
755 chunked.def(py::init<std::vector<int>>())
756 .def_readonly("chunks", &Chunked::chunks)
757 .def("__repr__",
758 [](const Chunked& chuncked) {
759 return absl::StrCat("Chunked(",
760 absl::StrJoin(chuncked.chunks, ","), ")");
761 })
762 .def("__eq__", [](const Chunked& self, py::object other) {
763 if (!py::isinstance<Chunked>(other)) {
764 return false;
765 }
766 return self == py::cast<const Chunked&>(other);
767 });
768
769 py::class_<Unstacked> unstacked(pmap_lib, "Unstacked");
770 unstacked.def(py::init<int>())
771 .def_readonly("size", &Unstacked::size)
772 .def("__repr__",
773 [](const Unstacked& x) {
774 return absl::StrCat("Unstacked(", x.size, ")");
775 })
776 .def("__eq__", [](const Unstacked& self, py::object other) {
777 if (!py::isinstance<Unstacked>(other)) {
778 return false;
779 }
780 return self == py::cast<const Unstacked&>(other);
781 });
782
783 py::class_<ShardedAxis> sharded_axis(pmap_lib, "ShardedAxis");
784 sharded_axis.def(py::init<int>()).def_readonly("axis", &ShardedAxis::axis);
785 sharded_axis
786 .def("__repr__",
787 [](const ShardedAxis& x) {
788 return absl::StrCat("ShardedAxis(axis=", x.axis, ")");
789 })
790 .def("__eq__", [](const ShardedAxis& self, const ShardedAxis& other) {
791 return self == other;
792 });
793
794 py::class_<Replicated> replicated(pmap_lib, "Replicated");
795 replicated.def(py::init<int>())
796 .def_readonly("replicas", &Replicated::replicas)
797 .def("__repr__",
798 [](const Replicated& x) {
799 return absl::StrCat("Replicated(replicas=", x.replicas, ")");
800 })
801 .def("__eq__", [](const Replicated& self, const Replicated& other) {
802 return self == other;
803 });
804
805 py::class_<ShardingSpec> sharding_spec(pmap_lib, "ShardingSpec");
806 sharding_spec
807 .def(py::init<py::iterable, py::iterable>(), py::arg("sharding"),
808 py::arg("mesh_mapping"))
809 .def_property_readonly(
810 "sharding",
811 [](const ShardingSpec& self) {
812 return xla::SpanToTuple(absl::MakeConstSpan(self.GetSharding()));
813 })
814 .def_property_readonly(
815 "mesh_mapping",
816 [](const ShardingSpec& self) {
817 return xla::SpanToTuple(absl::MakeConstSpan(self.GetMeshMapping()));
818 })
819 .def("__eq__", [](const ShardingSpec& self,
820 const ShardingSpec& other) { return self == other; })
821 .def("__hash__", [](const ShardingSpec& self) {
822 const size_t hash = absl::HashOf(self);
823 return py::int_(hash);
824 });
825
826 TF_CHECK_OK(ShardedDeviceArray::RegisterTypes(pmap_lib));
827
828 // We need to use heap-allocated type objects because we want to add
829 // additional methods dynamically.
830 py::object cfun;
831 {
832 py::str name = py::str("PmapFunction");
833 py::str qualname = py::str("PmapFunction");
834 PyHeapTypeObject* heap_type = reinterpret_cast<PyHeapTypeObject*>(
835 PyType_Type.tp_alloc(&PyType_Type, 0));
836 // Caution: we must not call any functions that might invoke the GC until
837 // PyType_Ready() is called. Otherwise the GC might see a half-constructed
838 // type object.
839 CHECK(heap_type) << "Unable to create heap type object";
840 heap_type->ht_name = name.release().ptr();
841 heap_type->ht_qualname = qualname.release().ptr();
842 PyTypeObject* type = &heap_type->ht_type;
843 type->tp_name = "PmapFunction";
844 type->tp_basicsize = sizeof(JaxPmapFunctionObject);
845 type->tp_flags =
846 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE | Py_TPFLAGS_HAVE_GC;
847 type->tp_new = JaxPmapFunction_tp_new;
848 type->tp_dealloc = JaxPmapFunction_tp_dealloc;
849 type->tp_dictoffset = offsetof(JaxPmapFunctionObject, dict);
850 type->tp_traverse = JaxPmapFunction_tp_traverse;
851 type->tp_clear = JaxPmapFunction_tp_clear;
852 type->tp_weaklistoffset = offsetof(JaxPmapFunctionObject, weakrefs);
853 type->tp_getset = JaxPmapFunction_tp_getset;
854 type->tp_descr_get = JaxPmapFunction_tp_descr_get;
855 type->tp_call = JaxPmapFunction_tp_call;
856 CHECK_EQ(PyType_Ready(type), 0);
857 JaxPmapFunction_Type = reinterpret_cast<PyObject*>(type);
858 cfun = py::reinterpret_borrow<py::object>(JaxPmapFunction_Type);
859 }
860 py::object cfun_type =
861 py::reinterpret_borrow<py::object>(JaxPmapFunction_Type);
862
863 // Add PmapFunction to the xla_extension module so it can be pickled.
864 m.attr("PmapFunction") = cfun_type;
865
866 cfun.attr("__signature__") =
867 property_readonly([](py::handle self) -> xla::StatusOr<py::object> {
868 TF_ASSIGN_OR_RETURN(PmapFunction * fun, AsPmapFunction(self));
869 return fun->PythonSignature();
870 });
871 // Required by `post_hook`.
872 cfun.attr("_cache_miss") =
873 property_readonly([](py::handle self) -> xla::StatusOr<py::object> {
874 TF_ASSIGN_OR_RETURN(PmapFunction * fun, AsPmapFunction(self));
875 return fun->cache_miss();
876 });
877 cfun.attr("__getstate__") = py::cpp_function(
878 [](const PmapFunction::object& self) {
879 PmapFunction* fn = self.func();
880 py::dict pickle;
881 pickle["version"] = kPmapFunctionPickleVersion;
882 pickle["fun"] = fn->fun();
883 pickle["cache_miss"] = fn->cache_miss();
884 pickle["static_argnums"] = fn->static_argnums();
885 pickle["python_shard_arg_fallback"] = fn->python_shard_arg_fallback();
886 return pickle;
887 },
888 py::is_method(cfun_type));
889 cfun.attr("__setstate__") = py::cpp_function(
890 [](PmapFunction::object& self, const py::dict& pickle) {
891 int version = py::cast<int>(pickle["version"]);
892 if (version != kPmapFunctionPickleVersion) {
893 throw std::invalid_argument(absl::StrFormat(
894 "Invalid PmapFunction pickle version, got %d, expected %d. "
895 "Pickling/Unpickling jitted functions using different JAX "
896 "versions is not supported.",
897 version, kPmapFunctionPickleVersion));
898 }
899 py::function fun = py::cast<py::function>(pickle["fun"]);
900 py::function cache_miss = py::cast<py::function>(pickle["cache_miss"]);
901 std::vector<int> static_argnums =
902 py::cast<std::vector<int>>(pickle["static_argnums"]);
903 py::function python_shard_arg_fallback =
904 py::cast<py::function>(pickle["python_shard_arg_fallback"]);
905
906 InitializePmapFunction(
907 reinterpret_cast<JaxPmapFunctionObject*>(self.ptr()),
908 std::move(fun), std::move(cache_miss), std::move(static_argnums),
909 std::move(python_shard_arg_fallback));
910 },
911 py::is_method(cfun_type));
912
913 // This is only for testing/debugging purposes.
914 cfun.attr("_cache_size") =
915 property_readonly([](py::handle self) -> xla::StatusOr<py::object> {
916 TF_ASSIGN_OR_RETURN(PmapFunction * fun, AsPmapFunction(self));
917 return py::cast<int>(fun->cache_size());
918 });
919
920 cfun.attr("_debug_cache_keys") = py::cpp_function(
921 [](py::handle self) -> xla::StatusOr<std::string> {
922 TF_ASSIGN_OR_RETURN(PmapFunction * fun, AsPmapFunction(self));
923 return fun->DebugCacheKeys();
924 },
925 py::is_method(cfun_type));
926
927 // Accepts _arbitrary_ arguments for a pmapped function and returns the
928 // corresponding signatures that are used as cache keys. No-op.
929 //
930 // This function allows to pass partial args, which is especially useful when
931 // the full list of arguments is too long and results in enormous signatures.
932 // For example, this function can be multiple times as
933 // > fn._debug_compute_cache_key(arg[0])
934 // > fn._debug_compute_cache_key(arg[1])
935 // > fn._debug_compute_cache_key(arg[-3:-1])
936 // ...
937 cfun.attr("_debug_compute_cache_key") = py::cpp_function(
938 [](const PmapFunction::object& self, const py::args& args,
939 const py::kwargs& kwargs) -> xla::StatusOr<std::string> {
940 ParsedArgumentsAsBuffers arguments;
941 TF_ASSIGN_OR_RETURN(PmapFunction * fun, AsPmapFunction(self));
942 TF_RETURN_IF_ERROR(ParseArguments(args, kwargs, fun->static_argnums(),
943 /*static_argnames=*/{}, arguments));
944 TF_RETURN_IF_ERROR(fun->UpdateArgsSignature(args, kwargs, arguments));
945 return arguments.signature.DebugString();
946 },
947 py::is_method(cfun_type));
948
949 pmap_lib.def("pmap",
950 [](py::function fun, py::function cache_miss,
951 std::vector<int> static_argnums,
952 py::function python_shard_arg_fallback) -> py::object {
953 return MakePmapFunction(std::move(fun), std::move(cache_miss),
954 std::move(static_argnums),
955 std::move(python_shard_arg_fallback));
956 });
957 }
958
959 } // namespace jax
960