1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/types.h"
17
18 #include "absl/container/flat_hash_map.h"
19 #include "tensorflow/compiler/xla/python/exceptions.h"
20 #include "tensorflow/compiler/xla/status_macros.h"
21 #include "tensorflow/python/lib/core/bfloat16.h"
22
23 namespace xla {
24
25 namespace py = pybind11;
26
DtypeToPrimitiveType(const py::dtype & np_type)27 xla::StatusOr<PrimitiveType> DtypeToPrimitiveType(const py::dtype& np_type) {
28 static auto* types =
29 new absl::flat_hash_map<std::pair<char, int>, PrimitiveType>({
30 {{'b', 1}, PRED},
31 {{'i', 1}, S8},
32 {{'i', 2}, S16},
33 {{'i', 4}, S32},
34 {{'i', 8}, S64},
35 {{'u', 1}, U8},
36 {{'u', 2}, U16},
37 {{'u', 4}, U32},
38 {{'u', 8}, U64},
39 {{'V', 2}, BF16}, // array protocol code for raw data (void*)
40 {{'f', 2}, F16},
41 {{'f', 4}, F32},
42 {{'f', 8}, F64},
43 {{'c', 8}, C64},
44 {{'c', 16}, C128},
45 });
46 auto it = types->find({np_type.kind(), np_type.itemsize()});
47 if (it == types->end()) {
48 return InvalidArgument("Unknown NumPy type %c size %d", np_type.kind(),
49 np_type.itemsize());
50 }
51 return it->second;
52 }
53
PrimitiveTypeToDtype(PrimitiveType type)54 xla::StatusOr<py::dtype> PrimitiveTypeToDtype(PrimitiveType type) {
55 switch (type) {
56 case PRED:
57 return py::dtype::of<bool>();
58 case S8:
59 return py::dtype::of<int8_t>();
60 case S16:
61 return py::dtype::of<int16_t>();
62 case S32:
63 return py::dtype::of<int32_t>();
64 case S64:
65 return py::dtype::of<int64_t>();
66 case U8:
67 return py::dtype::of<uint8_t>();
68 case U16:
69 return py::dtype::of<uint16_t>();
70 case U32:
71 return py::dtype::of<uint32_t>();
72 case U64:
73 return py::dtype::of<uint64_t>();
74 case BF16: {
75 py::handle bfloat16(tensorflow::Bfloat16Dtype());
76 return py::dtype::from_args(py::reinterpret_borrow<py::object>(bfloat16));
77 }
78 case F16:
79 return py::dtype("e"); // PEP 3118 code for "float16
80 case F32:
81 return py::dtype::of<float>();
82 case F64:
83 return py::dtype::of<double>();
84 case C64:
85 return py::dtype::of<std::complex<float>>();
86 case C128:
87 return py::dtype::of<std::complex<double>>();
88 default:
89 return Unimplemented("Unimplemented primitive type %s",
90 PrimitiveType_Name(type));
91 }
92 }
93
GetNumpyScalarTypes()94 const NumpyScalarTypes& GetNumpyScalarTypes() {
95 static const NumpyScalarTypes* singleton = []() {
96 NumpyScalarTypes* dtypes = new NumpyScalarTypes();
97 const auto numpy = py::module::import("numpy");
98 dtypes->np_bool = py::object(numpy.attr("bool_"));
99 dtypes->np_int8 = py::object(numpy.attr("int8"));
100 dtypes->np_int16 = py::object(numpy.attr("int16"));
101 dtypes->np_int32 = py::object(numpy.attr("int32"));
102 dtypes->np_int64 = py::object(numpy.attr("int64"));
103 dtypes->np_uint8 = py::object(numpy.attr("uint8"));
104 dtypes->np_uint16 = py::object(numpy.attr("uint16"));
105 dtypes->np_uint32 = py::object(numpy.attr("uint32"));
106 dtypes->np_uint64 = py::object(numpy.attr("uint64"));
107 dtypes->np_bfloat16 =
108 py::reinterpret_borrow<py::object>(tensorflow::Bfloat16Dtype());
109 dtypes->np_float16 = py::object(numpy.attr("float16"));
110 dtypes->np_float32 = py::object(numpy.attr("float32"));
111 dtypes->np_float64 = py::object(numpy.attr("float64"));
112 dtypes->np_complex64 = py::object(numpy.attr("complex64"));
113 dtypes->np_complex128 = py::object(numpy.attr("complex128"));
114 dtypes->np_longlong = py::object(numpy.attr("longlong"));
115 dtypes->np_intc = py::object(numpy.attr("intc"));
116 return dtypes;
117 }();
118 return *singleton;
119 }
120
121 // Returns a numpy-style format descriptor string for `type`.
FormatDescriptorForPrimitiveType(PrimitiveType type)122 StatusOr<std::string> FormatDescriptorForPrimitiveType(PrimitiveType type) {
123 // We use an "=" prefix to indicate that we prefer "standard" types like
124 // np.int32 rather than "native" types like np.cint. pybind11 does not qualify
125 // its format descriptors.
126 switch (type) {
127 case PRED:
128 return std::string("?");
129 case S8:
130 return std::string("=b");
131 case S16:
132 return std::string("=h");
133 case S32:
134 return std::string("=i");
135 case S64:
136 return std::string("=q");
137 case U8:
138 return std::string("=B");
139 case U16:
140 return std::string("=H");
141 case U32:
142 return std::string("=I");
143 case U64:
144 return std::string("=Q");
145 case F16:
146 return std::string("=e");
147 case F32:
148 return std::string("=f");
149 case F64:
150 return std::string("=d");
151 case C64:
152 return std::string("=Zf");
153 case C128:
154 return std::string("=Zd");
155 default:
156 return Unimplemented("Unimplemented primitive type %s",
157 PrimitiveType_Name(type));
158 }
159 }
160
TypeDescriptorForPrimitiveType(PrimitiveType type)161 StatusOr<py::str> TypeDescriptorForPrimitiveType(PrimitiveType type) {
162 static_assert(__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__,
163 "Big endian support not implemented");
164 switch (type) {
165 case PRED:
166 return py::str("|b1");
167 case S8:
168 return py::str("|i1");
169 case S16:
170 return py::str("<i2");
171 case S32:
172 return py::str("<i4");
173 case S64:
174 return py::str("<i8");
175 case U8:
176 return py::str("|u1");
177 case U16:
178 return py::str("<u2");
179 case U32:
180 return py::str("<u4");
181 case U64:
182 return py::str("<u8");
183 case BF16:
184 return py::str("<V2");
185 case F16:
186 return py::str("<f2");
187 case F32:
188 return py::str("<f4");
189 case F64:
190 return py::str("<f8");
191 case C64:
192 return py::str("<c8");
193 case C128:
194 return py::str("<c16");
195 default:
196 return Unimplemented("Unimplemented primitive type %s",
197 PrimitiveType_Name(type));
198 }
199 }
200
Squash64BitTypes(PrimitiveType type)201 PrimitiveType Squash64BitTypes(PrimitiveType type) {
202 switch (type) {
203 case S64:
204 return S32;
205 case U64:
206 return U32;
207 case F64:
208 return F32;
209 case C128:
210 return C64;
211 default:
212 return type;
213 }
214 }
215
216 // Returns the strides for `shape`.
ByteStridesForShape(const Shape & shape)217 std::vector<ssize_t> ByteStridesForShape(const Shape& shape) {
218 std::vector<ssize_t> strides;
219 CHECK(shape.IsArray());
220 CHECK(shape.has_layout());
221
222 strides.resize(shape.dimensions_size());
223 ssize_t stride = ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type());
224 for (int i : shape.layout().minor_to_major()) {
225 strides.at(i) = stride;
226 stride *= shape.dimensions(i);
227 }
228 return strides;
229 }
230
ByteStridesForShapeInt64(const Shape & shape)231 std::vector<int64_t> ByteStridesForShapeInt64(const Shape& shape) {
232 std::vector<int64_t> strides;
233 CHECK(shape.IsArray());
234 CHECK(shape.has_layout());
235
236 strides.resize(shape.dimensions_size());
237 int64_t stride = ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type());
238 for (int i : shape.layout().minor_to_major()) {
239 strides.at(i) = stride;
240 stride *= shape.dimensions(i);
241 }
242 return strides;
243 }
244
LiteralToPython(std::shared_ptr<xla::Literal> literal)245 StatusOr<py::object> LiteralToPython(std::shared_ptr<xla::Literal> literal) {
246 xla::Literal& m = *literal;
247 if (m.shape().IsTuple()) {
248 std::vector<Literal> elems = m.DecomposeTuple();
249 std::vector<py::object> arrays(elems.size());
250 for (int i = 0; i < elems.size(); ++i) {
251 TF_ASSIGN_OR_RETURN(
252 arrays[i],
253 LiteralToPython(std::make_unique<Literal>(std::move(elems[i]))));
254 }
255 py::tuple result(elems.size());
256 for (int i = 0; i < elems.size(); ++i) {
257 PyTuple_SET_ITEM(result.ptr(), i, arrays[i].release().ptr());
258 }
259 return result;
260 }
261 TF_RET_CHECK(m.shape().IsArray());
262
263 py::object literal_object = py::cast(literal);
264 TF_ASSIGN_OR_RETURN(py::dtype dtype,
265 PrimitiveTypeToDtype(m.shape().element_type()));
266 return py::array(dtype, m.shape().dimensions(),
267 ByteStridesForShape(m.shape()), m.untyped_data(),
268 literal_object);
269 }
270
GetPythonBufferTree(const py::object & argument)271 StatusOr<PythonBufferTree> GetPythonBufferTree(const py::object& argument) {
272 PythonBufferTree tree;
273 if (py::isinstance<py::tuple>(argument)) {
274 py::tuple tuple = py::reinterpret_borrow<py::tuple>(argument);
275 std::vector<Shape> host_shapes(tuple.size());
276 for (int i = 0; i < host_shapes.size(); ++i) {
277 TF_ASSIGN_OR_RETURN(PythonBufferTree subtree,
278 GetPythonBufferTree(tuple[i]));
279 tree.leaves.reserve(tree.leaves.size() + subtree.leaves.size());
280 std::move(subtree.leaves.begin(), subtree.leaves.end(),
281 std::back_inserter(tree.leaves));
282 tree.arrays.reserve(tree.arrays.size() + subtree.arrays.size());
283 std::move(subtree.arrays.begin(), subtree.arrays.end(),
284 std::back_inserter(tree.arrays));
285 host_shapes[i] = std::move(subtree.shape);
286 }
287 tree.shape = ShapeUtil::MakeTupleShape(host_shapes);
288 } else {
289 pybind11::detail::type_caster<BorrowingLiteral> caster;
290 if (!caster.load(argument, /*convert=*/true)) {
291 return InvalidArgument("Invalid array value.");
292 }
293 DCHECK_EQ(caster.arrays.size(), 1);
294 tree.arrays.push_back(std::move(caster.arrays.front()));
295 tree.leaves.push_back(std::move(*caster));
296 tree.shape = tree.leaves.front().shape();
297 }
298 return tree;
299 }
300
301 template <typename IntType>
IntSpanToTupleHelper(absl::Span<IntType const> xs)302 static py::tuple IntSpanToTupleHelper(absl::Span<IntType const> xs) {
303 py::tuple out(xs.size());
304 for (int i = 0; i < xs.size(); ++i) {
305 out[i] = py::int_(xs[i]);
306 }
307 return out;
308 }
309
310 template <>
SpanToTuple(absl::Span<int const> xs)311 pybind11::tuple SpanToTuple(absl::Span<int const> xs) {
312 return IntSpanToTupleHelper(xs);
313 }
314 template <>
SpanToTuple(absl::Span<int64_t const> xs)315 pybind11::tuple SpanToTuple(absl::Span<int64_t const> xs) {
316 return IntSpanToTupleHelper(xs);
317 }
318
CastToArray(py::handle h)319 std::optional<CastToArrayResult> CastToArray(py::handle h) {
320 py::array array = py::array::ensure(
321 h, py::array::c_style | py::detail::npy_api::NPY_ARRAY_ALIGNED_);
322 if (!array) {
323 return std::nullopt;
324 }
325 auto type_or_status = DtypeToPrimitiveType(array.dtype());
326 if (!type_or_status.ok()) {
327 throw xla::XlaRuntimeError(type_or_status.status());
328 }
329 PrimitiveType type = type_or_status.ValueOrDie();
330
331 absl::InlinedVector<int64_t, 4> dims(array.ndim());
332 for (int i = 0; i < array.ndim(); ++i) {
333 dims[i] = array.shape(i);
334 }
335 Shape shape = ShapeUtil::MakeShape(type, dims);
336 if (array.size() * array.itemsize() != ShapeUtil::ByteSizeOf(shape)) {
337 throw xla::XlaRuntimeError(absl::StrCat(
338 "Size mismatch for buffer: ", array.size() * array.itemsize(), " vs. ",
339 ShapeUtil::ByteSizeOf(shape)));
340 }
341 return CastToArrayResult{array, static_cast<const char*>(array.data()),
342 shape};
343 }
344
345 } // namespace xla
346