xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/py_values.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/py_values.h"
17 
18 #include "pybind11/pybind11.h"
19 #include "pybind11/pytypes.h"
20 #include "tensorflow/compiler/xla/primitive_util.h"
21 #include "tensorflow/compiler/xla/python/py_buffer.h"
22 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
23 #include "tensorflow/compiler/xla/python/sharded_device_array.h"
24 #include "tensorflow/compiler/xla/python/types.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/profiler/lib/traceme.h"
27 #include "tensorflow/python/lib/core/numpy.h"
28 
29 namespace py = pybind11;
30 
31 namespace xla {
32 
33 namespace {
34 
35 using DevicePutFunc = std::function<StatusOr<DevicePutResult>(
36     py::handle, PjRtDevice*, const DevicePutOptions& options)>;
37 
38 template <typename T, typename SquashedT>
HandlePythonScalar(py::handle obj,PjRtDevice * to_device,const DevicePutOptions & options)39 StatusOr<DevicePutResult> HandlePythonScalar(py::handle obj,
40                                              PjRtDevice* to_device,
41                                              const DevicePutOptions& options) {
42   T data;
43 
44   try {
45     data = py::cast<T>(obj);
46   } catch (const std::exception& e) {
47     return InvalidArgument(
48         "Unable to convert Python scalar to %s. This most likely means the "
49         "value (%s) overflows the range of the type.",
50         PrimitiveType_Name(primitive_util::NativeToPrimitiveType<T>()),
51         py::repr(obj));
52   }
53 
54   void* ptr;
55   SquashedT squashed_data;
56   Shape shape;
57   PrimitiveType type;
58   if (std::is_same<T, SquashedT>() || !options.squash_64bit_types) {
59     ptr = &data;
60     type = primitive_util::NativeToPrimitiveType<T>();
61   } else {
62     // TODO(phawkins): we should check for overflow here, e.g., because of bugs
63     // like https://github.com/google/jax/issues/2006
64     squashed_data = static_cast<SquashedT>(data);
65     ptr = &squashed_data;
66     type = primitive_util::NativeToPrimitiveType<SquashedT>();
67   }
68   // Must release the GIL before BufferFromHostBuffer because backends may
69   // decide to block/sleep for device buffer allocation.
70   py::gil_scoped_release gil_release;
71   TF_ASSIGN_OR_RETURN(
72       auto buffer,
73       to_device->client()->BufferFromHostBuffer(
74           ptr, type, /*dims=*/{}, /*byte_strides=*/{},
75           PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
76           /*on_done_with_host_buffer=*/nullptr, to_device));
77   return DevicePutResult(std::move(buffer), /*weak_type=*/true);
78 }
79 
HandlePythonInt(py::handle obj,PjRtDevice * to_device,const DevicePutOptions & options)80 StatusOr<DevicePutResult> HandlePythonInt(py::handle obj, PjRtDevice* to_device,
81                                           const DevicePutOptions& options) {
82   void* ptr;
83   PrimitiveType type;
84   int64_t data_int64;
85   int32_t data_int32;
86 
87   if (options.squash_64bit_types) {
88     try {
89       data_int32 = py::cast<int32_t>(obj);
90     } catch (const std::exception& e) {
91       return InvalidArgument(
92           "Unable to convert Python scalar to %s. This most likely means the "
93           "value (%s) overflows the range of the type.",
94           PrimitiveType_Name(primitive_util::NativeToPrimitiveType<int32_t>()),
95           py::repr(obj));
96     }
97     ptr = &data_int32;
98     type = S32;
99   } else {
100     try {
101       data_int64 = py::cast<int64_t>(obj);
102     } catch (const std::exception& e) {
103       return InvalidArgument(
104           "Unable to convert Python scalar to %s. This most likely means the "
105           "value (%s) overflows the range of the type.",
106           PrimitiveType_Name(primitive_util::NativeToPrimitiveType<int64_t>()),
107           py::repr(obj));
108     }
109     ptr = &data_int64;
110     type = S64;
111   }
112   // Must release the GIL before BufferFromHostBuffer because backends may
113   // decide to block/sleep for device buffer allocation.
114   py::gil_scoped_release gil_release;
115   TF_ASSIGN_OR_RETURN(
116       auto buffer,
117       to_device->client()->BufferFromHostBuffer(
118           ptr, type, /*dims=*/{}, /*byte_strides=*/{},
119           PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
120           /*on_done_with_host_buffer=*/nullptr, to_device));
121   return DevicePutResult(std::move(buffer), /*weak_type=*/true);
122 }
123 
124 template <typename T, typename SquashedT = T>
HandleNumpyScalar(py::handle h,PjRtDevice * to_device,const DevicePutOptions & options)125 StatusOr<DevicePutResult> HandleNumpyScalar(py::handle h, PjRtDevice* to_device,
126                                             const DevicePutOptions& options) {
127   T data;
128   SquashedT data_squashed;
129   void* ptr;
130   PrimitiveType type;
131   if (std::is_same<T, bfloat16>()) {
132     // For extension types, ScalarAsCtype returns a pointer to the data.
133     PyArray_ScalarAsCtype(h.ptr(), &ptr);
134     type = BF16;
135   } else if (std::is_same<T, SquashedT>() || !options.squash_64bit_types) {
136     PyArray_ScalarAsCtype(h.ptr(), &data);
137     ptr = &data;
138     type = primitive_util::NativeToPrimitiveType<T>();
139   } else {
140     PyArray_ScalarAsCtype(h.ptr(), &data);
141     data_squashed = static_cast<SquashedT>(data);
142     ptr = &data_squashed;
143     type = primitive_util::NativeToPrimitiveType<SquashedT>();
144   }
145   // Must release the GIL before BufferFromHostBuffer because backends may
146   // decide to block/sleep for device buffer allocation.
147   py::gil_scoped_release gil_release;
148   TF_ASSIGN_OR_RETURN(
149       std::unique_ptr<PjRtBuffer> buffer,
150       to_device->client()->BufferFromHostBuffer(
151           ptr, type, /*dims=*/{}, /*byte_strides=*/{},
152           PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
153           /*on_done_with_host_buffer=*/nullptr, to_device));
154   return DevicePutResult(std::move(buffer), /*weak_type=*/false);
155 }
156 
HandleNumpyArray(py::handle h,PjRtDevice * to_device,const DevicePutOptions & options)157 StatusOr<DevicePutResult> HandleNumpyArray(py::handle h, PjRtDevice* to_device,
158                                            const DevicePutOptions& options) {
159   py::array array = py::cast<py::array>(h);
160   TF_ASSIGN_OR_RETURN(PrimitiveType type, DtypeToPrimitiveType(array.dtype()));
161 
162   PrimitiveType squashed_type;
163   if (options.squash_64bit_types) {
164     squashed_type = Squash64BitTypes(type);
165     if (squashed_type != type) {
166       TF_ASSIGN_OR_RETURN(py::dtype squashed_dtype,
167                           PrimitiveTypeToDtype(squashed_type));
168       array = py::reinterpret_steal<py::array>(PyArray_CastToType(
169           reinterpret_cast<PyArrayObject*>(array.ptr()),
170           reinterpret_cast<PyArray_Descr*>(squashed_dtype.release().ptr()),
171           /*fortran=*/0));
172     }
173   } else {
174     squashed_type = type;
175   }
176 
177   absl::InlinedVector<int64_t, 4> dims(array.ndim());
178   absl::InlinedVector<int64_t, 4> byte_strides(array.ndim());
179   for (int i = 0; i < array.ndim(); ++i) {
180     dims[i] = array.shape(i);
181     byte_strides[i] = array.strides(i);
182   }
183   const void* data = array.data();
184   PjRtClient::HostBufferSemantics host_buffer_semantics =
185       PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall;
186   std::function<void()> on_done_with_host_buffer;
187   if (options.allow_zero_copy) {
188     std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
189         GlobalPyRefManager()->ManageReference(std::move(array));
190     on_done_with_host_buffer =
191         [py_buffer_ref{
192             std::move(py_buffer_ref)}]() { /* keeps py_buffer_ref alive */ };
193     host_buffer_semantics = PjRtClient::HostBufferSemantics::kZeroCopy;
194   }
195   // Must release the GIL before BufferFromHostBuffer because backends may
196   // decide to block/sleep for device buffer allocation.
197   py::gil_scoped_release gil_release;
198   TF_ASSIGN_OR_RETURN(
199       auto buffer,
200       to_device->client()->BufferFromHostBuffer(
201           data, squashed_type, dims, byte_strides, host_buffer_semantics,
202           std::move(on_done_with_host_buffer), to_device));
203   return DevicePutResult(std::move(buffer), /*weak_type=*/false);
204 }
205 
PyBufferHelper(py::handle obj,py::handle py_buffer,PyBuffer * buffer,PjRtDevice * to_device)206 StatusOr<DevicePutResult> PyBufferHelper(py::handle obj, py::handle py_buffer,
207                                          PyBuffer* buffer,
208                                          PjRtDevice* to_device) {
209   bool weak_type = buffer->weak_type()
210                        ? *buffer->weak_type()
211                        : py::cast<bool>(obj.attr("aval").attr("weak_type"));
212   if (buffer->buffer()->device() == to_device) {
213     return DevicePutResult(
214         buffer->buffer(), weak_type,
215         /*owning_pybuffer=*/py::reinterpret_borrow<py::object>(py_buffer));
216   } else {
217     TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtBuffer> copied_buffer,
218                         buffer->buffer()->CopyToDevice(to_device));
219     return DevicePutResult(std::move(copied_buffer), weak_type);
220   }
221 }
222 
HandlePyBuffer(py::handle obj,PjRtDevice * to_device,const DevicePutOptions & options)223 StatusOr<DevicePutResult> HandlePyBuffer(py::handle obj, PjRtDevice* to_device,
224                                          const DevicePutOptions& options) {
225   return PyBufferHelper(obj, obj, PyBuffer::AsPyBufferUnchecked(obj),
226                         to_device);
227 }
228 
HandleDeviceArray(py::handle obj,PjRtDevice * to_device,const DevicePutOptions & options)229 StatusOr<DevicePutResult> HandleDeviceArray(py::handle obj,
230                                             PjRtDevice* to_device,
231                                             const DevicePutOptions& options) {
232   // Handle Python DeviceArray objects provided they have a .device_buffer field
233   // Otherwise, fallback to handling as a NumPy array, since we do not
234   // understand how to get a buffer object out. For example, ShardedDeviceArray
235   // in JAX is handled by this path.
236   py::object buffer = py::getattr(obj, "device_buffer", py::none());
237   if (buffer.is_none()) {
238     return HandleNumpyArray(obj, to_device, options);
239   }
240 
241   return PyBufferHelper(obj, buffer, py::cast<PyBuffer*>(buffer), to_device);
242 }
243 
244 }  // namespace
245 
DevicePut(py::handle arg,PjRtDevice * to_device,const DevicePutOptions & options)246 StatusOr<DevicePutResult> DevicePut(py::handle arg, PjRtDevice* to_device,
247                                     const DevicePutOptions& options) {
248   tensorflow::profiler::TraceMe traceme("DevicePut");
249   static const absl::flat_hash_map<PyObject*, DevicePutFunc>* const handlers =
250       [] {
251         auto p = new absl::flat_hash_map<PyObject*, DevicePutFunc>();
252         const NumpyScalarTypes& dtypes = GetNumpyScalarTypes();
253         // Python scalar types.
254         static_assert(sizeof(bool) == 1,
255                       "Conversion code assumes bool is 1 byte");
256         (*p)[reinterpret_cast<PyObject*>(&PyBool_Type)] =
257             HandlePythonScalar<bool, bool>;
258         (*p)[reinterpret_cast<PyObject*>(&PyLong_Type)] = HandlePythonInt;
259         (*p)[reinterpret_cast<PyObject*>(&PyFloat_Type)] =
260             HandlePythonScalar<double, float>;
261         (*p)[reinterpret_cast<PyObject*>(&PyComplex_Type)] =
262             HandlePythonScalar<complex128, complex64>;
263 
264         // Generic subclasses of DeviceArray, e.g., ShardedDeviceArray.
265         (*p)[PyBuffer::base_type()] = HandleDeviceArray;
266 
267         try {
268           py::object xla_module = py::module::import("jax.interpreters.xla");
269           py::object device_array =
270               py::getattr(xla_module, "_DeviceArray", py::none());
271           if (!device_array.is_none()) {
272             (*p)[device_array.ptr()] = HandleDeviceArray;
273           }
274         } catch (const py::error_already_set& e) {
275           // Ignore; jax may not be present.
276         }
277 
278         try {
279           py::object pxla_module = py::module::import("jax.interpreters.pxla");
280           py::object sda =
281               py::getattr(pxla_module, "ShardedDeviceArray", py::none());
282           if (!sda.is_none()) {
283             (*p)[sda.ptr()] = HandleDeviceArray;
284           }
285         } catch (const py::error_already_set& e) {
286           // Ignore; jax may not be present.
287         }
288 
289         const auto numpy = py::module::import("numpy");
290         (*p)[numpy.attr("ndarray").ptr()] = HandleNumpyArray;
291 
292         // Numpy scalar types. For some of them, we share the handler with
293         // Python types (np_int64, np_float64, np_complex128).
294         (*p)[dtypes.np_bool.ptr()] = HandleNumpyScalar<bool>;
295         (*p)[dtypes.np_int8.ptr()] = HandleNumpyScalar<int8_t>;
296         (*p)[dtypes.np_int16.ptr()] = HandleNumpyScalar<int16_t>;
297         (*p)[dtypes.np_int32.ptr()] = HandleNumpyScalar<int32_t>;
298         (*p)[dtypes.np_int64.ptr()] = HandleNumpyScalar<int64_t, int32_t>;
299         (*p)[dtypes.np_uint8.ptr()] = HandleNumpyScalar<uint8_t>;
300         (*p)[dtypes.np_uint16.ptr()] = HandleNumpyScalar<uint16_t>;
301         (*p)[dtypes.np_uint32.ptr()] = HandleNumpyScalar<uint32_t>;
302         (*p)[dtypes.np_uint64.ptr()] = HandleNumpyScalar<uint64_t, uint32_t>;
303         (*p)[dtypes.np_bfloat16.ptr()] = HandleNumpyScalar<bfloat16>;
304         (*p)[dtypes.np_float16.ptr()] = HandleNumpyScalar<half>;
305         (*p)[dtypes.np_float32.ptr()] = HandleNumpyScalar<float>;
306         (*p)[dtypes.np_float64.ptr()] = HandleNumpyScalar<double, float>;
307         (*p)[dtypes.np_complex64.ptr()] = HandleNumpyScalar<complex64>;
308         (*p)[dtypes.np_complex128.ptr()] =
309             HandleNumpyScalar<complex128, complex64>;
310         static_assert(sizeof(long long) == sizeof(int64_t),  // NOLINT
311                       "long long must be the same size as int64_t");
312         (*p)[dtypes.np_longlong.ptr()] = HandleNumpyScalar<int64_t, int32_t>;
313         static_assert(sizeof(int) == sizeof(int32_t),
314                       "int must be the same size as int32_t");
315         (*p)[dtypes.np_intc.ptr()] = HandleNumpyScalar<int32_t>;
316 
317         return p;
318       }();
319 
320   // Fast-path for the most common case of PyBuffer.
321   if (arg.get_type().ptr() == PyBuffer::type()) {
322     return HandlePyBuffer(arg, to_device, options);
323   }
324 
325   auto res = handlers->find(arg.get_type().ptr());
326   if (res == handlers->end()) {
327     for (auto base_class : arg.get_type().attr("__mro__")) {
328       res = handlers->find(base_class.ptr());
329       if (res != handlers->end()) {
330         return res->second(arg, to_device, options);
331       }
332     }
333     return InvalidArgument(
334         "%s", absl::StrCat(
335                   "Not supported: The C++ jax jit execution path, only accepts "
336                   "DeviceArray, Numpy arrays scalars of supported types "
337                   "(see implementation), or Python scalars. Got type ",
338                   py::cast<std::string>(py::str(arg.get_type()))));
339   }
340   return res->second(arg, to_device, options);
341 }
342 
IsFloat0(py::array arg)343 bool IsFloat0(py::array arg) {
344   static const auto* dtypes_module =
345       new py::module(py::module::import("jax.dtypes"));
346   static const auto* float0_dtype =
347       new py::handle(dtypes_module->attr("float0"));
348   return float0_dtype->is(arg.attr("dtype"));
349 }
350 
DebugString() const351 std::string PyArgSignature::DebugString() const {
352   std::string result = "";
353   if (weak_type) {
354     absl::StrAppend(&result, "weak_");
355   }
356   absl::StrAppend(&result, xla::PrimitiveType_Name(dtype));
357   absl::StrAppend(&result, "[", absl::StrJoin(shape, ","), "]");
358   return result;
359 }
360 
361 using ToPyArgSignatureHandler =
362     std::function<StatusOr<PyArgSignature>(py::handle, bool)>;
363 
PyArgSignatureOfValue(py::handle arg,bool jax_enable_x64)364 StatusOr<PyArgSignature> PyArgSignatureOfValue(py::handle arg,
365                                                bool jax_enable_x64) {
366   static const absl::flat_hash_map<PyObject*, ToPyArgSignatureHandler>* const
367       handlers = [] {
368         auto p = new absl::flat_hash_map<PyObject*, ToPyArgSignatureHandler>();
369 
370         const NumpyScalarTypes& dtypes = GetNumpyScalarTypes();
371 
372         // The 4 Python native types.
373         ToPyArgSignatureHandler bool_handler =
374             [](py::handle, bool) -> StatusOr<PyArgSignature> {
375           return PyArgSignature(PrimitiveType::PRED, {}, true);
376         };
377         ToPyArgSignatureHandler int_handler =
378             [](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
379           // TODO(phawkins): we should consider checking for integer overflow.
380           if (jax_enable_x64) {
381             return PyArgSignature(PrimitiveType::S64, {}, true);
382           } else {
383             return PyArgSignature(PrimitiveType::S32, {}, true);
384           }
385         };
386         ToPyArgSignatureHandler float_handler =
387             [&dtypes](py::handle h,
388                       bool jax_enable_x64) -> StatusOr<PyArgSignature> {
389           // Only Python native types has a True weak_type.
390           bool weak_type = !py::isinstance(h, dtypes.np_float64);
391           if (jax_enable_x64) {
392             return PyArgSignature(PrimitiveType::F64, {}, weak_type);
393           } else {
394             return PyArgSignature(PrimitiveType::F32, {}, weak_type);
395           }
396         };
397         ToPyArgSignatureHandler complex_handler =
398             [&dtypes](py::handle h,
399                       bool jax_enable_x64) -> StatusOr<PyArgSignature> {
400           // Note that this branch is also taken  for np.complex128:
401           // isinstance(np.complex128(3), complex) returns True
402           // isinstance(np.complex64(3), complex) returns False
403           bool weak_type = !py::isinstance(h, dtypes.np_complex128);
404           if (jax_enable_x64) {
405             return PyArgSignature(PrimitiveType::C128, {}, weak_type);
406           } else {
407             return PyArgSignature(PrimitiveType::C64, {}, weak_type);
408           }
409         };
410 
411         (*p)[reinterpret_cast<PyObject*>(&PyBool_Type)] = bool_handler;
412         (*p)[reinterpret_cast<PyObject*>(&PyLong_Type)] = int_handler;
413         (*p)[reinterpret_cast<PyObject*>(&PyFloat_Type)] = float_handler;
414         (*p)[reinterpret_cast<PyObject*>(&PyComplex_Type)] = complex_handler;
415 
416         // The Buffer types except for fast-path PyBuffer.
417         ToPyArgSignatureHandler device_array_handler =
418             [](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
419           py::handle aval = h.attr("aval");
420           TF_ASSIGN_OR_RETURN(auto dtype,
421                               DtypeToPrimitiveType(aval.attr("dtype")));
422           return PyArgSignature(
423               dtype, py::cast<std::vector<int64_t>>(aval.attr("shape")),
424               py::cast<py::bool_>(aval.attr("weak_type")));
425         };
426         (*p)[PyBuffer::base_type()] = device_array_handler;
427 
428         try {
429           py::object xla_module = py::module::import("jax.interpreters.xla");
430           py::object device_array =
431               py::getattr(xla_module, "_DeviceArray", py::none());
432           if (!device_array.is_none()) {
433             (*p)[device_array.ptr()] = device_array_handler;
434           }
435         } catch (const py::error_already_set& e) {
436           // Ignore; jax may not be present.
437         }
438 
439         try {
440           py::object pxla_module = py::module::import("jax.interpreters.pxla");
441           py::object sda =
442               py::getattr(pxla_module, "ShardedDeviceArray", py::none());
443           if (!sda.is_none()) {
444             (*p)[sda.ptr()] = device_array_handler;
445           }
446         } catch (const py::error_already_set& e) {
447           // Ignore; jax may not be present.
448         }
449 
450         ToPyArgSignatureHandler numpy_handler =
451             [](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
452           py::array numpy_array = py::cast<py::array>(h);
453           TF_ASSIGN_OR_RETURN(PrimitiveType dtype,
454                               DtypeToPrimitiveType(numpy_array.dtype()));
455           if (!jax_enable_x64) {
456             dtype = Squash64BitTypes(dtype);
457           }
458           // We use reinterpret_cast<> to defend against environments where
459           // ssize_t may not be precisely the same type as int64_t, even if it
460           // is the same size (long vs long long).
461           static_assert(sizeof(int64_t) == sizeof(ssize_t),
462                         "Code assumes ssize_t is the same as int64_t");
463           return PyArgSignature(
464               dtype,
465               absl::MakeConstSpan(
466                   reinterpret_cast<const int64_t*>(numpy_array.shape()),
467                   numpy_array.ndim()),
468               /*weak_type=*/false);
469         };
470         const auto numpy = py::module::import("numpy");
471         (*p)[numpy.attr("ndarray").ptr()] = numpy_handler;
472 
473         ToPyArgSignatureHandler np_uint64_handler =
474             [](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
475           if (jax_enable_x64) {
476             return PyArgSignature(PrimitiveType::U64, {}, /*weak_type=*/false);
477           } else {
478             return PyArgSignature(PrimitiveType::U32, {}, /*weak_type=*/false);
479           }
480         };
481         ToPyArgSignatureHandler np_int_handler =
482             [](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
483           if (jax_enable_x64) {
484             return PyArgSignature(PrimitiveType::S64, {}, /*weak_type=*/false);
485           } else {
486             return PyArgSignature(PrimitiveType::S32, {}, /*weak_type=*/false);
487           }
488         };
489         ToPyArgSignatureHandler numpy_array_handler =
490             [](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
491           // This block deals with all numpy scalar types, except for int64_dt,
492           // float64_dt and complex128_dt which are taken care of in previous if
493           // blocks.
494           TF_ASSIGN_OR_RETURN(auto dtype,
495                               DtypeToPrimitiveType(h.attr("dtype")));
496           return PyArgSignature(dtype, {}, /*weak_type=*/false);
497         };
498 
499         // This block deals with all numpy scalar types, except for int64_dt,
500         // float64_dt and complex128_dt which are taken care of in previous if
501         // blocks.
502         (*p)[dtypes.np_bool.ptr()] = numpy_array_handler;
503         (*p)[dtypes.np_int8.ptr()] = numpy_array_handler;
504         (*p)[dtypes.np_int16.ptr()] = numpy_array_handler;
505         (*p)[dtypes.np_int32.ptr()] = numpy_array_handler;
506         (*p)[dtypes.np_int64.ptr()] = np_int_handler;
507         (*p)[dtypes.np_uint8.ptr()] = numpy_array_handler;
508         (*p)[dtypes.np_uint16.ptr()] = numpy_array_handler;
509         (*p)[dtypes.np_uint32.ptr()] = numpy_array_handler;
510         (*p)[dtypes.np_uint64.ptr()] = np_uint64_handler;
511         (*p)[dtypes.np_float16.ptr()] = numpy_array_handler;
512         (*p)[dtypes.np_bfloat16.ptr()] = numpy_array_handler;
513         (*p)[dtypes.np_float32.ptr()] = numpy_array_handler;
514         (*p)[dtypes.np_float64.ptr()] = float_handler;
515         (*p)[dtypes.np_complex64.ptr()] = numpy_array_handler;
516         (*p)[dtypes.np_complex128.ptr()] = complex_handler;
517         (*p)[dtypes.np_longlong.ptr()] = np_int_handler;
518         (*p)[dtypes.np_intc.ptr()] = numpy_array_handler;
519 
520         return p;
521       }();
522 
523   // Fast-path for the most common case of PyBuffer.
524   if (arg.get_type().ptr() == PyBuffer::type()) {
525     TF_ASSIGN_OR_RETURN(PyBuffer * buffer, PyBuffer::AsPyBuffer(arg));
526     bool weak_type = buffer->weak_type().has_value()
527                          ? *buffer->weak_type()
528                          : py::cast<bool>(arg.attr("aval").attr("weak_type"));
529     return PyArgSignature(buffer->buffer()->on_device_shape().element_type(),
530                           buffer->buffer()->on_device_shape().dimensions(),
531                           weak_type);
532   }
533 
534   // Fast-path for ShardedDeviceArray.
535   if (jax::ShardedDeviceArray::IsShardedDeviceArray(arg)) {
536     jax::ShardedDeviceArray* sda =
537         jax::ShardedDeviceArray::AsShardedDeviceArrayUnchecked(arg);
538 
539     // TODO(jblespiau): See if we can be faster not accessing the aval attribute
540     // and storing these directly.
541     py::handle aval = arg.attr("aval");
542     TF_ASSIGN_OR_RETURN(auto dtype, DtypeToPrimitiveType(aval.attr("dtype")));
543     return PyArgSignature(dtype,
544                           py::cast<std::vector<int64_t>>(aval.attr("shape")),
545                           sda->weak_type());
546   }
547 
548   auto res = handlers->find(arg.get_type().ptr());
549   if (res == handlers->end()) {
550     // We attempt to look at the MRO classes
551     for (auto base_class : arg.get_type().attr("__mro__")) {
552       res = handlers->find(base_class.ptr());
553       if (res != handlers->end()) {
554         return res->second(arg, jax_enable_x64);
555       }
556     }
557     return InvalidArgument(
558         "%s",
559         absl::StrCat("Not supported: The C++ ToPyArgSignature only accepts "
560                      "Buffer/DeviceArray/ShardedDeviceArray, Numpy "
561                      "arrays scalars of supported types "
562                      "(see implementation), or Python scalars. Got type ",
563                      py::cast<std::string>(py::str(arg.get_type()))));
564   }
565   return res->second(arg, jax_enable_x64);
566 }
567 
568 }  // namespace xla
569