xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/jax_jit.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This files implements the `jax.jit` dispatch and just-in-time feature.
17 //
18 // In a nutshell, `Jit(f)` returns a callable that will dispatch (i.e. forward
19 // based on passed arguments dtypes/shapes/identity) the execution to a
20 // just-in-time compiled XLA Executable. All of that is done in C++ for
21 // performance reasons.
22 //
23 // This file contains the utilities to:
24 // (a) inspect arguments and describe their structure, dtype/shapes, etc.
25 // (b) keep a mapping from function signatures to compiled XLA Executables.
26 
27 #include "tensorflow/compiler/xla/python/jax_jit.h"
28 
29 #include <Python.h>
30 
31 #include <algorithm>
32 #include <exception>
33 #include <memory>
34 #include <optional>
35 #include <stdexcept>
36 #include <string>
37 #include <utility>
38 
39 #include "absl/container/flat_hash_map.h"
40 #include "absl/container/inlined_vector.h"
41 #include "absl/strings/str_cat.h"
42 #include "absl/strings/str_format.h"
43 #include "absl/synchronization/notification.h"
44 #include "absl/types/span.h"
45 #include "pybind11/cast.h"
46 #include "pybind11/numpy.h"
47 #include "pybind11/pybind11.h"
48 #include "pybind11/pytypes.h"
49 #include "tensorflow/compiler/xla/pjrt/lru_cache.h"
50 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
51 #include "tensorflow/compiler/xla/python/exceptions.h"
52 #include "tensorflow/compiler/xla/python/py_buffer.h"
53 #include "tensorflow/compiler/xla/python/py_executable.h"
54 #include "tensorflow/compiler/xla/python/py_values.h"
55 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
56 #include "tensorflow/compiler/xla/python/python_utils.h"
57 #include "tensorflow/compiler/xla/python/pytree.h"
58 #include "tensorflow/compiler/xla/python/types.h"
59 #include "tensorflow/compiler/xla/shape_util.h"
60 #include "tensorflow/compiler/xla/statusor.h"
61 #include "tensorflow/compiler/xla/types.h"
62 #include "tensorflow/compiler/xla/util.h"
63 #include "tensorflow/compiler/xla/xla_data.pb.h"
64 #include "tensorflow/core/platform/status.h"
65 #include "tensorflow/core/profiler/lib/traceme.h"
66 
67 namespace jax {
68 
69 namespace py = pybind11;
70 
71 // TODO(phawkins): Add support for Tracers.
72 // TODO(jblespiau): Use absl Status.
73 
74 namespace {
75 
76 // Protected by the GIL.
77 JitState& global_state = *new JitState();
78 
79 // TODO(phawkins): Google style guide forbids thread-local values with
80 // non-trivial destructors.
81 ABSL_CONST_INIT thread_local JitState thread_local_state;  // NOLINT
82 
83 }  // namespace
84 
85 // `thread_local_state.extra_jit_context` is set from Python. It's done when
86 // loading the Python jax modules on the main-thread. For other threads, we
87 // need to initialize the field the first time we access `thread_local_state`.
88 py::object& initialize_local_state = *new py::object();
89 
GetGlobalState()90 JitState& GetGlobalState() { return global_state; }
GetLocalState()91 JitState& GetLocalState() {
92   if (thread_local_state.extra_jit_context == std::nullopt) {
93     CHECK(initialize_local_state.ptr() != nullptr);
94     initialize_local_state();
95   }
96   return thread_local_state;
97 }
98 
GetDisableJit()99 bool GetDisableJit() {
100   CHECK(global_state.disable_jit.has_value());
101   return thread_local_state.disable_jit.value_or(*global_state.disable_jit);
102 }
103 
GetEnableX64()104 bool GetEnableX64() {
105   CHECK(global_state.enable_x64.has_value());
106   return thread_local_state.enable_x64.value_or(*global_state.enable_x64);
107 }
108 
GetDefaultDevice()109 std::optional<py::object> GetDefaultDevice() {
110   return thread_local_state.default_device.has_value()
111              ? thread_local_state.default_device
112              : global_state.default_device;
113 }
114 
GetPostHook()115 std::optional<pybind11::function> GetPostHook() {
116   return thread_local_state.post_hook.has_value() ? thread_local_state.post_hook
117                                                   : global_state.post_hook;
118 }
119 
OptionalDebugString(const std::optional<py::object> optional)120 static std::string OptionalDebugString(
121     const std::optional<py::object> optional) {
122   if (optional.has_value()) {
123     return py::cast<std::string>(py::str(optional.value()));
124   } else {
125     return "None";
126   }
127 }
128 
DebugString() const129 std::string CallSignature::DebugString() const {
130   auto py_object_formatter = [](std::string* out, const py::object& o) {
131     out->append(py::cast<std::string>(py::str(o)));
132   };
133   auto treedef_formatter = [](std::string* out, const xla::PyTreeDef& d) {
134     out->append(d.ToString());
135   };
136   auto signature_formatter = [](std::string* out,
137                                 const xla::PyArgSignature& s) {
138     out->append(s.DebugString());
139   };
140   return absl::StrFormat(
141       "static args (positional + keyword): %s\nstatic arg keyword names: %s\n"
142       "dynamic arg signatures (positional + keyword): %s\n"
143       "dynamic arg keyword names: %s\ndynamic arg treedefs: %s\n"
144       "device: %s\n"
145       "jax_enable_x64: %d\n"
146       "global_extra_jit_context: %s\n"
147       "thread_local_extra_jit_context: %s\n",
148       absl::StrJoin(static_args, ",", py_object_formatter),
149       absl::StrJoin(static_arg_names, ",", py_object_formatter),
150       absl::StrJoin(dynamic_arg_signatures, ", ", signature_formatter),
151       absl::StrJoin(dynamic_arg_names, ",", py_object_formatter),
152       absl::StrJoin(dynamic_arg_treedefs, "| ", treedef_formatter),  // new line
153       device != nullptr ? device->DebugString() : "nullptr", jax_enable_x64,
154       OptionalDebugString(global_extra_jit_context),
155       OptionalDebugString(thread_local_extra_jit_context));
156 }
157 
operator ==(const CallSignature & other) const158 bool CallSignature::operator==(const CallSignature& other) const {
159   return std::tie(dynamic_arg_treedefs, dynamic_arg_names,
160                   dynamic_arg_signatures, device, jax_enable_x64,
161                   static_arg_names) ==
162              std::tie(other.dynamic_arg_treedefs, other.dynamic_arg_names,
163                       other.dynamic_arg_signatures, other.device,
164                       other.jax_enable_x64, other.static_arg_names) &&
165          // `==` on py:objects is the Python `is`. We need equal.
166          std::equal(
167              static_args.begin(), static_args.end(), other.static_args.begin(),
168              other.static_args.end(),
169              [this](const py::object& a, const py::object& b) {
170                try {
171                  return py::type::handle_of(a) == py::type::handle_of(b) &&
172                         a.equal(b);
173                } catch (const py::error_already_set& e) {
174                  throw std::invalid_argument(absl::StrCat(
175                      "static arguments should be comparable using __eq__."
176                      "The following error was raised during a call to '",
177                      function_name, "' when comparing two objects of types ",
178                      py::cast<std::string>(py::str(py::type::of(a))), " and ",
179                      py::cast<std::string>(py::str(py::type::of(b))),
180                      ". The error was:\n", e.what()));
181                }
182              }) &&
183          (global_extra_jit_context.has_value() ==
184           other.global_extra_jit_context.has_value()) &&
185          (!global_extra_jit_context.has_value() ||
186           global_extra_jit_context->equal(*other.global_extra_jit_context)) &&
187          (thread_local_extra_jit_context.has_value() ==
188           other.thread_local_extra_jit_context.has_value()) &&
189          (!thread_local_extra_jit_context.has_value() ||
190           thread_local_extra_jit_context->equal(
191               *other.thread_local_extra_jit_context));
192 }
193 
194 template <typename H>
AbslHashValue(H h,const CallSignature & s)195 H AbslHashValue(H h, const CallSignature& s) {
196   h = H::combine(std::move(h), s.dynamic_arg_treedefs,
197                  s.dynamic_arg_signatures);
198   for (const auto& name : s.dynamic_arg_names) {
199     h = H::combine(std::move(h), name.ptr());
200   }
201   h = H::combine(std::move(h), s.dynamic_arg_names.size());
202   for (const auto& static_arg : s.static_args) {
203     ssize_t hash;
204     try {
205       hash = py::hash(static_arg);
206     } catch (const py::error_already_set& e) {
207       if (!e.matches(PyExc_TypeError)) throw;
208       throw std::invalid_argument(absl::StrCat(
209           "Non-hashable static arguments are not supported. An error occurred "
210           "during a call to '",
211           s.function_name, "' while trying to hash an object of type ",
212           py::cast<std::string>(py::str(py::type::of(static_arg))), ", ",
213           py::cast<std::string>(py::str(static_arg)), ". The error was:\n",
214           e.what(), "\n"));
215     }
216     h = H::combine(std::move(h), hash);
217   }
218   h = H::combine(std::move(h), s.static_args.size());
219   for (const auto& name : s.static_arg_names) {
220     h = H::combine(std::move(h), name.ptr());
221   }
222   h = H::combine(std::move(h), s.static_arg_names.size());
223   h = H::combine(std::move(h), s.device, s.jax_enable_x64);
224 
225   // We do not hash the extra_jit_context fields since calling Python hash
226   // functions is expensive (~300ns) and we don't expect a large number of
227   // different contexts.
228   return h;
229 }
230 
231 // Filter out static arguments, flatten and concatenate other arguments (i.e.
232 // dynamic positional and keyword arguments), filling `arguments` in place.
ParseArguments(py::handle args,const std::optional<py::kwargs> & py_kwargs,absl::Span<int const> static_argnums,absl::Span<py::str const> static_argnames,ParsedArgumentsAsBuffers & arguments)233 xla::Status ParseArguments(py::handle args,
234                            const std::optional<py::kwargs>& py_kwargs,
235                            absl::Span<int const> static_argnums,
236                            absl::Span<py::str const> static_argnames,
237                            ParsedArgumentsAsBuffers& arguments) {
238   tensorflow::profiler::TraceMe traceme("ParseArguments");
239   int num_args = PyTuple_GET_SIZE(args.ptr());
240   int num_kwargs = py_kwargs ? py_kwargs->size() : 0;
241 
242   arguments.flat_dynamic_args.reserve(num_args + num_kwargs);
243   if (static_argnums.empty()) {
244     arguments.signature.dynamic_arg_treedefs.resize(num_args);
245 
246     // Positional arguments.
247     for (int i = 0; i < num_args; ++i) {
248       xla::PyTreeDef& pytree_def = arguments.signature.dynamic_arg_treedefs[i];
249       pytree_def.FlattenInto(PyTuple_GET_ITEM(args.ptr(), i),
250                              arguments.flat_dynamic_args);
251     }
252   } else {
253     arguments.signature.dynamic_arg_treedefs.reserve(num_args);
254 
255     // Positional arguments.
256     for (int i = 0; i < num_args; ++i) {
257       if (std::find(static_argnums.begin(), static_argnums.end(), i) ==
258           static_argnums.end()) {
259         arguments.signature.dynamic_arg_treedefs.emplace_back();
260         xla::PyTreeDef& pytree_def =
261             arguments.signature.dynamic_arg_treedefs.back();
262         pytree_def.FlattenInto(PyTuple_GET_ITEM(args.ptr(), i),
263                                arguments.flat_dynamic_args);
264       } else {
265         arguments.signature.static_args.emplace_back(
266             py::reinterpret_borrow<py::object>(
267                 PyTuple_GET_ITEM(args.ptr(), i)));
268       }
269     }
270   }
271 
272   // Keyword arguments.
273   if (py_kwargs) {
274     std::vector<std::pair<py::handle, py::handle>> kwargs(py_kwargs->begin(),
275                                                           py_kwargs->end());
276     // We first intern the keys, then sort them (by name, as in the Python path)
277     // (see also xla::PyTreeDef::Flatten) and then create the signatures.
278     // TODO(jblespiau): We should be able to sort the keys by interned-key
279     // pointers, but this requires the Python compilation to do the same.
280     for (int i = 0; i < num_kwargs; ++i) {
281       // Intern the key if not already interned.
282       kwargs[i].first.inc_ref();
283       if (!PyUnicode_CHECK_INTERNED(kwargs[i].first.ptr())) {
284         PyUnicode_InternInPlace(&kwargs[i].first.ptr());
285       }
286     }
287 
288     std::sort(kwargs.begin(), kwargs.end(),
289               [](const std::pair<py::handle, py::handle>& a,
290                  const std::pair<py::handle, py::handle>& b) {
291                 return a.first < b.first;
292               });
293     auto kwarg_is_static = [&](py::handle name) {
294       for (const auto& kw : static_argnames) {
295         if (kw.ptr() == name.ptr()) return true;
296       }
297       return false;
298     };
299 
300     arguments.signature.dynamic_arg_names.reserve(num_kwargs);
301     for (int i = 0; i < num_kwargs; ++i) {
302       if (kwarg_is_static(kwargs[i].first)) {
303         arguments.signature.static_arg_names.push_back(
304             py::reinterpret_steal<py::object>(kwargs[i].first));
305         arguments.signature.static_args.push_back(
306             py::reinterpret_borrow<py::object>(kwargs[i].second));
307       } else {
308         arguments.signature.dynamic_arg_names.push_back(
309             py::reinterpret_steal<py::object>(kwargs[i].first));
310         arguments.signature.dynamic_arg_treedefs.emplace_back();
311         xla::PyTreeDef& pytree_def =
312             arguments.signature.dynamic_arg_treedefs.back();
313         pytree_def.FlattenInto(kwargs[i].second, arguments.flat_dynamic_args);
314       }
315     }
316   }
317   return ::tensorflow::OkStatus();
318 }
319 
320 namespace {
321 
322 // Elements of CacheEntry are protected by the GIL.
323 struct CacheEntry {
324   // Ensures a single thread performs the compilation for a given executable.
325   //
326   // The first thread (holding the GIL) will create the CacheEntry associated to
327   // a signature and fill it. Other threads will wait for the notification.
328   // If an error occurred during the compilation, `fall_back_to_python` is set
329   // to `true`, and other threads will fail with the same error.
330   absl::Notification compilation_complete;
331 
332   std::shared_ptr<xla::PyExecutable> executable;
333   xla::PyTreeDef out_pytree_def;
334   // We use Python types within the vector because this is what we will be
335   // returning to Python. No need to convert back and forth.
336   // We need py::object to maintain the objects alive.
337   std::vector<py::object> out_avals;
338   std::vector<bool> out_weak_types;
339 
340   // Bitvector of kept arguments from Jaxpr DCE pass. Used to drop some `args`
341   // in CompiledFunction::Call before calling into compiled computation.
342   std::optional<std::vector<bool>> kept_var_bitvec;
343   std::optional<xla::ClientAndPtr<xla::PjRtDevice>> sticky_device;
344 
345   // Fallback to Python happens:
346   // - for trivial computations
347   // - when running a jax(pmap)
348   // - after a compilation error, for threads that did not compile it the first
349   //   time
350   bool fall_back_to_python = false;
351 
352   // Python objects (notably in the cache key) that must remain alive as long
353   // as the cache entry does. Currently this is the `key` values in the kwarg
354   // entries in the cache key.
355   std::vector<py::object> keepalive;
356 };
357 
358 // A CompiledFunctionCache represents a cache of compiled functions that can be
359 // shared between one or more CompiledFunction objects. It serves two goals:
360 // - reduce the number of lru caches (hash map) across multiple JITs.
361 // - make the cache global to increase cache hits (e.g. calling jit(f)(3) twice)
362 //   keeping entries alive as long as the underlying function f is alive.
363 // Assume the cache is protected by the GIL.
364 class CompiledFunctionCache {
365  public:
366   static constexpr int kDefaultCapacity = 4096;
367   explicit CompiledFunctionCache(int capacity);
368 
369   // Cache entries are shared_ptr<>s because it's possible the cache entry
370   // might be evicted before we finish tracing/compiling.
371   typedef xla::LRUCache<CallSignature, std::shared_ptr<CacheEntry>> Cache;
372 
373   // We include as part of the cache key `donate_argnums` (and any other fields
374   // that aren't subsumed by the CallSignature we compute for each call).
375   std::shared_ptr<Cache> Lookup(py::handle function,
376                                 absl::Span<const int> donate_argnums);
377 
Size() const378   int Size() const { return lru_list_.Size(); }
Capacity() const379   int Capacity() const { return lru_list_.Capacity(); }
Clear()380   void Clear() { lru_list_.Clear(); }
381 
382  private:
383   struct Key {
384     py::handle function;  // Does not hold a reference.
385 
386     // Other fields that are part of the arguments to `jit`, but are not
387     // otherwise part of CallSignature.
388     std::vector<int> donate_argnums;
389 
operator ==jax::__anon901e1bb70811::CompiledFunctionCache::Key390     bool operator==(const Key& other) const {
391       return std::tie(function, donate_argnums) ==
392              std::tie(other.function, other.donate_argnums);
393     }
394   };
395   template <typename H>
AbslHashValue(H h,const Key & key)396   friend H AbslHashValue(H h, const Key& key) {
397     h = H::combine(std::move(h), key.function.ptr());
398     h = H::combine_contiguous(std::move(h), key.donate_argnums.data(),
399                               key.donate_argnums.size());
400     return h;
401   }
402 
403   struct Value {
Valuejax::__anon901e1bb70811::CompiledFunctionCache::Value404     explicit Value(std::shared_ptr<Cache> cache) : cache(std::move(cache)) {}
405     std::shared_ptr<Cache> cache;
406 
407     // A weak reference to the key function. We use the weak reference to
408     // register a callback that is triggered when the key function is destroyed.
409     // We use a weak pointer because we want to allow caching across multiple
410     // calls to `jax.jit(f)` if `f` remains alive, but we do not want the cache
411     // to keep `f` alive if all other references are dropped.
412     py::weakref weakref;
413   };
414 
415   Cache::LRUList lru_list_;
416   absl::flat_hash_map<Key, std::unique_ptr<Value>> functions_;
417 };
418 
CompiledFunctionCache(int capacity)419 CompiledFunctionCache::CompiledFunctionCache(int capacity)
420     : lru_list_(capacity) {}
421 
Lookup(py::handle function,absl::Span<const int> donate_argnums)422 std::shared_ptr<CompiledFunctionCache::Cache> CompiledFunctionCache::Lookup(
423     py::handle function, absl::Span<const int> donate_argnums) {
424   Key key;
425   key.function = function;
426   key.donate_argnums =
427       std::vector<int>(donate_argnums.begin(), donate_argnums.end());
428   auto insert = functions_.emplace(key, nullptr);
429   std::shared_ptr<Cache> cache = std::make_shared<Cache>(&lru_list_);
430   if (insert.second) {
431     py::cpp_function callback([this, key{std::move(key)}](py::handle weakref) {
432       functions_.erase(key);
433     });
434     PyObject* weakref = PyWeakref_NewRef(function.ptr(), callback.ptr());
435     if (weakref) {
436       std::unique_ptr<Value>& entry = insert.first->second;
437       entry = std::make_unique<Value>(cache);
438       entry->weakref = py::reinterpret_steal<py::weakref>(weakref);
439     } else {
440       PyErr_Clear();
441       // `function` is not weak-referenceable. Don't bother adding it to the
442       // shared cache in that case; the `jit` object will hold the only shared
443       // reference to the cache entry.
444       functions_.erase(insert.first);
445     }
446   }
447   return cache;
448 }
449 
450 // A `CompiledFunction` is associated to a `jax.jit(f)` and takes care of the
451 // bookkeeping of the different signatures used and the dispatch of calls to
452 // the correct underlying `PyExecutable`. This class is thread-safe.
453 class CompiledFunction {
454  public:
455   CompiledFunction(py::function fun, py::function cache_miss,
456                    py::function get_device, bool has_explicit_device,
457                    std::vector<int> static_argnums,
458                    std::vector<py::str> static_argnames,
459                    std::vector<int> donate_argnums,
460                    std::shared_ptr<CompiledFunctionCache> cache);
461   ~CompiledFunction();
462 
463   // pybind11::object typed subclass for CompiledFunction objects.
464   class pyobject : public py::object {
465    public:
466     PYBIND11_OBJECT(pyobject,  // NOLINT
467                     py::object, CompiledFunction::IsCompiledFunction);
468     pyobject() = default;
func() const469     CompiledFunction* func() const {
470       return CompiledFunction::AsCompiledFunctionUnchecked(*this);
471     }
472   };
473   // Alias as ::object; outside the scope above we won't confuse pybind11's
474   // macros.
475   using object = pyobject;
476 
477   // Returns true if `h` is a CompiledFunction.
478   static bool IsCompiledFunction(py::handle handle);
479   // Converts `handle` to a CompiledFunction*. Does not do any checking.
480   static CompiledFunction* AsCompiledFunctionUnchecked(py::handle handle);
481 
482   // This function will:
483   // (a) flatten the inputs using pytree
484   // (b) get buffer objects from the arguments
485   // (c) call the executable
486   // (d) construct `DeviceArray` objects from the outputs
487   // (e) reconstruct the `PyTree`.
488   xla::StatusOr<py::object> Call(py::handle args,
489                                  std::optional<py::kwargs> kwargs);
490 
491   // This allows `inspect.signature(cpp_jitted_f)` from Python.
PythonSignature()492   py::object PythonSignature() {
493     static const auto* inspect = new py::module(py::module::import("inspect"));
494     return inspect->attr("signature")(fun_);
495   }
496 
cache_size() const497   int cache_size() const { return executables_->Size(); }
ClearCache()498   void ClearCache() {
499     // Setting `default_device_` to nullptr forces Call() to retrieve the
500     // device.
501     default_device_ = nullptr;
502     executables_->Clear();
503   }
504 
fun() const505   const py::function& fun() const { return fun_; }
cache_miss() const506   const py::function& cache_miss() const { return cache_miss_; }
get_device() const507   const py::function& get_device() const { return get_device_; }
has_explicit_device() const508   bool has_explicit_device() const { return has_explicit_device_; }
static_argnums() const509   const std::vector<int>& static_argnums() const { return static_argnums_; }
static_argnames() const510   const std::vector<py::str>& static_argnames() const {
511     return static_argnames_;
512   }
donate_argnums() const513   const std::vector<int>& donate_argnums() const { return donate_argnums_; }
cache() const514   const std::shared_ptr<CompiledFunctionCache>& cache() const { return cache_; }
515 
516   // Helper function used by the tp_clear GC method.
ClearPythonReferences()517   void ClearPythonReferences() {
518     py::function fun, cache_miss, get_device;
519     // Swap values for nulls before they are destroyed. See the Python
520     // Py_CLEAR() documentation for a discussion of this topic.
521     std::swap(fun_, fun);
522     std::swap(cache_miss_, cache_miss);
523     std::swap(get_device_, get_device);
524   }
525 
526   py::handle AsPyHandle();
function_name() const527   const std::string& function_name() const { return function_name_; }
528 
529  private:
530   // Attempts to populate default_device_. May release the GIL; is
531   // reentrant-safe.
532   void TryToPopulateDefaultDevice();
533 
534   void PopulateCacheEntry(CacheEntry* entry, const CallSignature& signature,
535                           const py::tuple& out_and_fastpath_data);
536   bool always_fallback_to_python_ = false;
537 
538   py::function fun_;  // The Python function to jit.
539   std::string function_name_;
540 
541   // See JAX _cpp_jit in api.py for documentation.
542   py::function cache_miss_;
543 
544   // We need to know the static arguments to remove them from the arguments
545   // passed to the underlying PyExecutable. In sorted order.
546   std::vector<int> static_argnums_;
547   // Keyword arguments, interned.
548   std::vector<py::str> static_argnames_;
549   std::vector<int> donate_argnums_;
550 
551   // Whether this function has an explicit device set by either the `device` or
552   // `backend` arguments to jit.
553   bool has_explicit_device_;
554 
555   // A function taking no arguments and returning the default device and whether
556   // jax.jit has been committed to it.
557   py::function get_device_;
558 
559   // Keeps the shared LRU cache alive as long as the CompiledFunction is alive.
560   std::shared_ptr<CompiledFunctionCache> cache_;
561 
562   // The part of cache_ specific to this CompiledFunction.
563   std::shared_ptr<CompiledFunctionCache::Cache> executables_;
564 
565   // The logic if the following:
566   // - if `device` or `backend` are not specified to `jax.jit`, we will use
567   //   the input sticky buffer device, or `default_device_` if there is no
568   //   such sticky buffer.
569   // - When one of `device` or `backend` is specified, this will determine
570   //   the `default_device_` which will be used as the targeted device. In
571   //   which case, we will always copy input buffers to this device.
572   // These fields are protected by the GIL.
573   xla::PjRtDevice* default_device_ = nullptr;
574   bool is_committed_;
575 };
576 
577 // This class keeps references to all CompiledFunctions. This class is
578 // thread-compatible.
579 class CompiledFunctionStore {
580  public:
Insert(CompiledFunction * function)581   void Insert(CompiledFunction* function) {
582     compiled_functions_.insert(function);
583   }
584 
Erase(CompiledFunction * function)585   void Erase(CompiledFunction* function) {
586     compiled_functions_.erase(function);
587   }
588 
ClearFunctionCache()589   void ClearFunctionCache() {
590     for (auto* function : compiled_functions_) {
591       function->ClearCache();
592     }
593   }
594 
595  private:
596   absl::flat_hash_set<CompiledFunction*> compiled_functions_;
597 };
598 
599 // Protected by GIL.
GetGlobalCompiledFunctionStore()600 CompiledFunctionStore& GetGlobalCompiledFunctionStore() {
601   static auto* const store = new CompiledFunctionStore();
602   return *store;
603 }
604 
CompiledFunction(py::function fun,py::function cache_miss,py::function get_device,bool has_explicit_device,std::vector<int> static_argnums,std::vector<py::str> static_argnames,std::vector<int> donate_argnums,std::shared_ptr<CompiledFunctionCache> cache)605 CompiledFunction::CompiledFunction(py::function fun, py::function cache_miss,
606                                    py::function get_device,
607                                    bool has_explicit_device,
608                                    std::vector<int> static_argnums,
609                                    std::vector<py::str> static_argnames,
610                                    std::vector<int> donate_argnums,
611                                    std::shared_ptr<CompiledFunctionCache> cache)
612     : fun_(std::move(fun)),
613       cache_miss_(std::move(cache_miss)),
614       static_argnums_(std::move(static_argnums)),
615       static_argnames_(std::move(static_argnames)),
616       donate_argnums_(donate_argnums),
617       has_explicit_device_(std::move(has_explicit_device)),
618       get_device_(std::move(get_device)),
619       cache_(std::move(cache)) {
620   std::sort(static_argnums_.begin(), static_argnums_.end());
621   for (py::str& s : static_argnames) {
622     PyUnicode_InternInPlace(&s.ptr());
623   }
624   executables_ = cache_->Lookup(fun_, donate_argnums);
625   function_name_ = py::str(py::getattr(fun_, "__name__", fun));
626 
627   GetGlobalCompiledFunctionStore().Insert(this);
628 }
629 
~CompiledFunction()630 CompiledFunction::~CompiledFunction() {
631   GetGlobalCompiledFunctionStore().Erase(this);
632 }
633 
634 // Returns nullptr if arg has no sticky device
GetJitArgumentStickyDevice(py::handle arg)635 static xla::StatusOr<xla::PjRtDevice*> GetJitArgumentStickyDevice(
636     py::handle arg) {
637   struct PythonTypes {
638     py::object device_array;
639   };
640   static const auto& types = *[]() -> PythonTypes* {
641     py::module xla_module(py::module::import("jax.interpreters.xla"));
642     py::object device_array;
643     if (py::hasattr(xla_module, "_DeviceArray")) {
644       device_array = xla_module.attr("_DeviceArray");
645     }
646     return new PythonTypes{device_array};
647   }();
648 
649   // We specically only deal with DeviceArray (not ShardedDeviceArray).
650   // (Can happen in jit(pmap), e.g. "test_jit_nested_donate_ignored").
651   if (arg.get_type().ptr() == xla::PyBuffer::type()) {
652     xla::PyBuffer* buffer = xla::PyBuffer::AsPyBufferUnchecked(arg);
653     if (!buffer->sticky_device()) {
654       return nullptr;
655     }
656     return buffer->sticky_device();
657   }
658 
659   if (arg.get_type().ptr() == types.device_array.ptr()) {
660     if (arg.attr("_device").is_none()) {
661       return nullptr;
662     }
663     try {
664       // This can fail, e.g. for cloud TPU 2VM buffers.
665       TF_ASSIGN_OR_RETURN(xla::PyBuffer * buffer,
666                           xla::PyBuffer::AsPyBuffer(arg.attr("device_buffer")));
667       return buffer->buffer()->device();
668     } catch (const py::cast_error& e) {
669       return xla::InvalidArgument(
670           "%s", absl::StrCat("[jaxjit] Unsupported subclass of `DeviceArray`: "
671                              "`device_buffer` field is of type ",
672                              py::cast<std::string>(
673                                  arg.attr("device_buffer").get_type().str()),
674                              " while a `PyBuffer` was expected."));
675     }
676   }
677 
678   return nullptr;
679 }
680 
681 // Compute signature for arguments.
682 //
683 // Returns `Status::OK()` on success. Returning an error should lead to
684 // calling the Python fallback.
ComputeSignature(bool jax_enable_x64,xla::PjRtDevice * default_device,bool is_committed,ParsedArgumentsAsBuffers & arguments)685 xla::Status ComputeSignature(bool jax_enable_x64,
686                              xla::PjRtDevice* default_device, bool is_committed,
687                              ParsedArgumentsAsBuffers& arguments) {
688   tensorflow::profiler::TraceMe traceme("ComputeSignature");
689 
690   int num_flat_dynamic_args = arguments.flat_dynamic_args.size();
691   // When the jitted function is not committed, we first check whether any
692   // sticky `DeviceArray` is present and on which device they live. See also:
693   // https://github.com/google/jax/pull/1884
694   // https://github.com/google/jax/pull/1916 for the rationale why the
695   // computation follows the data locality.
696   // It's also similar to PyTorch's behavior.
697   xla::PjRtDevice* data_device = nullptr;
698   if (is_committed) {
699     data_device = default_device;
700   } else {
701     for (int i = 0; i < num_flat_dynamic_args; ++i) {
702       TF_ASSIGN_OR_RETURN(
703           xla::PjRtDevice * device,
704           GetJitArgumentStickyDevice(arguments.flat_dynamic_args[i]));
705       if (device) {
706         if (data_device && (device != data_device)) {
707           throw std::invalid_argument(absl::StrCat(
708               "primitive arguments must be colocated on the same device ("
709               "C++ jax.jit). Arguments are on devices: ",
710               device->DebugString(), " and ", data_device->DebugString()));
711         } else {
712           data_device = device;
713         }
714       }
715     }
716   }
717   if (!data_device) {
718     // No `DeviceArray` were found default to `default_device`.
719     data_device = default_device;
720   }
721   CHECK(data_device);
722   arguments.signature.device = data_device;
723 
724   arguments.signature.dynamic_arg_signatures.reserve(num_flat_dynamic_args);
725   for (int i = 0; i < num_flat_dynamic_args; ++i) {
726     py::handle arg = arguments.flat_dynamic_args[i];
727     TF_ASSIGN_OR_RETURN(auto sig,
728                         xla::PyArgSignatureOfValue(arg, jax_enable_x64));
729     arguments.signature.dynamic_arg_signatures.push_back(std::move(sig));
730   }
731   return ::tensorflow::OkStatus();
732 }
733 
734 // Copy buffers to device, skipping pruned arguments.
735 // Returns `Status::OK()` on success. Returning an error should lead to
736 // calling the Python fallback.
CopyBuffersToDevice(bool jax_enable_x64,const std::optional<std::vector<bool>> & kept_args,ParsedArgumentsAsBuffers & arguments)737 xla::Status CopyBuffersToDevice(
738     bool jax_enable_x64, const std::optional<std::vector<bool>>& kept_args,
739     ParsedArgumentsAsBuffers& arguments) {
740   std::vector<xla::PjRtBuffer*>& arg_buffers = arguments.arg_buffers;
741   xla::PjRtDevice* data_device = arguments.signature.device;
742 
743   int num_flat_dynamic_args = arguments.flat_dynamic_args.size();
744   xla::DevicePutOptions options;
745   options.squash_64bit_types = !jax_enable_x64;
746   options.allow_zero_copy = true;
747   arg_buffers.reserve(num_flat_dynamic_args);
748   bool input_pruning_enabled = kept_args.has_value();
749   for (int i = 0; i < num_flat_dynamic_args; ++i) {
750     if (input_pruning_enabled && !kept_args.value()[i]) {
751       continue;
752     }
753 
754     py::handle arg = arguments.flat_dynamic_args[i];
755     TF_ASSIGN_OR_RETURN(xla::DevicePutResult on_device,
756                         DevicePut(arg, data_device, options));
757 
758     xla::PjRtBuffer* buffer = on_device.buffer;
759     arg_buffers.push_back(buffer);
760     if (on_device.owned_buffer) {
761       arguments.keep_alive.push_back(std::move(on_device.owned_buffer));
762     } else if (on_device.owning_pybuffer) {
763       arguments.keep_alive_objects.push_back(
764           std::move(on_device.owning_pybuffer));
765     }
766   }
767   return ::tensorflow::OkStatus();
768 }
769 
PopulateCacheEntry(CacheEntry * cache_entry,const CallSignature & signature,const py::tuple & out_and_fastpath_data)770 void CompiledFunction::PopulateCacheEntry(
771     CacheEntry* cache_entry, const CallSignature& signature,
772     const py::tuple& out_and_fastpath_data) {
773   CHECK_EQ(out_and_fastpath_data.size(), 2);
774   if (out_and_fastpath_data[1].is_none()) {
775     cache_entry->fall_back_to_python = true;
776     return;
777   }
778 
779   py::tuple executable_handlers_out_tree =
780       py::cast<py::tuple>(out_and_fastpath_data[1]);
781   auto executable = py::cast<std::shared_ptr<xla::PyExecutable>>(
782       executable_handlers_out_tree.attr("xla_executable"));
783   cache_entry->executable = std::move(executable);
784   int num_devices =
785       cache_entry->executable->pjrt_executable().addressable_devices().size();
786   // The presence of jit(pmap) is detected from Python.
787   CHECK_EQ(num_devices, 1);
788 
789   auto out_tree = py::cast<xla::PyTreeDef>(
790       executable_handlers_out_tree.attr("out_pytree_def"));
791   cache_entry->out_pytree_def = std::move(out_tree);
792 
793   cache_entry->sticky_device =
794       py::cast<std::optional<xla::ClientAndPtr<xla::PjRtDevice>>>(
795           executable_handlers_out_tree.attr("sticky_device"));
796   auto avals = py::cast<py::list>(executable_handlers_out_tree.attr("avals"));
797 
798   cache_entry->out_avals.reserve(avals.size());
799   cache_entry->out_weak_types.reserve(avals.size());
800   for (int i = 0; i < avals.size(); ++i) {
801     py::object shaped_array = py::reinterpret_borrow<py::object>(avals[i]);
802 
803     cache_entry->out_avals.push_back(shaped_array);
804     cache_entry->out_weak_types.push_back(
805         py::cast<bool>(shaped_array.attr("weak_type")));
806   }
807   auto kept_var_bitvec_attr =
808       py::getattr(executable_handlers_out_tree, "kept_var_bitvec", py::none());
809   if (!kept_var_bitvec_attr.is_none()) {
810     auto kept_var_bitvec = py::cast<py::list>(kept_var_bitvec_attr);
811     cache_entry->kept_var_bitvec =
812         std::make_optional<std::vector<bool>>(kept_var_bitvec.size(), false);
813     for (int i = 0; i < kept_var_bitvec.size(); ++i) {
814       cache_entry->kept_var_bitvec.value()[i] =
815           py::cast<bool>(kept_var_bitvec[i]);
816     }
817   }
818 }
819 
TryToPopulateDefaultDevice()820 void CompiledFunction::TryToPopulateDefaultDevice() {
821   // The following line calls Python and may release the GIL.
822   py::object device_and_is_committed;
823   try {
824     device_and_is_committed = get_device_();
825   } catch (py::error_already_set& e) {
826     // Backend or device initialization failed. Handle this in Python.
827     always_fallback_to_python_ = true;
828     return;
829   }
830   // If the GIL was released by the call to get_device_, another thread may
831   // have filled in default_device_.
832   if (!default_device_) {
833     try {
834       auto default_pydevice = py::cast<xla::ClientAndPtr<xla::PjRtDevice>>(
835           device_and_is_committed.attr("default_device"));
836       is_committed_ =
837           py::cast<bool>(device_and_is_committed.attr("committed_to_device"));
838       default_device_ = default_pydevice.contents;
839     } catch (const py::cast_error& e) {
840       // Pathways, Cloud TPU 2VM, and UPTC runtime.
841       always_fallback_to_python_ = true;
842     }
843   }
844 }
845 
Call(py::handle args,std::optional<py::kwargs> kwargs)846 xla::StatusOr<py::object> CompiledFunction::Call(
847     py::handle args, std::optional<py::kwargs> kwargs) {
848   VLOG(3) << "Calling CompiledFunction " << function_name_;
849 
850   // Make sure we trigger a garbage collection on JIT function calls. Otherwise
851   // code like
852   // f = jit(...)
853   // while True:
854   //   f(x)
855   // may never free temporary buffers for copies of arguments.
856   xla::GlobalPyRefManager()->MaybeCollectGarbage();
857 
858   auto& tls = thread_local_state;
859   if (GetDisableJit()) {
860     return fun_(*py::reinterpret_borrow<py::args>(args),
861                 **kwargs.value_or(py::kwargs()));
862   }
863   if (always_fallback_to_python_) {
864     return py::object(
865         py::cast<py::tuple>(cache_miss_(*py::reinterpret_borrow<py::args>(args),
866                                         **kwargs.value_or(py::kwargs())))[0]);
867   }
868 
869   xla::PjRtDevice* device = nullptr;
870   // Whether `device` should override an input with a sticky device.
871   bool is_committed;
872   if (!has_explicit_device_ && GetDefaultDevice().has_value()) {
873     xla::ClientAndPtr<xla::PjRtDevice> pjrt_device_ptr;
874     bool cast_success = true;
875     try {
876       pjrt_device_ptr =
877           GetDefaultDevice()->cast<xla::ClientAndPtr<xla::PjRtDevice>>();
878     } catch (py::cast_error& e) {
879       // We assume GetDefaultDevice() returned a non-PJRT device object. Leave
880       // `device` unset so we fallback to Python path and handle default device
881       // there.
882       cast_success = false;
883     }
884     if (cast_success) {
885       device = pjrt_device_ptr.get();
886       is_committed = false;
887       VLOG(3) << "Using config.default_device (uncommitted): "
888               << device->DebugString();
889     }
890   }
891   if (device == nullptr) {
892     // Call back into Python to find system default device, which will be stored
893     // in default_device_.
894     if (!default_device_) {
895       // On the first call to `Call`, compute a default device. We need to wait
896       // until after platform initialization is complete before doing so, but
897       // @jit may be used as a decorator.
898       TryToPopulateDefaultDevice();
899       if (!default_device_) {
900         return py::object(py::cast<py::tuple>(
901             cache_miss_(*py::reinterpret_borrow<py::args>(args),
902                         **kwargs.value_or(py::kwargs())))[0]);
903       }
904     }
905     device = default_device_;
906     is_committed = is_committed_;
907     VLOG(3) << "Using device from Python): " << device->DebugString()
908             << ", committed: " << is_committed;
909   }
910   CHECK(device != nullptr);
911 
912   ParsedArgumentsAsBuffers arguments;
913   arguments.signature.function_name = function_name_;
914   xla::Status status = ParseArguments(args, kwargs, static_argnums_,
915                                       static_argnames_, arguments);
916   if (!status.ok()) {
917     VLOG(2) << "ParseArguments failed: " << status;
918     return py::object(
919         py::cast<py::tuple>(cache_miss_(*py::reinterpret_borrow<py::args>(args),
920                                         **kwargs.value_or(py::kwargs())))[0]);
921   }
922 
923   bool jax_enable_x64 = GetEnableX64();
924   arguments.signature.jax_enable_x64 = jax_enable_x64;
925   // The C++ jit do not support Tracers arguments inputs yet. The Python-based
926   // jit function will be called if any of the dynamic arguments is unsupported.
927   status = ComputeSignature(jax_enable_x64, device, is_committed, arguments);
928   if (!status.ok()) {
929     VLOG(2) << "ComputeSignature failed: " << status;
930     return py::object(
931         py::cast<py::tuple>(cache_miss_(*py::reinterpret_borrow<py::args>(args),
932                                         **kwargs.value_or(py::kwargs())))[0]);
933   }
934   arguments.signature.global_extra_jit_context = global_state.extra_jit_context;
935   arguments.signature.thread_local_extra_jit_context = tls.extra_jit_context;
936 
937   VLOG(3) << "CallSignature:\n" << arguments.signature.DebugString();
938   bool inserted = false;
939   std::shared_ptr<CacheEntry> cache_entry = executables_->GetOrCreateIfAbsent(
940       arguments.signature, [&inserted](const CallSignature& key) {
941         inserted = true;
942         return std::make_shared<CacheEntry>();
943       });
944 
945   if (!cache_entry->compilation_complete.HasBeenNotified()) {
946     // In case of several threads attempting to compile the executable, only
947     // the one that inserted the item will perform the compilation.
948     if (inserted) {
949       py::object out_and_fastpath_data;
950       py::tuple out_tuple;
951       VLOG(2) << "Cache miss for\n" << arguments.signature.DebugString();
952       try {
953         // Calls Python and may release the GIL. May also throw if
954         // compilation/tracing fails.
955         out_and_fastpath_data =
956             cache_miss_(*py::reinterpret_borrow<py::args>(args),
957                         **kwargs.value_or(py::kwargs()));
958         out_tuple = py::cast<py::tuple>(out_and_fastpath_data);
959         PopulateCacheEntry(cache_entry.get(), arguments.signature, out_tuple);
960       } catch (const std::exception& e) {
961         cache_entry->fall_back_to_python = true;
962         cache_entry->compilation_complete.Notify();
963         throw;
964       }
965       cache_entry->compilation_complete.Notify();
966 
967       // We have already computed the result in the miss path so we can return
968       // it. We are even *required* to do so if there are donated arguments,
969       // because any donated buffers will now be invalid.
970       return py::object(out_tuple[0]);
971     } else {
972       // Release the GIL while we wait, making sure the compile thread can
973       // lock it.
974       py::gil_scoped_release release;
975       cache_entry->compilation_complete.WaitForNotification();
976     }
977   }
978   // It's hard to reraise the exact same kind of errors when a compilation error
979   // occurred. If the first compilation failed, other threads will also execute
980   // the Python path.
981   if (cache_entry->fall_back_to_python) {
982     return py::object(
983         py::cast<py::tuple>(cache_miss_(*py::reinterpret_borrow<py::args>(args),
984                                         **kwargs.value_or(py::kwargs())))[0]);
985   }
986 
987   status = CopyBuffersToDevice(jax_enable_x64, cache_entry->kept_var_bitvec,
988                                arguments);
989   if (!status.ok()) {
990     VLOG(2) << "CopyBuffersToDevice failed: " << status;
991     return py::object(
992         py::cast<py::tuple>(cache_miss_(*py::reinterpret_borrow<py::args>(args),
993                                         **kwargs.value_or(py::kwargs())))[0]);
994   }
995 
996   // Executes the computation.
997   std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>> output_buffers;
998   {
999     py::gil_scoped_release gil_release;
1000     TF_ASSIGN_OR_RETURN(
1001         output_buffers,
1002         cache_entry->executable->mutable_pjrt_executable()->Execute(
1003             {arguments.arg_buffers}, cache_entry->executable->options()));
1004   }
1005   auto traceback = xla::Traceback::Get();
1006 
1007   int num_outputs = output_buffers[0].size();
1008   absl::InlinedVector<py::object, 1> flat_device_arrays;
1009   flat_device_arrays.reserve(num_outputs);
1010   for (int i = 0; i < output_buffers[0].size(); ++i) {
1011     bool last = (i == (num_outputs - 1));
1012     xla::PyBuffer::object buffer = xla::PyBuffer::Make(
1013         cache_entry->executable->client(), std::move(output_buffers[0][i]),
1014         last ? std::move(traceback) : traceback);
1015     buffer.buf()->SetAval(cache_entry->out_avals[i]);
1016     buffer.buf()->set_weak_type(cache_entry->out_weak_types[i]);
1017     if (cache_entry->sticky_device.has_value()) {
1018       TF_RETURN_IF_ERROR(
1019           buffer.buf()->set_sticky_device((*cache_entry->sticky_device).get()));
1020     }
1021     flat_device_arrays.push_back(std::move(buffer));
1022   }
1023   py::object out = cache_entry->out_pytree_def.Unflatten(flat_device_arrays);
1024 
1025   // If there is a post-hook function, call it with the inputs and the outputs.
1026   std::optional<py::object> post_hook = GetPostHook();
1027   if (post_hook) {
1028     (*post_hook)(AsPyHandle(), args,
1029                  py::cast<py::dict>(kwargs.value_or(py::kwargs())), out);
1030   }
1031   return std::move(out);
1032 }
1033 
1034 struct JaxCompiledFunctionObject {
1035   PyObject_HEAD;
1036   PyObject* dict;      // Dictionary for __dict__
1037   PyObject* weakrefs;  // Weak references; for use by the Python interpreter.
1038   CompiledFunction fun;
1039 };
1040 
1041 PyObject* JaxCompiledFunction_Type = nullptr;
1042 
IsCompiledFunction(py::handle handle)1043 bool CompiledFunction::IsCompiledFunction(py::handle handle) {
1044   return handle.get_type() == JaxCompiledFunction_Type;
1045 }
1046 
AsCompiledFunctionUnchecked(py::handle handle)1047 CompiledFunction* CompiledFunction::AsCompiledFunctionUnchecked(
1048     py::handle handle) {
1049   return &(reinterpret_cast<JaxCompiledFunctionObject*>(handle.ptr())->fun);
1050 }
1051 
AsCompiledFunction(py::handle handle)1052 xla::StatusOr<CompiledFunction*> AsCompiledFunction(py::handle handle) {
1053   if (!CompiledFunction::IsCompiledFunction(handle)) {
1054     return xla::InvalidArgument("Expected a CompiledFunction");
1055   }
1056   return CompiledFunction::AsCompiledFunctionUnchecked(handle);
1057 }
1058 
AsPyHandle()1059 py::handle CompiledFunction::AsPyHandle() {
1060   return reinterpret_cast<PyObject*>(reinterpret_cast<char*>(this) -
1061                                      offsetof(JaxCompiledFunctionObject, fun));
1062 }
1063 
1064 extern "C" {
1065 
JaxCompiledFunction_tp_new(PyTypeObject * subtype,PyObject * args,PyObject * kwds)1066 PyObject* JaxCompiledFunction_tp_new(PyTypeObject* subtype, PyObject* args,
1067                                      PyObject* kwds) {
1068   JaxCompiledFunctionObject* self =
1069       reinterpret_cast<JaxCompiledFunctionObject*>(
1070           subtype->tp_alloc(subtype, 0));
1071   if (!self) return nullptr;
1072   self->dict = nullptr;
1073   self->weakrefs = nullptr;
1074   return reinterpret_cast<PyObject*>(self);
1075 }
1076 
JaxCompiledFunction_tp_dealloc(PyObject * self)1077 void JaxCompiledFunction_tp_dealloc(PyObject* self) {
1078   PyTypeObject* tp = Py_TYPE(self);
1079   JaxCompiledFunctionObject* o =
1080       reinterpret_cast<JaxCompiledFunctionObject*>(self);
1081   if (o->weakrefs) {
1082     PyObject_ClearWeakRefs(self);
1083   }
1084   Py_CLEAR(o->dict);
1085   o->fun.~CompiledFunction();
1086   tp->tp_free(self);
1087   Py_DECREF(tp);
1088 }
1089 
JaxCompiledFunction_tp_traverse(PyObject * self,visitproc visit,void * arg)1090 int JaxCompiledFunction_tp_traverse(PyObject* self, visitproc visit,
1091                                     void* arg) {
1092   JaxCompiledFunctionObject* o =
1093       reinterpret_cast<JaxCompiledFunctionObject*>(self);
1094   Py_VISIT(o->dict);
1095   Py_VISIT(o->fun.fun().ptr());
1096   Py_VISIT(o->fun.cache_miss().ptr());
1097   Py_VISIT(o->fun.get_device().ptr());
1098   return 0;
1099 }
1100 
JaxCompiledFunction_tp_clear(PyObject * self)1101 int JaxCompiledFunction_tp_clear(PyObject* self) {
1102   JaxCompiledFunctionObject* o =
1103       reinterpret_cast<JaxCompiledFunctionObject*>(self);
1104   Py_CLEAR(o->dict);
1105   o->fun.ClearPythonReferences();
1106   return 0;
1107 }
1108 
1109 // Implements the Python descriptor protocol so JIT-compiled functions can be
1110 // used as bound methods. See:
1111 // https://docs.python.org/3/howto/descriptor.html#functions-and-methods
JaxCompiledFunction_tp_descr_get(PyObject * self,PyObject * obj,PyObject * type)1112 PyObject* JaxCompiledFunction_tp_descr_get(PyObject* self, PyObject* obj,
1113                                            PyObject* type) {
1114   if (obj == nullptr || obj == Py_None) {
1115     Py_INCREF(self);
1116     return self;
1117   }
1118   return PyMethod_New(self, obj);
1119 }
1120 
1121 // Support d = instance.__dict__.
JaxCompiledFunction_get_dict(PyObject * self,void *)1122 PyObject* JaxCompiledFunction_get_dict(PyObject* self, void*) {
1123   JaxCompiledFunctionObject* o =
1124       reinterpret_cast<JaxCompiledFunctionObject*>(self);
1125   if (!o->dict) {
1126     o->dict = PyDict_New();
1127   }
1128   Py_XINCREF(o->dict);
1129   return o->dict;
1130 }
1131 
JaxCompiledFunction_set_dict(PyObject * self,PyObject * new_dict,void *)1132 int JaxCompiledFunction_set_dict(PyObject* self, PyObject* new_dict, void*) {
1133   JaxCompiledFunctionObject* o =
1134       reinterpret_cast<JaxCompiledFunctionObject*>(self);
1135   if (!PyDict_Check(new_dict)) {
1136     PyErr_Format(PyExc_TypeError,
1137                  "__dict__ must be set to a dictionary, not a '%s'",
1138                  Py_TYPE(new_dict)->tp_name);
1139     return -1;
1140   }
1141   Py_INCREF(new_dict);
1142   Py_CLEAR(o->dict);
1143   o->dict = new_dict;
1144   return 0;
1145 }
1146 
1147 static PyGetSetDef JaxCompiledFunction_tp_getset[] = {
1148     // Having a __dict__ seems necessary to allow !functool.wraps to override
1149     // __doc__.
1150     {const_cast<char*>("__dict__"), JaxCompiledFunction_get_dict,
1151      JaxCompiledFunction_set_dict, nullptr, nullptr},
1152     {nullptr, nullptr, nullptr, nullptr, nullptr}};
1153 
JaxCompiledFunction_tp_call(PyObject * self,PyObject * args,PyObject * kwargs)1154 PyObject* JaxCompiledFunction_tp_call(PyObject* self, PyObject* args,
1155                                       PyObject* kwargs) {
1156   JaxCompiledFunctionObject* o =
1157       reinterpret_cast<JaxCompiledFunctionObject*>(self);
1158   tensorflow::profiler::TraceMe traceme([&] {
1159     return absl::StrCat("JaxCompiledFunction(", o->fun.function_name(), ")");
1160   });
1161   std::optional<py::kwargs> py_kwargs;
1162   if (kwargs) {
1163     py_kwargs = py::reinterpret_borrow<py::kwargs>(kwargs);
1164   }
1165   try {
1166     xla::StatusOr<py::object> out = o->fun.Call(args, std::move(py_kwargs));
1167     if (!out.ok()) {
1168       PyErr_SetString(PyExc_ValueError, out.status().ToString().c_str());
1169       return nullptr;
1170     }
1171     return out.ValueOrDie().release().ptr();
1172   } catch (py::error_already_set& e) {
1173     e.restore();
1174     return nullptr;
1175   } catch (py::cast_error& e) {
1176     PyErr_SetString(PyExc_ValueError, e.what());
1177     return nullptr;
1178   } catch (std::invalid_argument& e) {
1179     PyErr_SetString(PyExc_ValueError, e.what());
1180     return nullptr;
1181   } catch (std::runtime_error& e) {
1182     PyErr_SetString(PyExc_ValueError, e.what());
1183     return nullptr;
1184   }
1185 }
1186 
JaxCompiledFunction_tp_repr(PyObject * self)1187 PyObject* JaxCompiledFunction_tp_repr(PyObject* self) {
1188   try {
1189     const std::string& repr = absl::StrFormat(
1190         "<CompiledFunction of %s>",
1191         static_cast<std::string>(py::repr(py::getattr(self, "__wrapped__"))));
1192     return PyUnicode_FromString(repr.c_str());
1193   } catch (...) {
1194     // Ignore all errors when accessing a repr.
1195     return PyUnicode_FromString("<CompiledFunction>");
1196   }
1197 }
1198 
InitializeCompiledFunction(JaxCompiledFunctionObject * cfun,py::function fun,py::function cache_miss,py::function get_device,bool has_explicit_device,std::vector<int> static_argnums,std::vector<py::str> static_argnames,std::vector<int> donate_argnums,std::shared_ptr<CompiledFunctionCache> cache)1199 void InitializeCompiledFunction(JaxCompiledFunctionObject* cfun,
1200                                 py::function fun, py::function cache_miss,
1201                                 py::function get_device,
1202                                 bool has_explicit_device,
1203                                 std::vector<int> static_argnums,
1204                                 std::vector<py::str> static_argnames,
1205                                 std::vector<int> donate_argnums,
1206                                 std::shared_ptr<CompiledFunctionCache> cache) {
1207   new (&cfun->fun) CompiledFunction(
1208       std::move(fun), std::move(cache_miss), std::move(get_device),
1209       has_explicit_device, std::move(static_argnums),
1210       std::move(static_argnames), std::move(donate_argnums), std::move(cache));
1211 }
1212 
1213 }  // extern "C"
1214 
MakeCompiledFunction(py::function fun,py::function cache_miss,py::function get_device,bool has_explicit_device,std::vector<int> static_argnums,std::vector<py::str> static_argnames,std::vector<int> donate_argnums,std::shared_ptr<CompiledFunctionCache> cache)1215 py::object MakeCompiledFunction(py::function fun, py::function cache_miss,
1216                                 py::function get_device,
1217                                 bool has_explicit_device,
1218                                 std::vector<int> static_argnums,
1219                                 std::vector<py::str> static_argnames,
1220                                 std::vector<int> donate_argnums,
1221                                 std::shared_ptr<CompiledFunctionCache> cache) {
1222   py::object obj = py::reinterpret_steal<py::object>(JaxCompiledFunction_tp_new(
1223       reinterpret_cast<PyTypeObject*>(JaxCompiledFunction_Type), nullptr,
1224       nullptr));
1225   JaxCompiledFunctionObject* buf =
1226       reinterpret_cast<JaxCompiledFunctionObject*>(obj.ptr());
1227   if (!cache) {
1228     cache = std::make_shared<CompiledFunctionCache>(
1229         CompiledFunctionCache::kDefaultCapacity);
1230   }
1231   InitializeCompiledFunction(
1232       buf, std::move(fun), std::move(cache_miss), std::move(get_device),
1233       has_explicit_device, std::move(static_argnums),
1234       std::move(static_argnames), std::move(donate_argnums), std::move(cache));
1235   return obj;
1236 }
1237 
1238 // Version numbers for the pickled representations of
1239 // CompiledFunction/CompiledFunctionCache. Increment these if changing them.
1240 const int kCompiledFunctionCachePickleVersion = 1;
1241 const int kCompiledFunctionPickleVersion = 1;
1242 
1243 }  // namespace
1244 
BuildJaxjitSubmodule(py::module & m)1245 void BuildJaxjitSubmodule(py::module& m) {
1246   py::module jitlib = m.def_submodule("jax_jit", "Jax C++ jit library");
1247 
1248   py::class_<CompiledFunctionCache, std::shared_ptr<CompiledFunctionCache>>
1249       cache(jitlib, "CompiledFunctionCache");
1250   cache.def(py::init<int>(),
1251             py::arg("capacity") = CompiledFunctionCache::kDefaultCapacity);
1252   cache.def("size", &CompiledFunctionCache::Size);
1253   cache.def("capacity", &CompiledFunctionCache::Capacity);
1254   cache.def("clear", &CompiledFunctionCache::Clear);
1255   cache.def_static("clear_all", []() {
1256     GetGlobalCompiledFunctionStore().ClearFunctionCache();
1257   });
1258   cache.def(py::pickle(
1259       // __getstate__
1260       // Pickles as an empty cache; the client can repopulate as needed.
1261       [](const CompiledFunctionCache& cache) {
1262         py::dict pickle;
1263         pickle["version"] = kCompiledFunctionCachePickleVersion;
1264         pickle["capacity"] = cache.Capacity();
1265         return pickle;
1266       },
1267       // __setstate__
1268       [](const py::dict& pickle) {
1269         int version = py::cast<int>(pickle["version"]);
1270         if (version != kCompiledFunctionCachePickleVersion) {
1271           throw std::invalid_argument(absl::StrFormat(
1272               "Invalid CompiledFunction pickle version, got %d, expected %d",
1273               version, kCompiledFunctionCachePickleVersion));
1274         }
1275         int capacity = py::cast<int>(pickle["capacity"]);
1276         return std::make_shared<CompiledFunctionCache>(capacity);
1277       }));
1278 
1279   // We need to use heap-allocated type objects because we want to add
1280   // additional methods dynamically.
1281   py::object cfun;
1282   {
1283     py::str name = py::str("CompiledFunction");
1284     py::str qualname = py::str("CompiledFunction");
1285     PyHeapTypeObject* heap_type = reinterpret_cast<PyHeapTypeObject*>(
1286         PyType_Type.tp_alloc(&PyType_Type, 0));
1287     // Caution: we must not call any functions that might invoke the GC until
1288     // PyType_Ready() is called. Otherwise the GC might see a half-constructed
1289     // type object.
1290     CHECK(heap_type) << "Unable to create heap type object";
1291     heap_type->ht_name = name.release().ptr();
1292     heap_type->ht_qualname = qualname.release().ptr();
1293     PyTypeObject* type = &heap_type->ht_type;
1294     type->tp_name = "CompiledFunction";
1295     type->tp_basicsize = sizeof(JaxCompiledFunctionObject);
1296     type->tp_flags =
1297         Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE | Py_TPFLAGS_HAVE_GC;
1298     type->tp_new = JaxCompiledFunction_tp_new;
1299     type->tp_dealloc = JaxCompiledFunction_tp_dealloc;
1300     type->tp_dictoffset = offsetof(JaxCompiledFunctionObject, dict);
1301     type->tp_traverse = JaxCompiledFunction_tp_traverse;
1302     type->tp_clear = JaxCompiledFunction_tp_clear;
1303     type->tp_weaklistoffset = offsetof(JaxCompiledFunctionObject, weakrefs);
1304     type->tp_getset = JaxCompiledFunction_tp_getset;
1305     type->tp_descr_get = JaxCompiledFunction_tp_descr_get;
1306     type->tp_call = JaxCompiledFunction_tp_call;
1307     type->tp_repr = JaxCompiledFunction_tp_repr;
1308     CHECK_EQ(PyType_Ready(type), 0);
1309     JaxCompiledFunction_Type = reinterpret_cast<PyObject*>(type);
1310     cfun = py::reinterpret_borrow<py::object>(JaxCompiledFunction_Type);
1311   }
1312   py::object cfun_type =
1313       py::reinterpret_borrow<py::object>(JaxCompiledFunction_Type);
1314 
1315   // Add CompiledFunction to the xla_extension module so it can be pickled.
1316   m.attr("CompiledFunction") = cfun_type;
1317   cfun.attr("__module__") = m.attr("__name__");
1318 
1319   cfun.attr("__signature__") =
1320       property_readonly([](py::handle self) -> xla::StatusOr<py::object> {
1321         TF_ASSIGN_OR_RETURN(CompiledFunction * fun, AsCompiledFunction(self));
1322         return fun->PythonSignature();
1323       });
1324   cfun.attr("_cache_miss") =
1325       property_readonly([](py::handle self) -> xla::StatusOr<py::object> {
1326         TF_ASSIGN_OR_RETURN(CompiledFunction * fun, AsCompiledFunction(self));
1327         return fun->cache_miss();
1328       });
1329   cfun.attr("__getstate__") = py::cpp_function(
1330       [](const CompiledFunction::object& self) {
1331         CompiledFunction* fn = self.func();
1332         py::dict pickle;
1333         pickle["version"] = kCompiledFunctionPickleVersion;
1334         pickle["fun"] = fn->fun();
1335         pickle["cache_miss"] = fn->cache_miss();
1336         pickle["get_device"] = fn->get_device();
1337         pickle["has_explicit_device"] = fn->has_explicit_device();
1338         pickle["static_argnums"] = fn->static_argnums();
1339         pickle["static_argnames"] = fn->static_argnames();
1340         pickle["donate_argnums"] = fn->donate_argnums();
1341         pickle["cache"] = fn->cache();
1342         return pickle;
1343       },
1344       py::is_method(cfun_type));
1345   cfun.attr("__setstate__") = py::cpp_function(
1346       [](CompiledFunction::object& self, const py::dict& pickle) {
1347         int version = py::cast<int>(pickle["version"]);
1348         if (version != kCompiledFunctionPickleVersion) {
1349           throw std::invalid_argument(absl::StrFormat(
1350               "Invalid CompiledFunction pickle version, got %d, expected %d. "
1351               "Pickling/Unpickling jitted functions using different JAX "
1352               "versions is not supported.",
1353               version, kCompiledFunctionPickleVersion));
1354         }
1355         py::function fun = py::cast<py::function>(pickle["fun"]);
1356         py::function cache_miss = py::cast<py::function>(pickle["cache_miss"]);
1357         py::function get_device = py::cast<py::function>(pickle["get_device"]);
1358         bool has_explicit_device =
1359             py::cast<bool>(pickle["has_explicit_device"]);
1360         std::vector<int> static_argnums =
1361             py::cast<std::vector<int>>(pickle["static_argnums"]);
1362         std::vector<py::str> static_argnames =
1363             py::cast<std::vector<py::str>>(pickle["static_argnames"]);
1364         std::vector<int> donate_argnums =
1365             py::cast<std::vector<int>>(pickle["donate_argnums"]);
1366         std::shared_ptr<CompiledFunctionCache> cache =
1367             py::cast<std::shared_ptr<CompiledFunctionCache>>(pickle["cache"]);
1368         InitializeCompiledFunction(
1369             reinterpret_cast<JaxCompiledFunctionObject*>(self.ptr()),
1370             std::move(fun), std::move(cache_miss), std::move(get_device),
1371             has_explicit_device, std::move(static_argnums),
1372             std::move(static_argnames), std::move(donate_argnums),
1373             std::move(cache));
1374       },
1375       py::is_method(cfun_type));
1376 
1377   py::class_<JitState> jit_state_(jitlib, "JitState");
1378   jit_state_.def_readwrite("disable_jit", &JitState::disable_jit);
1379   jit_state_.def_readwrite("enable_x64", &JitState::enable_x64);
1380   jit_state_.def_readwrite("default_device", &JitState::default_device);
1381   jit_state_.def_readwrite("extra_jit_context", &JitState::extra_jit_context);
1382   jit_state_.def_readwrite("post_hook", &JitState::post_hook);
1383 
1384   jitlib.def(
1385       "global_state", [&]() { return &global_state; },
1386       py::return_value_policy::reference);
1387   jitlib.def(
1388       "thread_local_state", [&]() { return &thread_local_state; },
1389       py::return_value_policy::reference);
1390 
1391   jitlib.def("jit_is_disabled", &GetDisableJit);
1392   jitlib.def("get_enable_x64", &GetEnableX64);
1393   jitlib.def("set_thread_local_state_initialization_callback",
1394              [](py::object f) { initialize_local_state = f; });
1395 
1396   jitlib.def(
1397       "jit",
1398       [](py::function fun, py::function cache_miss, py::function get_device,
1399          std::vector<int> static_argnums, std::vector<py::str> static_argnames,
1400          std::vector<int> donate_argnums, bool has_explicit_device,
1401          std::shared_ptr<CompiledFunctionCache> cache) -> py::object {
1402         return MakeCompiledFunction(
1403             std::move(fun), std::move(cache_miss), std::move(get_device),
1404             has_explicit_device, std::move(static_argnums),
1405             std::move(static_argnames), std::move(donate_argnums),
1406             std::move(cache));
1407       },
1408       py::arg("fun"), py::arg("cache_miss"), py::arg("get_device"),
1409       py::arg("static_argnums"),
1410       py::arg("static_argnames") = std::vector<py::str>(),
1411       py::arg("donate_argnums") = std::vector<int>(),
1412       py::arg("has_explicit_device") = false, py::arg("cache") = nullptr);
1413 
1414   // This function is not yet a full replacement for the Python one, because:
1415   // (a) it does not support abstract types,
1416   // (b) it does not set the device stickiness yet.
1417   // TODO(jblespiau): Finish the replacement of the Python feature.
1418   jitlib.def("device_put",
1419              [](py::handle obj, bool jax_enable_x64,
1420                 xla::ClientAndPtr<xla::PjRtDevice> to_device)
1421                  -> xla::StatusOr<py::object> {
1422                std::shared_ptr<xla::PyClient>& pyclient = to_device.client;
1423                xla::DevicePutOptions options;
1424                options.squash_64bit_types = !jax_enable_x64;
1425                options.allow_zero_copy = true;
1426                xla::StatusOr<xla::DevicePutResult> results =
1427                    DevicePut(obj, to_device.contents, options);
1428                if (!results.ok()) {
1429                  throw xla::XlaRuntimeError(results.status().error_message());
1430                }
1431                if (results->owned_buffer) {
1432                  auto buffer = xla::PyBuffer::Make(
1433                      pyclient, std::move(results->owned_buffer),
1434                      xla::Traceback::Get());
1435 
1436                  static const auto* jax_core =
1437                      new py::module(py::module::import("jax.core"));
1438                  static const auto* shaped_array =
1439                      new py::handle(jax_core->attr("ShapedArray"));
1440                  buffer.buf()->SetAval((*shaped_array)(
1441                      buffer.buf()->python_shape(), buffer.buf()->python_dtype(),
1442                      results->weak_type));
1443                  TF_RETURN_IF_ERROR(buffer.buf()->set_sticky_device(nullptr));
1444 
1445                  return std::move(buffer);
1446                } else {
1447                  return py::cast<py::object>(obj);
1448                }
1449              });
1450 
1451   py::class_<xla::PyArgSignature> arg_signature(jitlib, "PyArgSignature");
1452   arg_signature
1453       .def_property_readonly("dtype",
1454                              [](const xla::PyArgSignature& sig) {
1455                                return PrimitiveTypeToDtype(sig.dtype);
1456                              })
1457       .def_property_readonly(
1458           "shape",
1459           [](const xla::PyArgSignature& sig) {
1460             return xla::SpanToTuple(absl::MakeConstSpan(sig.shape));
1461           })
1462       .def_readonly("weak_type", &xla::PyArgSignature::weak_type);
1463   jitlib.def("_ArgSignatureOfValue", &xla::PyArgSignatureOfValue);
1464 
1465   // All private members are only for testing/debugging purposes
1466   cfun.attr("_cache_size") = py::cpp_function(
1467       [](py::handle self) -> xla::StatusOr<int> {
1468         TF_ASSIGN_OR_RETURN(CompiledFunction * fun, AsCompiledFunction(self));
1469         return fun->cache_size();
1470       },
1471       py::is_method(cfun));
1472   cfun.attr("_clear_cache") = py::cpp_function(
1473       [](py::handle self) -> xla::Status {
1474         TF_ASSIGN_OR_RETURN(CompiledFunction * fun, AsCompiledFunction(self));
1475         fun->ClearCache();
1476         return ::tensorflow::OkStatus();
1477       },
1478       py::is_method(cfun));
1479   jitlib.def("_is_float0", &xla::IsFloat0);
1480 }
1481 
1482 }  // namespace jax
1483