1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_ 17 #define TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_ 18 19 #include <string> 20 21 #include "absl/container/inlined_vector.h" 22 #include "absl/strings/str_cat.h" 23 #include "absl/strings/str_join.h" 24 #include "pybind11/pybind11.h" 25 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" 26 #include "tensorflow/compiler/xla/python/py_client.h" 27 #include "tensorflow/compiler/xla/python/py_values.h" 28 #include "tensorflow/compiler/xla/python/python_ref_manager.h" 29 #include "tensorflow/compiler/xla/python/pytree.h" 30 #include "tensorflow/compiler/xla/types.h" 31 #include "tensorflow/compiler/xla/xla_data.pb.h" 32 33 namespace jax { 34 35 // Flags, such as JIT disable and the x64 mode, are controlled by: 36 // - a global flag value, e.g., associated to --jax_enable_x64 37 // - possibly a thread-local value, which initially is std::nullopt and 38 // overrides the global value if set. The thread-local state is 39 // used to implement context managers that locally override the global state. 40 struct JitState { ~JitStateJitState41 ~JitState() { 42 if (extra_jit_context) { 43 // We likely do not hold the GIL if this JitState is thread-local, so we 44 // hand the Python object to the global reference manager to destroy. 45 pybind11::object o = std::move(*extra_jit_context); 46 xla::GlobalPyRefManager()->AddGarbage(absl::MakeSpan(&o, 1)); 47 extra_jit_context = std::nullopt; 48 } 49 } 50 51 std::optional<bool> disable_jit; 52 std::optional<bool> enable_x64; 53 54 // Used to manually set the default device jax should use. May be unset even 55 // in global state, indicating there is no manual override. 56 // TODO(skyewm): make this a C++ type when all JAX backends support a single 57 // C++ device interface 58 std::optional<pybind11::object> default_device; 59 60 // Extra context that should be included in the JIT cache key. Must be 61 // hashable and have an equality defined. 62 std::optional<pybind11::object> extra_jit_context; 63 64 // A callback that, if present, is called when a JITted function is executed 65 // from cache. May be unset even in global state. 66 std::optional<pybind11::function> post_hook; 67 }; 68 69 JitState& GetGlobalState(); 70 JitState& GetLocalState(); 71 72 // Getters for JitState fields that first look in thread-local state, then 73 // fallback to global state. 74 bool GetDisableJit(); 75 bool GetEnableX64(); 76 // TODO(skyewm): return a C++ type when all JAX backends support a single C++ 77 // device interface 78 std::optional<pybind11::object> GetDefaultDevice(); 79 std::optional<pybind11::function> GetPostHook(); 80 81 // The signature of Python jitted function call, partitioned into: 82 // - dynamic positional arguments (i.e. positional args which are not static) 83 // - static positional arguments (i.e. the args associated to static_argnums) 84 // - keyword arguments 85 // The CallSignature should unambiguously identify a function call, thus, 86 // equality is based on: 87 // (a) Same PyTree for all dynamic positional arguments and keyword arguments 88 // (a) equality of the arguments and keyword arguments ArgSignature 89 // (a) equality (delegated to Python) of the static arguments. 90 struct CallSignature { 91 // Not part of the signature, but we need it for error messages. 92 absl::string_view function_name; 93 94 // A PyTreeDef for each dynamic argument, positional arguments first 95 // followed by keyword arguments. Keyword arguments are in the order given 96 // by dynamic_arg_names. 97 absl::InlinedVector<xla::PyTreeDef, 2> dynamic_arg_treedefs; 98 // Dynamic keyword argument names. Interned, and sorted by the keyword 99 // name. 100 std::vector<pybind11::object> dynamic_arg_names; 101 // Shape and dtype for both the dynamic positional arguments and the keyword 102 // arguments (sorted by keyword name). 103 absl::InlinedVector<xla::PyArgSignature, 2> dynamic_arg_signatures; 104 105 // Static arguments. Contains the positional arguments sorted in argument 106 // order, followed by static keyword arguments in the order given by 107 // `static_arg_names`. 108 std::vector<pybind11::object> static_args; 109 // Static keyword argument names. Interned, and sorted by keyword name. 110 std::vector<pybind11::object> static_arg_names; 111 112 // For JIT, we need this in the key because computation follows the data, so 113 // we may have multiple executables depending on the devices the data is on. 114 // This is not the case for PMAP, and is set to `nullptr`. 115 xla::PjRtDevice* device = nullptr; 116 bool jax_enable_x64; 117 118 // Opaque additional context that should be included as part of the cache key. 119 std::optional<pybind11::object> global_extra_jit_context; 120 std::optional<pybind11::object> thread_local_extra_jit_context; 121 122 bool operator==(const CallSignature& other) const; 123 bool operator!=(const CallSignature& other) const { 124 return !(*this == other); 125 } 126 127 std::string DebugString() const; 128 }; 129 130 template <typename H> 131 H AbslHashValue(H h, const CallSignature& s); 132 133 // The resulting information of the parsing and conversion of the arguments. 134 struct ParsedArgumentsAsBuffers { 135 // The call signature will be filled during 2 steps: 136 // - `ParseArguments` will fill the static arguments and the pytree 137 // structures 138 // - the shapes and dtypes are filled later, by `ParseAndTransferArguments`. 139 CallSignature signature; 140 // The concatenation of the dynamic positional arguments and the sorted 141 // keyword arguments. 142 absl::InlinedVector<pybind11::object, 2> flat_dynamic_args; 143 std::vector<pybind11::object> keep_alive_objects; 144 145 // The following is only valid if the parsing succeeds. 146 std::vector<xla::PjRtBuffer*> arg_buffers; 147 // We may need to keep these objects around, because: 148 // (a) we need to extend the lifetime of objects created within 149 // `CopyBuffersToDevice` 150 // (b) `arg_buffers` do not maintain ownership 151 std::vector<std::unique_ptr<xla::PjRtBuffer>> keep_alive; 152 }; 153 154 // Filter out static arguments, flatten and concatenate other arguments (i.e. 155 // dynamic positional and keyword arguments), filling `arguments` in place. 156 xla::Status ParseArguments(pybind11::handle args, 157 const std::optional<pybind11::kwargs>& py_kwargs, 158 absl::Span<int const> static_argnums, 159 absl::Span<pybind11::str const> static_argnames, 160 ParsedArgumentsAsBuffers& arguments); 161 162 // The function to call in `xla.cc` to add the bindings for this module. 163 void BuildJaxjitSubmodule(pybind11::module& m); 164 165 } // namespace jax 166 167 #endif // TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_ 168