xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/pmap_lib.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/pmap_lib.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <stdexcept>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/hash/hash.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_format.h"
28 #include "absl/strings/str_join.h"
29 #include "absl/synchronization/notification.h"
30 #include "absl/types/span.h"
31 #include "absl/types/variant.h"
32 #include "pybind11/cast.h"
33 #include "pybind11/pybind11.h"
34 #include "pybind11/pytypes.h"
35 #include "pybind11_abseil/absl_casters.h"  // from @pybind11_abseil
36 #include "tensorflow/compiler/xla/python/exceptions.h"
37 #include "tensorflow/compiler/xla/python/jax_jit.h"
38 #include "tensorflow/compiler/xla/python/py_buffer.h"
39 #include "tensorflow/compiler/xla/python/py_executable.h"
40 #include "tensorflow/compiler/xla/python/py_values.h"
41 #include "tensorflow/compiler/xla/python/python_utils.h"
42 #include "tensorflow/compiler/xla/python/sharded_device_array.h"
43 #include "tensorflow/compiler/xla/python/types.h"
44 #include "tensorflow/compiler/xla/xla_data.pb.h"
45 #include "tensorflow/core/platform/logging.h"
46 #include "tensorflow/core/profiler/lib/traceme.h"
47 
48 namespace jax {
49 
50 namespace py = pybind11;
51 
52 namespace {
53 
54 // Specifies how to shard the inputs. Even though everything could be computed
55 // from `sharding_specs` and the argument shape, we cache derived computations
56 // for performance.
57 struct InputSpec {
InputSpecjax::__anond5510e320111::InputSpec58   InputSpec(ShardingSpec sharding_spec, py::object indices)
59       : sharding_spec(std::move(sharding_spec)), indices(std::move(indices)) {}
60   ShardingSpec sharding_spec;
61   py::object indices;
62 };
63 
64 // An object containing the arguments to create ShardedDeviceArray from the
65 // output buffers.
66 struct ResultSpec {
67  public:
ResultSpecjax::__anond5510e320111::ResultSpec68   ResultSpec(py::object aval, ShardingSpec out_spec, py::object out_indices)
69       : out_aval(std::move(aval)),
70         weak_type(py::cast<bool>(out_aval.attr("weak_type"))),
71         out_spec(std::move(out_spec)),
72         out_indices(std::move(out_indices)) {}
73   py::object out_aval;
74   bool weak_type;
75   ShardingSpec out_spec;
76   py::object out_indices;
77 };
78 
79 // The result of `ShardArg`.
80 struct ShardArgResult {
81   // Points to the on-device buffers. Not owned.
82   // Size `num_devices`.
83   std::vector<xla::PjRtBuffer*> per_device_buffers;
84 
85   // The Python argument will be always be copied to `owning_sda`.
86   // If we need to copy data to a device, the newly created buffers will be
87   // added to `owned_buffers`.
88   std::vector<std::unique_ptr<xla::PjRtBuffer>> owned_buffers;
89   py::object owning_sda;
90 };
91 
92 // Shars a single argument over devices.
93 //
94 // We currently only support fully in C++, C++ ShardedDeviceArray. For all
95 // other usages, we call a Python function returning C++ ShardedDeviceArray
96 // that will be casted back to the C++ objects.
97 //
98 // This function is not usable for JAX extensions that do not comply with the
99 // PjRt interfaces.
100 //
101 // Arguments:
102 // `arg`: The object to shard across `devices`. If a `ShardedDeviceArray`,
103 //   a fast-path will be executed if it's already correctly sharded.
104 //
105 // Returns a failure Status when an unrecoverable error occurred, so we don't
106 // need to fallback to Python.
107 //
108 // Both `devices` and `sharding_spec` has the same length.
ShardArg(py::handle arg,absl::Span<xla::PjRtDevice * const> devices,const InputSpec & input_spec,py::handle py_devices,const py::function & python_fallback)109 xla::StatusOr<ShardArgResult> ShardArg(
110     py::handle arg, absl::Span<xla::PjRtDevice* const> devices,
111     const InputSpec& input_spec, py::handle py_devices,
112     const py::function& python_fallback) {
113   if (ShardedDeviceArray::IsShardedDeviceArray(arg)) {
114     ShardedDeviceArray* sda =
115         ShardedDeviceArray::AsShardedDeviceArrayUnchecked(arg);
116     const ShardingSpec& sharding_spec = input_spec.sharding_spec;
117     if (sharding_spec == sda->GetShardingSpec()) {
118       const int num_devices = devices.size();
119       TF_ASSIGN_OR_RETURN(absl::Span<xla::PjRtBuffer* const> sda_buffers,
120                           sda->GetPjRtBuffers());
121       CHECK_EQ(sda_buffers.size(), num_devices);
122 
123       ShardArgResult result;
124       result.owning_sda = py::reinterpret_borrow<py::object>(arg);
125       std::vector<xla::PjRtBuffer*>& per_device_buffers =
126           result.per_device_buffers;
127       per_device_buffers.reserve(num_devices);
128 
129       for (int i = 0; i < num_devices; ++i) {
130         xla::PjRtBuffer* current_buffer = sda_buffers[i];
131         if (devices[i] == current_buffer->device()) {  // Pointer equality.
132           per_device_buffers.push_back(current_buffer);
133         } else {
134           TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::PjRtBuffer> out,
135                               current_buffer->CopyToDevice(devices[i]));
136           per_device_buffers.push_back(out.get());
137           result.owned_buffers.push_back(std::move(out));
138         }
139       }
140       return result;
141     }
142   }
143 
144   // This fallback is better than nothing, but ideally we should be able to
145   // convert the argument in C++. At least, we can call the C++ DevicePut from
146   // Python.
147   auto per_device_pybuffers =
148       py::cast<py::list>(python_fallback(arg, py_devices, input_spec.indices));
149   ShardArgResult result;
150   result.owning_sda = py::reinterpret_borrow<py::object>(per_device_pybuffers);
151   std::vector<xla::PjRtBuffer*>& per_device_buffers = result.per_device_buffers;
152   if (!per_device_pybuffers.empty()) {
153     per_device_buffers.reserve(per_device_pybuffers.size());
154 
155     // The JAX Python shard_arg function is expected to return JAX PyBuffer
156     // objects. If executing a JAX extension, it should have fallbacked to
157     // Python well before this point.
158     TF_RET_CHECK(xla::PyBuffer::IsPyBuffer(per_device_pybuffers[0]));
159     for (py::handle per_device_pybuffer : per_device_pybuffers) {
160       xla::PjRtBuffer* buf =
161           xla::PyBuffer::AsPyBuffer(per_device_pybuffer).ValueOrDie()->buffer();
162       per_device_buffers.push_back(buf);
163     }
164   }
165   return result;
166 }
167 
168 struct PmapCacheEntry {
169   std::shared_ptr<xla::PyExecutable> executable;
170   // The value `backend.local_devices()`.
171   py::object py_devices;  // To pass back to Python.
172   std::vector<xla::PjRtDevice*> devices;
173   std::vector<InputSpec> input_specs;
174   xla::PyTreeDef out_pytree_def;
175   // Objects necessary to build the out ShardedDeviceArray objects.
176   std::vector<ResultSpec> out_result_specs;
177 
178   // Ensures a single thread performs the compilation for a given executable.
179   //
180   // The first thread (holding the GIL) will create the CacheEntry associated to
181   // a signature and if the object has been inserted already, other threads
182   // will wait for the notification.
183   absl::Notification compilation_complete;
184 
185   bool fall_back_to_python = false;
186 };
187 
188 }  // namespace
189 
190 // A `PmapFunction` is associated to a `jax.pmap(f)` and takes care of the
191 // bookkeeping of the different signatures used and the dispatch of calls to
192 // the correct underlying `PyExecutable`. This class is thread-safe.
193 class PmapFunction {
194  public:
PmapFunction(py::function fun,py::function cache_miss,std::vector<int> static_argnums,py::function python_shard_arg_fallback)195   PmapFunction(py::function fun, py::function cache_miss,
196                std::vector<int> static_argnums,
197                py::function python_shard_arg_fallback)
198       : fun_(std::move(fun)),
199         cache_miss_(std::move(cache_miss)),
200         static_argnums_(std::move(static_argnums)),
201         python_shard_arg_fallback_(std::move(python_shard_arg_fallback)) {
202     std::sort(static_argnums_.begin(), static_argnums_.end());
203 
204     function_name_ = py::str(py::getattr(fun_, "__name__", fun));
205   }
206   PmapFunction(const PmapFunction&) = delete;
207   PmapFunction& operator=(const PmapFunction& other) = delete;
208   PmapFunction(PmapFunction&&) = default;
209   PmapFunction& operator=(PmapFunction&&) = default;
210 
211   // This function will:
212   // (a) flatten the inputs using pytree
213   // (b) get buffer objects from the arguments
214   // (c) call the executable
215   // (d) construct `ShardedDeviceArray` objects from the outputs
216   // (e) reconstruct the `PyTree`.
217   xla::StatusOr<py::object> Call(py::args args, py::kwargs kwargs);
218 
PythonSignature()219   py::object PythonSignature() {
220     static const auto* inspect = new py::module(py::module::import("inspect"));
221     return inspect->attr("signature")(fun_);
222   }
223 
cache_size() const224   int cache_size() const { return executables_.size(); }
fun() const225   const py::function& fun() const { return fun_; }
cache_miss() const226   const py::function& cache_miss() const { return cache_miss_; }
function_name() const227   const std::string& function_name() const { return function_name_; }
python_shard_arg_fallback() const228   const py::function& python_shard_arg_fallback() const {
229     return python_shard_arg_fallback_;
230   }
static_argnums() const231   const std::vector<int>& static_argnums() const { return static_argnums_; }
232 
233   // pybind11::object typed subclass for PmapFunction objects.
234   class pyobject : public py::object {
235    public:
236     PYBIND11_OBJECT(pyobject,  // NOLINT
237                     py::object, PmapFunction::IsPmapFunction);
238     pyobject() = default;
func() const239     PmapFunction* func() const {
240       return PmapFunction::AsPmapFunctionUnchecked(*this);
241     }
242   };
243   // Alias as ::object; outside the scope above we won't confuse pybind11's
244   // macros.
245   using object = pyobject;
246 
247   py::handle AsPyHandle();
248   // Returns true if `h` is a PmapFunction.
249   static bool IsPmapFunction(py::handle handle);
250   // Converts `handle` to a PmapFunction*. Does not do any checking.
251   static PmapFunction* AsPmapFunctionUnchecked(py::handle handle);
252 
253   // Helper function used by the tp_clear GC method.
ClearPythonReferences()254   void ClearPythonReferences() {
255     py::function fun, cache_miss, python_shard_arg_fallback;
256     // Swap values for nulls before they are destroyed. See the Python
257     // Py_CLEAR() documentation for a discussion of this topic.
258     std::swap(fun_, fun);
259     std::swap(cache_miss_, cache_miss);
260     std::swap(python_shard_arg_fallback_, python_shard_arg_fallback);
261   }
262 
263   // Updates the signature of arguments for a pmapped function.
264   //
265   // It deals with the arguments signatures and also of the global and
266   // thread-local jit context.
UpdateArgsSignature(const py::args & args,const py::kwargs & kwargs,ParsedArgumentsAsBuffers & arguments)267   xla::Status UpdateArgsSignature(const py::args& args,
268                                   const py::kwargs& kwargs,
269                                   ParsedArgumentsAsBuffers& arguments) {
270     arguments.signature.function_name = function_name_;
271 
272     // Get dynamic argument signatures.
273     JitState& global_state = jax::GetGlobalState();
274     JitState& tls = jax::GetLocalState();
275     const bool jax_enable_x64 = GetEnableX64();
276     arguments.signature.jax_enable_x64 = jax_enable_x64;
277     for (py::handle arg : arguments.flat_dynamic_args) {
278       auto signature_or_error = xla::PyArgSignatureOfValue(arg, jax_enable_x64);
279       if (!signature_or_error.ok()) {
280         VLOG(2) << "PyArgSignatureOfValue failed: "
281                 << signature_or_error.status();
282         return signature_or_error.status();
283       }
284       arguments.signature.dynamic_arg_signatures.push_back(
285           std::move(signature_or_error).ValueOrDie());
286     }
287     try {
288       py::object pxla_module = py::module::import("jax").attr("config");
289       py::object sda = py::getattr(pxla_module, "_trace_context", py::none());
290       if (!sda.is_none()) {
291         arguments.signature.thread_local_extra_jit_context = sda();
292       }
293     } catch (const py::error_already_set& e) {
294       // Ignore; jax may not be present.
295     }
296     if (!arguments.signature.thread_local_extra_jit_context.has_value()) {
297       arguments.signature.thread_local_extra_jit_context =
298           tls.extra_jit_context;
299       arguments.signature.global_extra_jit_context =
300           global_state.extra_jit_context;
301     }
302     return xla::Status();
303   }
304 
305   // Returns, for debugging purposes (e.g. finding why some call misses the
306   // cache and recompiles), the list of the string representations of the keys.
307   //
308   // The format can change at any time.
DebugCacheKeys() const309   std::string DebugCacheKeys() const {
310     std::vector<std::string> key_strings = {
311         absl::StrCat("The cache contains ", executables_.size(), " elements:")};
312     // We will be able to use auto& [key, _] when TF uses C++ 17.
313     for (auto& pair : executables_) {
314       key_strings.push_back(pair.first.DebugString());
315     }
316     return absl::StrJoin(key_strings, "\n\n");
317   }
318 
319  private:
320   // Mutates `cache_entry` in place.
321   void PopulateCacheEntry(PmapCacheEntry& cache_entry,
322                           const CallSignature& signature,
323                           const py::tuple& out_and_fastpath_data);
324 
325   bool always_fallback_to_python_ = false;
326 
327   py::function fun_;  // The Python function to pmap.
328   std::string function_name_;
329   // See JAX _cpp_pmap in api.py for documentation.
330   py::function cache_miss_;
331 
332   // We need to know the static arguments to remove them from the arguments
333   // passed to the underlying PyExecutable. In sorted order.
334   std::vector<int> static_argnums_;
335   // We need a `unique_ptr` here to ensure value pointer stability.
336   absl::flat_hash_map<CallSignature, std::unique_ptr<PmapCacheEntry>>
337       executables_;
338 
339   // The fallback function to use with `ShardArgs`.
340   // TODO(jblespiau): Add support for more types from C++.
341   py::function python_shard_arg_fallback_;
342 };
343 
PopulateCacheEntry(PmapCacheEntry & cache_entry,const CallSignature & signature,const py::tuple & out_and_fastpath_data)344 void PmapFunction::PopulateCacheEntry(PmapCacheEntry& cache_entry,
345                                       const CallSignature& signature,
346                                       const py::tuple& out_and_fastpath_data) {
347   CHECK_EQ(out_and_fastpath_data.size(), 2);
348   if (out_and_fastpath_data[1].is_none()) {
349     cache_entry.fall_back_to_python = true;
350     return;
351   }
352 
353   py::tuple pmap_data = py::cast<py::tuple>(out_and_fastpath_data[1]);
354   if (py::cast<int>(pmap_data.attr("version")) != 1) {
355     throw xla::XlaRuntimeError(absl::StrCat(
356         "The versions of jaxlib and Jax are incompatible (pmap cpp version 1 "
357         "expected, but got ",
358         py::cast<int>(pmap_data.attr("version")),
359         "Upgrade jaxlib and jax. Provided data was:",
360         py::cast<std::string>(py::str(py::repr(pmap_data)))));
361   }
362   // See api.py::_PmapFastpathData in the JAX code base for the expected
363   // namedtuple.
364   std::shared_ptr<xla::PyExecutable> executable;
365   try {
366     executable = py::cast<std::shared_ptr<xla::PyExecutable>>(
367         pmap_data.attr("xla_executable"));
368   } catch (const py::cast_error& e) {
369     // Backends that don't implement the C++ PjRt APIs
370     always_fallback_to_python_ = true;
371     return;
372   }
373   cache_entry.executable = std::move(executable);
374   const std::vector<xla::ClientAndPtr<xla::PjRtDevice>>& client_and_devices =
375       cache_entry.executable->AddressableDevices();
376   cache_entry.devices.reserve(client_and_devices.size());
377   for (auto& client_and_device : client_and_devices) {
378     cache_entry.devices.push_back(client_and_device.get());
379   }
380 
381   // Inputs shard args details.
382   auto input_sharding_specs = py::cast<std::vector<ShardingSpec>>(
383       pmap_data.attr("input_sharding_specs"));
384   py::list input_indices = pmap_data.attr("input_indices");
385 
386   cache_entry.py_devices = pmap_data.attr("input_devices");
387   auto input_devices =
388       py::cast<std::vector<xla::PjRtDevice*>>(pmap_data.attr("input_devices"));
389   CHECK_EQ(input_sharding_specs.size(), input_indices.size());
390   cache_entry.input_specs.reserve(input_sharding_specs.size());
391   for (int i = 0; i < input_sharding_specs.size(); ++i) {
392     cache_entry.input_specs.emplace_back(input_sharding_specs[i],
393                                          input_indices[i]);
394   }
395 
396   // Outputs specs.
397   auto out_tree = py::cast<xla::PyTreeDef>(pmap_data.attr("out_pytree_def"));
398   cache_entry.out_pytree_def = std::move(out_tree);
399   py::list out_avals = pmap_data.attr("out_avals");
400   py::list out_indices = pmap_data.attr("out_indices");
401   auto out_sharding_specs =
402       py::cast<std::vector<ShardingSpec>>(pmap_data.attr("out_sharding_specs"));
403   CHECK_EQ(out_avals.size(), out_indices.size());
404   CHECK_EQ(out_indices.size(), out_sharding_specs.size());
405 
406   cache_entry.out_result_specs.reserve(out_avals.size());
407   for (int i = 0; i < out_avals.size(); ++i) {
408     cache_entry.out_result_specs.emplace_back(
409         out_avals[i], std::move(out_sharding_specs[i]), out_indices[i]);
410   }
411 }
412 
Call(py::args args,py::kwargs kwargs)413 xla::StatusOr<py::object> PmapFunction::Call(py::args args, py::kwargs kwargs) {
414   if (always_fallback_to_python_) {
415     return py::object(py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0]);
416   }
417 
418   ParsedArgumentsAsBuffers arguments;
419   xla::Status status = ParseArguments(args, kwargs, static_argnums_,
420                                       /*static_argnames=*/{}, arguments);
421   if (!status.ok()) {
422     VLOG(2) << "ParseArguments failed: " << status;
423     return py::object(py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0]);
424   }
425 
426   status = UpdateArgsSignature(args, kwargs, arguments);
427   if (!status.ok()) {
428     return py::object(py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0]);
429   }
430 
431   // Retrieve/Maybe add the executable to the cache.
432   absl::flat_hash_map<CallSignature, std::unique_ptr<PmapCacheEntry>>::iterator
433       it;
434   bool inserted;
435   std::tie(it, inserted) = executables_.try_emplace(
436       arguments.signature, std::make_unique<PmapCacheEntry>());
437   PmapCacheEntry& cache_entry = *(it->second);
438 
439   if (!cache_entry.compilation_complete.HasBeenNotified()) {
440     // In case of several threads attempting to compile the executable, only
441     // the one that inserted the item will perform the compilation.
442     if (inserted) {
443       py::object out_and_fastpath_data;
444       py::tuple out_tuple;
445       VLOG(2) << "Cache miss for " << arguments.signature.DebugString();
446       try {
447         // Calls Python and may release the GIL. May also throw if
448         // compilation/tracing fails.
449         out_and_fastpath_data = cache_miss_(*args, **kwargs);
450         out_tuple = py::cast<py::tuple>(out_and_fastpath_data);
451         PopulateCacheEntry(cache_entry, arguments.signature, out_tuple);
452       } catch (const std::exception& e) {
453         cache_entry.fall_back_to_python = true;
454         cache_entry.compilation_complete.Notify();
455         throw;
456       }
457       cache_entry.compilation_complete.Notify();
458 
459       // We have already computed the result in the miss path so we can return
460       // it. We are even *required* to do so if there are donated arguments,
461       // because any donated buffers will now be invalid.
462       return py::object(out_tuple[0]);
463     } else {
464       // Release the GIL while we wait, making sure the compile thread can
465       // lock it.
466       py::gil_scoped_release release;
467       cache_entry.compilation_complete.WaitForNotification();
468     }
469   }
470   if (cache_entry.fall_back_to_python) {
471     return py::object(py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0]);
472   }
473 
474   // 1. Parse arguments.
475   std::vector<xla::PjRtDevice*>& input_devices = cache_entry.devices;
476   const int num_computations =
477       cache_entry.executable->AddressableDevices().size();
478   std::vector<InputSpec>& input_specs = cache_entry.input_specs;
479   const int num_args = arguments.flat_dynamic_args.size();
480 
481   // We need [num_computation, num_args] for the `Execute` call bellow,
482   std::vector<std::vector<xla::PjRtBuffer*>> num_computation_num_args_buffers(
483       num_computations);
484   for (int computation = 0; computation < num_computations; ++computation) {
485     num_computation_num_args_buffers[computation].resize(num_args);
486   }
487   for (int i = 0; i < num_args; ++i) {
488     TF_ASSIGN_OR_RETURN(
489         ShardArgResult sharded_arg,
490         ShardArg(arguments.flat_dynamic_args[i], input_devices, input_specs[i],
491                  cache_entry.py_devices, python_shard_arg_fallback_));
492 
493     std::vector<xla::PjRtBuffer*>& per_device_buffers =
494         sharded_arg.per_device_buffers;
495     for (int computation = 0; computation < num_computations; ++computation) {
496       num_computation_num_args_buffers[computation][i] =
497           per_device_buffers[computation];
498     }
499     for (auto& owned_buffer : sharded_arg.owned_buffers) {
500       arguments.keep_alive.push_back(std::move(owned_buffer));
501     }
502     if (sharded_arg.owning_sda) {
503       arguments.keep_alive_objects.push_back(std::move(sharded_arg.owning_sda));
504     }
505   }
506 
507   // A vector of [num_devices, num_outputs].
508   std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>> output_buffers;
509   {
510     py::gil_scoped_release gil_release;
511     auto pjrt_executable = cache_entry.executable->mutable_pjrt_executable();
512     TF_ASSIGN_OR_RETURN(output_buffers, pjrt_executable->Execute(
513                                             num_computation_num_args_buffers,
514                                             cache_entry.executable->options()));
515   }
516 
517   // TODO(jblespiau): We don't need to create the PyBuffer objects.
518   // Having a C++ `ShardedDeviceArray`, keeping internally the PjRtBuffer
519   // objects is sufficient, and we can lazily create the `PyBuffer` only if
520   // we access them from Python.
521   auto traceback = xla::Traceback::Get();
522   // TODO(jblespiau): Change the `client` function to return a reference.
523   std::shared_ptr<xla::PyClient> client = cache_entry.executable->client();
524 
525   // Convert the PjRtBuffer objects to PyBuffer, and invert the order from
526   // [num_devices, num_args] to [num_args, num_devices].
527   const int num_outputs = output_buffers[0].size();
528   std::vector<std::vector<xla::PyBuffer::object>> outputs;
529   outputs.resize(num_outputs);
530   for (int output_id = 0; output_id < num_outputs; ++output_id) {
531     outputs[output_id].reserve(num_computations);
532     for (int computation = 0; computation < num_computations; ++computation) {
533       outputs[output_id].push_back(xla::PyBuffer::Make(
534           client, std::move(output_buffers[computation][output_id]),
535           traceback));
536     }
537   }
538 
539   py::list outputs_as_python_objects;
540   const auto& output_specs = cache_entry.out_result_specs;
541 
542   std::vector<py::object> flat_sharded_device_arrays;
543   flat_sharded_device_arrays.reserve(num_outputs);
544   for (int i = 0; i < num_outputs; ++i) {
545     const ResultSpec& result_spec = output_specs[i];
546     flat_sharded_device_arrays.push_back(ShardedDeviceArray::Make(
547         /*aval=*/result_spec.out_aval,
548         /*sharding_spec=*/result_spec.out_spec,
549         /*device_buffers=*/py::cast(std::move(outputs[i])),
550         /*indices=*/result_spec.out_indices,
551         /*weak_type=*/result_spec.weak_type));
552   }
553   py::object out =
554       cache_entry.out_pytree_def.Unflatten(flat_sharded_device_arrays);
555 
556   // If there is a post-hook function, call it with the inputs and the outputs.
557   std::optional<py::object> post_hook = GetPostHook();
558   if (post_hook) {
559     (*post_hook)(this->AsPyHandle(), args, kwargs, out);
560   }
561 
562   return out;
563 }
564 
565 struct JaxPmapFunctionObject {
566   PyObject_HEAD;
567   PyObject* dict;      // Dictionary for __dict__
568   PyObject* weakrefs;  // Weak references; for use by the Python interpreter.
569   PmapFunction fun;
570 };
571 
572 PyObject* JaxPmapFunction_Type = nullptr;
573 
IsPmapFunction(py::handle handle)574 bool PmapFunction::IsPmapFunction(py::handle handle) {
575   return handle.get_type() == JaxPmapFunction_Type;
576 }
577 
AsPmapFunctionUnchecked(py::handle handle)578 PmapFunction* PmapFunction::AsPmapFunctionUnchecked(py::handle handle) {
579   return &(reinterpret_cast<JaxPmapFunctionObject*>(handle.ptr())->fun);
580 }
581 
AsPmapFunction(py::handle handle)582 xla::StatusOr<PmapFunction*> AsPmapFunction(py::handle handle) {
583   if (!PmapFunction::IsPmapFunction(handle)) {
584     return xla::InvalidArgument("Expected a PmapFunction");
585   }
586   return PmapFunction::AsPmapFunctionUnchecked(handle);
587 }
588 
AsPyHandle()589 py::handle PmapFunction::AsPyHandle() {
590   return reinterpret_cast<PyObject*>(reinterpret_cast<char*>(this) -
591                                      offsetof(JaxPmapFunctionObject, fun));
592 }
593 
594 namespace {
595 
596 extern "C" {
597 
JaxPmapFunction_tp_new(PyTypeObject * subtype,PyObject * args,PyObject * kwds)598 PyObject* JaxPmapFunction_tp_new(PyTypeObject* subtype, PyObject* args,
599                                  PyObject* kwds) {
600   JaxPmapFunctionObject* self =
601       reinterpret_cast<JaxPmapFunctionObject*>(subtype->tp_alloc(subtype, 0));
602   if (!self) return nullptr;
603   self->dict = nullptr;
604   self->weakrefs = nullptr;
605   return reinterpret_cast<PyObject*>(self);
606 }
607 
JaxPmapFunction_tp_dealloc(PyObject * self)608 void JaxPmapFunction_tp_dealloc(PyObject* self) {
609   PyTypeObject* tp = Py_TYPE(self);
610   JaxPmapFunctionObject* o = reinterpret_cast<JaxPmapFunctionObject*>(self);
611   if (o->weakrefs) {
612     PyObject_ClearWeakRefs(self);
613   }
614   Py_CLEAR(o->dict);
615   o->fun.~PmapFunction();
616   tp->tp_free(self);
617   Py_DECREF(tp);
618 }
619 
JaxPmapFunction_tp_traverse(PyObject * self,visitproc visit,void * arg)620 int JaxPmapFunction_tp_traverse(PyObject* self, visitproc visit, void* arg) {
621   JaxPmapFunctionObject* o = reinterpret_cast<JaxPmapFunctionObject*>(self);
622   Py_VISIT(o->dict);
623   Py_VISIT(o->fun.fun().ptr());
624   Py_VISIT(o->fun.cache_miss().ptr());
625   return 0;
626 }
627 
JaxPmapFunction_tp_clear(PyObject * self)628 int JaxPmapFunction_tp_clear(PyObject* self) {
629   JaxPmapFunctionObject* o = reinterpret_cast<JaxPmapFunctionObject*>(self);
630   Py_CLEAR(o->dict);
631   o->fun.ClearPythonReferences();
632   return 0;
633 }
634 
635 // Implements the Python descriptor protocol so PMAP-compiled functions can be
636 // used as bound methods. See:
637 // https://docs.python.org/3/howto/descriptor.html#functions-and-methods
JaxPmapFunction_tp_descr_get(PyObject * self,PyObject * obj,PyObject * type)638 PyObject* JaxPmapFunction_tp_descr_get(PyObject* self, PyObject* obj,
639                                        PyObject* type) {
640   if (obj == nullptr || obj == Py_None) {
641     Py_INCREF(self);
642     return self;
643   }
644   return PyMethod_New(self, obj);
645 }
646 
647 // Support d = instance.__dict__.
JaxPmapFunction_get_dict(PyObject * self,void *)648 PyObject* JaxPmapFunction_get_dict(PyObject* self, void*) {
649   JaxPmapFunctionObject* o = reinterpret_cast<JaxPmapFunctionObject*>(self);
650   if (!o->dict) {
651     o->dict = PyDict_New();
652   }
653   Py_XINCREF(o->dict);
654   return o->dict;
655 }
656 
JaxPmapFunction_set_dict(PyObject * self,PyObject * new_dict,void *)657 int JaxPmapFunction_set_dict(PyObject* self, PyObject* new_dict, void*) {
658   JaxPmapFunctionObject* o = reinterpret_cast<JaxPmapFunctionObject*>(self);
659   if (!PyDict_Check(new_dict)) {
660     PyErr_Format(PyExc_TypeError,
661                  "__dict__ must be set to a dictionary, not a '%s'",
662                  Py_TYPE(new_dict)->tp_name);
663     return -1;
664   }
665   Py_INCREF(new_dict);
666   Py_CLEAR(o->dict);
667   o->dict = new_dict;
668   return 0;
669 }
670 
671 static PyGetSetDef JaxPmapFunction_tp_getset[] = {
672     // Having a __dict__ seems necessary to allow !functool.wraps to override
673     // __doc__.
674     {const_cast<char*>("__dict__"), JaxPmapFunction_get_dict,
675      JaxPmapFunction_set_dict, nullptr, nullptr},
676     {nullptr, nullptr, nullptr, nullptr, nullptr}};
677 
JaxPmapFunction_tp_call(PyObject * self,PyObject * args,PyObject * kwargs)678 PyObject* JaxPmapFunction_tp_call(PyObject* self, PyObject* args,
679                                   PyObject* kwargs) {
680   JaxPmapFunctionObject* o = reinterpret_cast<JaxPmapFunctionObject*>(self);
681   tensorflow::profiler::TraceMe traceme([&] {
682     return absl::StrCat("JaxPmapFunction(", o->fun.function_name(), ")");
683   });
684   py::kwargs py_kwargs;
685   if (kwargs) {
686     py_kwargs = py::reinterpret_borrow<py::kwargs>(kwargs);
687   }
688   try {
689     xla::StatusOr<py::object> out = o->fun.Call(
690         py::reinterpret_borrow<py::args>(args), std::move(py_kwargs));
691     if (!out.ok()) {
692       PyErr_SetString(PyExc_ValueError, out.status().ToString().c_str());
693       return nullptr;
694     }
695     return out.ValueOrDie().release().ptr();
696   } catch (py::error_already_set& e) {
697     e.restore();
698     return nullptr;
699   } catch (py::cast_error& e) {
700     PyErr_SetString(PyExc_ValueError, e.what());
701     return nullptr;
702   } catch (std::invalid_argument& e) {
703     PyErr_SetString(PyExc_ValueError, e.what());
704     return nullptr;
705   }
706 }
707 
InitializePmapFunction(JaxPmapFunctionObject * cfun,py::function fun,py::function cache_miss,std::vector<int> static_argnums,py::function python_shard_arg_fallback)708 void InitializePmapFunction(JaxPmapFunctionObject* cfun, py::function fun,
709                             py::function cache_miss,
710                             std::vector<int> static_argnums,
711                             py::function python_shard_arg_fallback) {
712   new (&cfun->fun) PmapFunction(std::move(fun), std::move(cache_miss),
713                                 std::move(static_argnums),
714                                 std::move(python_shard_arg_fallback));
715 }
716 
717 }  // extern "C"
718 
MakePmapFunction(py::function fun,py::function cache_miss,std::vector<int> static_argnums,py::function python_shard_arg_fallback)719 py::object MakePmapFunction(py::function fun, py::function cache_miss,
720                             std::vector<int> static_argnums,
721                             py::function python_shard_arg_fallback) {
722   py::object obj = py::reinterpret_steal<py::object>(JaxPmapFunction_tp_new(
723       reinterpret_cast<PyTypeObject*>(JaxPmapFunction_Type), nullptr, nullptr));
724   JaxPmapFunctionObject* buf =
725       reinterpret_cast<JaxPmapFunctionObject*>(obj.ptr());
726   InitializePmapFunction(buf, std::move(fun), std::move(cache_miss),
727                          std::move(static_argnums),
728                          std::move(python_shard_arg_fallback));
729   return obj;
730 }
731 
732 // Version numbers for the pickled representations.
733 // Increment these if changing them.
734 const int kPmapFunctionPickleVersion = 1;
735 
736 }  // namespace
737 
BuildPmapSubmodule(py::module & m)738 void BuildPmapSubmodule(py::module& m) {
739   py::module pmap_lib = m.def_submodule("pmap_lib", "Jax C++ pmap library");
740 
741   py::class_<NoSharding> no_sharding(pmap_lib, "NoSharding");
742   no_sharding.def(py::init<>())
743       .def("__repr__",
744            [](const NoSharding& chuncked) { return "NoSharding()"; })
745       .def("__eq__",
746            [](const NoSharding& self, py::object obj) {
747              return py::isinstance<NoSharding>(obj);
748            })
749       .def("__hash__", [](const NoSharding& self) {
750         const size_t hash = absl::HashOf(self);
751         return py::int_(hash);
752       });
753 
754   py::class_<Chunked> chunked(pmap_lib, "Chunked");
755   chunked.def(py::init<std::vector<int>>())
756       .def_readonly("chunks", &Chunked::chunks)
757       .def("__repr__",
758            [](const Chunked& chuncked) {
759              return absl::StrCat("Chunked(",
760                                  absl::StrJoin(chuncked.chunks, ","), ")");
761            })
762       .def("__eq__", [](const Chunked& self, py::object other) {
763         if (!py::isinstance<Chunked>(other)) {
764           return false;
765         }
766         return self == py::cast<const Chunked&>(other);
767       });
768 
769   py::class_<Unstacked> unstacked(pmap_lib, "Unstacked");
770   unstacked.def(py::init<int>())
771       .def_readonly("size", &Unstacked::size)
772       .def("__repr__",
773            [](const Unstacked& x) {
774              return absl::StrCat("Unstacked(", x.size, ")");
775            })
776       .def("__eq__", [](const Unstacked& self, py::object other) {
777         if (!py::isinstance<Unstacked>(other)) {
778           return false;
779         }
780         return self == py::cast<const Unstacked&>(other);
781       });
782 
783   py::class_<ShardedAxis> sharded_axis(pmap_lib, "ShardedAxis");
784   sharded_axis.def(py::init<int>()).def_readonly("axis", &ShardedAxis::axis);
785   sharded_axis
786       .def("__repr__",
787            [](const ShardedAxis& x) {
788              return absl::StrCat("ShardedAxis(axis=", x.axis, ")");
789            })
790       .def("__eq__", [](const ShardedAxis& self, const ShardedAxis& other) {
791         return self == other;
792       });
793 
794   py::class_<Replicated> replicated(pmap_lib, "Replicated");
795   replicated.def(py::init<int>())
796       .def_readonly("replicas", &Replicated::replicas)
797       .def("__repr__",
798            [](const Replicated& x) {
799              return absl::StrCat("Replicated(replicas=", x.replicas, ")");
800            })
801       .def("__eq__", [](const Replicated& self, const Replicated& other) {
802         return self == other;
803       });
804 
805   py::class_<ShardingSpec> sharding_spec(pmap_lib, "ShardingSpec");
806   sharding_spec
807       .def(py::init<py::iterable, py::iterable>(), py::arg("sharding"),
808            py::arg("mesh_mapping"))
809       .def_property_readonly(
810           "sharding",
811           [](const ShardingSpec& self) {
812             return xla::SpanToTuple(absl::MakeConstSpan(self.GetSharding()));
813           })
814       .def_property_readonly(
815           "mesh_mapping",
816           [](const ShardingSpec& self) {
817             return xla::SpanToTuple(absl::MakeConstSpan(self.GetMeshMapping()));
818           })
819       .def("__eq__", [](const ShardingSpec& self,
820                         const ShardingSpec& other) { return self == other; })
821       .def("__hash__", [](const ShardingSpec& self) {
822         const size_t hash = absl::HashOf(self);
823         return py::int_(hash);
824       });
825 
826   TF_CHECK_OK(ShardedDeviceArray::RegisterTypes(pmap_lib));
827 
828   // We need to use heap-allocated type objects because we want to add
829   // additional methods dynamically.
830   py::object cfun;
831   {
832     py::str name = py::str("PmapFunction");
833     py::str qualname = py::str("PmapFunction");
834     PyHeapTypeObject* heap_type = reinterpret_cast<PyHeapTypeObject*>(
835         PyType_Type.tp_alloc(&PyType_Type, 0));
836     // Caution: we must not call any functions that might invoke the GC until
837     // PyType_Ready() is called. Otherwise the GC might see a half-constructed
838     // type object.
839     CHECK(heap_type) << "Unable to create heap type object";
840     heap_type->ht_name = name.release().ptr();
841     heap_type->ht_qualname = qualname.release().ptr();
842     PyTypeObject* type = &heap_type->ht_type;
843     type->tp_name = "PmapFunction";
844     type->tp_basicsize = sizeof(JaxPmapFunctionObject);
845     type->tp_flags =
846         Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE | Py_TPFLAGS_HAVE_GC;
847     type->tp_new = JaxPmapFunction_tp_new;
848     type->tp_dealloc = JaxPmapFunction_tp_dealloc;
849     type->tp_dictoffset = offsetof(JaxPmapFunctionObject, dict);
850     type->tp_traverse = JaxPmapFunction_tp_traverse;
851     type->tp_clear = JaxPmapFunction_tp_clear;
852     type->tp_weaklistoffset = offsetof(JaxPmapFunctionObject, weakrefs);
853     type->tp_getset = JaxPmapFunction_tp_getset;
854     type->tp_descr_get = JaxPmapFunction_tp_descr_get;
855     type->tp_call = JaxPmapFunction_tp_call;
856     CHECK_EQ(PyType_Ready(type), 0);
857     JaxPmapFunction_Type = reinterpret_cast<PyObject*>(type);
858     cfun = py::reinterpret_borrow<py::object>(JaxPmapFunction_Type);
859   }
860   py::object cfun_type =
861       py::reinterpret_borrow<py::object>(JaxPmapFunction_Type);
862 
863   // Add PmapFunction to the xla_extension module so it can be pickled.
864   m.attr("PmapFunction") = cfun_type;
865 
866   cfun.attr("__signature__") =
867       property_readonly([](py::handle self) -> xla::StatusOr<py::object> {
868         TF_ASSIGN_OR_RETURN(PmapFunction * fun, AsPmapFunction(self));
869         return fun->PythonSignature();
870       });
871   // Required by `post_hook`.
872   cfun.attr("_cache_miss") =
873       property_readonly([](py::handle self) -> xla::StatusOr<py::object> {
874         TF_ASSIGN_OR_RETURN(PmapFunction * fun, AsPmapFunction(self));
875         return fun->cache_miss();
876       });
877   cfun.attr("__getstate__") = py::cpp_function(
878       [](const PmapFunction::object& self) {
879         PmapFunction* fn = self.func();
880         py::dict pickle;
881         pickle["version"] = kPmapFunctionPickleVersion;
882         pickle["fun"] = fn->fun();
883         pickle["cache_miss"] = fn->cache_miss();
884         pickle["static_argnums"] = fn->static_argnums();
885         pickle["python_shard_arg_fallback"] = fn->python_shard_arg_fallback();
886         return pickle;
887       },
888       py::is_method(cfun_type));
889   cfun.attr("__setstate__") = py::cpp_function(
890       [](PmapFunction::object& self, const py::dict& pickle) {
891         int version = py::cast<int>(pickle["version"]);
892         if (version != kPmapFunctionPickleVersion) {
893           throw std::invalid_argument(absl::StrFormat(
894               "Invalid PmapFunction pickle version, got %d, expected %d. "
895               "Pickling/Unpickling jitted functions using different JAX "
896               "versions is not supported.",
897               version, kPmapFunctionPickleVersion));
898         }
899         py::function fun = py::cast<py::function>(pickle["fun"]);
900         py::function cache_miss = py::cast<py::function>(pickle["cache_miss"]);
901         std::vector<int> static_argnums =
902             py::cast<std::vector<int>>(pickle["static_argnums"]);
903         py::function python_shard_arg_fallback =
904             py::cast<py::function>(pickle["python_shard_arg_fallback"]);
905 
906         InitializePmapFunction(
907             reinterpret_cast<JaxPmapFunctionObject*>(self.ptr()),
908             std::move(fun), std::move(cache_miss), std::move(static_argnums),
909             std::move(python_shard_arg_fallback));
910       },
911       py::is_method(cfun_type));
912 
913   // This is only for testing/debugging purposes.
914   cfun.attr("_cache_size") =
915       property_readonly([](py::handle self) -> xla::StatusOr<py::object> {
916         TF_ASSIGN_OR_RETURN(PmapFunction * fun, AsPmapFunction(self));
917         return py::cast<int>(fun->cache_size());
918       });
919 
920   cfun.attr("_debug_cache_keys") = py::cpp_function(
921       [](py::handle self) -> xla::StatusOr<std::string> {
922         TF_ASSIGN_OR_RETURN(PmapFunction * fun, AsPmapFunction(self));
923         return fun->DebugCacheKeys();
924       },
925       py::is_method(cfun_type));
926 
927   // Accepts _arbitrary_ arguments for a pmapped function and returns the
928   // corresponding signatures that are used as cache keys. No-op.
929   //
930   // This function allows to pass partial args, which is especially useful when
931   // the full list of arguments is too long and results in enormous signatures.
932   // For example, this function can be multiple times as
933   // > fn._debug_compute_cache_key(arg[0])
934   // > fn._debug_compute_cache_key(arg[1])
935   // > fn._debug_compute_cache_key(arg[-3:-1])
936   // ...
937   cfun.attr("_debug_compute_cache_key") = py::cpp_function(
938       [](const PmapFunction::object& self, const py::args& args,
939          const py::kwargs& kwargs) -> xla::StatusOr<std::string> {
940         ParsedArgumentsAsBuffers arguments;
941         TF_ASSIGN_OR_RETURN(PmapFunction * fun, AsPmapFunction(self));
942         TF_RETURN_IF_ERROR(ParseArguments(args, kwargs, fun->static_argnums(),
943                                           /*static_argnames=*/{}, arguments));
944         TF_RETURN_IF_ERROR(fun->UpdateArgsSignature(args, kwargs, arguments));
945         return arguments.signature.DebugString();
946       },
947       py::is_method(cfun_type));
948 
949   pmap_lib.def("pmap",
950                [](py::function fun, py::function cache_miss,
951                   std::vector<int> static_argnums,
952                   py::function python_shard_arg_fallback) -> py::object {
953                  return MakePmapFunction(std::move(fun), std::move(cache_miss),
954                                          std::move(static_argnums),
955                                          std::move(python_shard_arg_fallback));
956                });
957 }
958 
959 }  // namespace jax
960