1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This files implements the `jax.jit` dispatch and just-in-time feature.
17 //
18 // In a nutshell, `Jit(f)` returns a callable that will dispatch (i.e. forward
19 // based on passed arguments dtypes/shapes/identity) the execution to a
20 // just-in-time compiled XLA Executable. All of that is done in C++ for
21 // performance reasons.
22 //
23 // This file contains the utilities to:
24 // (a) inspect arguments and describe their structure, dtype/shapes, etc.
25 // (b) keep a mapping from function signatures to compiled XLA Executables.
26
27 #include "tensorflow/compiler/xla/python/jax_jit.h"
28
29 #include <Python.h>
30
31 #include <algorithm>
32 #include <exception>
33 #include <memory>
34 #include <optional>
35 #include <stdexcept>
36 #include <string>
37 #include <utility>
38
39 #include "absl/container/flat_hash_map.h"
40 #include "absl/container/inlined_vector.h"
41 #include "absl/strings/str_cat.h"
42 #include "absl/strings/str_format.h"
43 #include "absl/synchronization/notification.h"
44 #include "absl/types/span.h"
45 #include "pybind11/cast.h"
46 #include "pybind11/numpy.h"
47 #include "pybind11/pybind11.h"
48 #include "pybind11/pytypes.h"
49 #include "tensorflow/compiler/xla/pjrt/lru_cache.h"
50 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
51 #include "tensorflow/compiler/xla/python/exceptions.h"
52 #include "tensorflow/compiler/xla/python/py_buffer.h"
53 #include "tensorflow/compiler/xla/python/py_executable.h"
54 #include "tensorflow/compiler/xla/python/py_values.h"
55 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
56 #include "tensorflow/compiler/xla/python/python_utils.h"
57 #include "tensorflow/compiler/xla/python/pytree.h"
58 #include "tensorflow/compiler/xla/python/types.h"
59 #include "tensorflow/compiler/xla/shape_util.h"
60 #include "tensorflow/compiler/xla/statusor.h"
61 #include "tensorflow/compiler/xla/types.h"
62 #include "tensorflow/compiler/xla/util.h"
63 #include "tensorflow/compiler/xla/xla_data.pb.h"
64 #include "tensorflow/core/platform/status.h"
65 #include "tensorflow/core/profiler/lib/traceme.h"
66
67 namespace jax {
68
69 namespace py = pybind11;
70
71 // TODO(phawkins): Add support for Tracers.
72 // TODO(jblespiau): Use absl Status.
73
74 namespace {
75
76 // Protected by the GIL.
77 JitState& global_state = *new JitState();
78
79 // TODO(phawkins): Google style guide forbids thread-local values with
80 // non-trivial destructors.
81 ABSL_CONST_INIT thread_local JitState thread_local_state; // NOLINT
82
83 } // namespace
84
85 // `thread_local_state.extra_jit_context` is set from Python. It's done when
86 // loading the Python jax modules on the main-thread. For other threads, we
87 // need to initialize the field the first time we access `thread_local_state`.
88 py::object& initialize_local_state = *new py::object();
89
GetGlobalState()90 JitState& GetGlobalState() { return global_state; }
GetLocalState()91 JitState& GetLocalState() {
92 if (thread_local_state.extra_jit_context == std::nullopt) {
93 CHECK(initialize_local_state.ptr() != nullptr);
94 initialize_local_state();
95 }
96 return thread_local_state;
97 }
98
GetDisableJit()99 bool GetDisableJit() {
100 CHECK(global_state.disable_jit.has_value());
101 return thread_local_state.disable_jit.value_or(*global_state.disable_jit);
102 }
103
GetEnableX64()104 bool GetEnableX64() {
105 CHECK(global_state.enable_x64.has_value());
106 return thread_local_state.enable_x64.value_or(*global_state.enable_x64);
107 }
108
GetDefaultDevice()109 std::optional<py::object> GetDefaultDevice() {
110 return thread_local_state.default_device.has_value()
111 ? thread_local_state.default_device
112 : global_state.default_device;
113 }
114
GetPostHook()115 std::optional<pybind11::function> GetPostHook() {
116 return thread_local_state.post_hook.has_value() ? thread_local_state.post_hook
117 : global_state.post_hook;
118 }
119
OptionalDebugString(const std::optional<py::object> optional)120 static std::string OptionalDebugString(
121 const std::optional<py::object> optional) {
122 if (optional.has_value()) {
123 return py::cast<std::string>(py::str(optional.value()));
124 } else {
125 return "None";
126 }
127 }
128
DebugString() const129 std::string CallSignature::DebugString() const {
130 auto py_object_formatter = [](std::string* out, const py::object& o) {
131 out->append(py::cast<std::string>(py::str(o)));
132 };
133 auto treedef_formatter = [](std::string* out, const xla::PyTreeDef& d) {
134 out->append(d.ToString());
135 };
136 auto signature_formatter = [](std::string* out,
137 const xla::PyArgSignature& s) {
138 out->append(s.DebugString());
139 };
140 return absl::StrFormat(
141 "static args (positional + keyword): %s\nstatic arg keyword names: %s\n"
142 "dynamic arg signatures (positional + keyword): %s\n"
143 "dynamic arg keyword names: %s\ndynamic arg treedefs: %s\n"
144 "device: %s\n"
145 "jax_enable_x64: %d\n"
146 "global_extra_jit_context: %s\n"
147 "thread_local_extra_jit_context: %s\n",
148 absl::StrJoin(static_args, ",", py_object_formatter),
149 absl::StrJoin(static_arg_names, ",", py_object_formatter),
150 absl::StrJoin(dynamic_arg_signatures, ", ", signature_formatter),
151 absl::StrJoin(dynamic_arg_names, ",", py_object_formatter),
152 absl::StrJoin(dynamic_arg_treedefs, "| ", treedef_formatter), // new line
153 device != nullptr ? device->DebugString() : "nullptr", jax_enable_x64,
154 OptionalDebugString(global_extra_jit_context),
155 OptionalDebugString(thread_local_extra_jit_context));
156 }
157
operator ==(const CallSignature & other) const158 bool CallSignature::operator==(const CallSignature& other) const {
159 return std::tie(dynamic_arg_treedefs, dynamic_arg_names,
160 dynamic_arg_signatures, device, jax_enable_x64,
161 static_arg_names) ==
162 std::tie(other.dynamic_arg_treedefs, other.dynamic_arg_names,
163 other.dynamic_arg_signatures, other.device,
164 other.jax_enable_x64, other.static_arg_names) &&
165 // `==` on py:objects is the Python `is`. We need equal.
166 std::equal(
167 static_args.begin(), static_args.end(), other.static_args.begin(),
168 other.static_args.end(),
169 [this](const py::object& a, const py::object& b) {
170 try {
171 return py::type::handle_of(a) == py::type::handle_of(b) &&
172 a.equal(b);
173 } catch (const py::error_already_set& e) {
174 throw std::invalid_argument(absl::StrCat(
175 "static arguments should be comparable using __eq__."
176 "The following error was raised during a call to '",
177 function_name, "' when comparing two objects of types ",
178 py::cast<std::string>(py::str(py::type::of(a))), " and ",
179 py::cast<std::string>(py::str(py::type::of(b))),
180 ". The error was:\n", e.what()));
181 }
182 }) &&
183 (global_extra_jit_context.has_value() ==
184 other.global_extra_jit_context.has_value()) &&
185 (!global_extra_jit_context.has_value() ||
186 global_extra_jit_context->equal(*other.global_extra_jit_context)) &&
187 (thread_local_extra_jit_context.has_value() ==
188 other.thread_local_extra_jit_context.has_value()) &&
189 (!thread_local_extra_jit_context.has_value() ||
190 thread_local_extra_jit_context->equal(
191 *other.thread_local_extra_jit_context));
192 }
193
194 template <typename H>
AbslHashValue(H h,const CallSignature & s)195 H AbslHashValue(H h, const CallSignature& s) {
196 h = H::combine(std::move(h), s.dynamic_arg_treedefs,
197 s.dynamic_arg_signatures);
198 for (const auto& name : s.dynamic_arg_names) {
199 h = H::combine(std::move(h), name.ptr());
200 }
201 h = H::combine(std::move(h), s.dynamic_arg_names.size());
202 for (const auto& static_arg : s.static_args) {
203 ssize_t hash;
204 try {
205 hash = py::hash(static_arg);
206 } catch (const py::error_already_set& e) {
207 if (!e.matches(PyExc_TypeError)) throw;
208 throw std::invalid_argument(absl::StrCat(
209 "Non-hashable static arguments are not supported. An error occurred "
210 "during a call to '",
211 s.function_name, "' while trying to hash an object of type ",
212 py::cast<std::string>(py::str(py::type::of(static_arg))), ", ",
213 py::cast<std::string>(py::str(static_arg)), ". The error was:\n",
214 e.what(), "\n"));
215 }
216 h = H::combine(std::move(h), hash);
217 }
218 h = H::combine(std::move(h), s.static_args.size());
219 for (const auto& name : s.static_arg_names) {
220 h = H::combine(std::move(h), name.ptr());
221 }
222 h = H::combine(std::move(h), s.static_arg_names.size());
223 h = H::combine(std::move(h), s.device, s.jax_enable_x64);
224
225 // We do not hash the extra_jit_context fields since calling Python hash
226 // functions is expensive (~300ns) and we don't expect a large number of
227 // different contexts.
228 return h;
229 }
230
231 // Filter out static arguments, flatten and concatenate other arguments (i.e.
232 // dynamic positional and keyword arguments), filling `arguments` in place.
ParseArguments(py::handle args,const std::optional<py::kwargs> & py_kwargs,absl::Span<int const> static_argnums,absl::Span<py::str const> static_argnames,ParsedArgumentsAsBuffers & arguments)233 xla::Status ParseArguments(py::handle args,
234 const std::optional<py::kwargs>& py_kwargs,
235 absl::Span<int const> static_argnums,
236 absl::Span<py::str const> static_argnames,
237 ParsedArgumentsAsBuffers& arguments) {
238 tensorflow::profiler::TraceMe traceme("ParseArguments");
239 int num_args = PyTuple_GET_SIZE(args.ptr());
240 int num_kwargs = py_kwargs ? py_kwargs->size() : 0;
241
242 arguments.flat_dynamic_args.reserve(num_args + num_kwargs);
243 if (static_argnums.empty()) {
244 arguments.signature.dynamic_arg_treedefs.resize(num_args);
245
246 // Positional arguments.
247 for (int i = 0; i < num_args; ++i) {
248 xla::PyTreeDef& pytree_def = arguments.signature.dynamic_arg_treedefs[i];
249 pytree_def.FlattenInto(PyTuple_GET_ITEM(args.ptr(), i),
250 arguments.flat_dynamic_args);
251 }
252 } else {
253 arguments.signature.dynamic_arg_treedefs.reserve(num_args);
254
255 // Positional arguments.
256 for (int i = 0; i < num_args; ++i) {
257 if (std::find(static_argnums.begin(), static_argnums.end(), i) ==
258 static_argnums.end()) {
259 arguments.signature.dynamic_arg_treedefs.emplace_back();
260 xla::PyTreeDef& pytree_def =
261 arguments.signature.dynamic_arg_treedefs.back();
262 pytree_def.FlattenInto(PyTuple_GET_ITEM(args.ptr(), i),
263 arguments.flat_dynamic_args);
264 } else {
265 arguments.signature.static_args.emplace_back(
266 py::reinterpret_borrow<py::object>(
267 PyTuple_GET_ITEM(args.ptr(), i)));
268 }
269 }
270 }
271
272 // Keyword arguments.
273 if (py_kwargs) {
274 std::vector<std::pair<py::handle, py::handle>> kwargs(py_kwargs->begin(),
275 py_kwargs->end());
276 // We first intern the keys, then sort them (by name, as in the Python path)
277 // (see also xla::PyTreeDef::Flatten) and then create the signatures.
278 // TODO(jblespiau): We should be able to sort the keys by interned-key
279 // pointers, but this requires the Python compilation to do the same.
280 for (int i = 0; i < num_kwargs; ++i) {
281 // Intern the key if not already interned.
282 kwargs[i].first.inc_ref();
283 if (!PyUnicode_CHECK_INTERNED(kwargs[i].first.ptr())) {
284 PyUnicode_InternInPlace(&kwargs[i].first.ptr());
285 }
286 }
287
288 std::sort(kwargs.begin(), kwargs.end(),
289 [](const std::pair<py::handle, py::handle>& a,
290 const std::pair<py::handle, py::handle>& b) {
291 return a.first < b.first;
292 });
293 auto kwarg_is_static = [&](py::handle name) {
294 for (const auto& kw : static_argnames) {
295 if (kw.ptr() == name.ptr()) return true;
296 }
297 return false;
298 };
299
300 arguments.signature.dynamic_arg_names.reserve(num_kwargs);
301 for (int i = 0; i < num_kwargs; ++i) {
302 if (kwarg_is_static(kwargs[i].first)) {
303 arguments.signature.static_arg_names.push_back(
304 py::reinterpret_steal<py::object>(kwargs[i].first));
305 arguments.signature.static_args.push_back(
306 py::reinterpret_borrow<py::object>(kwargs[i].second));
307 } else {
308 arguments.signature.dynamic_arg_names.push_back(
309 py::reinterpret_steal<py::object>(kwargs[i].first));
310 arguments.signature.dynamic_arg_treedefs.emplace_back();
311 xla::PyTreeDef& pytree_def =
312 arguments.signature.dynamic_arg_treedefs.back();
313 pytree_def.FlattenInto(kwargs[i].second, arguments.flat_dynamic_args);
314 }
315 }
316 }
317 return ::tensorflow::OkStatus();
318 }
319
320 namespace {
321
322 // Elements of CacheEntry are protected by the GIL.
323 struct CacheEntry {
324 // Ensures a single thread performs the compilation for a given executable.
325 //
326 // The first thread (holding the GIL) will create the CacheEntry associated to
327 // a signature and fill it. Other threads will wait for the notification.
328 // If an error occurred during the compilation, `fall_back_to_python` is set
329 // to `true`, and other threads will fail with the same error.
330 absl::Notification compilation_complete;
331
332 std::shared_ptr<xla::PyExecutable> executable;
333 xla::PyTreeDef out_pytree_def;
334 // We use Python types within the vector because this is what we will be
335 // returning to Python. No need to convert back and forth.
336 // We need py::object to maintain the objects alive.
337 std::vector<py::object> out_avals;
338 std::vector<bool> out_weak_types;
339
340 // Bitvector of kept arguments from Jaxpr DCE pass. Used to drop some `args`
341 // in CompiledFunction::Call before calling into compiled computation.
342 std::optional<std::vector<bool>> kept_var_bitvec;
343 std::optional<xla::ClientAndPtr<xla::PjRtDevice>> sticky_device;
344
345 // Fallback to Python happens:
346 // - for trivial computations
347 // - when running a jax(pmap)
348 // - after a compilation error, for threads that did not compile it the first
349 // time
350 bool fall_back_to_python = false;
351
352 // Python objects (notably in the cache key) that must remain alive as long
353 // as the cache entry does. Currently this is the `key` values in the kwarg
354 // entries in the cache key.
355 std::vector<py::object> keepalive;
356 };
357
358 // A CompiledFunctionCache represents a cache of compiled functions that can be
359 // shared between one or more CompiledFunction objects. It serves two goals:
360 // - reduce the number of lru caches (hash map) across multiple JITs.
361 // - make the cache global to increase cache hits (e.g. calling jit(f)(3) twice)
362 // keeping entries alive as long as the underlying function f is alive.
363 // Assume the cache is protected by the GIL.
364 class CompiledFunctionCache {
365 public:
366 static constexpr int kDefaultCapacity = 4096;
367 explicit CompiledFunctionCache(int capacity);
368
369 // Cache entries are shared_ptr<>s because it's possible the cache entry
370 // might be evicted before we finish tracing/compiling.
371 typedef xla::LRUCache<CallSignature, std::shared_ptr<CacheEntry>> Cache;
372
373 // We include as part of the cache key `donate_argnums` (and any other fields
374 // that aren't subsumed by the CallSignature we compute for each call).
375 std::shared_ptr<Cache> Lookup(py::handle function,
376 absl::Span<const int> donate_argnums);
377
Size() const378 int Size() const { return lru_list_.Size(); }
Capacity() const379 int Capacity() const { return lru_list_.Capacity(); }
Clear()380 void Clear() { lru_list_.Clear(); }
381
382 private:
383 struct Key {
384 py::handle function; // Does not hold a reference.
385
386 // Other fields that are part of the arguments to `jit`, but are not
387 // otherwise part of CallSignature.
388 std::vector<int> donate_argnums;
389
operator ==jax::__anon901e1bb70811::CompiledFunctionCache::Key390 bool operator==(const Key& other) const {
391 return std::tie(function, donate_argnums) ==
392 std::tie(other.function, other.donate_argnums);
393 }
394 };
395 template <typename H>
AbslHashValue(H h,const Key & key)396 friend H AbslHashValue(H h, const Key& key) {
397 h = H::combine(std::move(h), key.function.ptr());
398 h = H::combine_contiguous(std::move(h), key.donate_argnums.data(),
399 key.donate_argnums.size());
400 return h;
401 }
402
403 struct Value {
Valuejax::__anon901e1bb70811::CompiledFunctionCache::Value404 explicit Value(std::shared_ptr<Cache> cache) : cache(std::move(cache)) {}
405 std::shared_ptr<Cache> cache;
406
407 // A weak reference to the key function. We use the weak reference to
408 // register a callback that is triggered when the key function is destroyed.
409 // We use a weak pointer because we want to allow caching across multiple
410 // calls to `jax.jit(f)` if `f` remains alive, but we do not want the cache
411 // to keep `f` alive if all other references are dropped.
412 py::weakref weakref;
413 };
414
415 Cache::LRUList lru_list_;
416 absl::flat_hash_map<Key, std::unique_ptr<Value>> functions_;
417 };
418
CompiledFunctionCache(int capacity)419 CompiledFunctionCache::CompiledFunctionCache(int capacity)
420 : lru_list_(capacity) {}
421
Lookup(py::handle function,absl::Span<const int> donate_argnums)422 std::shared_ptr<CompiledFunctionCache::Cache> CompiledFunctionCache::Lookup(
423 py::handle function, absl::Span<const int> donate_argnums) {
424 Key key;
425 key.function = function;
426 key.donate_argnums =
427 std::vector<int>(donate_argnums.begin(), donate_argnums.end());
428 auto insert = functions_.emplace(key, nullptr);
429 std::shared_ptr<Cache> cache = std::make_shared<Cache>(&lru_list_);
430 if (insert.second) {
431 py::cpp_function callback([this, key{std::move(key)}](py::handle weakref) {
432 functions_.erase(key);
433 });
434 PyObject* weakref = PyWeakref_NewRef(function.ptr(), callback.ptr());
435 if (weakref) {
436 std::unique_ptr<Value>& entry = insert.first->second;
437 entry = std::make_unique<Value>(cache);
438 entry->weakref = py::reinterpret_steal<py::weakref>(weakref);
439 } else {
440 PyErr_Clear();
441 // `function` is not weak-referenceable. Don't bother adding it to the
442 // shared cache in that case; the `jit` object will hold the only shared
443 // reference to the cache entry.
444 functions_.erase(insert.first);
445 }
446 }
447 return cache;
448 }
449
450 // A `CompiledFunction` is associated to a `jax.jit(f)` and takes care of the
451 // bookkeeping of the different signatures used and the dispatch of calls to
452 // the correct underlying `PyExecutable`. This class is thread-safe.
453 class CompiledFunction {
454 public:
455 CompiledFunction(py::function fun, py::function cache_miss,
456 py::function get_device, bool has_explicit_device,
457 std::vector<int> static_argnums,
458 std::vector<py::str> static_argnames,
459 std::vector<int> donate_argnums,
460 std::shared_ptr<CompiledFunctionCache> cache);
461 ~CompiledFunction();
462
463 // pybind11::object typed subclass for CompiledFunction objects.
464 class pyobject : public py::object {
465 public:
466 PYBIND11_OBJECT(pyobject, // NOLINT
467 py::object, CompiledFunction::IsCompiledFunction);
468 pyobject() = default;
func() const469 CompiledFunction* func() const {
470 return CompiledFunction::AsCompiledFunctionUnchecked(*this);
471 }
472 };
473 // Alias as ::object; outside the scope above we won't confuse pybind11's
474 // macros.
475 using object = pyobject;
476
477 // Returns true if `h` is a CompiledFunction.
478 static bool IsCompiledFunction(py::handle handle);
479 // Converts `handle` to a CompiledFunction*. Does not do any checking.
480 static CompiledFunction* AsCompiledFunctionUnchecked(py::handle handle);
481
482 // This function will:
483 // (a) flatten the inputs using pytree
484 // (b) get buffer objects from the arguments
485 // (c) call the executable
486 // (d) construct `DeviceArray` objects from the outputs
487 // (e) reconstruct the `PyTree`.
488 xla::StatusOr<py::object> Call(py::handle args,
489 std::optional<py::kwargs> kwargs);
490
491 // This allows `inspect.signature(cpp_jitted_f)` from Python.
PythonSignature()492 py::object PythonSignature() {
493 static const auto* inspect = new py::module(py::module::import("inspect"));
494 return inspect->attr("signature")(fun_);
495 }
496
cache_size() const497 int cache_size() const { return executables_->Size(); }
ClearCache()498 void ClearCache() {
499 // Setting `default_device_` to nullptr forces Call() to retrieve the
500 // device.
501 default_device_ = nullptr;
502 executables_->Clear();
503 }
504
fun() const505 const py::function& fun() const { return fun_; }
cache_miss() const506 const py::function& cache_miss() const { return cache_miss_; }
get_device() const507 const py::function& get_device() const { return get_device_; }
has_explicit_device() const508 bool has_explicit_device() const { return has_explicit_device_; }
static_argnums() const509 const std::vector<int>& static_argnums() const { return static_argnums_; }
static_argnames() const510 const std::vector<py::str>& static_argnames() const {
511 return static_argnames_;
512 }
donate_argnums() const513 const std::vector<int>& donate_argnums() const { return donate_argnums_; }
cache() const514 const std::shared_ptr<CompiledFunctionCache>& cache() const { return cache_; }
515
516 // Helper function used by the tp_clear GC method.
ClearPythonReferences()517 void ClearPythonReferences() {
518 py::function fun, cache_miss, get_device;
519 // Swap values for nulls before they are destroyed. See the Python
520 // Py_CLEAR() documentation for a discussion of this topic.
521 std::swap(fun_, fun);
522 std::swap(cache_miss_, cache_miss);
523 std::swap(get_device_, get_device);
524 }
525
526 py::handle AsPyHandle();
function_name() const527 const std::string& function_name() const { return function_name_; }
528
529 private:
530 // Attempts to populate default_device_. May release the GIL; is
531 // reentrant-safe.
532 void TryToPopulateDefaultDevice();
533
534 void PopulateCacheEntry(CacheEntry* entry, const CallSignature& signature,
535 const py::tuple& out_and_fastpath_data);
536 bool always_fallback_to_python_ = false;
537
538 py::function fun_; // The Python function to jit.
539 std::string function_name_;
540
541 // See JAX _cpp_jit in api.py for documentation.
542 py::function cache_miss_;
543
544 // We need to know the static arguments to remove them from the arguments
545 // passed to the underlying PyExecutable. In sorted order.
546 std::vector<int> static_argnums_;
547 // Keyword arguments, interned.
548 std::vector<py::str> static_argnames_;
549 std::vector<int> donate_argnums_;
550
551 // Whether this function has an explicit device set by either the `device` or
552 // `backend` arguments to jit.
553 bool has_explicit_device_;
554
555 // A function taking no arguments and returning the default device and whether
556 // jax.jit has been committed to it.
557 py::function get_device_;
558
559 // Keeps the shared LRU cache alive as long as the CompiledFunction is alive.
560 std::shared_ptr<CompiledFunctionCache> cache_;
561
562 // The part of cache_ specific to this CompiledFunction.
563 std::shared_ptr<CompiledFunctionCache::Cache> executables_;
564
565 // The logic if the following:
566 // - if `device` or `backend` are not specified to `jax.jit`, we will use
567 // the input sticky buffer device, or `default_device_` if there is no
568 // such sticky buffer.
569 // - When one of `device` or `backend` is specified, this will determine
570 // the `default_device_` which will be used as the targeted device. In
571 // which case, we will always copy input buffers to this device.
572 // These fields are protected by the GIL.
573 xla::PjRtDevice* default_device_ = nullptr;
574 bool is_committed_;
575 };
576
577 // This class keeps references to all CompiledFunctions. This class is
578 // thread-compatible.
579 class CompiledFunctionStore {
580 public:
Insert(CompiledFunction * function)581 void Insert(CompiledFunction* function) {
582 compiled_functions_.insert(function);
583 }
584
Erase(CompiledFunction * function)585 void Erase(CompiledFunction* function) {
586 compiled_functions_.erase(function);
587 }
588
ClearFunctionCache()589 void ClearFunctionCache() {
590 for (auto* function : compiled_functions_) {
591 function->ClearCache();
592 }
593 }
594
595 private:
596 absl::flat_hash_set<CompiledFunction*> compiled_functions_;
597 };
598
599 // Protected by GIL.
GetGlobalCompiledFunctionStore()600 CompiledFunctionStore& GetGlobalCompiledFunctionStore() {
601 static auto* const store = new CompiledFunctionStore();
602 return *store;
603 }
604
CompiledFunction(py::function fun,py::function cache_miss,py::function get_device,bool has_explicit_device,std::vector<int> static_argnums,std::vector<py::str> static_argnames,std::vector<int> donate_argnums,std::shared_ptr<CompiledFunctionCache> cache)605 CompiledFunction::CompiledFunction(py::function fun, py::function cache_miss,
606 py::function get_device,
607 bool has_explicit_device,
608 std::vector<int> static_argnums,
609 std::vector<py::str> static_argnames,
610 std::vector<int> donate_argnums,
611 std::shared_ptr<CompiledFunctionCache> cache)
612 : fun_(std::move(fun)),
613 cache_miss_(std::move(cache_miss)),
614 static_argnums_(std::move(static_argnums)),
615 static_argnames_(std::move(static_argnames)),
616 donate_argnums_(donate_argnums),
617 has_explicit_device_(std::move(has_explicit_device)),
618 get_device_(std::move(get_device)),
619 cache_(std::move(cache)) {
620 std::sort(static_argnums_.begin(), static_argnums_.end());
621 for (py::str& s : static_argnames) {
622 PyUnicode_InternInPlace(&s.ptr());
623 }
624 executables_ = cache_->Lookup(fun_, donate_argnums);
625 function_name_ = py::str(py::getattr(fun_, "__name__", fun));
626
627 GetGlobalCompiledFunctionStore().Insert(this);
628 }
629
~CompiledFunction()630 CompiledFunction::~CompiledFunction() {
631 GetGlobalCompiledFunctionStore().Erase(this);
632 }
633
634 // Returns nullptr if arg has no sticky device
GetJitArgumentStickyDevice(py::handle arg)635 static xla::StatusOr<xla::PjRtDevice*> GetJitArgumentStickyDevice(
636 py::handle arg) {
637 struct PythonTypes {
638 py::object device_array;
639 };
640 static const auto& types = *[]() -> PythonTypes* {
641 py::module xla_module(py::module::import("jax.interpreters.xla"));
642 py::object device_array;
643 if (py::hasattr(xla_module, "_DeviceArray")) {
644 device_array = xla_module.attr("_DeviceArray");
645 }
646 return new PythonTypes{device_array};
647 }();
648
649 // We specically only deal with DeviceArray (not ShardedDeviceArray).
650 // (Can happen in jit(pmap), e.g. "test_jit_nested_donate_ignored").
651 if (arg.get_type().ptr() == xla::PyBuffer::type()) {
652 xla::PyBuffer* buffer = xla::PyBuffer::AsPyBufferUnchecked(arg);
653 if (!buffer->sticky_device()) {
654 return nullptr;
655 }
656 return buffer->sticky_device();
657 }
658
659 if (arg.get_type().ptr() == types.device_array.ptr()) {
660 if (arg.attr("_device").is_none()) {
661 return nullptr;
662 }
663 try {
664 // This can fail, e.g. for cloud TPU 2VM buffers.
665 TF_ASSIGN_OR_RETURN(xla::PyBuffer * buffer,
666 xla::PyBuffer::AsPyBuffer(arg.attr("device_buffer")));
667 return buffer->buffer()->device();
668 } catch (const py::cast_error& e) {
669 return xla::InvalidArgument(
670 "%s", absl::StrCat("[jaxjit] Unsupported subclass of `DeviceArray`: "
671 "`device_buffer` field is of type ",
672 py::cast<std::string>(
673 arg.attr("device_buffer").get_type().str()),
674 " while a `PyBuffer` was expected."));
675 }
676 }
677
678 return nullptr;
679 }
680
681 // Compute signature for arguments.
682 //
683 // Returns `Status::OK()` on success. Returning an error should lead to
684 // calling the Python fallback.
ComputeSignature(bool jax_enable_x64,xla::PjRtDevice * default_device,bool is_committed,ParsedArgumentsAsBuffers & arguments)685 xla::Status ComputeSignature(bool jax_enable_x64,
686 xla::PjRtDevice* default_device, bool is_committed,
687 ParsedArgumentsAsBuffers& arguments) {
688 tensorflow::profiler::TraceMe traceme("ComputeSignature");
689
690 int num_flat_dynamic_args = arguments.flat_dynamic_args.size();
691 // When the jitted function is not committed, we first check whether any
692 // sticky `DeviceArray` is present and on which device they live. See also:
693 // https://github.com/google/jax/pull/1884
694 // https://github.com/google/jax/pull/1916 for the rationale why the
695 // computation follows the data locality.
696 // It's also similar to PyTorch's behavior.
697 xla::PjRtDevice* data_device = nullptr;
698 if (is_committed) {
699 data_device = default_device;
700 } else {
701 for (int i = 0; i < num_flat_dynamic_args; ++i) {
702 TF_ASSIGN_OR_RETURN(
703 xla::PjRtDevice * device,
704 GetJitArgumentStickyDevice(arguments.flat_dynamic_args[i]));
705 if (device) {
706 if (data_device && (device != data_device)) {
707 throw std::invalid_argument(absl::StrCat(
708 "primitive arguments must be colocated on the same device ("
709 "C++ jax.jit). Arguments are on devices: ",
710 device->DebugString(), " and ", data_device->DebugString()));
711 } else {
712 data_device = device;
713 }
714 }
715 }
716 }
717 if (!data_device) {
718 // No `DeviceArray` were found default to `default_device`.
719 data_device = default_device;
720 }
721 CHECK(data_device);
722 arguments.signature.device = data_device;
723
724 arguments.signature.dynamic_arg_signatures.reserve(num_flat_dynamic_args);
725 for (int i = 0; i < num_flat_dynamic_args; ++i) {
726 py::handle arg = arguments.flat_dynamic_args[i];
727 TF_ASSIGN_OR_RETURN(auto sig,
728 xla::PyArgSignatureOfValue(arg, jax_enable_x64));
729 arguments.signature.dynamic_arg_signatures.push_back(std::move(sig));
730 }
731 return ::tensorflow::OkStatus();
732 }
733
734 // Copy buffers to device, skipping pruned arguments.
735 // Returns `Status::OK()` on success. Returning an error should lead to
736 // calling the Python fallback.
CopyBuffersToDevice(bool jax_enable_x64,const std::optional<std::vector<bool>> & kept_args,ParsedArgumentsAsBuffers & arguments)737 xla::Status CopyBuffersToDevice(
738 bool jax_enable_x64, const std::optional<std::vector<bool>>& kept_args,
739 ParsedArgumentsAsBuffers& arguments) {
740 std::vector<xla::PjRtBuffer*>& arg_buffers = arguments.arg_buffers;
741 xla::PjRtDevice* data_device = arguments.signature.device;
742
743 int num_flat_dynamic_args = arguments.flat_dynamic_args.size();
744 xla::DevicePutOptions options;
745 options.squash_64bit_types = !jax_enable_x64;
746 options.allow_zero_copy = true;
747 arg_buffers.reserve(num_flat_dynamic_args);
748 bool input_pruning_enabled = kept_args.has_value();
749 for (int i = 0; i < num_flat_dynamic_args; ++i) {
750 if (input_pruning_enabled && !kept_args.value()[i]) {
751 continue;
752 }
753
754 py::handle arg = arguments.flat_dynamic_args[i];
755 TF_ASSIGN_OR_RETURN(xla::DevicePutResult on_device,
756 DevicePut(arg, data_device, options));
757
758 xla::PjRtBuffer* buffer = on_device.buffer;
759 arg_buffers.push_back(buffer);
760 if (on_device.owned_buffer) {
761 arguments.keep_alive.push_back(std::move(on_device.owned_buffer));
762 } else if (on_device.owning_pybuffer) {
763 arguments.keep_alive_objects.push_back(
764 std::move(on_device.owning_pybuffer));
765 }
766 }
767 return ::tensorflow::OkStatus();
768 }
769
PopulateCacheEntry(CacheEntry * cache_entry,const CallSignature & signature,const py::tuple & out_and_fastpath_data)770 void CompiledFunction::PopulateCacheEntry(
771 CacheEntry* cache_entry, const CallSignature& signature,
772 const py::tuple& out_and_fastpath_data) {
773 CHECK_EQ(out_and_fastpath_data.size(), 2);
774 if (out_and_fastpath_data[1].is_none()) {
775 cache_entry->fall_back_to_python = true;
776 return;
777 }
778
779 py::tuple executable_handlers_out_tree =
780 py::cast<py::tuple>(out_and_fastpath_data[1]);
781 auto executable = py::cast<std::shared_ptr<xla::PyExecutable>>(
782 executable_handlers_out_tree.attr("xla_executable"));
783 cache_entry->executable = std::move(executable);
784 int num_devices =
785 cache_entry->executable->pjrt_executable().addressable_devices().size();
786 // The presence of jit(pmap) is detected from Python.
787 CHECK_EQ(num_devices, 1);
788
789 auto out_tree = py::cast<xla::PyTreeDef>(
790 executable_handlers_out_tree.attr("out_pytree_def"));
791 cache_entry->out_pytree_def = std::move(out_tree);
792
793 cache_entry->sticky_device =
794 py::cast<std::optional<xla::ClientAndPtr<xla::PjRtDevice>>>(
795 executable_handlers_out_tree.attr("sticky_device"));
796 auto avals = py::cast<py::list>(executable_handlers_out_tree.attr("avals"));
797
798 cache_entry->out_avals.reserve(avals.size());
799 cache_entry->out_weak_types.reserve(avals.size());
800 for (int i = 0; i < avals.size(); ++i) {
801 py::object shaped_array = py::reinterpret_borrow<py::object>(avals[i]);
802
803 cache_entry->out_avals.push_back(shaped_array);
804 cache_entry->out_weak_types.push_back(
805 py::cast<bool>(shaped_array.attr("weak_type")));
806 }
807 auto kept_var_bitvec_attr =
808 py::getattr(executable_handlers_out_tree, "kept_var_bitvec", py::none());
809 if (!kept_var_bitvec_attr.is_none()) {
810 auto kept_var_bitvec = py::cast<py::list>(kept_var_bitvec_attr);
811 cache_entry->kept_var_bitvec =
812 std::make_optional<std::vector<bool>>(kept_var_bitvec.size(), false);
813 for (int i = 0; i < kept_var_bitvec.size(); ++i) {
814 cache_entry->kept_var_bitvec.value()[i] =
815 py::cast<bool>(kept_var_bitvec[i]);
816 }
817 }
818 }
819
TryToPopulateDefaultDevice()820 void CompiledFunction::TryToPopulateDefaultDevice() {
821 // The following line calls Python and may release the GIL.
822 py::object device_and_is_committed;
823 try {
824 device_and_is_committed = get_device_();
825 } catch (py::error_already_set& e) {
826 // Backend or device initialization failed. Handle this in Python.
827 always_fallback_to_python_ = true;
828 return;
829 }
830 // If the GIL was released by the call to get_device_, another thread may
831 // have filled in default_device_.
832 if (!default_device_) {
833 try {
834 auto default_pydevice = py::cast<xla::ClientAndPtr<xla::PjRtDevice>>(
835 device_and_is_committed.attr("default_device"));
836 is_committed_ =
837 py::cast<bool>(device_and_is_committed.attr("committed_to_device"));
838 default_device_ = default_pydevice.contents;
839 } catch (const py::cast_error& e) {
840 // Pathways, Cloud TPU 2VM, and UPTC runtime.
841 always_fallback_to_python_ = true;
842 }
843 }
844 }
845
Call(py::handle args,std::optional<py::kwargs> kwargs)846 xla::StatusOr<py::object> CompiledFunction::Call(
847 py::handle args, std::optional<py::kwargs> kwargs) {
848 VLOG(3) << "Calling CompiledFunction " << function_name_;
849
850 // Make sure we trigger a garbage collection on JIT function calls. Otherwise
851 // code like
852 // f = jit(...)
853 // while True:
854 // f(x)
855 // may never free temporary buffers for copies of arguments.
856 xla::GlobalPyRefManager()->MaybeCollectGarbage();
857
858 auto& tls = thread_local_state;
859 if (GetDisableJit()) {
860 return fun_(*py::reinterpret_borrow<py::args>(args),
861 **kwargs.value_or(py::kwargs()));
862 }
863 if (always_fallback_to_python_) {
864 return py::object(
865 py::cast<py::tuple>(cache_miss_(*py::reinterpret_borrow<py::args>(args),
866 **kwargs.value_or(py::kwargs())))[0]);
867 }
868
869 xla::PjRtDevice* device = nullptr;
870 // Whether `device` should override an input with a sticky device.
871 bool is_committed;
872 if (!has_explicit_device_ && GetDefaultDevice().has_value()) {
873 xla::ClientAndPtr<xla::PjRtDevice> pjrt_device_ptr;
874 bool cast_success = true;
875 try {
876 pjrt_device_ptr =
877 GetDefaultDevice()->cast<xla::ClientAndPtr<xla::PjRtDevice>>();
878 } catch (py::cast_error& e) {
879 // We assume GetDefaultDevice() returned a non-PJRT device object. Leave
880 // `device` unset so we fallback to Python path and handle default device
881 // there.
882 cast_success = false;
883 }
884 if (cast_success) {
885 device = pjrt_device_ptr.get();
886 is_committed = false;
887 VLOG(3) << "Using config.default_device (uncommitted): "
888 << device->DebugString();
889 }
890 }
891 if (device == nullptr) {
892 // Call back into Python to find system default device, which will be stored
893 // in default_device_.
894 if (!default_device_) {
895 // On the first call to `Call`, compute a default device. We need to wait
896 // until after platform initialization is complete before doing so, but
897 // @jit may be used as a decorator.
898 TryToPopulateDefaultDevice();
899 if (!default_device_) {
900 return py::object(py::cast<py::tuple>(
901 cache_miss_(*py::reinterpret_borrow<py::args>(args),
902 **kwargs.value_or(py::kwargs())))[0]);
903 }
904 }
905 device = default_device_;
906 is_committed = is_committed_;
907 VLOG(3) << "Using device from Python): " << device->DebugString()
908 << ", committed: " << is_committed;
909 }
910 CHECK(device != nullptr);
911
912 ParsedArgumentsAsBuffers arguments;
913 arguments.signature.function_name = function_name_;
914 xla::Status status = ParseArguments(args, kwargs, static_argnums_,
915 static_argnames_, arguments);
916 if (!status.ok()) {
917 VLOG(2) << "ParseArguments failed: " << status;
918 return py::object(
919 py::cast<py::tuple>(cache_miss_(*py::reinterpret_borrow<py::args>(args),
920 **kwargs.value_or(py::kwargs())))[0]);
921 }
922
923 bool jax_enable_x64 = GetEnableX64();
924 arguments.signature.jax_enable_x64 = jax_enable_x64;
925 // The C++ jit do not support Tracers arguments inputs yet. The Python-based
926 // jit function will be called if any of the dynamic arguments is unsupported.
927 status = ComputeSignature(jax_enable_x64, device, is_committed, arguments);
928 if (!status.ok()) {
929 VLOG(2) << "ComputeSignature failed: " << status;
930 return py::object(
931 py::cast<py::tuple>(cache_miss_(*py::reinterpret_borrow<py::args>(args),
932 **kwargs.value_or(py::kwargs())))[0]);
933 }
934 arguments.signature.global_extra_jit_context = global_state.extra_jit_context;
935 arguments.signature.thread_local_extra_jit_context = tls.extra_jit_context;
936
937 VLOG(3) << "CallSignature:\n" << arguments.signature.DebugString();
938 bool inserted = false;
939 std::shared_ptr<CacheEntry> cache_entry = executables_->GetOrCreateIfAbsent(
940 arguments.signature, [&inserted](const CallSignature& key) {
941 inserted = true;
942 return std::make_shared<CacheEntry>();
943 });
944
945 if (!cache_entry->compilation_complete.HasBeenNotified()) {
946 // In case of several threads attempting to compile the executable, only
947 // the one that inserted the item will perform the compilation.
948 if (inserted) {
949 py::object out_and_fastpath_data;
950 py::tuple out_tuple;
951 VLOG(2) << "Cache miss for\n" << arguments.signature.DebugString();
952 try {
953 // Calls Python and may release the GIL. May also throw if
954 // compilation/tracing fails.
955 out_and_fastpath_data =
956 cache_miss_(*py::reinterpret_borrow<py::args>(args),
957 **kwargs.value_or(py::kwargs()));
958 out_tuple = py::cast<py::tuple>(out_and_fastpath_data);
959 PopulateCacheEntry(cache_entry.get(), arguments.signature, out_tuple);
960 } catch (const std::exception& e) {
961 cache_entry->fall_back_to_python = true;
962 cache_entry->compilation_complete.Notify();
963 throw;
964 }
965 cache_entry->compilation_complete.Notify();
966
967 // We have already computed the result in the miss path so we can return
968 // it. We are even *required* to do so if there are donated arguments,
969 // because any donated buffers will now be invalid.
970 return py::object(out_tuple[0]);
971 } else {
972 // Release the GIL while we wait, making sure the compile thread can
973 // lock it.
974 py::gil_scoped_release release;
975 cache_entry->compilation_complete.WaitForNotification();
976 }
977 }
978 // It's hard to reraise the exact same kind of errors when a compilation error
979 // occurred. If the first compilation failed, other threads will also execute
980 // the Python path.
981 if (cache_entry->fall_back_to_python) {
982 return py::object(
983 py::cast<py::tuple>(cache_miss_(*py::reinterpret_borrow<py::args>(args),
984 **kwargs.value_or(py::kwargs())))[0]);
985 }
986
987 status = CopyBuffersToDevice(jax_enable_x64, cache_entry->kept_var_bitvec,
988 arguments);
989 if (!status.ok()) {
990 VLOG(2) << "CopyBuffersToDevice failed: " << status;
991 return py::object(
992 py::cast<py::tuple>(cache_miss_(*py::reinterpret_borrow<py::args>(args),
993 **kwargs.value_or(py::kwargs())))[0]);
994 }
995
996 // Executes the computation.
997 std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>> output_buffers;
998 {
999 py::gil_scoped_release gil_release;
1000 TF_ASSIGN_OR_RETURN(
1001 output_buffers,
1002 cache_entry->executable->mutable_pjrt_executable()->Execute(
1003 {arguments.arg_buffers}, cache_entry->executable->options()));
1004 }
1005 auto traceback = xla::Traceback::Get();
1006
1007 int num_outputs = output_buffers[0].size();
1008 absl::InlinedVector<py::object, 1> flat_device_arrays;
1009 flat_device_arrays.reserve(num_outputs);
1010 for (int i = 0; i < output_buffers[0].size(); ++i) {
1011 bool last = (i == (num_outputs - 1));
1012 xla::PyBuffer::object buffer = xla::PyBuffer::Make(
1013 cache_entry->executable->client(), std::move(output_buffers[0][i]),
1014 last ? std::move(traceback) : traceback);
1015 buffer.buf()->SetAval(cache_entry->out_avals[i]);
1016 buffer.buf()->set_weak_type(cache_entry->out_weak_types[i]);
1017 if (cache_entry->sticky_device.has_value()) {
1018 TF_RETURN_IF_ERROR(
1019 buffer.buf()->set_sticky_device((*cache_entry->sticky_device).get()));
1020 }
1021 flat_device_arrays.push_back(std::move(buffer));
1022 }
1023 py::object out = cache_entry->out_pytree_def.Unflatten(flat_device_arrays);
1024
1025 // If there is a post-hook function, call it with the inputs and the outputs.
1026 std::optional<py::object> post_hook = GetPostHook();
1027 if (post_hook) {
1028 (*post_hook)(AsPyHandle(), args,
1029 py::cast<py::dict>(kwargs.value_or(py::kwargs())), out);
1030 }
1031 return std::move(out);
1032 }
1033
1034 struct JaxCompiledFunctionObject {
1035 PyObject_HEAD;
1036 PyObject* dict; // Dictionary for __dict__
1037 PyObject* weakrefs; // Weak references; for use by the Python interpreter.
1038 CompiledFunction fun;
1039 };
1040
1041 PyObject* JaxCompiledFunction_Type = nullptr;
1042
IsCompiledFunction(py::handle handle)1043 bool CompiledFunction::IsCompiledFunction(py::handle handle) {
1044 return handle.get_type() == JaxCompiledFunction_Type;
1045 }
1046
AsCompiledFunctionUnchecked(py::handle handle)1047 CompiledFunction* CompiledFunction::AsCompiledFunctionUnchecked(
1048 py::handle handle) {
1049 return &(reinterpret_cast<JaxCompiledFunctionObject*>(handle.ptr())->fun);
1050 }
1051
AsCompiledFunction(py::handle handle)1052 xla::StatusOr<CompiledFunction*> AsCompiledFunction(py::handle handle) {
1053 if (!CompiledFunction::IsCompiledFunction(handle)) {
1054 return xla::InvalidArgument("Expected a CompiledFunction");
1055 }
1056 return CompiledFunction::AsCompiledFunctionUnchecked(handle);
1057 }
1058
AsPyHandle()1059 py::handle CompiledFunction::AsPyHandle() {
1060 return reinterpret_cast<PyObject*>(reinterpret_cast<char*>(this) -
1061 offsetof(JaxCompiledFunctionObject, fun));
1062 }
1063
1064 extern "C" {
1065
JaxCompiledFunction_tp_new(PyTypeObject * subtype,PyObject * args,PyObject * kwds)1066 PyObject* JaxCompiledFunction_tp_new(PyTypeObject* subtype, PyObject* args,
1067 PyObject* kwds) {
1068 JaxCompiledFunctionObject* self =
1069 reinterpret_cast<JaxCompiledFunctionObject*>(
1070 subtype->tp_alloc(subtype, 0));
1071 if (!self) return nullptr;
1072 self->dict = nullptr;
1073 self->weakrefs = nullptr;
1074 return reinterpret_cast<PyObject*>(self);
1075 }
1076
JaxCompiledFunction_tp_dealloc(PyObject * self)1077 void JaxCompiledFunction_tp_dealloc(PyObject* self) {
1078 PyTypeObject* tp = Py_TYPE(self);
1079 JaxCompiledFunctionObject* o =
1080 reinterpret_cast<JaxCompiledFunctionObject*>(self);
1081 if (o->weakrefs) {
1082 PyObject_ClearWeakRefs(self);
1083 }
1084 Py_CLEAR(o->dict);
1085 o->fun.~CompiledFunction();
1086 tp->tp_free(self);
1087 Py_DECREF(tp);
1088 }
1089
JaxCompiledFunction_tp_traverse(PyObject * self,visitproc visit,void * arg)1090 int JaxCompiledFunction_tp_traverse(PyObject* self, visitproc visit,
1091 void* arg) {
1092 JaxCompiledFunctionObject* o =
1093 reinterpret_cast<JaxCompiledFunctionObject*>(self);
1094 Py_VISIT(o->dict);
1095 Py_VISIT(o->fun.fun().ptr());
1096 Py_VISIT(o->fun.cache_miss().ptr());
1097 Py_VISIT(o->fun.get_device().ptr());
1098 return 0;
1099 }
1100
JaxCompiledFunction_tp_clear(PyObject * self)1101 int JaxCompiledFunction_tp_clear(PyObject* self) {
1102 JaxCompiledFunctionObject* o =
1103 reinterpret_cast<JaxCompiledFunctionObject*>(self);
1104 Py_CLEAR(o->dict);
1105 o->fun.ClearPythonReferences();
1106 return 0;
1107 }
1108
1109 // Implements the Python descriptor protocol so JIT-compiled functions can be
1110 // used as bound methods. See:
1111 // https://docs.python.org/3/howto/descriptor.html#functions-and-methods
JaxCompiledFunction_tp_descr_get(PyObject * self,PyObject * obj,PyObject * type)1112 PyObject* JaxCompiledFunction_tp_descr_get(PyObject* self, PyObject* obj,
1113 PyObject* type) {
1114 if (obj == nullptr || obj == Py_None) {
1115 Py_INCREF(self);
1116 return self;
1117 }
1118 return PyMethod_New(self, obj);
1119 }
1120
1121 // Support d = instance.__dict__.
JaxCompiledFunction_get_dict(PyObject * self,void *)1122 PyObject* JaxCompiledFunction_get_dict(PyObject* self, void*) {
1123 JaxCompiledFunctionObject* o =
1124 reinterpret_cast<JaxCompiledFunctionObject*>(self);
1125 if (!o->dict) {
1126 o->dict = PyDict_New();
1127 }
1128 Py_XINCREF(o->dict);
1129 return o->dict;
1130 }
1131
JaxCompiledFunction_set_dict(PyObject * self,PyObject * new_dict,void *)1132 int JaxCompiledFunction_set_dict(PyObject* self, PyObject* new_dict, void*) {
1133 JaxCompiledFunctionObject* o =
1134 reinterpret_cast<JaxCompiledFunctionObject*>(self);
1135 if (!PyDict_Check(new_dict)) {
1136 PyErr_Format(PyExc_TypeError,
1137 "__dict__ must be set to a dictionary, not a '%s'",
1138 Py_TYPE(new_dict)->tp_name);
1139 return -1;
1140 }
1141 Py_INCREF(new_dict);
1142 Py_CLEAR(o->dict);
1143 o->dict = new_dict;
1144 return 0;
1145 }
1146
1147 static PyGetSetDef JaxCompiledFunction_tp_getset[] = {
1148 // Having a __dict__ seems necessary to allow !functool.wraps to override
1149 // __doc__.
1150 {const_cast<char*>("__dict__"), JaxCompiledFunction_get_dict,
1151 JaxCompiledFunction_set_dict, nullptr, nullptr},
1152 {nullptr, nullptr, nullptr, nullptr, nullptr}};
1153
JaxCompiledFunction_tp_call(PyObject * self,PyObject * args,PyObject * kwargs)1154 PyObject* JaxCompiledFunction_tp_call(PyObject* self, PyObject* args,
1155 PyObject* kwargs) {
1156 JaxCompiledFunctionObject* o =
1157 reinterpret_cast<JaxCompiledFunctionObject*>(self);
1158 tensorflow::profiler::TraceMe traceme([&] {
1159 return absl::StrCat("JaxCompiledFunction(", o->fun.function_name(), ")");
1160 });
1161 std::optional<py::kwargs> py_kwargs;
1162 if (kwargs) {
1163 py_kwargs = py::reinterpret_borrow<py::kwargs>(kwargs);
1164 }
1165 try {
1166 xla::StatusOr<py::object> out = o->fun.Call(args, std::move(py_kwargs));
1167 if (!out.ok()) {
1168 PyErr_SetString(PyExc_ValueError, out.status().ToString().c_str());
1169 return nullptr;
1170 }
1171 return out.ValueOrDie().release().ptr();
1172 } catch (py::error_already_set& e) {
1173 e.restore();
1174 return nullptr;
1175 } catch (py::cast_error& e) {
1176 PyErr_SetString(PyExc_ValueError, e.what());
1177 return nullptr;
1178 } catch (std::invalid_argument& e) {
1179 PyErr_SetString(PyExc_ValueError, e.what());
1180 return nullptr;
1181 } catch (std::runtime_error& e) {
1182 PyErr_SetString(PyExc_ValueError, e.what());
1183 return nullptr;
1184 }
1185 }
1186
JaxCompiledFunction_tp_repr(PyObject * self)1187 PyObject* JaxCompiledFunction_tp_repr(PyObject* self) {
1188 try {
1189 const std::string& repr = absl::StrFormat(
1190 "<CompiledFunction of %s>",
1191 static_cast<std::string>(py::repr(py::getattr(self, "__wrapped__"))));
1192 return PyUnicode_FromString(repr.c_str());
1193 } catch (...) {
1194 // Ignore all errors when accessing a repr.
1195 return PyUnicode_FromString("<CompiledFunction>");
1196 }
1197 }
1198
InitializeCompiledFunction(JaxCompiledFunctionObject * cfun,py::function fun,py::function cache_miss,py::function get_device,bool has_explicit_device,std::vector<int> static_argnums,std::vector<py::str> static_argnames,std::vector<int> donate_argnums,std::shared_ptr<CompiledFunctionCache> cache)1199 void InitializeCompiledFunction(JaxCompiledFunctionObject* cfun,
1200 py::function fun, py::function cache_miss,
1201 py::function get_device,
1202 bool has_explicit_device,
1203 std::vector<int> static_argnums,
1204 std::vector<py::str> static_argnames,
1205 std::vector<int> donate_argnums,
1206 std::shared_ptr<CompiledFunctionCache> cache) {
1207 new (&cfun->fun) CompiledFunction(
1208 std::move(fun), std::move(cache_miss), std::move(get_device),
1209 has_explicit_device, std::move(static_argnums),
1210 std::move(static_argnames), std::move(donate_argnums), std::move(cache));
1211 }
1212
1213 } // extern "C"
1214
MakeCompiledFunction(py::function fun,py::function cache_miss,py::function get_device,bool has_explicit_device,std::vector<int> static_argnums,std::vector<py::str> static_argnames,std::vector<int> donate_argnums,std::shared_ptr<CompiledFunctionCache> cache)1215 py::object MakeCompiledFunction(py::function fun, py::function cache_miss,
1216 py::function get_device,
1217 bool has_explicit_device,
1218 std::vector<int> static_argnums,
1219 std::vector<py::str> static_argnames,
1220 std::vector<int> donate_argnums,
1221 std::shared_ptr<CompiledFunctionCache> cache) {
1222 py::object obj = py::reinterpret_steal<py::object>(JaxCompiledFunction_tp_new(
1223 reinterpret_cast<PyTypeObject*>(JaxCompiledFunction_Type), nullptr,
1224 nullptr));
1225 JaxCompiledFunctionObject* buf =
1226 reinterpret_cast<JaxCompiledFunctionObject*>(obj.ptr());
1227 if (!cache) {
1228 cache = std::make_shared<CompiledFunctionCache>(
1229 CompiledFunctionCache::kDefaultCapacity);
1230 }
1231 InitializeCompiledFunction(
1232 buf, std::move(fun), std::move(cache_miss), std::move(get_device),
1233 has_explicit_device, std::move(static_argnums),
1234 std::move(static_argnames), std::move(donate_argnums), std::move(cache));
1235 return obj;
1236 }
1237
1238 // Version numbers for the pickled representations of
1239 // CompiledFunction/CompiledFunctionCache. Increment these if changing them.
1240 const int kCompiledFunctionCachePickleVersion = 1;
1241 const int kCompiledFunctionPickleVersion = 1;
1242
1243 } // namespace
1244
BuildJaxjitSubmodule(py::module & m)1245 void BuildJaxjitSubmodule(py::module& m) {
1246 py::module jitlib = m.def_submodule("jax_jit", "Jax C++ jit library");
1247
1248 py::class_<CompiledFunctionCache, std::shared_ptr<CompiledFunctionCache>>
1249 cache(jitlib, "CompiledFunctionCache");
1250 cache.def(py::init<int>(),
1251 py::arg("capacity") = CompiledFunctionCache::kDefaultCapacity);
1252 cache.def("size", &CompiledFunctionCache::Size);
1253 cache.def("capacity", &CompiledFunctionCache::Capacity);
1254 cache.def("clear", &CompiledFunctionCache::Clear);
1255 cache.def_static("clear_all", []() {
1256 GetGlobalCompiledFunctionStore().ClearFunctionCache();
1257 });
1258 cache.def(py::pickle(
1259 // __getstate__
1260 // Pickles as an empty cache; the client can repopulate as needed.
1261 [](const CompiledFunctionCache& cache) {
1262 py::dict pickle;
1263 pickle["version"] = kCompiledFunctionCachePickleVersion;
1264 pickle["capacity"] = cache.Capacity();
1265 return pickle;
1266 },
1267 // __setstate__
1268 [](const py::dict& pickle) {
1269 int version = py::cast<int>(pickle["version"]);
1270 if (version != kCompiledFunctionCachePickleVersion) {
1271 throw std::invalid_argument(absl::StrFormat(
1272 "Invalid CompiledFunction pickle version, got %d, expected %d",
1273 version, kCompiledFunctionCachePickleVersion));
1274 }
1275 int capacity = py::cast<int>(pickle["capacity"]);
1276 return std::make_shared<CompiledFunctionCache>(capacity);
1277 }));
1278
1279 // We need to use heap-allocated type objects because we want to add
1280 // additional methods dynamically.
1281 py::object cfun;
1282 {
1283 py::str name = py::str("CompiledFunction");
1284 py::str qualname = py::str("CompiledFunction");
1285 PyHeapTypeObject* heap_type = reinterpret_cast<PyHeapTypeObject*>(
1286 PyType_Type.tp_alloc(&PyType_Type, 0));
1287 // Caution: we must not call any functions that might invoke the GC until
1288 // PyType_Ready() is called. Otherwise the GC might see a half-constructed
1289 // type object.
1290 CHECK(heap_type) << "Unable to create heap type object";
1291 heap_type->ht_name = name.release().ptr();
1292 heap_type->ht_qualname = qualname.release().ptr();
1293 PyTypeObject* type = &heap_type->ht_type;
1294 type->tp_name = "CompiledFunction";
1295 type->tp_basicsize = sizeof(JaxCompiledFunctionObject);
1296 type->tp_flags =
1297 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE | Py_TPFLAGS_HAVE_GC;
1298 type->tp_new = JaxCompiledFunction_tp_new;
1299 type->tp_dealloc = JaxCompiledFunction_tp_dealloc;
1300 type->tp_dictoffset = offsetof(JaxCompiledFunctionObject, dict);
1301 type->tp_traverse = JaxCompiledFunction_tp_traverse;
1302 type->tp_clear = JaxCompiledFunction_tp_clear;
1303 type->tp_weaklistoffset = offsetof(JaxCompiledFunctionObject, weakrefs);
1304 type->tp_getset = JaxCompiledFunction_tp_getset;
1305 type->tp_descr_get = JaxCompiledFunction_tp_descr_get;
1306 type->tp_call = JaxCompiledFunction_tp_call;
1307 type->tp_repr = JaxCompiledFunction_tp_repr;
1308 CHECK_EQ(PyType_Ready(type), 0);
1309 JaxCompiledFunction_Type = reinterpret_cast<PyObject*>(type);
1310 cfun = py::reinterpret_borrow<py::object>(JaxCompiledFunction_Type);
1311 }
1312 py::object cfun_type =
1313 py::reinterpret_borrow<py::object>(JaxCompiledFunction_Type);
1314
1315 // Add CompiledFunction to the xla_extension module so it can be pickled.
1316 m.attr("CompiledFunction") = cfun_type;
1317 cfun.attr("__module__") = m.attr("__name__");
1318
1319 cfun.attr("__signature__") =
1320 property_readonly([](py::handle self) -> xla::StatusOr<py::object> {
1321 TF_ASSIGN_OR_RETURN(CompiledFunction * fun, AsCompiledFunction(self));
1322 return fun->PythonSignature();
1323 });
1324 cfun.attr("_cache_miss") =
1325 property_readonly([](py::handle self) -> xla::StatusOr<py::object> {
1326 TF_ASSIGN_OR_RETURN(CompiledFunction * fun, AsCompiledFunction(self));
1327 return fun->cache_miss();
1328 });
1329 cfun.attr("__getstate__") = py::cpp_function(
1330 [](const CompiledFunction::object& self) {
1331 CompiledFunction* fn = self.func();
1332 py::dict pickle;
1333 pickle["version"] = kCompiledFunctionPickleVersion;
1334 pickle["fun"] = fn->fun();
1335 pickle["cache_miss"] = fn->cache_miss();
1336 pickle["get_device"] = fn->get_device();
1337 pickle["has_explicit_device"] = fn->has_explicit_device();
1338 pickle["static_argnums"] = fn->static_argnums();
1339 pickle["static_argnames"] = fn->static_argnames();
1340 pickle["donate_argnums"] = fn->donate_argnums();
1341 pickle["cache"] = fn->cache();
1342 return pickle;
1343 },
1344 py::is_method(cfun_type));
1345 cfun.attr("__setstate__") = py::cpp_function(
1346 [](CompiledFunction::object& self, const py::dict& pickle) {
1347 int version = py::cast<int>(pickle["version"]);
1348 if (version != kCompiledFunctionPickleVersion) {
1349 throw std::invalid_argument(absl::StrFormat(
1350 "Invalid CompiledFunction pickle version, got %d, expected %d. "
1351 "Pickling/Unpickling jitted functions using different JAX "
1352 "versions is not supported.",
1353 version, kCompiledFunctionPickleVersion));
1354 }
1355 py::function fun = py::cast<py::function>(pickle["fun"]);
1356 py::function cache_miss = py::cast<py::function>(pickle["cache_miss"]);
1357 py::function get_device = py::cast<py::function>(pickle["get_device"]);
1358 bool has_explicit_device =
1359 py::cast<bool>(pickle["has_explicit_device"]);
1360 std::vector<int> static_argnums =
1361 py::cast<std::vector<int>>(pickle["static_argnums"]);
1362 std::vector<py::str> static_argnames =
1363 py::cast<std::vector<py::str>>(pickle["static_argnames"]);
1364 std::vector<int> donate_argnums =
1365 py::cast<std::vector<int>>(pickle["donate_argnums"]);
1366 std::shared_ptr<CompiledFunctionCache> cache =
1367 py::cast<std::shared_ptr<CompiledFunctionCache>>(pickle["cache"]);
1368 InitializeCompiledFunction(
1369 reinterpret_cast<JaxCompiledFunctionObject*>(self.ptr()),
1370 std::move(fun), std::move(cache_miss), std::move(get_device),
1371 has_explicit_device, std::move(static_argnums),
1372 std::move(static_argnames), std::move(donate_argnums),
1373 std::move(cache));
1374 },
1375 py::is_method(cfun_type));
1376
1377 py::class_<JitState> jit_state_(jitlib, "JitState");
1378 jit_state_.def_readwrite("disable_jit", &JitState::disable_jit);
1379 jit_state_.def_readwrite("enable_x64", &JitState::enable_x64);
1380 jit_state_.def_readwrite("default_device", &JitState::default_device);
1381 jit_state_.def_readwrite("extra_jit_context", &JitState::extra_jit_context);
1382 jit_state_.def_readwrite("post_hook", &JitState::post_hook);
1383
1384 jitlib.def(
1385 "global_state", [&]() { return &global_state; },
1386 py::return_value_policy::reference);
1387 jitlib.def(
1388 "thread_local_state", [&]() { return &thread_local_state; },
1389 py::return_value_policy::reference);
1390
1391 jitlib.def("jit_is_disabled", &GetDisableJit);
1392 jitlib.def("get_enable_x64", &GetEnableX64);
1393 jitlib.def("set_thread_local_state_initialization_callback",
1394 [](py::object f) { initialize_local_state = f; });
1395
1396 jitlib.def(
1397 "jit",
1398 [](py::function fun, py::function cache_miss, py::function get_device,
1399 std::vector<int> static_argnums, std::vector<py::str> static_argnames,
1400 std::vector<int> donate_argnums, bool has_explicit_device,
1401 std::shared_ptr<CompiledFunctionCache> cache) -> py::object {
1402 return MakeCompiledFunction(
1403 std::move(fun), std::move(cache_miss), std::move(get_device),
1404 has_explicit_device, std::move(static_argnums),
1405 std::move(static_argnames), std::move(donate_argnums),
1406 std::move(cache));
1407 },
1408 py::arg("fun"), py::arg("cache_miss"), py::arg("get_device"),
1409 py::arg("static_argnums"),
1410 py::arg("static_argnames") = std::vector<py::str>(),
1411 py::arg("donate_argnums") = std::vector<int>(),
1412 py::arg("has_explicit_device") = false, py::arg("cache") = nullptr);
1413
1414 // This function is not yet a full replacement for the Python one, because:
1415 // (a) it does not support abstract types,
1416 // (b) it does not set the device stickiness yet.
1417 // TODO(jblespiau): Finish the replacement of the Python feature.
1418 jitlib.def("device_put",
1419 [](py::handle obj, bool jax_enable_x64,
1420 xla::ClientAndPtr<xla::PjRtDevice> to_device)
1421 -> xla::StatusOr<py::object> {
1422 std::shared_ptr<xla::PyClient>& pyclient = to_device.client;
1423 xla::DevicePutOptions options;
1424 options.squash_64bit_types = !jax_enable_x64;
1425 options.allow_zero_copy = true;
1426 xla::StatusOr<xla::DevicePutResult> results =
1427 DevicePut(obj, to_device.contents, options);
1428 if (!results.ok()) {
1429 throw xla::XlaRuntimeError(results.status().error_message());
1430 }
1431 if (results->owned_buffer) {
1432 auto buffer = xla::PyBuffer::Make(
1433 pyclient, std::move(results->owned_buffer),
1434 xla::Traceback::Get());
1435
1436 static const auto* jax_core =
1437 new py::module(py::module::import("jax.core"));
1438 static const auto* shaped_array =
1439 new py::handle(jax_core->attr("ShapedArray"));
1440 buffer.buf()->SetAval((*shaped_array)(
1441 buffer.buf()->python_shape(), buffer.buf()->python_dtype(),
1442 results->weak_type));
1443 TF_RETURN_IF_ERROR(buffer.buf()->set_sticky_device(nullptr));
1444
1445 return std::move(buffer);
1446 } else {
1447 return py::cast<py::object>(obj);
1448 }
1449 });
1450
1451 py::class_<xla::PyArgSignature> arg_signature(jitlib, "PyArgSignature");
1452 arg_signature
1453 .def_property_readonly("dtype",
1454 [](const xla::PyArgSignature& sig) {
1455 return PrimitiveTypeToDtype(sig.dtype);
1456 })
1457 .def_property_readonly(
1458 "shape",
1459 [](const xla::PyArgSignature& sig) {
1460 return xla::SpanToTuple(absl::MakeConstSpan(sig.shape));
1461 })
1462 .def_readonly("weak_type", &xla::PyArgSignature::weak_type);
1463 jitlib.def("_ArgSignatureOfValue", &xla::PyArgSignatureOfValue);
1464
1465 // All private members are only for testing/debugging purposes
1466 cfun.attr("_cache_size") = py::cpp_function(
1467 [](py::handle self) -> xla::StatusOr<int> {
1468 TF_ASSIGN_OR_RETURN(CompiledFunction * fun, AsCompiledFunction(self));
1469 return fun->cache_size();
1470 },
1471 py::is_method(cfun));
1472 cfun.attr("_clear_cache") = py::cpp_function(
1473 [](py::handle self) -> xla::Status {
1474 TF_ASSIGN_OR_RETURN(CompiledFunction * fun, AsCompiledFunction(self));
1475 fun->ClearCache();
1476 return ::tensorflow::OkStatus();
1477 },
1478 py::is_method(cfun));
1479 jitlib.def("_is_float0", &xla::IsFloat0);
1480 }
1481
1482 } // namespace jax
1483