xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/types.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/types.h"
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "tensorflow/compiler/xla/python/exceptions.h"
20 #include "tensorflow/compiler/xla/status_macros.h"
21 #include "tensorflow/python/lib/core/bfloat16.h"
22 
23 namespace xla {
24 
25 namespace py = pybind11;
26 
DtypeToPrimitiveType(const py::dtype & np_type)27 xla::StatusOr<PrimitiveType> DtypeToPrimitiveType(const py::dtype& np_type) {
28   static auto* types =
29       new absl::flat_hash_map<std::pair<char, int>, PrimitiveType>({
30           {{'b', 1}, PRED},
31           {{'i', 1}, S8},
32           {{'i', 2}, S16},
33           {{'i', 4}, S32},
34           {{'i', 8}, S64},
35           {{'u', 1}, U8},
36           {{'u', 2}, U16},
37           {{'u', 4}, U32},
38           {{'u', 8}, U64},
39           {{'V', 2}, BF16},  // array protocol code for raw data (void*)
40           {{'f', 2}, F16},
41           {{'f', 4}, F32},
42           {{'f', 8}, F64},
43           {{'c', 8}, C64},
44           {{'c', 16}, C128},
45       });
46   auto it = types->find({np_type.kind(), np_type.itemsize()});
47   if (it == types->end()) {
48     return InvalidArgument("Unknown NumPy type %c size %d", np_type.kind(),
49                            np_type.itemsize());
50   }
51   return it->second;
52 }
53 
PrimitiveTypeToDtype(PrimitiveType type)54 xla::StatusOr<py::dtype> PrimitiveTypeToDtype(PrimitiveType type) {
55   switch (type) {
56     case PRED:
57       return py::dtype::of<bool>();
58     case S8:
59       return py::dtype::of<int8_t>();
60     case S16:
61       return py::dtype::of<int16_t>();
62     case S32:
63       return py::dtype::of<int32_t>();
64     case S64:
65       return py::dtype::of<int64_t>();
66     case U8:
67       return py::dtype::of<uint8_t>();
68     case U16:
69       return py::dtype::of<uint16_t>();
70     case U32:
71       return py::dtype::of<uint32_t>();
72     case U64:
73       return py::dtype::of<uint64_t>();
74     case BF16: {
75       py::handle bfloat16(tensorflow::Bfloat16Dtype());
76       return py::dtype::from_args(py::reinterpret_borrow<py::object>(bfloat16));
77     }
78     case F16:
79       return py::dtype("e");  // PEP 3118 code for "float16
80     case F32:
81       return py::dtype::of<float>();
82     case F64:
83       return py::dtype::of<double>();
84     case C64:
85       return py::dtype::of<std::complex<float>>();
86     case C128:
87       return py::dtype::of<std::complex<double>>();
88     default:
89       return Unimplemented("Unimplemented primitive type %s",
90                            PrimitiveType_Name(type));
91   }
92 }
93 
GetNumpyScalarTypes()94 const NumpyScalarTypes& GetNumpyScalarTypes() {
95   static const NumpyScalarTypes* singleton = []() {
96     NumpyScalarTypes* dtypes = new NumpyScalarTypes();
97     const auto numpy = py::module::import("numpy");
98     dtypes->np_bool = py::object(numpy.attr("bool_"));
99     dtypes->np_int8 = py::object(numpy.attr("int8"));
100     dtypes->np_int16 = py::object(numpy.attr("int16"));
101     dtypes->np_int32 = py::object(numpy.attr("int32"));
102     dtypes->np_int64 = py::object(numpy.attr("int64"));
103     dtypes->np_uint8 = py::object(numpy.attr("uint8"));
104     dtypes->np_uint16 = py::object(numpy.attr("uint16"));
105     dtypes->np_uint32 = py::object(numpy.attr("uint32"));
106     dtypes->np_uint64 = py::object(numpy.attr("uint64"));
107     dtypes->np_bfloat16 =
108         py::reinterpret_borrow<py::object>(tensorflow::Bfloat16Dtype());
109     dtypes->np_float16 = py::object(numpy.attr("float16"));
110     dtypes->np_float32 = py::object(numpy.attr("float32"));
111     dtypes->np_float64 = py::object(numpy.attr("float64"));
112     dtypes->np_complex64 = py::object(numpy.attr("complex64"));
113     dtypes->np_complex128 = py::object(numpy.attr("complex128"));
114     dtypes->np_longlong = py::object(numpy.attr("longlong"));
115     dtypes->np_intc = py::object(numpy.attr("intc"));
116     return dtypes;
117   }();
118   return *singleton;
119 }
120 
121 // Returns a numpy-style format descriptor string for `type`.
FormatDescriptorForPrimitiveType(PrimitiveType type)122 StatusOr<std::string> FormatDescriptorForPrimitiveType(PrimitiveType type) {
123   // We use an "=" prefix to indicate that we prefer "standard" types like
124   // np.int32 rather than "native" types like np.cint. pybind11 does not qualify
125   // its format descriptors.
126   switch (type) {
127     case PRED:
128       return std::string("?");
129     case S8:
130       return std::string("=b");
131     case S16:
132       return std::string("=h");
133     case S32:
134       return std::string("=i");
135     case S64:
136       return std::string("=q");
137     case U8:
138       return std::string("=B");
139     case U16:
140       return std::string("=H");
141     case U32:
142       return std::string("=I");
143     case U64:
144       return std::string("=Q");
145     case F16:
146       return std::string("=e");
147     case F32:
148       return std::string("=f");
149     case F64:
150       return std::string("=d");
151     case C64:
152       return std::string("=Zf");
153     case C128:
154       return std::string("=Zd");
155     default:
156       return Unimplemented("Unimplemented primitive type %s",
157                            PrimitiveType_Name(type));
158   }
159 }
160 
TypeDescriptorForPrimitiveType(PrimitiveType type)161 StatusOr<py::str> TypeDescriptorForPrimitiveType(PrimitiveType type) {
162   static_assert(__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__,
163                 "Big endian support not implemented");
164   switch (type) {
165     case PRED:
166       return py::str("|b1");
167     case S8:
168       return py::str("|i1");
169     case S16:
170       return py::str("<i2");
171     case S32:
172       return py::str("<i4");
173     case S64:
174       return py::str("<i8");
175     case U8:
176       return py::str("|u1");
177     case U16:
178       return py::str("<u2");
179     case U32:
180       return py::str("<u4");
181     case U64:
182       return py::str("<u8");
183     case BF16:
184       return py::str("<V2");
185     case F16:
186       return py::str("<f2");
187     case F32:
188       return py::str("<f4");
189     case F64:
190       return py::str("<f8");
191     case C64:
192       return py::str("<c8");
193     case C128:
194       return py::str("<c16");
195     default:
196       return Unimplemented("Unimplemented primitive type %s",
197                            PrimitiveType_Name(type));
198   }
199 }
200 
Squash64BitTypes(PrimitiveType type)201 PrimitiveType Squash64BitTypes(PrimitiveType type) {
202   switch (type) {
203     case S64:
204       return S32;
205     case U64:
206       return U32;
207     case F64:
208       return F32;
209     case C128:
210       return C64;
211     default:
212       return type;
213   }
214 }
215 
216 // Returns the strides for `shape`.
ByteStridesForShape(const Shape & shape)217 std::vector<ssize_t> ByteStridesForShape(const Shape& shape) {
218   std::vector<ssize_t> strides;
219   CHECK(shape.IsArray());
220   CHECK(shape.has_layout());
221 
222   strides.resize(shape.dimensions_size());
223   ssize_t stride = ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type());
224   for (int i : shape.layout().minor_to_major()) {
225     strides.at(i) = stride;
226     stride *= shape.dimensions(i);
227   }
228   return strides;
229 }
230 
ByteStridesForShapeInt64(const Shape & shape)231 std::vector<int64_t> ByteStridesForShapeInt64(const Shape& shape) {
232   std::vector<int64_t> strides;
233   CHECK(shape.IsArray());
234   CHECK(shape.has_layout());
235 
236   strides.resize(shape.dimensions_size());
237   int64_t stride = ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type());
238   for (int i : shape.layout().minor_to_major()) {
239     strides.at(i) = stride;
240     stride *= shape.dimensions(i);
241   }
242   return strides;
243 }
244 
LiteralToPython(std::shared_ptr<xla::Literal> literal)245 StatusOr<py::object> LiteralToPython(std::shared_ptr<xla::Literal> literal) {
246   xla::Literal& m = *literal;
247   if (m.shape().IsTuple()) {
248     std::vector<Literal> elems = m.DecomposeTuple();
249     std::vector<py::object> arrays(elems.size());
250     for (int i = 0; i < elems.size(); ++i) {
251       TF_ASSIGN_OR_RETURN(
252           arrays[i],
253           LiteralToPython(std::make_unique<Literal>(std::move(elems[i]))));
254     }
255     py::tuple result(elems.size());
256     for (int i = 0; i < elems.size(); ++i) {
257       PyTuple_SET_ITEM(result.ptr(), i, arrays[i].release().ptr());
258     }
259     return result;
260   }
261   TF_RET_CHECK(m.shape().IsArray());
262 
263   py::object literal_object = py::cast(literal);
264   TF_ASSIGN_OR_RETURN(py::dtype dtype,
265                       PrimitiveTypeToDtype(m.shape().element_type()));
266   return py::array(dtype, m.shape().dimensions(),
267                    ByteStridesForShape(m.shape()), m.untyped_data(),
268                    literal_object);
269 }
270 
GetPythonBufferTree(const py::object & argument)271 StatusOr<PythonBufferTree> GetPythonBufferTree(const py::object& argument) {
272   PythonBufferTree tree;
273   if (py::isinstance<py::tuple>(argument)) {
274     py::tuple tuple = py::reinterpret_borrow<py::tuple>(argument);
275     std::vector<Shape> host_shapes(tuple.size());
276     for (int i = 0; i < host_shapes.size(); ++i) {
277       TF_ASSIGN_OR_RETURN(PythonBufferTree subtree,
278                           GetPythonBufferTree(tuple[i]));
279       tree.leaves.reserve(tree.leaves.size() + subtree.leaves.size());
280       std::move(subtree.leaves.begin(), subtree.leaves.end(),
281                 std::back_inserter(tree.leaves));
282       tree.arrays.reserve(tree.arrays.size() + subtree.arrays.size());
283       std::move(subtree.arrays.begin(), subtree.arrays.end(),
284                 std::back_inserter(tree.arrays));
285       host_shapes[i] = std::move(subtree.shape);
286     }
287     tree.shape = ShapeUtil::MakeTupleShape(host_shapes);
288   } else {
289     pybind11::detail::type_caster<BorrowingLiteral> caster;
290     if (!caster.load(argument, /*convert=*/true)) {
291       return InvalidArgument("Invalid array value.");
292     }
293     DCHECK_EQ(caster.arrays.size(), 1);
294     tree.arrays.push_back(std::move(caster.arrays.front()));
295     tree.leaves.push_back(std::move(*caster));
296     tree.shape = tree.leaves.front().shape();
297   }
298   return tree;
299 }
300 
301 template <typename IntType>
IntSpanToTupleHelper(absl::Span<IntType const> xs)302 static py::tuple IntSpanToTupleHelper(absl::Span<IntType const> xs) {
303   py::tuple out(xs.size());
304   for (int i = 0; i < xs.size(); ++i) {
305     out[i] = py::int_(xs[i]);
306   }
307   return out;
308 }
309 
310 template <>
SpanToTuple(absl::Span<int const> xs)311 pybind11::tuple SpanToTuple(absl::Span<int const> xs) {
312   return IntSpanToTupleHelper(xs);
313 }
314 template <>
SpanToTuple(absl::Span<int64_t const> xs)315 pybind11::tuple SpanToTuple(absl::Span<int64_t const> xs) {
316   return IntSpanToTupleHelper(xs);
317 }
318 
CastToArray(py::handle h)319 std::optional<CastToArrayResult> CastToArray(py::handle h) {
320   py::array array = py::array::ensure(
321       h, py::array::c_style | py::detail::npy_api::NPY_ARRAY_ALIGNED_);
322   if (!array) {
323     return std::nullopt;
324   }
325   auto type_or_status = DtypeToPrimitiveType(array.dtype());
326   if (!type_or_status.ok()) {
327     throw xla::XlaRuntimeError(type_or_status.status());
328   }
329   PrimitiveType type = type_or_status.ValueOrDie();
330 
331   absl::InlinedVector<int64_t, 4> dims(array.ndim());
332   for (int i = 0; i < array.ndim(); ++i) {
333     dims[i] = array.shape(i);
334   }
335   Shape shape = ShapeUtil::MakeShape(type, dims);
336   if (array.size() * array.itemsize() != ShapeUtil::ByteSizeOf(shape)) {
337     throw xla::XlaRuntimeError(absl::StrCat(
338         "Size mismatch for buffer: ", array.size() * array.itemsize(), " vs. ",
339         ShapeUtil::ByteSizeOf(shape)));
340   }
341   return CastToArrayResult{array, static_cast<const char*>(array.data()),
342                            shape};
343 }
344 
345 }  // namespace xla
346