xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/py_values.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Helpers for converting Python values into buffers.
17 
18 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PY_VALUES_H_
19 #define TENSORFLOW_COMPILER_XLA_PYTHON_PY_VALUES_H_
20 
21 #include <memory>
22 
23 #include "pybind11/numpy.h"
24 #include "pybind11/pybind11.h"
25 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
26 #include "tensorflow/compiler/xla/python/py_client.h"
27 
28 namespace xla {
29 
30 struct DevicePutResult {
DevicePutResultDevicePutResult31   explicit DevicePutResult(PjRtBuffer* b, bool weak_type,
32                            pybind11::object owning_pybuffer)
33       : buffer(b), weak_type(weak_type), owning_pybuffer(owning_pybuffer) {}
DevicePutResultDevicePutResult34   explicit DevicePutResult(std::unique_ptr<PjRtBuffer> new_buffer,
35                            bool weak_type)
36       : buffer(new_buffer.get()),
37         weak_type(weak_type),
38         owned_buffer(std::move(new_buffer)) {}
39 
40   // Points to the on-device buffer. Not owned.
41   PjRtBuffer* buffer;
42   bool weak_type;
43 
44   // One of owned_buffer or owning_pybuffer is valid. If owned_buffer is
45   // non-null, it holds ownership of the buffer. Otherwise owning_pybuffer is
46   // the PyBuffer object that owns the buffer.
47   std::unique_ptr<PjRtBuffer> owned_buffer;
48   pybind11::object owning_pybuffer;
49 };
50 
51 // Copies a buffer-like object to be on device.
52 //
53 // If `arg` is not convertible to a `PjRtBuffer` from C++, an error will be
54 // returned; float0s are not supported yet.
55 // If the value is known to be a PyBuffer object, py_buffer can be passed as
56 // an optimization to avoid a Python->C++ cast.
57 //
58 // May throw exceptions from pybind11 in addition to failing via an error
59 // Status. (We could catch these if needed, but there seems little point.)
60 struct DevicePutOptions {
61   bool squash_64bit_types = false;
62   bool allow_zero_copy = true;
63 };
64 StatusOr<DevicePutResult> DevicePut(pybind11::handle arg, PjRtDevice* to_device,
65                                     const DevicePutOptions& options);
66 
67 // Returns `true` if `arg` is a JAX float0 array.
68 bool IsFloat0(pybind11::array arg);
69 
70 // Describes the abstract shape and dtype of an argument.
71 struct PyArgSignature {
PyArgSignaturePyArgSignature72   PyArgSignature(PrimitiveType dtype, absl::Span<const int64_t> shape,
73                  bool weak_type)
74       : dtype(dtype), shape(shape.begin(), shape.end()), weak_type(weak_type) {}
75   // This is the XLA dtype of the object.
76   const PrimitiveType dtype;
77   const absl::InlinedVector<int64_t, 4> shape;
78   // JAX arguments can be of weak type, if and only if they are Python scalars
79   // or `DeviceArray` values such that `aval.weak_type` is true.
80   const bool weak_type;
81   bool operator==(const PyArgSignature& other) const {
82     return std::tie(dtype, weak_type, shape) ==
83            std::tie(other.dtype, other.weak_type, other.shape);
84   }
85   bool operator!=(const PyArgSignature& other) const {
86     return !(*this == other);
87   }
88   std::string DebugString() const;
89 };
90 
91 // Returns the PyArgSignature associated with an argument. Returns an error if
92 // the argument is not supported.
93 StatusOr<PyArgSignature> PyArgSignatureOfValue(pybind11::handle arg,
94                                                bool jax_enable_x64);
95 
96 template <typename H>
AbslHashValue(H h,const xla::PyArgSignature & s)97 H AbslHashValue(H h, const xla::PyArgSignature& s) {
98   h = H::combine(std::move(h), s.dtype);
99   h = H::combine_contiguous(std::move(h), s.shape.data(), s.shape.size());
100   return h;
101 }
102 }  // namespace xla
103 
104 #endif  // TENSORFLOW_COMPILER_XLA_PYTHON_PY_VALUES_H_
105