1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/py_values.h"
17
18 #include "pybind11/pybind11.h"
19 #include "pybind11/pytypes.h"
20 #include "tensorflow/compiler/xla/primitive_util.h"
21 #include "tensorflow/compiler/xla/python/py_buffer.h"
22 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
23 #include "tensorflow/compiler/xla/python/sharded_device_array.h"
24 #include "tensorflow/compiler/xla/python/types.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/profiler/lib/traceme.h"
27 #include "tensorflow/python/lib/core/numpy.h"
28
29 namespace py = pybind11;
30
31 namespace xla {
32
33 namespace {
34
35 using DevicePutFunc = std::function<StatusOr<DevicePutResult>(
36 py::handle, PjRtDevice*, const DevicePutOptions& options)>;
37
38 template <typename T, typename SquashedT>
HandlePythonScalar(py::handle obj,PjRtDevice * to_device,const DevicePutOptions & options)39 StatusOr<DevicePutResult> HandlePythonScalar(py::handle obj,
40 PjRtDevice* to_device,
41 const DevicePutOptions& options) {
42 T data;
43
44 try {
45 data = py::cast<T>(obj);
46 } catch (const std::exception& e) {
47 return InvalidArgument(
48 "Unable to convert Python scalar to %s. This most likely means the "
49 "value (%s) overflows the range of the type.",
50 PrimitiveType_Name(primitive_util::NativeToPrimitiveType<T>()),
51 py::repr(obj));
52 }
53
54 void* ptr;
55 SquashedT squashed_data;
56 Shape shape;
57 PrimitiveType type;
58 if (std::is_same<T, SquashedT>() || !options.squash_64bit_types) {
59 ptr = &data;
60 type = primitive_util::NativeToPrimitiveType<T>();
61 } else {
62 // TODO(phawkins): we should check for overflow here, e.g., because of bugs
63 // like https://github.com/google/jax/issues/2006
64 squashed_data = static_cast<SquashedT>(data);
65 ptr = &squashed_data;
66 type = primitive_util::NativeToPrimitiveType<SquashedT>();
67 }
68 // Must release the GIL before BufferFromHostBuffer because backends may
69 // decide to block/sleep for device buffer allocation.
70 py::gil_scoped_release gil_release;
71 TF_ASSIGN_OR_RETURN(
72 auto buffer,
73 to_device->client()->BufferFromHostBuffer(
74 ptr, type, /*dims=*/{}, /*byte_strides=*/{},
75 PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
76 /*on_done_with_host_buffer=*/nullptr, to_device));
77 return DevicePutResult(std::move(buffer), /*weak_type=*/true);
78 }
79
HandlePythonInt(py::handle obj,PjRtDevice * to_device,const DevicePutOptions & options)80 StatusOr<DevicePutResult> HandlePythonInt(py::handle obj, PjRtDevice* to_device,
81 const DevicePutOptions& options) {
82 void* ptr;
83 PrimitiveType type;
84 int64_t data_int64;
85 int32_t data_int32;
86
87 if (options.squash_64bit_types) {
88 try {
89 data_int32 = py::cast<int32_t>(obj);
90 } catch (const std::exception& e) {
91 return InvalidArgument(
92 "Unable to convert Python scalar to %s. This most likely means the "
93 "value (%s) overflows the range of the type.",
94 PrimitiveType_Name(primitive_util::NativeToPrimitiveType<int32_t>()),
95 py::repr(obj));
96 }
97 ptr = &data_int32;
98 type = S32;
99 } else {
100 try {
101 data_int64 = py::cast<int64_t>(obj);
102 } catch (const std::exception& e) {
103 return InvalidArgument(
104 "Unable to convert Python scalar to %s. This most likely means the "
105 "value (%s) overflows the range of the type.",
106 PrimitiveType_Name(primitive_util::NativeToPrimitiveType<int64_t>()),
107 py::repr(obj));
108 }
109 ptr = &data_int64;
110 type = S64;
111 }
112 // Must release the GIL before BufferFromHostBuffer because backends may
113 // decide to block/sleep for device buffer allocation.
114 py::gil_scoped_release gil_release;
115 TF_ASSIGN_OR_RETURN(
116 auto buffer,
117 to_device->client()->BufferFromHostBuffer(
118 ptr, type, /*dims=*/{}, /*byte_strides=*/{},
119 PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
120 /*on_done_with_host_buffer=*/nullptr, to_device));
121 return DevicePutResult(std::move(buffer), /*weak_type=*/true);
122 }
123
124 template <typename T, typename SquashedT = T>
HandleNumpyScalar(py::handle h,PjRtDevice * to_device,const DevicePutOptions & options)125 StatusOr<DevicePutResult> HandleNumpyScalar(py::handle h, PjRtDevice* to_device,
126 const DevicePutOptions& options) {
127 T data;
128 SquashedT data_squashed;
129 void* ptr;
130 PrimitiveType type;
131 if (std::is_same<T, bfloat16>()) {
132 // For extension types, ScalarAsCtype returns a pointer to the data.
133 PyArray_ScalarAsCtype(h.ptr(), &ptr);
134 type = BF16;
135 } else if (std::is_same<T, SquashedT>() || !options.squash_64bit_types) {
136 PyArray_ScalarAsCtype(h.ptr(), &data);
137 ptr = &data;
138 type = primitive_util::NativeToPrimitiveType<T>();
139 } else {
140 PyArray_ScalarAsCtype(h.ptr(), &data);
141 data_squashed = static_cast<SquashedT>(data);
142 ptr = &data_squashed;
143 type = primitive_util::NativeToPrimitiveType<SquashedT>();
144 }
145 // Must release the GIL before BufferFromHostBuffer because backends may
146 // decide to block/sleep for device buffer allocation.
147 py::gil_scoped_release gil_release;
148 TF_ASSIGN_OR_RETURN(
149 std::unique_ptr<PjRtBuffer> buffer,
150 to_device->client()->BufferFromHostBuffer(
151 ptr, type, /*dims=*/{}, /*byte_strides=*/{},
152 PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
153 /*on_done_with_host_buffer=*/nullptr, to_device));
154 return DevicePutResult(std::move(buffer), /*weak_type=*/false);
155 }
156
HandleNumpyArray(py::handle h,PjRtDevice * to_device,const DevicePutOptions & options)157 StatusOr<DevicePutResult> HandleNumpyArray(py::handle h, PjRtDevice* to_device,
158 const DevicePutOptions& options) {
159 py::array array = py::cast<py::array>(h);
160 TF_ASSIGN_OR_RETURN(PrimitiveType type, DtypeToPrimitiveType(array.dtype()));
161
162 PrimitiveType squashed_type;
163 if (options.squash_64bit_types) {
164 squashed_type = Squash64BitTypes(type);
165 if (squashed_type != type) {
166 TF_ASSIGN_OR_RETURN(py::dtype squashed_dtype,
167 PrimitiveTypeToDtype(squashed_type));
168 array = py::reinterpret_steal<py::array>(PyArray_CastToType(
169 reinterpret_cast<PyArrayObject*>(array.ptr()),
170 reinterpret_cast<PyArray_Descr*>(squashed_dtype.release().ptr()),
171 /*fortran=*/0));
172 }
173 } else {
174 squashed_type = type;
175 }
176
177 absl::InlinedVector<int64_t, 4> dims(array.ndim());
178 absl::InlinedVector<int64_t, 4> byte_strides(array.ndim());
179 for (int i = 0; i < array.ndim(); ++i) {
180 dims[i] = array.shape(i);
181 byte_strides[i] = array.strides(i);
182 }
183 const void* data = array.data();
184 PjRtClient::HostBufferSemantics host_buffer_semantics =
185 PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall;
186 std::function<void()> on_done_with_host_buffer;
187 if (options.allow_zero_copy) {
188 std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
189 GlobalPyRefManager()->ManageReference(std::move(array));
190 on_done_with_host_buffer =
191 [py_buffer_ref{
192 std::move(py_buffer_ref)}]() { /* keeps py_buffer_ref alive */ };
193 host_buffer_semantics = PjRtClient::HostBufferSemantics::kZeroCopy;
194 }
195 // Must release the GIL before BufferFromHostBuffer because backends may
196 // decide to block/sleep for device buffer allocation.
197 py::gil_scoped_release gil_release;
198 TF_ASSIGN_OR_RETURN(
199 auto buffer,
200 to_device->client()->BufferFromHostBuffer(
201 data, squashed_type, dims, byte_strides, host_buffer_semantics,
202 std::move(on_done_with_host_buffer), to_device));
203 return DevicePutResult(std::move(buffer), /*weak_type=*/false);
204 }
205
PyBufferHelper(py::handle obj,py::handle py_buffer,PyBuffer * buffer,PjRtDevice * to_device)206 StatusOr<DevicePutResult> PyBufferHelper(py::handle obj, py::handle py_buffer,
207 PyBuffer* buffer,
208 PjRtDevice* to_device) {
209 bool weak_type = buffer->weak_type()
210 ? *buffer->weak_type()
211 : py::cast<bool>(obj.attr("aval").attr("weak_type"));
212 if (buffer->buffer()->device() == to_device) {
213 return DevicePutResult(
214 buffer->buffer(), weak_type,
215 /*owning_pybuffer=*/py::reinterpret_borrow<py::object>(py_buffer));
216 } else {
217 TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtBuffer> copied_buffer,
218 buffer->buffer()->CopyToDevice(to_device));
219 return DevicePutResult(std::move(copied_buffer), weak_type);
220 }
221 }
222
HandlePyBuffer(py::handle obj,PjRtDevice * to_device,const DevicePutOptions & options)223 StatusOr<DevicePutResult> HandlePyBuffer(py::handle obj, PjRtDevice* to_device,
224 const DevicePutOptions& options) {
225 return PyBufferHelper(obj, obj, PyBuffer::AsPyBufferUnchecked(obj),
226 to_device);
227 }
228
HandleDeviceArray(py::handle obj,PjRtDevice * to_device,const DevicePutOptions & options)229 StatusOr<DevicePutResult> HandleDeviceArray(py::handle obj,
230 PjRtDevice* to_device,
231 const DevicePutOptions& options) {
232 // Handle Python DeviceArray objects provided they have a .device_buffer field
233 // Otherwise, fallback to handling as a NumPy array, since we do not
234 // understand how to get a buffer object out. For example, ShardedDeviceArray
235 // in JAX is handled by this path.
236 py::object buffer = py::getattr(obj, "device_buffer", py::none());
237 if (buffer.is_none()) {
238 return HandleNumpyArray(obj, to_device, options);
239 }
240
241 return PyBufferHelper(obj, buffer, py::cast<PyBuffer*>(buffer), to_device);
242 }
243
244 } // namespace
245
DevicePut(py::handle arg,PjRtDevice * to_device,const DevicePutOptions & options)246 StatusOr<DevicePutResult> DevicePut(py::handle arg, PjRtDevice* to_device,
247 const DevicePutOptions& options) {
248 tensorflow::profiler::TraceMe traceme("DevicePut");
249 static const absl::flat_hash_map<PyObject*, DevicePutFunc>* const handlers =
250 [] {
251 auto p = new absl::flat_hash_map<PyObject*, DevicePutFunc>();
252 const NumpyScalarTypes& dtypes = GetNumpyScalarTypes();
253 // Python scalar types.
254 static_assert(sizeof(bool) == 1,
255 "Conversion code assumes bool is 1 byte");
256 (*p)[reinterpret_cast<PyObject*>(&PyBool_Type)] =
257 HandlePythonScalar<bool, bool>;
258 (*p)[reinterpret_cast<PyObject*>(&PyLong_Type)] = HandlePythonInt;
259 (*p)[reinterpret_cast<PyObject*>(&PyFloat_Type)] =
260 HandlePythonScalar<double, float>;
261 (*p)[reinterpret_cast<PyObject*>(&PyComplex_Type)] =
262 HandlePythonScalar<complex128, complex64>;
263
264 // Generic subclasses of DeviceArray, e.g., ShardedDeviceArray.
265 (*p)[PyBuffer::base_type()] = HandleDeviceArray;
266
267 try {
268 py::object xla_module = py::module::import("jax.interpreters.xla");
269 py::object device_array =
270 py::getattr(xla_module, "_DeviceArray", py::none());
271 if (!device_array.is_none()) {
272 (*p)[device_array.ptr()] = HandleDeviceArray;
273 }
274 } catch (const py::error_already_set& e) {
275 // Ignore; jax may not be present.
276 }
277
278 try {
279 py::object pxla_module = py::module::import("jax.interpreters.pxla");
280 py::object sda =
281 py::getattr(pxla_module, "ShardedDeviceArray", py::none());
282 if (!sda.is_none()) {
283 (*p)[sda.ptr()] = HandleDeviceArray;
284 }
285 } catch (const py::error_already_set& e) {
286 // Ignore; jax may not be present.
287 }
288
289 const auto numpy = py::module::import("numpy");
290 (*p)[numpy.attr("ndarray").ptr()] = HandleNumpyArray;
291
292 // Numpy scalar types. For some of them, we share the handler with
293 // Python types (np_int64, np_float64, np_complex128).
294 (*p)[dtypes.np_bool.ptr()] = HandleNumpyScalar<bool>;
295 (*p)[dtypes.np_int8.ptr()] = HandleNumpyScalar<int8_t>;
296 (*p)[dtypes.np_int16.ptr()] = HandleNumpyScalar<int16_t>;
297 (*p)[dtypes.np_int32.ptr()] = HandleNumpyScalar<int32_t>;
298 (*p)[dtypes.np_int64.ptr()] = HandleNumpyScalar<int64_t, int32_t>;
299 (*p)[dtypes.np_uint8.ptr()] = HandleNumpyScalar<uint8_t>;
300 (*p)[dtypes.np_uint16.ptr()] = HandleNumpyScalar<uint16_t>;
301 (*p)[dtypes.np_uint32.ptr()] = HandleNumpyScalar<uint32_t>;
302 (*p)[dtypes.np_uint64.ptr()] = HandleNumpyScalar<uint64_t, uint32_t>;
303 (*p)[dtypes.np_bfloat16.ptr()] = HandleNumpyScalar<bfloat16>;
304 (*p)[dtypes.np_float16.ptr()] = HandleNumpyScalar<half>;
305 (*p)[dtypes.np_float32.ptr()] = HandleNumpyScalar<float>;
306 (*p)[dtypes.np_float64.ptr()] = HandleNumpyScalar<double, float>;
307 (*p)[dtypes.np_complex64.ptr()] = HandleNumpyScalar<complex64>;
308 (*p)[dtypes.np_complex128.ptr()] =
309 HandleNumpyScalar<complex128, complex64>;
310 static_assert(sizeof(long long) == sizeof(int64_t), // NOLINT
311 "long long must be the same size as int64_t");
312 (*p)[dtypes.np_longlong.ptr()] = HandleNumpyScalar<int64_t, int32_t>;
313 static_assert(sizeof(int) == sizeof(int32_t),
314 "int must be the same size as int32_t");
315 (*p)[dtypes.np_intc.ptr()] = HandleNumpyScalar<int32_t>;
316
317 return p;
318 }();
319
320 // Fast-path for the most common case of PyBuffer.
321 if (arg.get_type().ptr() == PyBuffer::type()) {
322 return HandlePyBuffer(arg, to_device, options);
323 }
324
325 auto res = handlers->find(arg.get_type().ptr());
326 if (res == handlers->end()) {
327 for (auto base_class : arg.get_type().attr("__mro__")) {
328 res = handlers->find(base_class.ptr());
329 if (res != handlers->end()) {
330 return res->second(arg, to_device, options);
331 }
332 }
333 return InvalidArgument(
334 "%s", absl::StrCat(
335 "Not supported: The C++ jax jit execution path, only accepts "
336 "DeviceArray, Numpy arrays scalars of supported types "
337 "(see implementation), or Python scalars. Got type ",
338 py::cast<std::string>(py::str(arg.get_type()))));
339 }
340 return res->second(arg, to_device, options);
341 }
342
IsFloat0(py::array arg)343 bool IsFloat0(py::array arg) {
344 static const auto* dtypes_module =
345 new py::module(py::module::import("jax.dtypes"));
346 static const auto* float0_dtype =
347 new py::handle(dtypes_module->attr("float0"));
348 return float0_dtype->is(arg.attr("dtype"));
349 }
350
DebugString() const351 std::string PyArgSignature::DebugString() const {
352 std::string result = "";
353 if (weak_type) {
354 absl::StrAppend(&result, "weak_");
355 }
356 absl::StrAppend(&result, xla::PrimitiveType_Name(dtype));
357 absl::StrAppend(&result, "[", absl::StrJoin(shape, ","), "]");
358 return result;
359 }
360
361 using ToPyArgSignatureHandler =
362 std::function<StatusOr<PyArgSignature>(py::handle, bool)>;
363
PyArgSignatureOfValue(py::handle arg,bool jax_enable_x64)364 StatusOr<PyArgSignature> PyArgSignatureOfValue(py::handle arg,
365 bool jax_enable_x64) {
366 static const absl::flat_hash_map<PyObject*, ToPyArgSignatureHandler>* const
367 handlers = [] {
368 auto p = new absl::flat_hash_map<PyObject*, ToPyArgSignatureHandler>();
369
370 const NumpyScalarTypes& dtypes = GetNumpyScalarTypes();
371
372 // The 4 Python native types.
373 ToPyArgSignatureHandler bool_handler =
374 [](py::handle, bool) -> StatusOr<PyArgSignature> {
375 return PyArgSignature(PrimitiveType::PRED, {}, true);
376 };
377 ToPyArgSignatureHandler int_handler =
378 [](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
379 // TODO(phawkins): we should consider checking for integer overflow.
380 if (jax_enable_x64) {
381 return PyArgSignature(PrimitiveType::S64, {}, true);
382 } else {
383 return PyArgSignature(PrimitiveType::S32, {}, true);
384 }
385 };
386 ToPyArgSignatureHandler float_handler =
387 [&dtypes](py::handle h,
388 bool jax_enable_x64) -> StatusOr<PyArgSignature> {
389 // Only Python native types has a True weak_type.
390 bool weak_type = !py::isinstance(h, dtypes.np_float64);
391 if (jax_enable_x64) {
392 return PyArgSignature(PrimitiveType::F64, {}, weak_type);
393 } else {
394 return PyArgSignature(PrimitiveType::F32, {}, weak_type);
395 }
396 };
397 ToPyArgSignatureHandler complex_handler =
398 [&dtypes](py::handle h,
399 bool jax_enable_x64) -> StatusOr<PyArgSignature> {
400 // Note that this branch is also taken for np.complex128:
401 // isinstance(np.complex128(3), complex) returns True
402 // isinstance(np.complex64(3), complex) returns False
403 bool weak_type = !py::isinstance(h, dtypes.np_complex128);
404 if (jax_enable_x64) {
405 return PyArgSignature(PrimitiveType::C128, {}, weak_type);
406 } else {
407 return PyArgSignature(PrimitiveType::C64, {}, weak_type);
408 }
409 };
410
411 (*p)[reinterpret_cast<PyObject*>(&PyBool_Type)] = bool_handler;
412 (*p)[reinterpret_cast<PyObject*>(&PyLong_Type)] = int_handler;
413 (*p)[reinterpret_cast<PyObject*>(&PyFloat_Type)] = float_handler;
414 (*p)[reinterpret_cast<PyObject*>(&PyComplex_Type)] = complex_handler;
415
416 // The Buffer types except for fast-path PyBuffer.
417 ToPyArgSignatureHandler device_array_handler =
418 [](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
419 py::handle aval = h.attr("aval");
420 TF_ASSIGN_OR_RETURN(auto dtype,
421 DtypeToPrimitiveType(aval.attr("dtype")));
422 return PyArgSignature(
423 dtype, py::cast<std::vector<int64_t>>(aval.attr("shape")),
424 py::cast<py::bool_>(aval.attr("weak_type")));
425 };
426 (*p)[PyBuffer::base_type()] = device_array_handler;
427
428 try {
429 py::object xla_module = py::module::import("jax.interpreters.xla");
430 py::object device_array =
431 py::getattr(xla_module, "_DeviceArray", py::none());
432 if (!device_array.is_none()) {
433 (*p)[device_array.ptr()] = device_array_handler;
434 }
435 } catch (const py::error_already_set& e) {
436 // Ignore; jax may not be present.
437 }
438
439 try {
440 py::object pxla_module = py::module::import("jax.interpreters.pxla");
441 py::object sda =
442 py::getattr(pxla_module, "ShardedDeviceArray", py::none());
443 if (!sda.is_none()) {
444 (*p)[sda.ptr()] = device_array_handler;
445 }
446 } catch (const py::error_already_set& e) {
447 // Ignore; jax may not be present.
448 }
449
450 ToPyArgSignatureHandler numpy_handler =
451 [](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
452 py::array numpy_array = py::cast<py::array>(h);
453 TF_ASSIGN_OR_RETURN(PrimitiveType dtype,
454 DtypeToPrimitiveType(numpy_array.dtype()));
455 if (!jax_enable_x64) {
456 dtype = Squash64BitTypes(dtype);
457 }
458 // We use reinterpret_cast<> to defend against environments where
459 // ssize_t may not be precisely the same type as int64_t, even if it
460 // is the same size (long vs long long).
461 static_assert(sizeof(int64_t) == sizeof(ssize_t),
462 "Code assumes ssize_t is the same as int64_t");
463 return PyArgSignature(
464 dtype,
465 absl::MakeConstSpan(
466 reinterpret_cast<const int64_t*>(numpy_array.shape()),
467 numpy_array.ndim()),
468 /*weak_type=*/false);
469 };
470 const auto numpy = py::module::import("numpy");
471 (*p)[numpy.attr("ndarray").ptr()] = numpy_handler;
472
473 ToPyArgSignatureHandler np_uint64_handler =
474 [](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
475 if (jax_enable_x64) {
476 return PyArgSignature(PrimitiveType::U64, {}, /*weak_type=*/false);
477 } else {
478 return PyArgSignature(PrimitiveType::U32, {}, /*weak_type=*/false);
479 }
480 };
481 ToPyArgSignatureHandler np_int_handler =
482 [](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
483 if (jax_enable_x64) {
484 return PyArgSignature(PrimitiveType::S64, {}, /*weak_type=*/false);
485 } else {
486 return PyArgSignature(PrimitiveType::S32, {}, /*weak_type=*/false);
487 }
488 };
489 ToPyArgSignatureHandler numpy_array_handler =
490 [](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
491 // This block deals with all numpy scalar types, except for int64_dt,
492 // float64_dt and complex128_dt which are taken care of in previous if
493 // blocks.
494 TF_ASSIGN_OR_RETURN(auto dtype,
495 DtypeToPrimitiveType(h.attr("dtype")));
496 return PyArgSignature(dtype, {}, /*weak_type=*/false);
497 };
498
499 // This block deals with all numpy scalar types, except for int64_dt,
500 // float64_dt and complex128_dt which are taken care of in previous if
501 // blocks.
502 (*p)[dtypes.np_bool.ptr()] = numpy_array_handler;
503 (*p)[dtypes.np_int8.ptr()] = numpy_array_handler;
504 (*p)[dtypes.np_int16.ptr()] = numpy_array_handler;
505 (*p)[dtypes.np_int32.ptr()] = numpy_array_handler;
506 (*p)[dtypes.np_int64.ptr()] = np_int_handler;
507 (*p)[dtypes.np_uint8.ptr()] = numpy_array_handler;
508 (*p)[dtypes.np_uint16.ptr()] = numpy_array_handler;
509 (*p)[dtypes.np_uint32.ptr()] = numpy_array_handler;
510 (*p)[dtypes.np_uint64.ptr()] = np_uint64_handler;
511 (*p)[dtypes.np_float16.ptr()] = numpy_array_handler;
512 (*p)[dtypes.np_bfloat16.ptr()] = numpy_array_handler;
513 (*p)[dtypes.np_float32.ptr()] = numpy_array_handler;
514 (*p)[dtypes.np_float64.ptr()] = float_handler;
515 (*p)[dtypes.np_complex64.ptr()] = numpy_array_handler;
516 (*p)[dtypes.np_complex128.ptr()] = complex_handler;
517 (*p)[dtypes.np_longlong.ptr()] = np_int_handler;
518 (*p)[dtypes.np_intc.ptr()] = numpy_array_handler;
519
520 return p;
521 }();
522
523 // Fast-path for the most common case of PyBuffer.
524 if (arg.get_type().ptr() == PyBuffer::type()) {
525 TF_ASSIGN_OR_RETURN(PyBuffer * buffer, PyBuffer::AsPyBuffer(arg));
526 bool weak_type = buffer->weak_type().has_value()
527 ? *buffer->weak_type()
528 : py::cast<bool>(arg.attr("aval").attr("weak_type"));
529 return PyArgSignature(buffer->buffer()->on_device_shape().element_type(),
530 buffer->buffer()->on_device_shape().dimensions(),
531 weak_type);
532 }
533
534 // Fast-path for ShardedDeviceArray.
535 if (jax::ShardedDeviceArray::IsShardedDeviceArray(arg)) {
536 jax::ShardedDeviceArray* sda =
537 jax::ShardedDeviceArray::AsShardedDeviceArrayUnchecked(arg);
538
539 // TODO(jblespiau): See if we can be faster not accessing the aval attribute
540 // and storing these directly.
541 py::handle aval = arg.attr("aval");
542 TF_ASSIGN_OR_RETURN(auto dtype, DtypeToPrimitiveType(aval.attr("dtype")));
543 return PyArgSignature(dtype,
544 py::cast<std::vector<int64_t>>(aval.attr("shape")),
545 sda->weak_type());
546 }
547
548 auto res = handlers->find(arg.get_type().ptr());
549 if (res == handlers->end()) {
550 // We attempt to look at the MRO classes
551 for (auto base_class : arg.get_type().attr("__mro__")) {
552 res = handlers->find(base_class.ptr());
553 if (res != handlers->end()) {
554 return res->second(arg, jax_enable_x64);
555 }
556 }
557 return InvalidArgument(
558 "%s",
559 absl::StrCat("Not supported: The C++ ToPyArgSignature only accepts "
560 "Buffer/DeviceArray/ShardedDeviceArray, Numpy "
561 "arrays scalars of supported types "
562 "(see implementation), or Python scalars. Got type ",
563 py::cast<std::string>(py::str(arg.get_type()))));
564 }
565 return res->second(arg, jax_enable_x64);
566 }
567
568 } // namespace xla
569