1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TYPES_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_TYPES_H_
18
19 #include <memory>
20 #include <optional>
21 #include <vector>
22
23 #include "absl/container/inlined_vector.h"
24 #include "pybind11/numpy.h"
25 #include "pybind11/pybind11.h"
26 #include "pybind11/pytypes.h"
27 #include "pybind11/stl.h"
28 #include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil
29 #include "tensorflow/compiler/xla/literal.h"
30 #include "tensorflow/compiler/xla/python/status_casters.h"
31 #include "tensorflow/compiler/xla/shape.h"
32 #include "tensorflow/compiler/xla/status.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/platform/protobuf.h"
37
38 namespace xla {
39
40 // Converts a NumPy dtype to a PrimitiveType.
41 StatusOr<PrimitiveType> DtypeToPrimitiveType(const pybind11::dtype& np_type);
42
43 // Converts a PrimitiveType to a Numpy dtype.
44 StatusOr<pybind11::dtype> PrimitiveTypeToDtype(PrimitiveType type);
45
46 // Returns a numpy-style format descriptor string for `type`.
47 StatusOr<std::string> FormatDescriptorForPrimitiveType(PrimitiveType type);
48
49 // Returns a numpy-style typestr for `type`, as returned by np.dtype(...).str
50 StatusOr<pybind11::str> TypeDescriptorForPrimitiveType(PrimitiveType type);
51
52 struct NumpyScalarTypes {
53 pybind11::object np_bool;
54 pybind11::object np_int8;
55 pybind11::object np_int16;
56 pybind11::object np_int32;
57 pybind11::object np_int64;
58 pybind11::object np_uint8;
59 pybind11::object np_uint16;
60 pybind11::object np_uint32;
61 pybind11::object np_uint64;
62 pybind11::object np_bfloat16;
63 pybind11::object np_float16;
64 pybind11::object np_float32;
65 pybind11::object np_float64;
66 pybind11::object np_complex64;
67 pybind11::object np_complex128;
68 pybind11::object np_longlong;
69 pybind11::object np_intc;
70 };
71 const NumpyScalarTypes& GetNumpyScalarTypes();
72
73 // For S64/U64/F64/C128 types, returns the largest 32-bit equivalent.
74 PrimitiveType Squash64BitTypes(PrimitiveType type);
75
76 // Returns the strides for `shape`.
77 std::vector<ssize_t> ByteStridesForShape(const Shape& shape);
78 std::vector<int64_t> ByteStridesForShapeInt64(const Shape& shape);
79
80 // Converts a literal to (possibly-nested tuples of) NumPy arrays.
81 // The literal's leaf arrays are not copied; instead the NumPy arrays share
82 // buffers with the literals. Takes ownership of `literal` and keeps the
83 // necessary pieces alive using Python reference counting.
84 // Requires the GIL.
85 StatusOr<pybind11::object> LiteralToPython(std::shared_ptr<Literal> literal);
86
87 // Converts a Python object into an XLA shape and a vector of leaf buffers.
88 // The leaf buffers correspond to a depth-first, left-to-right traversal of
89 // the Python value.
90 // Requires the GIL.
91 struct PythonBufferTree {
92 // Holds a reference to the arrays pointed to by `leaves`, since we may
93 // need to make a copy if the array is not in a C-style layout.
94 absl::InlinedVector<pybind11::object, 1> arrays;
95 absl::InlinedVector<BorrowingLiteral, 1> leaves;
96 Shape shape;
97 };
98 StatusOr<PythonBufferTree> GetPythonBufferTree(
99 const pybind11::object& argument);
100
101 // Converts a sequence of C++ ints to a Python tuple of ints.
102 // Pybind11 by default converts a std::vector<T> to a Python list;
103 // we frequently want a tuple instead e.g. for shapes.
104 template <typename T>
SpanToTuple(absl::Span<T const> xs)105 pybind11::tuple SpanToTuple(absl::Span<T const> xs) {
106 pybind11::tuple out(xs.size());
107 for (int i = 0; i < xs.size(); ++i) {
108 out[i] = pybind11::cast(xs[i]);
109 }
110 return out;
111 }
112 template <>
113 pybind11::tuple SpanToTuple(absl::Span<int const> xs);
114 template <>
115 pybind11::tuple SpanToTuple(absl::Span<int64_t const> xs);
116
117 // Converts a Python iterable/sequence of T to std::vector<T>
118 template <typename T>
IterableToVector(const pybind11::iterable & iterable)119 std::vector<T> IterableToVector(const pybind11::iterable& iterable) {
120 std::vector<T> output;
121 for (auto item : iterable) {
122 output.push_back(item.cast<T>());
123 }
124 return output;
125 }
126 template <typename T>
SequenceToVector(const pybind11::sequence & sequence)127 std::vector<T> SequenceToVector(const pybind11::sequence& sequence) {
128 std::vector<T> output;
129 output.reserve(sequence.size());
130 for (auto item : sequence) {
131 output.push_back(item.cast<T>());
132 }
133 return output;
134 }
135
136 // Private helper function used in the implementation of the type caster for
137 // xla::BorrowingLiteral. Converts a Python array-like object into a buffer
138 // pointer and shape.
139 struct CastToArrayResult {
140 pybind11::object array; // Holds a reference to the array to keep it alive.
141 const char* buf_ptr;
142 xla::Shape shape;
143 };
144 std::optional<CastToArrayResult> CastToArray(pybind11::handle h);
145
146 } // namespace xla
147
148 // This namespace is a documented pybind11 extension point.
149 // Caution: Unusually for Google code, this code uses C++ exceptions because
150 // they are the only mechanism for reporting cast failures to pybind11. However,
151 // the exceptions are local to the binding code.
152 namespace pybind11 {
153 namespace detail {
154
155 // Literals.
156 // Literal data can be passed to XLA as a NumPy array; its value can be
157 // cast to an xla::BorrowingLiteral or xla::LiteralSlice in a zero-copy way.
158 // We don't have any literal -> numpy conversions here, since all the methods
159 // that want to return arrays build Python objects directly.
160
161 template <>
162 struct type_caster<xla::BorrowingLiteral> {
163 public:
164 PYBIND11_TYPE_CASTER(xla::BorrowingLiteral, _("xla::BorrowingLiteral"));
165
166 // Pybind appears to keep type_casters alive until the callee has run.
167 absl::InlinedVector<pybind11::array, 1> arrays;
168
169 bool load(handle input, bool) {
170 // TODO(b/79707221): support nested tuples if/when XLA adds support for
171 // nested BorrowingLiterals.
172 if (pybind11::isinstance<pybind11::tuple>(input)) {
173 pybind11::tuple tuple =
174 pybind11::reinterpret_borrow<pybind11::tuple>(input);
175 std::vector<xla::Shape> shapes;
176 std::vector<const char*> buffers;
177 arrays.reserve(tuple.size());
178 shapes.reserve(tuple.size());
179 buffers.reserve(tuple.size());
180 for (pybind11::handle entry : tuple) {
181 auto c = xla::CastToArray(entry);
182 if (!c) {
183 return false;
184 }
185 arrays.push_back(c->array);
186 buffers.push_back(c->buf_ptr);
187 shapes.push_back(c->shape);
188 }
189 value = xla::BorrowingLiteral(buffers,
190 xla::ShapeUtil::MakeTupleShape(shapes));
191 } else {
192 auto c = xla::CastToArray(input);
193 if (!c) {
194 return false;
195 }
196 arrays.push_back(c->array);
197 value = xla::BorrowingLiteral(c->buf_ptr, c->shape);
198 }
199 return true;
200 }
201 };
202
203 template <>
204 struct type_caster<xla::LiteralSlice> {
205 public:
206 PYBIND11_TYPE_CASTER(xla::LiteralSlice, _("xla::LiteralSlice"));
207
208 // Pybind appears to keep type_casters alive until the callee has run.
209 type_caster<xla::BorrowingLiteral> literal_caster;
210
211 bool load(handle handle, bool convert) {
212 if (!literal_caster.load(handle, convert)) {
213 return false;
214 }
215 value = static_cast<const xla::BorrowingLiteral&>(literal_caster);
216 return true;
217 }
218 };
219
220 // XLA protocol buffers
221 // We don't actually care that these are the protocol buffers, we merely want
222 // objects that duck type as protocol buffers. The client code currently avoids
223 // depending on Python protocol buffers to avoid conflicting definitions from
224 // different modules that both include XLA.
225
226 template <>
227 struct type_caster<xla::ConvolutionDimensionNumbers> {
228 public:
229 PYBIND11_TYPE_CASTER(xla::ConvolutionDimensionNumbers,
230 _("xla::ConvolutionDimensionNumbers"));
231
232 // PyObject -> C++ conversion.
233 bool load(handle handle, bool) {
234 value.set_input_batch_dimension(
235 getattr(handle, "input_batch_dimension").cast<int64_t>());
236 value.set_input_feature_dimension(
237 getattr(handle, "input_feature_dimension").cast<int64_t>());
238 value.set_output_batch_dimension(
239 getattr(handle, "output_batch_dimension").cast<int64_t>());
240 value.set_output_feature_dimension(
241 getattr(handle, "output_feature_dimension").cast<int64_t>());
242 value.set_kernel_input_feature_dimension(
243 getattr(handle, "kernel_input_feature_dimension").cast<int64_t>());
244 value.set_kernel_output_feature_dimension(
245 getattr(handle, "kernel_output_feature_dimension").cast<int64_t>());
246 std::vector<int64_t> dims;
247 dims = getattr(handle, "input_spatial_dimensions")
248 .cast<std::vector<int64_t>>();
249 std::copy(dims.begin(), dims.end(),
250 tensorflow::protobuf::RepeatedFieldBackInserter(
251 value.mutable_input_spatial_dimensions()));
252 dims = getattr(handle, "kernel_spatial_dimensions")
253 .cast<std::vector<int64_t>>();
254 std::copy(dims.begin(), dims.end(),
255 tensorflow::protobuf::RepeatedFieldBackInserter(
256 value.mutable_kernel_spatial_dimensions()));
257 dims = getattr(handle, "output_spatial_dimensions")
258 .cast<std::vector<int64_t>>();
259 std::copy(dims.begin(), dims.end(),
260 tensorflow::protobuf::RepeatedFieldBackInserter(
261 value.mutable_output_spatial_dimensions()));
262 return true;
263 }
264 };
265
266 template <>
267 struct type_caster<xla::DotDimensionNumbers> {
268 public:
269 PYBIND11_TYPE_CASTER(xla::DotDimensionNumbers, _("xla::DotDimensionNumbers"));
270
271 // PyObject -> C++ conversion.
272 bool load(handle handle, bool) {
273 std::vector<int64_t> dims;
274 dims = getattr(handle, "lhs_contracting_dimensions")
275 .cast<std::vector<int64_t>>();
276 std::copy(dims.begin(), dims.end(),
277 tensorflow::protobuf::RepeatedFieldBackInserter(
278 value.mutable_lhs_contracting_dimensions()));
279 dims = getattr(handle, "rhs_contracting_dimensions")
280 .cast<std::vector<int64_t>>();
281 std::copy(dims.begin(), dims.end(),
282 tensorflow::protobuf::RepeatedFieldBackInserter(
283 value.mutable_rhs_contracting_dimensions()));
284 dims = getattr(handle, "lhs_batch_dimensions").cast<std::vector<int64_t>>();
285 std::copy(dims.begin(), dims.end(),
286 tensorflow::protobuf::RepeatedFieldBackInserter(
287 value.mutable_lhs_batch_dimensions()));
288 dims = getattr(handle, "rhs_batch_dimensions").cast<std::vector<int64_t>>();
289 std::copy(dims.begin(), dims.end(),
290 tensorflow::protobuf::RepeatedFieldBackInserter(
291 value.mutable_rhs_batch_dimensions()));
292 return true;
293 }
294 };
295
296 template <>
297 struct type_caster<xla::GatherDimensionNumbers> {
298 public:
299 PYBIND11_TYPE_CASTER(xla::GatherDimensionNumbers,
300 _("xla::GatherDimensionNumbers"));
301
302 // PyObject -> C++ conversion.
303 bool load(handle handle, bool) {
304 std::vector<int64_t> dims;
305 dims = getattr(handle, "offset_dims").cast<std::vector<int64_t>>();
306 std::copy(dims.begin(), dims.end(),
307 tensorflow::protobuf::RepeatedFieldBackInserter(
308 value.mutable_offset_dims()));
309 dims = getattr(handle, "collapsed_slice_dims").cast<std::vector<int64_t>>();
310 std::copy(dims.begin(), dims.end(),
311 tensorflow::protobuf::RepeatedFieldBackInserter(
312 value.mutable_collapsed_slice_dims()));
313 dims = getattr(handle, "start_index_map").cast<std::vector<int64_t>>();
314 std::copy(dims.begin(), dims.end(),
315 tensorflow::protobuf::RepeatedFieldBackInserter(
316 value.mutable_start_index_map()));
317 value.set_index_vector_dim(
318 getattr(handle, "index_vector_dim").cast<int64_t>());
319 return true;
320 }
321 };
322
323 template <>
324 struct type_caster<xla::ScatterDimensionNumbers> {
325 public:
326 PYBIND11_TYPE_CASTER(xla::ScatterDimensionNumbers,
327 _("xla::ScatterDimensionNumbers"));
328
329 // PyObject -> C++ conversion.
330 bool load(handle handle, bool) {
331 std::vector<int64_t> dims;
332 dims = getattr(handle, "update_window_dims").cast<std::vector<int64_t>>();
333 std::copy(dims.begin(), dims.end(),
334 tensorflow::protobuf::RepeatedFieldBackInserter(
335 value.mutable_update_window_dims()));
336 dims = getattr(handle, "inserted_window_dims").cast<std::vector<int64_t>>();
337 std::copy(dims.begin(), dims.end(),
338 tensorflow::protobuf::RepeatedFieldBackInserter(
339 value.mutable_inserted_window_dims()));
340 dims = getattr(handle, "scatter_dims_to_operand_dims")
341 .cast<std::vector<int64_t>>();
342 std::copy(dims.begin(), dims.end(),
343 tensorflow::protobuf::RepeatedFieldBackInserter(
344 value.mutable_scatter_dims_to_operand_dims()));
345 value.set_index_vector_dim(
346 getattr(handle, "index_vector_dim").cast<int64_t>());
347 return true;
348 }
349 };
350
351 template <>
352 struct type_caster<xla::ReplicaGroup> {
353 public:
354 PYBIND11_TYPE_CASTER(xla::ReplicaGroup, _("xla::ReplicaGroup"));
355
356 // PyObject -> C++ conversion.
357 bool load(handle handle, bool) {
358 std::vector<int64_t> dims;
359 dims = getattr(handle, "replica_ids").cast<std::vector<int64_t>>();
360 std::copy(dims.begin(), dims.end(),
361 tensorflow::protobuf::RepeatedFieldBackInserter(
362 value.mutable_replica_ids()));
363 return true;
364 }
365 };
366
367 template <>
368 struct type_caster<xla::PaddingConfig> {
369 public:
370 PYBIND11_TYPE_CASTER(xla::PaddingConfig, _("xla::PaddingConfig"));
371
372 // PyObject -> C++ conversion.
373 bool load(handle handle, bool) {
374 sequence dimensions =
375 reinterpret_borrow<sequence>(getattr(handle, "dimensions"));
376
377 for (const auto& dimension : dimensions) {
378 xla::PaddingConfig::PaddingConfigDimension* config_dim =
379 value.add_dimensions();
380 config_dim->set_edge_padding_low(
381 getattr(dimension, "edge_padding_low").cast<int64_t>());
382 config_dim->set_edge_padding_high(
383 getattr(dimension, "edge_padding_high").cast<int64_t>());
384 config_dim->set_interior_padding(
385 getattr(dimension, "interior_padding").cast<int64_t>());
386 }
387 return true;
388 }
389 };
390
391 template <>
392 struct type_caster<xla::OpMetadata> {
393 public:
394 PYBIND11_TYPE_CASTER(xla::OpMetadata, _("xla::OpMetadata"));
395
396 // PyObject -> C++ conversion.
397 bool load(handle handle, bool) {
398 pybind11::handle op_type = getattr(handle, "op_type");
399 if (!op_type.is_none()) {
400 value.set_op_type(op_type.cast<std::string>());
401 }
402 pybind11::handle op_name = getattr(handle, "op_name");
403 if (!op_name.is_none()) {
404 value.set_op_name(op_name.cast<std::string>());
405 }
406 pybind11::handle source_file = getattr(handle, "source_file");
407 if (!source_file.is_none()) {
408 value.set_source_file(source_file.cast<std::string>());
409 }
410 pybind11::handle source_line = getattr(handle, "source_line");
411 if (!source_line.is_none()) {
412 value.set_source_line(source_line.cast<int32_t>());
413 }
414 return true;
415 }
416 };
417
418 template <>
419 struct type_caster<xla::PrecisionConfig> {
420 public:
421 PYBIND11_TYPE_CASTER(xla::PrecisionConfig, _("xla::PrecisionConfig"));
422
423 // PyObject -> C++ conversion.
424 bool load(handle handle, bool) {
425 if (handle.is_none()) {
426 return true;
427 }
428
429 sequence operand_precisions =
430 reinterpret_borrow<sequence>(getattr(handle, "operand_precision"));
431
432 for (const auto& operand_precision : operand_precisions) {
433 value.add_operand_precision(
434 operand_precision.cast<xla::PrecisionConfig::Precision>());
435 }
436 return true;
437 }
438 };
439
440 } // namespace detail
441 } // namespace pybind11
442
443 #endif // TENSORFLOW_COMPILER_XLA_PYTHON_TYPES_H_
444