xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/types.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TYPES_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_TYPES_H_
18 
19 #include <memory>
20 #include <optional>
21 #include <vector>
22 
23 #include "absl/container/inlined_vector.h"
24 #include "pybind11/numpy.h"
25 #include "pybind11/pybind11.h"
26 #include "pybind11/pytypes.h"
27 #include "pybind11/stl.h"
28 #include "pybind11_abseil/absl_casters.h"  // from @pybind11_abseil
29 #include "tensorflow/compiler/xla/literal.h"
30 #include "tensorflow/compiler/xla/python/status_casters.h"
31 #include "tensorflow/compiler/xla/shape.h"
32 #include "tensorflow/compiler/xla/status.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/platform/protobuf.h"
37 
38 namespace xla {
39 
40 // Converts a NumPy dtype to a PrimitiveType.
41 StatusOr<PrimitiveType> DtypeToPrimitiveType(const pybind11::dtype& np_type);
42 
43 // Converts a PrimitiveType to a Numpy dtype.
44 StatusOr<pybind11::dtype> PrimitiveTypeToDtype(PrimitiveType type);
45 
46 // Returns a numpy-style format descriptor string for `type`.
47 StatusOr<std::string> FormatDescriptorForPrimitiveType(PrimitiveType type);
48 
49 // Returns a numpy-style typestr for `type`, as returned by np.dtype(...).str
50 StatusOr<pybind11::str> TypeDescriptorForPrimitiveType(PrimitiveType type);
51 
52 struct NumpyScalarTypes {
53   pybind11::object np_bool;
54   pybind11::object np_int8;
55   pybind11::object np_int16;
56   pybind11::object np_int32;
57   pybind11::object np_int64;
58   pybind11::object np_uint8;
59   pybind11::object np_uint16;
60   pybind11::object np_uint32;
61   pybind11::object np_uint64;
62   pybind11::object np_bfloat16;
63   pybind11::object np_float16;
64   pybind11::object np_float32;
65   pybind11::object np_float64;
66   pybind11::object np_complex64;
67   pybind11::object np_complex128;
68   pybind11::object np_longlong;
69   pybind11::object np_intc;
70 };
71 const NumpyScalarTypes& GetNumpyScalarTypes();
72 
73 // For S64/U64/F64/C128 types, returns the largest 32-bit equivalent.
74 PrimitiveType Squash64BitTypes(PrimitiveType type);
75 
76 // Returns the strides for `shape`.
77 std::vector<ssize_t> ByteStridesForShape(const Shape& shape);
78 std::vector<int64_t> ByteStridesForShapeInt64(const Shape& shape);
79 
80 // Converts a literal to (possibly-nested tuples of) NumPy arrays.
81 // The literal's leaf arrays are not copied; instead the NumPy arrays share
82 // buffers with the literals. Takes ownership of `literal` and keeps the
83 // necessary pieces alive using Python reference counting.
84 // Requires the GIL.
85 StatusOr<pybind11::object> LiteralToPython(std::shared_ptr<Literal> literal);
86 
87 // Converts a Python object into an XLA shape and a vector of leaf buffers.
88 // The leaf buffers correspond to a depth-first, left-to-right traversal of
89 // the Python value.
90 // Requires the GIL.
91 struct PythonBufferTree {
92   // Holds a reference to the arrays pointed to by `leaves`, since we may
93   // need to make a copy if the array is not in a C-style layout.
94   absl::InlinedVector<pybind11::object, 1> arrays;
95   absl::InlinedVector<BorrowingLiteral, 1> leaves;
96   Shape shape;
97 };
98 StatusOr<PythonBufferTree> GetPythonBufferTree(
99     const pybind11::object& argument);
100 
101 // Converts a sequence of C++ ints to a Python tuple of ints.
102 // Pybind11 by default converts a std::vector<T> to a Python list;
103 // we frequently want a tuple instead e.g. for shapes.
104 template <typename T>
SpanToTuple(absl::Span<T const> xs)105 pybind11::tuple SpanToTuple(absl::Span<T const> xs) {
106   pybind11::tuple out(xs.size());
107   for (int i = 0; i < xs.size(); ++i) {
108     out[i] = pybind11::cast(xs[i]);
109   }
110   return out;
111 }
112 template <>
113 pybind11::tuple SpanToTuple(absl::Span<int const> xs);
114 template <>
115 pybind11::tuple SpanToTuple(absl::Span<int64_t const> xs);
116 
117 // Converts a Python iterable/sequence of T to std::vector<T>
118 template <typename T>
IterableToVector(const pybind11::iterable & iterable)119 std::vector<T> IterableToVector(const pybind11::iterable& iterable) {
120   std::vector<T> output;
121   for (auto item : iterable) {
122     output.push_back(item.cast<T>());
123   }
124   return output;
125 }
126 template <typename T>
SequenceToVector(const pybind11::sequence & sequence)127 std::vector<T> SequenceToVector(const pybind11::sequence& sequence) {
128   std::vector<T> output;
129   output.reserve(sequence.size());
130   for (auto item : sequence) {
131     output.push_back(item.cast<T>());
132   }
133   return output;
134 }
135 
136 // Private helper function used in the implementation of the type caster for
137 // xla::BorrowingLiteral. Converts a Python array-like object into a buffer
138 // pointer and shape.
139 struct CastToArrayResult {
140   pybind11::object array;  // Holds a reference to the array to keep it alive.
141   const char* buf_ptr;
142   xla::Shape shape;
143 };
144 std::optional<CastToArrayResult> CastToArray(pybind11::handle h);
145 
146 }  // namespace xla
147 
148 // This namespace is a documented pybind11 extension point.
149 // Caution: Unusually for Google code, this code uses C++ exceptions because
150 // they are the only mechanism for reporting cast failures to pybind11. However,
151 // the exceptions are local to the binding code.
152 namespace pybind11 {
153 namespace detail {
154 
155 // Literals.
156 // Literal data can be passed to XLA as a NumPy array; its value can be
157 // cast to an xla::BorrowingLiteral or xla::LiteralSlice in a zero-copy way.
158 // We don't have any literal -> numpy conversions here, since all the methods
159 // that want to return arrays build Python objects directly.
160 
161 template <>
162 struct type_caster<xla::BorrowingLiteral> {
163  public:
164   PYBIND11_TYPE_CASTER(xla::BorrowingLiteral, _("xla::BorrowingLiteral"));
165 
166   // Pybind appears to keep type_casters alive until the callee has run.
167   absl::InlinedVector<pybind11::array, 1> arrays;
168 
169   bool load(handle input, bool) {
170     // TODO(b/79707221): support nested tuples if/when XLA adds support for
171     // nested BorrowingLiterals.
172     if (pybind11::isinstance<pybind11::tuple>(input)) {
173       pybind11::tuple tuple =
174           pybind11::reinterpret_borrow<pybind11::tuple>(input);
175       std::vector<xla::Shape> shapes;
176       std::vector<const char*> buffers;
177       arrays.reserve(tuple.size());
178       shapes.reserve(tuple.size());
179       buffers.reserve(tuple.size());
180       for (pybind11::handle entry : tuple) {
181         auto c = xla::CastToArray(entry);
182         if (!c) {
183           return false;
184         }
185         arrays.push_back(c->array);
186         buffers.push_back(c->buf_ptr);
187         shapes.push_back(c->shape);
188       }
189       value = xla::BorrowingLiteral(buffers,
190                                     xla::ShapeUtil::MakeTupleShape(shapes));
191     } else {
192       auto c = xla::CastToArray(input);
193       if (!c) {
194         return false;
195       }
196       arrays.push_back(c->array);
197       value = xla::BorrowingLiteral(c->buf_ptr, c->shape);
198     }
199     return true;
200   }
201 };
202 
203 template <>
204 struct type_caster<xla::LiteralSlice> {
205  public:
206   PYBIND11_TYPE_CASTER(xla::LiteralSlice, _("xla::LiteralSlice"));
207 
208   // Pybind appears to keep type_casters alive until the callee has run.
209   type_caster<xla::BorrowingLiteral> literal_caster;
210 
211   bool load(handle handle, bool convert) {
212     if (!literal_caster.load(handle, convert)) {
213       return false;
214     }
215     value = static_cast<const xla::BorrowingLiteral&>(literal_caster);
216     return true;
217   }
218 };
219 
220 // XLA protocol buffers
221 // We don't actually care that these are the protocol buffers, we merely want
222 // objects that duck type as protocol buffers. The client code currently avoids
223 // depending on Python protocol buffers to avoid conflicting definitions from
224 // different modules that both include XLA.
225 
226 template <>
227 struct type_caster<xla::ConvolutionDimensionNumbers> {
228  public:
229   PYBIND11_TYPE_CASTER(xla::ConvolutionDimensionNumbers,
230                        _("xla::ConvolutionDimensionNumbers"));
231 
232   // PyObject -> C++ conversion.
233   bool load(handle handle, bool) {
234     value.set_input_batch_dimension(
235         getattr(handle, "input_batch_dimension").cast<int64_t>());
236     value.set_input_feature_dimension(
237         getattr(handle, "input_feature_dimension").cast<int64_t>());
238     value.set_output_batch_dimension(
239         getattr(handle, "output_batch_dimension").cast<int64_t>());
240     value.set_output_feature_dimension(
241         getattr(handle, "output_feature_dimension").cast<int64_t>());
242     value.set_kernel_input_feature_dimension(
243         getattr(handle, "kernel_input_feature_dimension").cast<int64_t>());
244     value.set_kernel_output_feature_dimension(
245         getattr(handle, "kernel_output_feature_dimension").cast<int64_t>());
246     std::vector<int64_t> dims;
247     dims = getattr(handle, "input_spatial_dimensions")
248                .cast<std::vector<int64_t>>();
249     std::copy(dims.begin(), dims.end(),
250               tensorflow::protobuf::RepeatedFieldBackInserter(
251                   value.mutable_input_spatial_dimensions()));
252     dims = getattr(handle, "kernel_spatial_dimensions")
253                .cast<std::vector<int64_t>>();
254     std::copy(dims.begin(), dims.end(),
255               tensorflow::protobuf::RepeatedFieldBackInserter(
256                   value.mutable_kernel_spatial_dimensions()));
257     dims = getattr(handle, "output_spatial_dimensions")
258                .cast<std::vector<int64_t>>();
259     std::copy(dims.begin(), dims.end(),
260               tensorflow::protobuf::RepeatedFieldBackInserter(
261                   value.mutable_output_spatial_dimensions()));
262     return true;
263   }
264 };
265 
266 template <>
267 struct type_caster<xla::DotDimensionNumbers> {
268  public:
269   PYBIND11_TYPE_CASTER(xla::DotDimensionNumbers, _("xla::DotDimensionNumbers"));
270 
271   // PyObject -> C++ conversion.
272   bool load(handle handle, bool) {
273     std::vector<int64_t> dims;
274     dims = getattr(handle, "lhs_contracting_dimensions")
275                .cast<std::vector<int64_t>>();
276     std::copy(dims.begin(), dims.end(),
277               tensorflow::protobuf::RepeatedFieldBackInserter(
278                   value.mutable_lhs_contracting_dimensions()));
279     dims = getattr(handle, "rhs_contracting_dimensions")
280                .cast<std::vector<int64_t>>();
281     std::copy(dims.begin(), dims.end(),
282               tensorflow::protobuf::RepeatedFieldBackInserter(
283                   value.mutable_rhs_contracting_dimensions()));
284     dims = getattr(handle, "lhs_batch_dimensions").cast<std::vector<int64_t>>();
285     std::copy(dims.begin(), dims.end(),
286               tensorflow::protobuf::RepeatedFieldBackInserter(
287                   value.mutable_lhs_batch_dimensions()));
288     dims = getattr(handle, "rhs_batch_dimensions").cast<std::vector<int64_t>>();
289     std::copy(dims.begin(), dims.end(),
290               tensorflow::protobuf::RepeatedFieldBackInserter(
291                   value.mutable_rhs_batch_dimensions()));
292     return true;
293   }
294 };
295 
296 template <>
297 struct type_caster<xla::GatherDimensionNumbers> {
298  public:
299   PYBIND11_TYPE_CASTER(xla::GatherDimensionNumbers,
300                        _("xla::GatherDimensionNumbers"));
301 
302   // PyObject -> C++ conversion.
303   bool load(handle handle, bool) {
304     std::vector<int64_t> dims;
305     dims = getattr(handle, "offset_dims").cast<std::vector<int64_t>>();
306     std::copy(dims.begin(), dims.end(),
307               tensorflow::protobuf::RepeatedFieldBackInserter(
308                   value.mutable_offset_dims()));
309     dims = getattr(handle, "collapsed_slice_dims").cast<std::vector<int64_t>>();
310     std::copy(dims.begin(), dims.end(),
311               tensorflow::protobuf::RepeatedFieldBackInserter(
312                   value.mutable_collapsed_slice_dims()));
313     dims = getattr(handle, "start_index_map").cast<std::vector<int64_t>>();
314     std::copy(dims.begin(), dims.end(),
315               tensorflow::protobuf::RepeatedFieldBackInserter(
316                   value.mutable_start_index_map()));
317     value.set_index_vector_dim(
318         getattr(handle, "index_vector_dim").cast<int64_t>());
319     return true;
320   }
321 };
322 
323 template <>
324 struct type_caster<xla::ScatterDimensionNumbers> {
325  public:
326   PYBIND11_TYPE_CASTER(xla::ScatterDimensionNumbers,
327                        _("xla::ScatterDimensionNumbers"));
328 
329   // PyObject -> C++ conversion.
330   bool load(handle handle, bool) {
331     std::vector<int64_t> dims;
332     dims = getattr(handle, "update_window_dims").cast<std::vector<int64_t>>();
333     std::copy(dims.begin(), dims.end(),
334               tensorflow::protobuf::RepeatedFieldBackInserter(
335                   value.mutable_update_window_dims()));
336     dims = getattr(handle, "inserted_window_dims").cast<std::vector<int64_t>>();
337     std::copy(dims.begin(), dims.end(),
338               tensorflow::protobuf::RepeatedFieldBackInserter(
339                   value.mutable_inserted_window_dims()));
340     dims = getattr(handle, "scatter_dims_to_operand_dims")
341                .cast<std::vector<int64_t>>();
342     std::copy(dims.begin(), dims.end(),
343               tensorflow::protobuf::RepeatedFieldBackInserter(
344                   value.mutable_scatter_dims_to_operand_dims()));
345     value.set_index_vector_dim(
346         getattr(handle, "index_vector_dim").cast<int64_t>());
347     return true;
348   }
349 };
350 
351 template <>
352 struct type_caster<xla::ReplicaGroup> {
353  public:
354   PYBIND11_TYPE_CASTER(xla::ReplicaGroup, _("xla::ReplicaGroup"));
355 
356   // PyObject -> C++ conversion.
357   bool load(handle handle, bool) {
358     std::vector<int64_t> dims;
359     dims = getattr(handle, "replica_ids").cast<std::vector<int64_t>>();
360     std::copy(dims.begin(), dims.end(),
361               tensorflow::protobuf::RepeatedFieldBackInserter(
362                   value.mutable_replica_ids()));
363     return true;
364   }
365 };
366 
367 template <>
368 struct type_caster<xla::PaddingConfig> {
369  public:
370   PYBIND11_TYPE_CASTER(xla::PaddingConfig, _("xla::PaddingConfig"));
371 
372   // PyObject -> C++ conversion.
373   bool load(handle handle, bool) {
374     sequence dimensions =
375         reinterpret_borrow<sequence>(getattr(handle, "dimensions"));
376 
377     for (const auto& dimension : dimensions) {
378       xla::PaddingConfig::PaddingConfigDimension* config_dim =
379           value.add_dimensions();
380       config_dim->set_edge_padding_low(
381           getattr(dimension, "edge_padding_low").cast<int64_t>());
382       config_dim->set_edge_padding_high(
383           getattr(dimension, "edge_padding_high").cast<int64_t>());
384       config_dim->set_interior_padding(
385           getattr(dimension, "interior_padding").cast<int64_t>());
386     }
387     return true;
388   }
389 };
390 
391 template <>
392 struct type_caster<xla::OpMetadata> {
393  public:
394   PYBIND11_TYPE_CASTER(xla::OpMetadata, _("xla::OpMetadata"));
395 
396   // PyObject -> C++ conversion.
397   bool load(handle handle, bool) {
398     pybind11::handle op_type = getattr(handle, "op_type");
399     if (!op_type.is_none()) {
400       value.set_op_type(op_type.cast<std::string>());
401     }
402     pybind11::handle op_name = getattr(handle, "op_name");
403     if (!op_name.is_none()) {
404       value.set_op_name(op_name.cast<std::string>());
405     }
406     pybind11::handle source_file = getattr(handle, "source_file");
407     if (!source_file.is_none()) {
408       value.set_source_file(source_file.cast<std::string>());
409     }
410     pybind11::handle source_line = getattr(handle, "source_line");
411     if (!source_line.is_none()) {
412       value.set_source_line(source_line.cast<int32_t>());
413     }
414     return true;
415   }
416 };
417 
418 template <>
419 struct type_caster<xla::PrecisionConfig> {
420  public:
421   PYBIND11_TYPE_CASTER(xla::PrecisionConfig, _("xla::PrecisionConfig"));
422 
423   // PyObject -> C++ conversion.
424   bool load(handle handle, bool) {
425     if (handle.is_none()) {
426       return true;
427     }
428 
429     sequence operand_precisions =
430         reinterpret_borrow<sequence>(getattr(handle, "operand_precision"));
431 
432     for (const auto& operand_precision : operand_precisions) {
433       value.add_operand_precision(
434           operand_precision.cast<xla::PrecisionConfig::Precision>());
435     }
436     return true;
437   }
438 };
439 
440 }  // namespace detail
441 }  // namespace pybind11
442 
443 #endif  // TENSORFLOW_COMPILER_XLA_PYTHON_TYPES_H_
444