xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/jax_jit.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_
18 
19 #include <string>
20 
21 #include "absl/container/inlined_vector.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "pybind11/pybind11.h"
25 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
26 #include "tensorflow/compiler/xla/python/py_client.h"
27 #include "tensorflow/compiler/xla/python/py_values.h"
28 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
29 #include "tensorflow/compiler/xla/python/pytree.h"
30 #include "tensorflow/compiler/xla/types.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 
33 namespace jax {
34 
35 // Flags, such as JIT disable and the x64 mode, are controlled by:
36 // - a global flag value, e.g., associated to --jax_enable_x64
37 // - possibly a thread-local value, which initially is std::nullopt and
38 //   overrides the global value if set. The thread-local state is
39 //   used to implement context managers that locally override the global state.
40 struct JitState {
~JitStateJitState41   ~JitState() {
42     if (extra_jit_context) {
43       // We likely do not hold the GIL if this JitState is thread-local, so we
44       // hand the Python object to the global reference manager to destroy.
45       pybind11::object o = std::move(*extra_jit_context);
46       xla::GlobalPyRefManager()->AddGarbage(absl::MakeSpan(&o, 1));
47       extra_jit_context = std::nullopt;
48     }
49   }
50 
51   std::optional<bool> disable_jit;
52   std::optional<bool> enable_x64;
53 
54   // Used to manually set the default device jax should use. May be unset even
55   // in global state, indicating there is no manual override.
56   // TODO(skyewm): make this a C++ type when all JAX backends support a single
57   // C++ device interface
58   std::optional<pybind11::object> default_device;
59 
60   // Extra context that should be included in the JIT cache key. Must be
61   // hashable and have an equality defined.
62   std::optional<pybind11::object> extra_jit_context;
63 
64   // A callback that, if present, is called when a JITted function is executed
65   // from cache. May be unset even in global state.
66   std::optional<pybind11::function> post_hook;
67 };
68 
69 JitState& GetGlobalState();
70 JitState& GetLocalState();
71 
72 // Getters for JitState fields that first look in thread-local state, then
73 // fallback to global state.
74 bool GetDisableJit();
75 bool GetEnableX64();
76 // TODO(skyewm): return a C++ type when all JAX backends support a single C++
77 // device interface
78 std::optional<pybind11::object> GetDefaultDevice();
79 std::optional<pybind11::function> GetPostHook();
80 
81 // The signature of Python jitted function call, partitioned into:
82 // - dynamic positional arguments (i.e. positional args which are not static)
83 // - static positional arguments (i.e. the args associated to static_argnums)
84 // - keyword arguments
85 // The CallSignature should unambiguously identify a function call, thus,
86 // equality is based on:
87 // (a) Same PyTree for all dynamic positional arguments and keyword arguments
88 // (a) equality of the arguments and keyword arguments ArgSignature
89 // (a) equality (delegated to Python) of the static arguments.
90 struct CallSignature {
91   // Not part of the signature, but we need it for error messages.
92   absl::string_view function_name;
93 
94   // A PyTreeDef for each dynamic argument, positional arguments first
95   // followed by keyword arguments. Keyword arguments are in the order given
96   // by dynamic_arg_names.
97   absl::InlinedVector<xla::PyTreeDef, 2> dynamic_arg_treedefs;
98   // Dynamic keyword argument names. Interned, and sorted by the keyword
99   // name.
100   std::vector<pybind11::object> dynamic_arg_names;
101   // Shape and dtype for both the dynamic positional arguments and the keyword
102   // arguments (sorted by keyword name).
103   absl::InlinedVector<xla::PyArgSignature, 2> dynamic_arg_signatures;
104 
105   // Static arguments. Contains the positional arguments sorted in argument
106   // order, followed by static keyword arguments in the order given by
107   // `static_arg_names`.
108   std::vector<pybind11::object> static_args;
109   // Static keyword argument names. Interned, and sorted by keyword name.
110   std::vector<pybind11::object> static_arg_names;
111 
112   // For JIT, we need this in the key because computation follows the data, so
113   // we may have multiple executables depending on the devices the data is on.
114   // This is not the case for PMAP, and is set to `nullptr`.
115   xla::PjRtDevice* device = nullptr;
116   bool jax_enable_x64;
117 
118   // Opaque additional context that should be included as part of the cache key.
119   std::optional<pybind11::object> global_extra_jit_context;
120   std::optional<pybind11::object> thread_local_extra_jit_context;
121 
122   bool operator==(const CallSignature& other) const;
123   bool operator!=(const CallSignature& other) const {
124     return !(*this == other);
125   }
126 
127   std::string DebugString() const;
128 };
129 
130 template <typename H>
131 H AbslHashValue(H h, const CallSignature& s);
132 
133 // The resulting information of the parsing and conversion of the arguments.
134 struct ParsedArgumentsAsBuffers {
135   // The call signature will be filled during 2 steps:
136   // - `ParseArguments` will fill the static arguments and the pytree
137   //    structures
138   // - the shapes and dtypes are filled later, by `ParseAndTransferArguments`.
139   CallSignature signature;
140   // The concatenation of the dynamic positional arguments and the sorted
141   // keyword arguments.
142   absl::InlinedVector<pybind11::object, 2> flat_dynamic_args;
143   std::vector<pybind11::object> keep_alive_objects;
144 
145   // The following is only valid if the parsing succeeds.
146   std::vector<xla::PjRtBuffer*> arg_buffers;
147   // We may need to keep these objects around, because:
148   // (a) we need to extend the lifetime of objects created within
149   //    `CopyBuffersToDevice`
150   // (b) `arg_buffers` do not maintain ownership
151   std::vector<std::unique_ptr<xla::PjRtBuffer>> keep_alive;
152 };
153 
154 // Filter out static arguments, flatten and concatenate other arguments (i.e.
155 // dynamic positional and keyword arguments), filling `arguments` in place.
156 xla::Status ParseArguments(pybind11::handle args,
157                            const std::optional<pybind11::kwargs>& py_kwargs,
158                            absl::Span<int const> static_argnums,
159                            absl::Span<pybind11::str const> static_argnames,
160                            ParsedArgumentsAsBuffers& arguments);
161 
162 // The function to call in `xla.cc` to add the bindings for this module.
163 void BuildJaxjitSubmodule(pybind11::module& m);
164 
165 }  // namespace jax
166 
167 #endif  // TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_
168