1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/python/lib/core/bfloat16.h"
17
18 #include <array>
19 #include <cmath>
20 #include <limits>
21 #include <locale>
22
23 #include "tensorflow/python/lib/core/float8_e4m3b11.h"
24 // Place `<locale>` before <Python.h> to avoid a build failure in macOS.
25 #include <Python.h>
26
27 #include "absl/strings/str_cat.h"
28 #include "third_party/eigen3/Eigen/Core"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/python/lib/core/numpy.h"
31
32 namespace tensorflow {
33 namespace {
34
35 struct PyDecrefDeleter {
operator ()tensorflow::__anon7a1ed7ad0111::PyDecrefDeleter36 void operator()(PyObject* p) const { Py_DECREF(p); }
37 };
38
39 // Safe container for an owned PyObject. On destruction, the reference count of
40 // the contained object will be decremented.
41 using Safe_PyObjectPtr = std::unique_ptr<PyObject, PyDecrefDeleter>;
make_safe(PyObject * object)42 Safe_PyObjectPtr make_safe(PyObject* object) {
43 return Safe_PyObjectPtr(object);
44 }
45
PyLong_CheckNoOverflow(PyObject * object)46 bool PyLong_CheckNoOverflow(PyObject* object) {
47 if (!PyLong_Check(object)) {
48 return false;
49 }
50 int overflow = 0;
51 PyLong_AsLongAndOverflow(object, &overflow);
52 return (overflow == 0);
53 }
54
55 template <typename T, typename Enable = void>
56 struct TypeDescriptor {
57 // typedef ... T; // Representation type in memory for NumPy values of type
58 // static int Dtype() { return NPY_...; } // Numpy type number for T.
59 };
60
61 template <typename T>
62 struct CustomFloatTypeDescriptor {
Dtypetensorflow::__anon7a1ed7ad0111::CustomFloatTypeDescriptor63 static int Dtype() { return npy_type; }
64
65 // Registered numpy type ID. Global variable populated by the registration
66 // code. Protected by the GIL.
67 static int npy_type;
68
69 static PyTypeObject type;
70 // Pointer to the python type object we are using. This is either a pointer
71 // to type, if we choose to register it, or to the python type
72 // registered by another system into NumPy.
73 static PyTypeObject* type_ptr;
74
75 static PyNumberMethods number_methods;
76
77 static PyArray_ArrFuncs arr_funcs;
78
79 static PyArray_Descr npy_descr;
80 };
81 template <typename T>
82 int CustomFloatTypeDescriptor<T>::npy_type = NPY_NOTYPE;
83 template <typename T>
84 PyTypeObject* CustomFloatTypeDescriptor<T>::type_ptr = nullptr;
85
86 // Representation of a Python custom float object.
87 template <typename T>
88 struct PyCustomFloat {
89 PyObject_HEAD; // Python object header
90 T value;
91 };
92
93 // Returns true if 'object' is a PyCustomFloat.
94 template <typename T>
PyCustomFloat_Check(PyObject * object)95 bool PyCustomFloat_Check(PyObject* object) {
96 return PyObject_IsInstance(
97 object, reinterpret_cast<PyObject*>(&TypeDescriptor<T>::type));
98 }
99
100 // Extracts the value of a PyCustomFloat object.
101 template <typename T>
PyCustomFloat_CustomFloat(PyObject * object)102 T PyCustomFloat_CustomFloat(PyObject* object) {
103 return reinterpret_cast<PyCustomFloat<T>*>(object)->value;
104 }
105
106 // Constructs a PyCustomFloat object from PyCustomFloat<T>::T.
107 template <typename T>
PyCustomFloat_FromT(T x)108 Safe_PyObjectPtr PyCustomFloat_FromT(T x) {
109 Safe_PyObjectPtr ref =
110 make_safe(TypeDescriptor<T>::type.tp_alloc(&TypeDescriptor<T>::type, 0));
111 PyCustomFloat<T>* p = reinterpret_cast<PyCustomFloat<T>*>(ref.get());
112 if (p) {
113 p->value = x;
114 }
115 return ref;
116 }
117
118 // Converts a Python object to a reduced float value. Returns true on success,
119 // returns false and reports a Python error on failure.
120 template <typename T>
CastToCustomFloat(PyObject * arg,T * output)121 bool CastToCustomFloat(PyObject* arg, T* output) {
122 if (PyCustomFloat_Check<T>(arg)) {
123 *output = PyCustomFloat_CustomFloat<T>(arg);
124 return true;
125 }
126 if (PyFloat_Check(arg)) {
127 double d = PyFloat_AsDouble(arg);
128 if (PyErr_Occurred()) {
129 return false;
130 }
131 // TODO(phawkins): check for overflow
132 *output = T(d);
133 return true;
134 }
135 if (PyLong_CheckNoOverflow(arg)) {
136 long l = PyLong_AsLong(arg); // NOLINT
137 if (PyErr_Occurred()) {
138 return false;
139 }
140 // TODO(phawkins): check for overflow
141 *output = T(static_cast<float>(l));
142 return true;
143 }
144 if (PyArray_IsScalar(arg, Half)) {
145 Eigen::half f;
146 PyArray_ScalarAsCtype(arg, &f);
147 *output = T(f);
148 return true;
149 }
150 if (PyArray_IsScalar(arg, Float)) {
151 float f;
152 PyArray_ScalarAsCtype(arg, &f);
153 *output = T(f);
154 return true;
155 }
156 if (PyArray_IsScalar(arg, Double)) {
157 double f;
158 PyArray_ScalarAsCtype(arg, &f);
159 *output = T(f);
160 return true;
161 }
162 if (PyArray_IsScalar(arg, LongDouble)) {
163 long double f;
164 PyArray_ScalarAsCtype(arg, &f);
165 *output = T(f);
166 return true;
167 }
168 if (PyArray_IsZeroDim(arg)) {
169 Safe_PyObjectPtr ref;
170 PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(arg);
171 if (PyArray_TYPE(arr) != TypeDescriptor<T>::Dtype()) {
172 ref = make_safe(PyArray_Cast(arr, TypeDescriptor<T>::Dtype()));
173 if (PyErr_Occurred()) {
174 return false;
175 }
176 arg = ref.get();
177 arr = reinterpret_cast<PyArrayObject*>(arg);
178 }
179 *output = *reinterpret_cast<T*>(PyArray_DATA(arr));
180 return true;
181 }
182 return false;
183 }
184
185 template <typename T>
SafeCastToCustomFloat(PyObject * arg,T * output)186 bool SafeCastToCustomFloat(PyObject* arg, T* output) {
187 if (PyCustomFloat_Check<T>(arg)) {
188 *output = PyCustomFloat_CustomFloat<T>(arg);
189 return true;
190 }
191 return false;
192 }
193
194 // Converts a PyReduceFloat into a PyFloat.
195 template <typename T>
PyCustomFloat_Float(PyObject * self)196 PyObject* PyCustomFloat_Float(PyObject* self) {
197 T x = PyCustomFloat_CustomFloat<T>(self);
198 return PyFloat_FromDouble(static_cast<double>(static_cast<float>(x)));
199 }
200
201 // Converts a PyReduceFloat into a PyInt.
202 template <typename T>
PyCustomFloat_Int(PyObject * self)203 PyObject* PyCustomFloat_Int(PyObject* self) {
204 T x = PyCustomFloat_CustomFloat<T>(self);
205 long y = static_cast<long>(static_cast<float>(x)); // NOLINT
206 return PyLong_FromLong(y);
207 }
208
209 // Negates a PyCustomFloat.
210 template <typename T>
PyCustomFloat_Negative(PyObject * self)211 PyObject* PyCustomFloat_Negative(PyObject* self) {
212 T x = PyCustomFloat_CustomFloat<T>(self);
213 return PyCustomFloat_FromT<T>(-x).release();
214 }
215
216 template <typename T>
PyCustomFloat_Add(PyObject * a,PyObject * b)217 PyObject* PyCustomFloat_Add(PyObject* a, PyObject* b) {
218 T x, y;
219 if (SafeCastToCustomFloat<T>(a, &x) && SafeCastToCustomFloat<T>(b, &y)) {
220 return PyCustomFloat_FromT<T>(x + y).release();
221 }
222 return PyArray_Type.tp_as_number->nb_add(a, b);
223 }
224
225 template <typename T>
PyCustomFloat_Subtract(PyObject * a,PyObject * b)226 PyObject* PyCustomFloat_Subtract(PyObject* a, PyObject* b) {
227 T x, y;
228 if (SafeCastToCustomFloat<T>(a, &x) && SafeCastToCustomFloat<T>(b, &y)) {
229 return PyCustomFloat_FromT<T>(x - y).release();
230 }
231 return PyArray_Type.tp_as_number->nb_subtract(a, b);
232 }
233
234 template <typename T>
PyCustomFloat_Multiply(PyObject * a,PyObject * b)235 PyObject* PyCustomFloat_Multiply(PyObject* a, PyObject* b) {
236 T x, y;
237 if (SafeCastToCustomFloat<T>(a, &x) && SafeCastToCustomFloat<T>(b, &y)) {
238 return PyCustomFloat_FromT<T>(x * y).release();
239 }
240 return PyArray_Type.tp_as_number->nb_multiply(a, b);
241 }
242
243 template <typename T>
PyCustomFloat_TrueDivide(PyObject * a,PyObject * b)244 PyObject* PyCustomFloat_TrueDivide(PyObject* a, PyObject* b) {
245 T x, y;
246 if (SafeCastToCustomFloat<T>(a, &x) && SafeCastToCustomFloat<T>(b, &y)) {
247 return PyCustomFloat_FromT<T>(x / y).release();
248 }
249 return PyArray_Type.tp_as_number->nb_true_divide(a, b);
250 }
251
252 // Python number methods for PyCustomFloat objects.
253 template <typename T>
254 PyNumberMethods CustomFloatTypeDescriptor<T>::number_methods = {
255 PyCustomFloat_Add<T>, // nb_add
256 PyCustomFloat_Subtract<T>, // nb_subtract
257 PyCustomFloat_Multiply<T>, // nb_multiply
258 nullptr, // nb_remainder
259 nullptr, // nb_divmod
260 nullptr, // nb_power
261 PyCustomFloat_Negative<T>, // nb_negative
262 nullptr, // nb_positive
263 nullptr, // nb_absolute
264 nullptr, // nb_nonzero
265 nullptr, // nb_invert
266 nullptr, // nb_lshift
267 nullptr, // nb_rshift
268 nullptr, // nb_and
269 nullptr, // nb_xor
270 nullptr, // nb_or
271 PyCustomFloat_Int<T>, // nb_int
272 nullptr, // reserved
273 PyCustomFloat_Float<T>, // nb_float
274
275 nullptr, // nb_inplace_add
276 nullptr, // nb_inplace_subtract
277 nullptr, // nb_inplace_multiply
278 nullptr, // nb_inplace_remainder
279 nullptr, // nb_inplace_power
280 nullptr, // nb_inplace_lshift
281 nullptr, // nb_inplace_rshift
282 nullptr, // nb_inplace_and
283 nullptr, // nb_inplace_xor
284 nullptr, // nb_inplace_or
285
286 nullptr, // nb_floor_divide
287 PyCustomFloat_TrueDivide<T>, // nb_true_divide
288 nullptr, // nb_inplace_floor_divide
289 nullptr, // nb_inplace_true_divide
290 nullptr, // nb_index
291 };
292
293 // Constructs a new PyCustomFloat.
294 template <typename T>
PyCustomFloat_New(PyTypeObject * type,PyObject * args,PyObject * kwds)295 PyObject* PyCustomFloat_New(PyTypeObject* type, PyObject* args,
296 PyObject* kwds) {
297 if (kwds && PyDict_Size(kwds)) {
298 PyErr_SetString(PyExc_TypeError, "constructor takes no keyword arguments");
299 return nullptr;
300 }
301 Py_ssize_t size = PyTuple_Size(args);
302 if (size != 1) {
303 PyErr_Format(PyExc_TypeError,
304 "expected number as argument to %s constructor",
305 TypeDescriptor<T>::kTypeName);
306 return nullptr;
307 }
308 PyObject* arg = PyTuple_GetItem(args, 0);
309
310 T value;
311 if (PyCustomFloat_Check<T>(arg)) {
312 Py_INCREF(arg);
313 return arg;
314 } else if (CastToCustomFloat<T>(arg, &value)) {
315 return PyCustomFloat_FromT<T>(value).release();
316 } else if (PyArray_Check(arg)) {
317 PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(arg);
318 if (PyArray_TYPE(arr) != TypeDescriptor<T>::Dtype()) {
319 return PyArray_Cast(arr, TypeDescriptor<T>::Dtype());
320 } else {
321 Py_INCREF(arg);
322 return arg;
323 }
324 }
325 PyErr_Format(PyExc_TypeError, "expected number, got %s",
326 Py_TYPE(arg)->tp_name);
327 return nullptr;
328 }
329
330 // Comparisons on PyCustomFloats.
331 template <typename T>
PyCustomFloat_RichCompare(PyObject * a,PyObject * b,int op)332 PyObject* PyCustomFloat_RichCompare(PyObject* a, PyObject* b, int op) {
333 T x, y;
334 if (!SafeCastToCustomFloat<T>(a, &x) || !SafeCastToCustomFloat<T>(b, &y)) {
335 return PyGenericArrType_Type.tp_richcompare(a, b, op);
336 }
337 bool result;
338 switch (op) {
339 case Py_LT:
340 result = x < y;
341 break;
342 case Py_LE:
343 result = x <= y;
344 break;
345 case Py_EQ:
346 result = x == y;
347 break;
348 case Py_NE:
349 result = x != y;
350 break;
351 case Py_GT:
352 result = x > y;
353 break;
354 case Py_GE:
355 result = x >= y;
356 break;
357 default:
358 LOG(FATAL) << "Invalid op type " << op;
359 }
360 return PyBool_FromLong(result);
361 }
362
363 // Implementation of repr() for PyCustomFloat.
364 template <typename T>
PyCustomFloat_Repr(PyObject * self)365 PyObject* PyCustomFloat_Repr(PyObject* self) {
366 T x = reinterpret_cast<PyCustomFloat<T>*>(self)->value;
367 std::string v = absl::StrCat(static_cast<float>(x));
368 return PyUnicode_FromString(v.c_str());
369 }
370
371 // Implementation of str() for PyCustomFloat.
372 template <typename T>
PyCustomFloat_Str(PyObject * self)373 PyObject* PyCustomFloat_Str(PyObject* self) {
374 T x = reinterpret_cast<PyCustomFloat<T>*>(self)->value;
375 std::string v = absl::StrCat(static_cast<float>(x));
376 return PyUnicode_FromString(v.c_str());
377 }
378
379 // _Py_HashDouble changed its prototype for Python 3.10 so we use an overload to
380 // handle the two possibilities.
381 // NOLINTNEXTLINE(clang-diagnostic-unused-function)
HashImpl(Py_hash_t (* hash_double)(PyObject *,double),PyObject * self,double value)382 Py_hash_t HashImpl(Py_hash_t (*hash_double)(PyObject*, double), PyObject* self,
383 double value) {
384 return hash_double(self, value);
385 }
386
387 // NOLINTNEXTLINE(clang-diagnostic-unused-function)
HashImpl(Py_hash_t (* hash_double)(double),PyObject * self,double value)388 Py_hash_t HashImpl(Py_hash_t (*hash_double)(double), PyObject* self,
389 double value) {
390 return hash_double(value);
391 }
392
393 // Hash function for PyCustomFloat.
394 template <typename T>
PyCustomFloat_Hash(PyObject * self)395 Py_hash_t PyCustomFloat_Hash(PyObject* self) {
396 T x = reinterpret_cast<PyCustomFloat<T>*>(self)->value;
397 return HashImpl(&_Py_HashDouble, self, static_cast<double>(x));
398 }
399
400 // Python type for PyCustomFloat objects.
401 template <typename T>
402 PyTypeObject CustomFloatTypeDescriptor<T>::type = {
403 PyVarObject_HEAD_INIT(nullptr, 0) TypeDescriptor<T>::kTypeName, // tp_name
404 sizeof(PyCustomFloat<T>), // tp_basicsize
405 0, // tp_itemsize
406 nullptr, // tp_dealloc
407 #if PY_VERSION_HEX < 0x03080000
408 nullptr, // tp_print
409 #else
410 0, // tp_vectorcall_offset
411 #endif
412 nullptr, // tp_getattr
413 nullptr, // tp_setattr
414 nullptr, // tp_compare / tp_reserved
415 PyCustomFloat_Repr<T>, // tp_repr
416 &CustomFloatTypeDescriptor<T>::number_methods, // tp_as_number
417 nullptr, // tp_as_sequence
418 nullptr, // tp_as_mapping
419 PyCustomFloat_Hash<T>, // tp_hash
420 nullptr, // tp_call
421 PyCustomFloat_Str<T>, // tp_str
422 nullptr, // tp_getattro
423 nullptr, // tp_setattro
424 nullptr, // tp_as_buffer
425 // tp_flags
426 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,
427 TypeDescriptor<T>::kTpDoc, // tp_doc
428 nullptr, // tp_traverse
429 nullptr, // tp_clear
430 PyCustomFloat_RichCompare<T>, // tp_richcompare
431 0, // tp_weaklistoffset
432 nullptr, // tp_iter
433 nullptr, // tp_iternext
434 nullptr, // tp_methods
435 nullptr, // tp_members
436 nullptr, // tp_getset
437 nullptr, // tp_base
438 nullptr, // tp_dict
439 nullptr, // tp_descr_get
440 nullptr, // tp_descr_set
441 0, // tp_dictoffset
442 nullptr, // tp_init
443 nullptr, // tp_alloc
444 PyCustomFloat_New<T>, // tp_new
445 nullptr, // tp_free
446 nullptr, // tp_is_gc
447 nullptr, // tp_bases
448 nullptr, // tp_mro
449 nullptr, // tp_cache
450 nullptr, // tp_subclasses
451 nullptr, // tp_weaklist
452 nullptr, // tp_del
453 0, // tp_version_tag
454 };
455
456 // Numpy support
457 template <typename T>
458 PyArray_ArrFuncs CustomFloatTypeDescriptor<T>::arr_funcs;
459
460 template <typename T>
461 PyArray_Descr CustomFloatTypeDescriptor<T>::npy_descr = {
462 PyObject_HEAD_INIT(nullptr) //
463 /*typeobj=*/
464 (&TypeDescriptor<T>::type),
465 /*kind=*/TypeDescriptor<T>::kNpyDescrKind,
466 /*type=*/TypeDescriptor<T>::kNpyDescrType,
467 /*byteorder=*/TypeDescriptor<T>::kNpyDescrByteorder,
468 /*flags=*/NPY_NEEDS_PYAPI | NPY_USE_SETITEM,
469 /*type_num=*/0,
470 /*elsize=*/sizeof(T),
471 /*alignment=*/alignof(T),
472 /*subarray=*/nullptr,
473 /*fields=*/nullptr,
474 /*names=*/nullptr,
475 /*f=*/&CustomFloatTypeDescriptor<T>::arr_funcs,
476 /*metadata=*/nullptr,
477 /*c_metadata=*/nullptr,
478 /*hash=*/-1, // -1 means "not computed yet".
479 };
480
481 // Implementations of NumPy array methods.
482
483 template <typename T>
NPyCustomFloat_GetItem(void * data,void * arr)484 PyObject* NPyCustomFloat_GetItem(void* data, void* arr) {
485 T x;
486 memcpy(&x, data, sizeof(T));
487 return PyFloat_FromDouble(static_cast<float>(x));
488 }
489
490 template <typename T>
NPyCustomFloat_SetItem(PyObject * item,void * data,void * arr)491 int NPyCustomFloat_SetItem(PyObject* item, void* data, void* arr) {
492 T x;
493 if (!CastToCustomFloat<T>(item, &x)) {
494 PyErr_Format(PyExc_TypeError, "expected number, got %s",
495 Py_TYPE(item)->tp_name);
496 return -1;
497 }
498 memcpy(data, &x, sizeof(T));
499 return 0;
500 }
501
ByteSwap16(void * value)502 void ByteSwap16(void* value) {
503 char* p = reinterpret_cast<char*>(value);
504 std::swap(p[0], p[1]);
505 }
506
507 template <typename T>
NPyCustomFloat_Compare(const void * a,const void * b,void * arr)508 int NPyCustomFloat_Compare(const void* a, const void* b, void* arr) {
509 T x;
510 memcpy(&x, a, sizeof(T));
511
512 T y;
513 memcpy(&y, b, sizeof(T));
514 float fy(y);
515 float fx(x);
516
517 if (fx < fy) {
518 return -1;
519 }
520 if (fy < fx) {
521 return 1;
522 }
523 // NaNs sort to the end.
524 if (!Eigen::numext::isnan(fx) && Eigen::numext::isnan(fy)) {
525 return -1;
526 }
527 if (Eigen::numext::isnan(fx) && !Eigen::numext::isnan(fy)) {
528 return 1;
529 }
530 return 0;
531 }
532
533 template <typename T>
NPyCustomFloat_CopySwapN(void * dstv,npy_intp dstride,void * srcv,npy_intp sstride,npy_intp n,int swap,void * arr)534 void NPyCustomFloat_CopySwapN(void* dstv, npy_intp dstride, void* srcv,
535 npy_intp sstride, npy_intp n, int swap,
536 void* arr) {
537 static_assert(sizeof(T) == sizeof(int16_t) || sizeof(T) == sizeof(int8_t),
538 "Not supported");
539 char* dst = reinterpret_cast<char*>(dstv);
540 char* src = reinterpret_cast<char*>(srcv);
541 if (!src) {
542 return;
543 }
544 if (swap && sizeof(T) == sizeof(int16_t)) {
545 for (npy_intp i = 0; i < n; i++) {
546 char* r = dst + dstride * i;
547 memcpy(r, src + sstride * i, sizeof(T));
548 ByteSwap16(r);
549 }
550 } else if (dstride == sizeof(T) && sstride == sizeof(T)) {
551 memcpy(dst, src, n * sizeof(T));
552 } else {
553 for (npy_intp i = 0; i < n; i++) {
554 memcpy(dst + dstride * i, src + sstride * i, sizeof(T));
555 }
556 }
557 }
558
559 template <typename T>
NPyCustomFloat_CopySwap(void * dst,void * src,int swap,void * arr)560 void NPyCustomFloat_CopySwap(void* dst, void* src, int swap, void* arr) {
561 if (!src) {
562 return;
563 }
564 memcpy(dst, src, sizeof(T));
565 static_assert(sizeof(T) == sizeof(int16_t) || sizeof(T) == sizeof(int8_t),
566 "Not supported");
567 if (swap && sizeof(T) == sizeof(int16_t)) {
568 ByteSwap16(dst);
569 }
570 }
571
572 template <typename T>
NPyCustomFloat_NonZero(void * data,void * arr)573 npy_bool NPyCustomFloat_NonZero(void* data, void* arr) {
574 T x;
575 memcpy(&x, data, sizeof(x));
576 return x != static_cast<T>(0);
577 }
578
579 template <typename T>
NPyCustomFloat_Fill(void * buffer_raw,npy_intp length,void * ignored)580 int NPyCustomFloat_Fill(void* buffer_raw, npy_intp length, void* ignored) {
581 T* const buffer = reinterpret_cast<T*>(buffer_raw);
582 const float start(buffer[0]);
583 const float delta = static_cast<float>(buffer[1]) - start;
584 for (npy_intp i = 2; i < length; ++i) {
585 buffer[i] = static_cast<T>(start + i * delta);
586 }
587 return 0;
588 }
589
590 template <typename T>
NPyCustomFloat_DotFunc(void * ip1,npy_intp is1,void * ip2,npy_intp is2,void * op,npy_intp n,void * arr)591 void NPyCustomFloat_DotFunc(void* ip1, npy_intp is1, void* ip2, npy_intp is2,
592 void* op, npy_intp n, void* arr) {
593 char* c1 = reinterpret_cast<char*>(ip1);
594 char* c2 = reinterpret_cast<char*>(ip2);
595 float acc = 0.0f;
596 for (npy_intp i = 0; i < n; ++i) {
597 T* const b1 = reinterpret_cast<T*>(c1);
598 T* const b2 = reinterpret_cast<T*>(c2);
599 acc += static_cast<float>(*b1) * static_cast<float>(*b2);
600 c1 += is1;
601 c2 += is2;
602 }
603 T* out = reinterpret_cast<T*>(op);
604 *out = static_cast<T>(acc);
605 }
606
607 template <typename T>
NPyCustomFloat_CompareFunc(const void * v1,const void * v2,void * arr)608 int NPyCustomFloat_CompareFunc(const void* v1, const void* v2, void* arr) {
609 T b1 = *reinterpret_cast<const T*>(v1);
610 T b2 = *reinterpret_cast<const T*>(v2);
611 if (b1 < b2) {
612 return -1;
613 }
614 if (b1 > b2) {
615 return 1;
616 }
617 return 0;
618 }
619
620 template <typename T>
NPyCustomFloat_ArgMaxFunc(void * data,npy_intp n,npy_intp * max_ind,void * arr)621 int NPyCustomFloat_ArgMaxFunc(void* data, npy_intp n, npy_intp* max_ind,
622 void* arr) {
623 const T* bdata = reinterpret_cast<const T*>(data);
624 // Start with a max_val of NaN, this results in the first iteration preferring
625 // bdata[0].
626 float max_val = std::numeric_limits<float>::quiet_NaN();
627 for (npy_intp i = 0; i < n; ++i) {
628 // This condition is chosen so that NaNs are always considered "max".
629 if (!(static_cast<float>(bdata[i]) <= max_val)) {
630 max_val = static_cast<float>(bdata[i]);
631 *max_ind = i;
632 // NumPy stops at the first NaN.
633 if (Eigen::numext::isnan(max_val)) {
634 break;
635 }
636 }
637 }
638 return 0;
639 }
640
641 template <typename T>
NPyCustomFloat_ArgMinFunc(void * data,npy_intp n,npy_intp * min_ind,void * arr)642 int NPyCustomFloat_ArgMinFunc(void* data, npy_intp n, npy_intp* min_ind,
643 void* arr) {
644 const T* bdata = reinterpret_cast<const T*>(data);
645 float min_val = std::numeric_limits<float>::quiet_NaN();
646 // Start with a min_val of NaN, this results in the first iteration preferring
647 // bdata[0].
648 for (npy_intp i = 0; i < n; ++i) {
649 // This condition is chosen so that NaNs are always considered "min".
650 if (!(static_cast<float>(bdata[i]) >= min_val)) {
651 min_val = static_cast<float>(bdata[i]);
652 *min_ind = i;
653 // NumPy stops at the first NaN.
654 if (Eigen::numext::isnan(min_val)) {
655 break;
656 }
657 }
658 }
659 return 0;
660 }
661
662 template <>
663 struct TypeDescriptor<unsigned char> {
664 typedef unsigned char T;
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor665 static int Dtype() { return NPY_UBYTE; }
666 };
667
668 template <>
669 struct TypeDescriptor<unsigned short> { // NOLINT
670 typedef unsigned short T; // NOLINT
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor671 static int Dtype() { return NPY_USHORT; }
672 };
673
674 // We register "int", "long", and "long long" types for portability across
675 // Linux, where "int" and "long" are the same type, and Windows, where "long"
676 // and "longlong" are the same type.
677 template <>
678 struct TypeDescriptor<unsigned int> {
679 typedef unsigned int T;
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor680 static int Dtype() { return NPY_UINT; }
681 };
682
683 template <>
684 struct TypeDescriptor<unsigned long> { // NOLINT
685 typedef unsigned long T; // NOLINT
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor686 static int Dtype() { return NPY_ULONG; }
687 };
688
689 template <>
690 struct TypeDescriptor<unsigned long long> { // NOLINT
691 typedef unsigned long long T; // NOLINT
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor692 static int Dtype() { return NPY_ULONGLONG; }
693 };
694
695 template <>
696 struct TypeDescriptor<signed char> {
697 typedef signed char T;
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor698 static int Dtype() { return NPY_BYTE; }
699 };
700
701 template <>
702 struct TypeDescriptor<short> { // NOLINT
703 typedef short T; // NOLINT
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor704 static int Dtype() { return NPY_SHORT; }
705 };
706
707 template <>
708 struct TypeDescriptor<int> {
709 typedef int T;
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor710 static int Dtype() { return NPY_INT; }
711 };
712
713 template <>
714 struct TypeDescriptor<long> { // NOLINT
715 typedef long T; // NOLINT
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor716 static int Dtype() { return NPY_LONG; }
717 };
718
719 template <>
720 struct TypeDescriptor<long long> { // NOLINT
721 typedef long long T; // NOLINT
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor722 static int Dtype() { return NPY_LONGLONG; }
723 };
724
725 template <>
726 struct TypeDescriptor<bool> {
727 typedef unsigned char T;
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor728 static int Dtype() { return NPY_BOOL; }
729 };
730
731 template <>
732 struct TypeDescriptor<Eigen::half> {
733 typedef Eigen::half T;
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor734 static int Dtype() { return NPY_HALF; }
735 };
736
737 template <>
738 struct TypeDescriptor<float> {
739 typedef float T;
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor740 static int Dtype() { return NPY_FLOAT; }
741 };
742
743 template <>
744 struct TypeDescriptor<double> {
745 typedef double T;
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor746 static int Dtype() { return NPY_DOUBLE; }
747 };
748
749 template <>
750 struct TypeDescriptor<long double> {
751 typedef long double T;
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor752 static int Dtype() { return NPY_LONGDOUBLE; }
753 };
754
755 template <>
756 struct TypeDescriptor<std::complex<float>> {
757 typedef std::complex<float> T;
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor758 static int Dtype() { return NPY_CFLOAT; }
759 };
760
761 template <>
762 struct TypeDescriptor<std::complex<double>> {
763 typedef std::complex<double> T;
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor764 static int Dtype() { return NPY_CDOUBLE; }
765 };
766
767 template <>
768 struct TypeDescriptor<std::complex<long double>> {
769 typedef std::complex<long double> T;
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor770 static int Dtype() { return NPY_CLONGDOUBLE; }
771 };
772
773 template <typename T>
CastToFloat(T value)774 float CastToFloat(T value) {
775 return static_cast<float>(value);
776 }
777
778 template <typename T>
CastToFloat(std::complex<T> value)779 float CastToFloat(std::complex<T> value) {
780 return CastToFloat(value.real());
781 }
782
783 // Performs a NumPy array cast from type 'From' to 'To'.
784 template <typename From, typename To>
NPyCast(void * from_void,void * to_void,npy_intp n,void * fromarr,void * toarr)785 void NPyCast(void* from_void, void* to_void, npy_intp n, void* fromarr,
786 void* toarr) {
787 const auto* from =
788 reinterpret_cast<typename TypeDescriptor<From>::T*>(from_void);
789 auto* to = reinterpret_cast<typename TypeDescriptor<To>::T*>(to_void);
790 for (npy_intp i = 0; i < n; ++i) {
791 to[i] = static_cast<typename TypeDescriptor<To>::T>(
792 static_cast<To>(CastToFloat(from[i])));
793 }
794 }
795
796 // Registers a cast between T (a reduced float) and type 'OtherT'. 'numpy_type'
797 // is the NumPy type corresponding to 'OtherT'.
798 template <typename T, typename OtherT>
RegisterCustomFloatCast(int numpy_type=TypeDescriptor<OtherT>::Dtype ())799 bool RegisterCustomFloatCast(int numpy_type = TypeDescriptor<OtherT>::Dtype()) {
800 PyArray_Descr* descr = PyArray_DescrFromType(numpy_type);
801 if (PyArray_RegisterCastFunc(descr, TypeDescriptor<T>::Dtype(),
802 NPyCast<OtherT, T>) < 0) {
803 return false;
804 }
805 if (PyArray_RegisterCastFunc(&CustomFloatTypeDescriptor<T>::npy_descr,
806 numpy_type, NPyCast<T, OtherT>) < 0) {
807 return false;
808 }
809 return true;
810 }
811
812 template <typename T>
RegisterCasts()813 bool RegisterCasts() {
814 if (!RegisterCustomFloatCast<T, Eigen::half>(NPY_HALF)) {
815 return false;
816 }
817
818 if (!RegisterCustomFloatCast<T, float>(NPY_FLOAT)) {
819 return false;
820 }
821 if (!RegisterCustomFloatCast<T, double>(NPY_DOUBLE)) {
822 return false;
823 }
824 if (!RegisterCustomFloatCast<T, long double>(NPY_LONGDOUBLE)) {
825 return false;
826 }
827 if (!RegisterCustomFloatCast<T, bool>(NPY_BOOL)) {
828 return false;
829 }
830 if (!RegisterCustomFloatCast<T, unsigned char>(NPY_UBYTE)) {
831 return false;
832 }
833 if (!RegisterCustomFloatCast<T, unsigned short>(NPY_USHORT)) { // NOLINT
834 return false;
835 }
836 if (!RegisterCustomFloatCast<T, unsigned int>(NPY_UINT)) {
837 return false;
838 }
839 if (!RegisterCustomFloatCast<T, unsigned long>(NPY_ULONG)) { // NOLINT
840 return false;
841 }
842 if (!RegisterCustomFloatCast<T, unsigned long long>( // NOLINT
843 NPY_ULONGLONG)) {
844 return false;
845 }
846 if (!RegisterCustomFloatCast<T, signed char>(NPY_BYTE)) {
847 return false;
848 }
849 if (!RegisterCustomFloatCast<T, short>(NPY_SHORT)) { // NOLINT
850 return false;
851 }
852 if (!RegisterCustomFloatCast<T, int>(NPY_INT)) {
853 return false;
854 }
855 if (!RegisterCustomFloatCast<T, long>(NPY_LONG)) { // NOLINT
856 return false;
857 }
858 if (!RegisterCustomFloatCast<T, long long>(NPY_LONGLONG)) { // NOLINT
859 return false;
860 }
861 // Following the numpy convention. imag part is dropped when converting to
862 // float.
863 if (!RegisterCustomFloatCast<T, std::complex<float>>(NPY_CFLOAT)) {
864 return false;
865 }
866 if (!RegisterCustomFloatCast<T, std::complex<double>>(NPY_CDOUBLE)) {
867 return false;
868 }
869 if (!RegisterCustomFloatCast<T, std::complex<long double>>(NPY_CLONGDOUBLE)) {
870 return false;
871 }
872
873 // Safe casts from T to other types
874 if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_FLOAT,
875 NPY_NOSCALAR) < 0) {
876 return false;
877 }
878 if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_DOUBLE,
879 NPY_NOSCALAR) < 0) {
880 return false;
881 }
882 if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_LONGDOUBLE,
883 NPY_NOSCALAR) < 0) {
884 return false;
885 }
886 if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_CFLOAT,
887 NPY_NOSCALAR) < 0) {
888 return false;
889 }
890 if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_CDOUBLE,
891 NPY_NOSCALAR) < 0) {
892 return false;
893 }
894 if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_CLONGDOUBLE,
895 NPY_NOSCALAR) < 0) {
896 return false;
897 }
898
899 // Safe casts to T from other types
900 if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_BOOL),
901 TypeDescriptor<T>::Dtype(), NPY_NOSCALAR) < 0) {
902 return false;
903 }
904 if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_UBYTE),
905 TypeDescriptor<T>::Dtype(), NPY_NOSCALAR) < 0) {
906 return false;
907 }
908 if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_BYTE),
909 TypeDescriptor<T>::Dtype(), NPY_NOSCALAR) < 0) {
910 return false;
911 }
912
913 return true;
914 }
915
916 template <typename InType, typename OutType, typename Functor>
917 struct UnaryUFunc {
Typestensorflow::__anon7a1ed7ad0111::UnaryUFunc918 static std::vector<int> Types() {
919 return {TypeDescriptor<InType>::Dtype(), TypeDescriptor<OutType>::Dtype()};
920 }
Calltensorflow::__anon7a1ed7ad0111::UnaryUFunc921 static void Call(char** args, const npy_intp* dimensions,
922 const npy_intp* steps, void* data) {
923 const char* i0 = args[0];
924 char* o = args[1];
925 for (npy_intp k = 0; k < *dimensions; k++) {
926 auto x = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i0);
927 *reinterpret_cast<typename TypeDescriptor<OutType>::T*>(o) = Functor()(x);
928 i0 += steps[0];
929 o += steps[1];
930 }
931 }
932 };
933
934 template <typename InType, typename OutType, typename OutType2,
935 typename Functor>
936 struct UnaryUFunc2 {
Typestensorflow::__anon7a1ed7ad0111::UnaryUFunc2937 static std::vector<int> Types() {
938 return {TypeDescriptor<InType>::Dtype(), TypeDescriptor<OutType>::Dtype(),
939 TypeDescriptor<OutType2>::Dtype()};
940 }
Calltensorflow::__anon7a1ed7ad0111::UnaryUFunc2941 static void Call(char** args, const npy_intp* dimensions,
942 const npy_intp* steps, void* data) {
943 const char* i0 = args[0];
944 char* o0 = args[1];
945 char* o1 = args[2];
946 for (npy_intp k = 0; k < *dimensions; k++) {
947 auto x = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i0);
948 std::tie(*reinterpret_cast<typename TypeDescriptor<OutType>::T*>(o0),
949 *reinterpret_cast<typename TypeDescriptor<OutType2>::T*>(o1)) =
950 Functor()(x);
951 i0 += steps[0];
952 o0 += steps[1];
953 o1 += steps[2];
954 }
955 }
956 };
957
958 template <typename InType, typename OutType, typename Functor>
959 struct BinaryUFunc {
Typestensorflow::__anon7a1ed7ad0111::BinaryUFunc960 static std::vector<int> Types() {
961 return {TypeDescriptor<InType>::Dtype(), TypeDescriptor<InType>::Dtype(),
962 TypeDescriptor<OutType>::Dtype()};
963 }
Calltensorflow::__anon7a1ed7ad0111::BinaryUFunc964 static void Call(char** args, const npy_intp* dimensions,
965 const npy_intp* steps, void* data) {
966 const char* i0 = args[0];
967 const char* i1 = args[1];
968 char* o = args[2];
969 for (npy_intp k = 0; k < *dimensions; k++) {
970 auto x = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i0);
971 auto y = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i1);
972 *reinterpret_cast<typename TypeDescriptor<OutType>::T*>(o) =
973 Functor()(x, y);
974 i0 += steps[0];
975 i1 += steps[1];
976 o += steps[2];
977 }
978 }
979 };
980
981 template <typename InType, typename InType2, typename OutType, typename Functor>
982 struct BinaryUFunc2 {
Typestensorflow::__anon7a1ed7ad0111::BinaryUFunc2983 static std::vector<int> Types() {
984 return {TypeDescriptor<InType>::Dtype(), TypeDescriptor<InType2>::Dtype(),
985 TypeDescriptor<OutType>::Dtype()};
986 }
Calltensorflow::__anon7a1ed7ad0111::BinaryUFunc2987 static void Call(char** args, const npy_intp* dimensions,
988 const npy_intp* steps, void* data) {
989 const char* i0 = args[0];
990 const char* i1 = args[1];
991 char* o = args[2];
992 for (npy_intp k = 0; k < *dimensions; k++) {
993 auto x = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i0);
994 auto y =
995 *reinterpret_cast<const typename TypeDescriptor<InType2>::T*>(i1);
996 *reinterpret_cast<typename TypeDescriptor<OutType>::T*>(o) =
997 Functor()(x, y);
998 i0 += steps[0];
999 i1 += steps[1];
1000 o += steps[2];
1001 }
1002 }
1003 };
1004
1005 template <typename UFunc, typename CustomFloatT>
RegisterUFunc(PyObject * numpy,const char * name)1006 bool RegisterUFunc(PyObject* numpy, const char* name) {
1007 std::vector<int> types = UFunc::Types();
1008 PyUFuncGenericFunction fn =
1009 reinterpret_cast<PyUFuncGenericFunction>(UFunc::Call);
1010 Safe_PyObjectPtr ufunc_obj = make_safe(PyObject_GetAttrString(numpy, name));
1011 if (!ufunc_obj) {
1012 return false;
1013 }
1014 PyUFuncObject* ufunc = reinterpret_cast<PyUFuncObject*>(ufunc_obj.get());
1015 if (static_cast<int>(types.size()) != ufunc->nargs) {
1016 PyErr_Format(PyExc_AssertionError,
1017 "ufunc %s takes %d arguments, loop takes %lu", name,
1018 ufunc->nargs, types.size());
1019 return false;
1020 }
1021 if (PyUFunc_RegisterLoopForType(ufunc, TypeDescriptor<CustomFloatT>::Dtype(),
1022 fn, const_cast<int*>(types.data()),
1023 nullptr) < 0) {
1024 return false;
1025 }
1026 return true;
1027 }
1028
1029 namespace ufuncs {
1030
1031 template <typename T>
1032 struct Add {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Add1033 T operator()(T a, T b) { return a + b; }
1034 };
1035 template <typename T>
1036 struct Subtract {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Subtract1037 T operator()(T a, T b) { return a - b; }
1038 };
1039 template <typename T>
1040 struct Multiply {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Multiply1041 T operator()(T a, T b) { return a * b; }
1042 };
1043 template <typename T>
1044 struct TrueDivide {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::TrueDivide1045 T operator()(T a, T b) { return a / b; }
1046 };
1047
divmod(float a,float b)1048 inline std::pair<float, float> divmod(float a, float b) {
1049 if (b == 0.0f) {
1050 float nan = std::numeric_limits<float>::quiet_NaN();
1051 return {nan, nan};
1052 }
1053 float mod = std::fmod(a, b);
1054 float div = (a - mod) / b;
1055 if (mod != 0.0f) {
1056 if ((b < 0.0f) != (mod < 0.0f)) {
1057 mod += b;
1058 div -= 1.0f;
1059 }
1060 } else {
1061 mod = std::copysign(0.0f, b);
1062 }
1063
1064 float floordiv;
1065 if (div != 0.0f) {
1066 floordiv = std::floor(div);
1067 if (div - floordiv > 0.5f) {
1068 floordiv += 1.0f;
1069 }
1070 } else {
1071 floordiv = std::copysign(0.0f, a / b);
1072 }
1073 return {floordiv, mod};
1074 }
1075
1076 template <typename T>
1077 struct FloorDivide {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::FloorDivide1078 T operator()(T a, T b) {
1079 return T(divmod(static_cast<float>(a), static_cast<float>(b)).first);
1080 }
1081 };
1082 template <typename T>
1083 struct Remainder {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Remainder1084 T operator()(T a, T b) {
1085 return T(divmod(static_cast<float>(a), static_cast<float>(b)).second);
1086 }
1087 };
1088 template <typename T>
1089 struct DivmodUFunc {
Typestensorflow::__anon7a1ed7ad0111::ufuncs::DivmodUFunc1090 static std::vector<int> Types() {
1091 return {TypeDescriptor<T>::Dtype(), TypeDescriptor<T>::Dtype(),
1092 TypeDescriptor<T>::Dtype(), TypeDescriptor<T>::Dtype()};
1093 }
Calltensorflow::__anon7a1ed7ad0111::ufuncs::DivmodUFunc1094 static void Call(char** args, npy_intp* dimensions, npy_intp* steps,
1095 void* data) {
1096 const char* i0 = args[0];
1097 const char* i1 = args[1];
1098 char* o0 = args[2];
1099 char* o1 = args[3];
1100 for (npy_intp k = 0; k < *dimensions; k++) {
1101 T x = *reinterpret_cast<const T*>(i0);
1102 T y = *reinterpret_cast<const T*>(i1);
1103 float floordiv, mod;
1104 std::tie(floordiv, mod) =
1105 divmod(static_cast<float>(x), static_cast<float>(y));
1106 *reinterpret_cast<T*>(o0) = T(floordiv);
1107 *reinterpret_cast<T*>(o1) = T(mod);
1108 i0 += steps[0];
1109 i1 += steps[1];
1110 o0 += steps[2];
1111 o1 += steps[3];
1112 }
1113 }
1114 };
1115 template <typename T>
1116 struct Fmod {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Fmod1117 T operator()(T a, T b) {
1118 return T(std::fmod(static_cast<float>(a), static_cast<float>(b)));
1119 }
1120 };
1121 template <typename T>
1122 struct Negative {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Negative1123 T operator()(T a) { return -a; }
1124 };
1125 template <typename T>
1126 struct Positive {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Positive1127 T operator()(T a) { return a; }
1128 };
1129 template <typename T>
1130 struct Power {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Power1131 T operator()(T a, T b) {
1132 return T(std::pow(static_cast<float>(a), static_cast<float>(b)));
1133 }
1134 };
1135 template <typename T>
1136 struct Abs {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Abs1137 T operator()(T a) { return T(std::abs(static_cast<float>(a))); }
1138 };
1139 template <typename T>
1140 struct Cbrt {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Cbrt1141 T operator()(T a) { return T(std::cbrt(static_cast<float>(a))); }
1142 };
1143 template <typename T>
1144 struct Ceil {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Ceil1145 T operator()(T a) { return T(std::ceil(static_cast<float>(a))); }
1146 };
1147 template <typename T>
1148 struct CopySign;
1149
1150 template <typename T>
1151 struct Exp {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Exp1152 T operator()(T a) { return T(std::exp(static_cast<float>(a))); }
1153 };
1154 template <typename T>
1155 struct Exp2 {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Exp21156 T operator()(T a) { return T(std::exp2(static_cast<float>(a))); }
1157 };
1158 template <typename T>
1159 struct Expm1 {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Expm11160 T operator()(T a) { return T(std::expm1(static_cast<float>(a))); }
1161 };
1162 template <typename T>
1163 struct Floor {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Floor1164 T operator()(T a) { return T(std::floor(static_cast<float>(a))); }
1165 };
1166 template <typename T>
1167 struct Frexp {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Frexp1168 std::pair<T, int> operator()(T a) {
1169 int exp;
1170 float f = std::frexp(static_cast<float>(a), &exp);
1171 return {T(f), exp};
1172 }
1173 };
1174 template <typename T>
1175 struct Heaviside {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Heaviside1176 T operator()(T bx, T h0) {
1177 float x = static_cast<float>(bx);
1178 if (Eigen::numext::isnan(x)) {
1179 return bx;
1180 }
1181 if (x < 0) {
1182 return T(0.0f);
1183 }
1184 if (x > 0) {
1185 return T(1.0f);
1186 }
1187 return h0; // x == 0
1188 }
1189 };
1190 template <typename T>
1191 struct Conjugate {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Conjugate1192 T operator()(T a) { return a; }
1193 };
1194 template <typename T>
1195 struct IsFinite {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::IsFinite1196 bool operator()(T a) { return std::isfinite(static_cast<float>(a)); }
1197 };
1198 template <typename T>
1199 struct IsInf {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::IsInf1200 bool operator()(T a) { return std::isinf(static_cast<float>(a)); }
1201 };
1202 template <typename T>
1203 struct IsNan {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::IsNan1204 bool operator()(T a) { return Eigen::numext::isnan(static_cast<float>(a)); }
1205 };
1206 template <typename T>
1207 struct Ldexp {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Ldexp1208 T operator()(T a, int exp) {
1209 return T(std::ldexp(static_cast<float>(a), exp));
1210 }
1211 };
1212 template <typename T>
1213 struct Log {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Log1214 T operator()(T a) { return T(std::log(static_cast<float>(a))); }
1215 };
1216 template <typename T>
1217 struct Log2 {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Log21218 T operator()(T a) { return T(std::log2(static_cast<float>(a))); }
1219 };
1220 template <typename T>
1221 struct Log10 {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Log101222 T operator()(T a) { return T(std::log10(static_cast<float>(a))); }
1223 };
1224 template <typename T>
1225 struct Log1p {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Log1p1226 T operator()(T a) { return T(std::log1p(static_cast<float>(a))); }
1227 };
1228 template <typename T>
1229 struct LogAddExp {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::LogAddExp1230 T operator()(T bx, T by) {
1231 float x = static_cast<float>(bx);
1232 float y = static_cast<float>(by);
1233 if (x == y) {
1234 // Handles infinities of the same sign.
1235 return T(x + std::log(2.0f));
1236 }
1237 float out = std::numeric_limits<float>::quiet_NaN();
1238 if (x > y) {
1239 out = x + std::log1p(std::exp(y - x));
1240 } else if (x < y) {
1241 out = y + std::log1p(std::exp(x - y));
1242 }
1243 return T(out);
1244 }
1245 };
1246 template <typename T>
1247 struct LogAddExp2 {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::LogAddExp21248 T operator()(T bx, T by) {
1249 float x = static_cast<float>(bx);
1250 float y = static_cast<float>(by);
1251 if (x == y) {
1252 // Handles infinities of the same sign.
1253 return T(x + 1.0f);
1254 }
1255 float out = std::numeric_limits<float>::quiet_NaN();
1256 if (x > y) {
1257 out = x + std::log1p(std::exp2(y - x)) / std::log(2.0f);
1258 } else if (x < y) {
1259 out = y + std::log1p(std::exp2(x - y)) / std::log(2.0f);
1260 }
1261 return T(out);
1262 }
1263 };
1264 template <typename T>
1265 struct Modf {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Modf1266 std::pair<T, T> operator()(T a) {
1267 float integral;
1268 float f = std::modf(static_cast<float>(a), &integral);
1269 return {T(f), T(integral)};
1270 }
1271 };
1272
1273 template <typename T>
1274 struct Reciprocal {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Reciprocal1275 T operator()(T a) { return T(1.f / static_cast<float>(a)); }
1276 };
1277 template <typename T>
1278 struct Rint {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Rint1279 T operator()(T a) { return T(std::rint(static_cast<float>(a))); }
1280 };
1281 template <typename T>
1282 struct Sign {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Sign1283 T operator()(T a) {
1284 float f(a);
1285 if (f < 0) {
1286 return T(-1);
1287 }
1288 if (f > 0) {
1289 return T(1);
1290 }
1291 return a;
1292 }
1293 };
1294 template <typename T>
1295 struct SignBit {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::SignBit1296 bool operator()(T a) { return std::signbit(static_cast<float>(a)); }
1297 };
1298 template <typename T>
1299 struct Sqrt {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Sqrt1300 T operator()(T a) { return T(std::sqrt(static_cast<float>(a))); }
1301 };
1302 template <typename T>
1303 struct Square {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Square1304 T operator()(T a) {
1305 float f(a);
1306 return T(f * f);
1307 }
1308 };
1309 template <typename T>
1310 struct Trunc {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Trunc1311 T operator()(T a) { return T(std::trunc(static_cast<float>(a))); }
1312 };
1313
1314 // Trigonometric functions
1315 template <typename T>
1316 struct Sin {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Sin1317 T operator()(T a) { return T(std::sin(static_cast<float>(a))); }
1318 };
1319 template <typename T>
1320 struct Cos {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Cos1321 T operator()(T a) { return T(std::cos(static_cast<float>(a))); }
1322 };
1323 template <typename T>
1324 struct Tan {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Tan1325 T operator()(T a) { return T(std::tan(static_cast<float>(a))); }
1326 };
1327 template <typename T>
1328 struct Arcsin {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Arcsin1329 T operator()(T a) { return T(std::asin(static_cast<float>(a))); }
1330 };
1331 template <typename T>
1332 struct Arccos {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Arccos1333 T operator()(T a) { return T(std::acos(static_cast<float>(a))); }
1334 };
1335 template <typename T>
1336 struct Arctan {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Arctan1337 T operator()(T a) { return T(std::atan(static_cast<float>(a))); }
1338 };
1339 template <typename T>
1340 struct Arctan2 {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Arctan21341 T operator()(T a, T b) {
1342 return T(std::atan2(static_cast<float>(a), static_cast<float>(b)));
1343 }
1344 };
1345 template <typename T>
1346 struct Hypot {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Hypot1347 T operator()(T a, T b) {
1348 return T(std::hypot(static_cast<float>(a), static_cast<float>(b)));
1349 }
1350 };
1351 template <typename T>
1352 struct Sinh {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Sinh1353 T operator()(T a) { return T(std::sinh(static_cast<float>(a))); }
1354 };
1355 template <typename T>
1356 struct Cosh {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Cosh1357 T operator()(T a) { return T(std::cosh(static_cast<float>(a))); }
1358 };
1359 template <typename T>
1360 struct Tanh {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Tanh1361 T operator()(T a) { return T(std::tanh(static_cast<float>(a))); }
1362 };
1363 template <typename T>
1364 struct Arcsinh {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Arcsinh1365 T operator()(T a) { return T(std::asinh(static_cast<float>(a))); }
1366 };
1367 template <typename T>
1368 struct Arccosh {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Arccosh1369 T operator()(T a) { return T(std::acosh(static_cast<float>(a))); }
1370 };
1371 template <typename T>
1372 struct Arctanh {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Arctanh1373 T operator()(T a) { return T(std::atanh(static_cast<float>(a))); }
1374 };
1375 template <typename T>
1376 struct Deg2rad {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Deg2rad1377 T operator()(T a) {
1378 static constexpr float radians_per_degree = M_PI / 180.0f;
1379 return T(static_cast<float>(a) * radians_per_degree);
1380 }
1381 };
1382 template <typename T>
1383 struct Rad2deg {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Rad2deg1384 T operator()(T a) {
1385 static constexpr float degrees_per_radian = 180.0f / M_PI;
1386 return T(static_cast<float>(a) * degrees_per_radian);
1387 }
1388 };
1389
1390 template <typename T>
1391 struct Eq {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Eq1392 npy_bool operator()(T a, T b) { return a == b; }
1393 };
1394 template <typename T>
1395 struct Ne {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Ne1396 npy_bool operator()(T a, T b) { return a != b; }
1397 };
1398 template <typename T>
1399 struct Lt {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Lt1400 npy_bool operator()(T a, T b) { return a < b; }
1401 };
1402 template <typename T>
1403 struct Gt {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Gt1404 npy_bool operator()(T a, T b) { return a > b; }
1405 };
1406 template <typename T>
1407 struct Le {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Le1408 npy_bool operator()(T a, T b) { return a <= b; }
1409 };
1410 template <typename T>
1411 struct Ge {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Ge1412 npy_bool operator()(T a, T b) { return a >= b; }
1413 };
1414 template <typename T>
1415 struct Maximum {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Maximum1416 T operator()(T a, T b) {
1417 float fa(a), fb(b);
1418 return Eigen::numext::isnan(fa) || fa > fb ? a : b;
1419 }
1420 };
1421 template <typename T>
1422 struct Minimum {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Minimum1423 T operator()(T a, T b) {
1424 float fa(a), fb(b);
1425 return Eigen::numext::isnan(fa) || fa < fb ? a : b;
1426 }
1427 };
1428 template <typename T>
1429 struct Fmax {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Fmax1430 T operator()(T a, T b) {
1431 float fa(a), fb(b);
1432 return Eigen::numext::isnan(fb) || fa > fb ? a : b;
1433 }
1434 };
1435 template <typename T>
1436 struct Fmin {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Fmin1437 T operator()(T a, T b) {
1438 float fa(a), fb(b);
1439 return Eigen::numext::isnan(fb) || fa < fb ? a : b;
1440 }
1441 };
1442
1443 template <typename T>
1444 struct LogicalNot {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::LogicalNot1445 npy_bool operator()(T a) { return !a; }
1446 };
1447 template <typename T>
1448 struct LogicalAnd {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::LogicalAnd1449 npy_bool operator()(T a, T b) { return a && b; }
1450 };
1451 template <typename T>
1452 struct LogicalOr {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::LogicalOr1453 npy_bool operator()(T a, T b) { return a || b; }
1454 };
1455 template <typename T>
1456 struct LogicalXor {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::LogicalXor1457 npy_bool operator()(T a, T b) {
1458 return static_cast<bool>(a) ^ static_cast<bool>(b);
1459 }
1460 };
1461
1462 template <typename T>
1463 struct NextAfter;
1464
1465 template <typename T>
1466 struct Spacing {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Spacing1467 T operator()(T x) {
1468 // Compute the distance between the input and the next number with greater
1469 // magnitude. The result should have the sign of the input.
1470 T away(std::copysign(std::numeric_limits<float>::infinity(),
1471 static_cast<float>(x)));
1472 return NextAfter<T>()(x, away) - x;
1473 }
1474 };
1475
1476 template <typename T>
RegisterUFuncs(PyObject * numpy)1477 bool RegisterUFuncs(PyObject* numpy) {
1478 bool ok =
1479 RegisterUFunc<BinaryUFunc<T, T, ufuncs::Add<T>>, T>(numpy, "add") &&
1480 RegisterUFunc<BinaryUFunc<T, T, ufuncs::Subtract<T>>, T>(numpy,
1481 "subtract") &&
1482 RegisterUFunc<BinaryUFunc<T, T, ufuncs::Multiply<T>>, T>(numpy,
1483 "multiply") &&
1484 RegisterUFunc<BinaryUFunc<T, T, ufuncs::TrueDivide<T>>, T>(numpy,
1485 "divide") &&
1486 RegisterUFunc<BinaryUFunc<T, T, ufuncs::LogAddExp<T>>, T>(numpy,
1487 "logaddexp") &&
1488 RegisterUFunc<BinaryUFunc<T, T, ufuncs::LogAddExp2<T>>, T>(
1489 numpy, "logaddexp2") &&
1490 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Negative<T>>, T>(numpy,
1491 "negative") &&
1492 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Positive<T>>, T>(numpy,
1493 "positive") &&
1494 RegisterUFunc<BinaryUFunc<T, T, ufuncs::TrueDivide<T>>, T>(
1495 numpy, "true_divide") &&
1496 RegisterUFunc<BinaryUFunc<T, T, ufuncs::FloorDivide<T>>, T>(
1497 numpy, "floor_divide") &&
1498 RegisterUFunc<BinaryUFunc<T, T, ufuncs::Power<T>>, T>(numpy, "power") &&
1499 RegisterUFunc<BinaryUFunc<T, T, ufuncs::Remainder<T>>, T>(numpy,
1500 "remainder") &&
1501 RegisterUFunc<BinaryUFunc<T, T, ufuncs::Remainder<T>>, T>(numpy, "mod") &&
1502 RegisterUFunc<BinaryUFunc<T, T, ufuncs::Fmod<T>>, T>(numpy, "fmod") &&
1503 RegisterUFunc<ufuncs::DivmodUFunc<T>, T>(numpy, "divmod") &&
1504 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Abs<T>>, T>(numpy, "absolute") &&
1505 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Abs<T>>, T>(numpy, "fabs") &&
1506 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Rint<T>>, T>(numpy, "rint") &&
1507 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Sign<T>>, T>(numpy, "sign") &&
1508 RegisterUFunc<BinaryUFunc<T, T, ufuncs::Heaviside<T>>, T>(numpy,
1509 "heaviside") &&
1510 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Conjugate<T>>, T>(numpy,
1511 "conjugate") &&
1512 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Exp<T>>, T>(numpy, "exp") &&
1513 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Exp2<T>>, T>(numpy, "exp2") &&
1514 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Expm1<T>>, T>(numpy, "expm1") &&
1515 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Log<T>>, T>(numpy, "log") &&
1516 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Log2<T>>, T>(numpy, "log2") &&
1517 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Log10<T>>, T>(numpy, "log10") &&
1518 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Log1p<T>>, T>(numpy, "log1p") &&
1519 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Sqrt<T>>, T>(numpy, "sqrt") &&
1520 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Square<T>>, T>(numpy, "square") &&
1521 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Cbrt<T>>, T>(numpy, "cbrt") &&
1522 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Reciprocal<T>>, T>(numpy,
1523 "reciprocal") &&
1524
1525 // Trigonometric functions
1526 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Sin<T>>, T>(numpy, "sin") &&
1527 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Cos<T>>, T>(numpy, "cos") &&
1528 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Tan<T>>, T>(numpy, "tan") &&
1529 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arcsin<T>>, T>(numpy, "arcsin") &&
1530 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arccos<T>>, T>(numpy, "arccos") &&
1531 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arctan<T>>, T>(numpy, "arctan") &&
1532 RegisterUFunc<BinaryUFunc<T, T, ufuncs::Arctan2<T>>, T>(numpy,
1533 "arctan2") &&
1534 RegisterUFunc<BinaryUFunc<T, T, ufuncs::Hypot<T>>, T>(numpy, "hypot") &&
1535 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Sinh<T>>, T>(numpy, "sinh") &&
1536 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Cosh<T>>, T>(numpy, "cosh") &&
1537 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Tanh<T>>, T>(numpy, "tanh") &&
1538 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arcsinh<T>>, T>(numpy,
1539 "arcsinh") &&
1540 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arccosh<T>>, T>(numpy,
1541 "arccosh") &&
1542 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arctanh<T>>, T>(numpy,
1543 "arctanh") &&
1544 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Deg2rad<T>>, T>(numpy,
1545 "deg2rad") &&
1546 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Rad2deg<T>>, T>(numpy,
1547 "rad2deg") &&
1548
1549 // Comparison functions
1550 RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Eq<T>>, T>(numpy, "equal") &&
1551 RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Ne<T>>, T>(numpy,
1552 "not_equal") &&
1553 RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Lt<T>>, T>(numpy, "less") &&
1554 RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Gt<T>>, T>(numpy, "greater") &&
1555 RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Le<T>>, T>(numpy,
1556 "less_equal") &&
1557 RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Ge<T>>, T>(numpy,
1558 "greater_equal") &&
1559 RegisterUFunc<BinaryUFunc<T, T, ufuncs::Maximum<T>>, T>(numpy,
1560 "maximum") &&
1561 RegisterUFunc<BinaryUFunc<T, T, ufuncs::Minimum<T>>, T>(numpy,
1562 "minimum") &&
1563 RegisterUFunc<BinaryUFunc<T, T, ufuncs::Fmax<T>>, T>(numpy, "fmax") &&
1564 RegisterUFunc<BinaryUFunc<T, T, ufuncs::Fmin<T>>, T>(numpy, "fmin") &&
1565 RegisterUFunc<BinaryUFunc<T, bool, ufuncs::LogicalAnd<T>>, T>(
1566 numpy, "logical_and") &&
1567 RegisterUFunc<BinaryUFunc<T, bool, ufuncs::LogicalOr<T>>, T>(
1568 numpy, "logical_or") &&
1569 RegisterUFunc<BinaryUFunc<T, bool, ufuncs::LogicalXor<T>>, T>(
1570 numpy, "logical_xor") &&
1571 RegisterUFunc<UnaryUFunc<T, bool, ufuncs::LogicalNot<T>>, T>(
1572 numpy, "logical_not") &&
1573
1574 // Floating point functions
1575 RegisterUFunc<UnaryUFunc<T, bool, ufuncs::IsFinite<T>>, T>(numpy,
1576 "isfinite") &&
1577 RegisterUFunc<UnaryUFunc<T, bool, ufuncs::IsInf<T>>, T>(numpy, "isinf") &&
1578 RegisterUFunc<UnaryUFunc<T, bool, ufuncs::IsNan<T>>, T>(numpy, "isnan") &&
1579 RegisterUFunc<UnaryUFunc<T, bool, ufuncs::SignBit<T>>, T>(numpy,
1580 "signbit") &&
1581 RegisterUFunc<BinaryUFunc<T, T, ufuncs::CopySign<T>>, T>(numpy,
1582 "copysign") &&
1583 RegisterUFunc<UnaryUFunc2<T, T, T, ufuncs::Modf<T>>, T>(numpy, "modf") &&
1584 RegisterUFunc<BinaryUFunc2<T, int, T, ufuncs::Ldexp<T>>, T>(numpy,
1585 "ldexp") &&
1586 RegisterUFunc<UnaryUFunc2<T, T, int, ufuncs::Frexp<T>>, T>(numpy,
1587 "frexp") &&
1588 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Floor<T>>, T>(numpy, "floor") &&
1589 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Ceil<T>>, T>(numpy, "ceil") &&
1590 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Trunc<T>>, T>(numpy, "trunc") &&
1591 RegisterUFunc<BinaryUFunc<T, T, ufuncs::NextAfter<T>>, T>(numpy,
1592 "nextafter") &&
1593 RegisterUFunc<UnaryUFunc<T, T, ufuncs::Spacing<T>>, T>(numpy, "spacing");
1594
1595 return ok;
1596 }
1597
1598 } // namespace ufuncs
1599
1600 template <typename T>
RegisterNumpyDtype(PyObject * numpy)1601 bool RegisterNumpyDtype(PyObject* numpy) {
1602 // If another module (presumably either TF or JAX) has registered a bfloat16
1603 // type, use it. We don't want two bfloat16 types if we can avoid it since it
1604 // leads to confusion if we have two different types with the same name. This
1605 // assumes that the other module has a sufficiently complete bfloat16
1606 // implementation. The only known NumPy bfloat16 extension at the time of
1607 // writing is this one (distributed in TF and JAX).
1608 // TODO(phawkins): distribute the bfloat16 extension as its own pip package,
1609 // so we can unambiguously refer to a single canonical definition of bfloat16.
1610 int typenum =
1611 PyArray_TypeNumFromName(const_cast<char*>(TypeDescriptor<T>::kTypeName));
1612 if (typenum != NPY_NOTYPE) {
1613 PyArray_Descr* descr = PyArray_DescrFromType(typenum);
1614 // The test for an argmax function here is to verify that the
1615 // bfloat16 implementation is sufficiently new, and, say, not from
1616 // an older version of TF or JAX.
1617 if (descr && descr->f && descr->f->argmax) {
1618 TypeDescriptor<T>::npy_type = typenum;
1619 TypeDescriptor<T>::type_ptr = descr->typeobj;
1620 return true;
1621 }
1622 }
1623
1624 TypeDescriptor<T>::type.tp_base = &PyGenericArrType_Type;
1625
1626 if (PyType_Ready(&TypeDescriptor<T>::type) < 0) {
1627 return false;
1628 }
1629
1630 // Initializes the NumPy descriptor.
1631 PyArray_ArrFuncs& arr_funcs = CustomFloatTypeDescriptor<T>::arr_funcs;
1632 PyArray_InitArrFuncs(&arr_funcs);
1633 arr_funcs.getitem = NPyCustomFloat_GetItem<T>;
1634 arr_funcs.setitem = NPyCustomFloat_SetItem<T>;
1635 arr_funcs.compare = NPyCustomFloat_Compare<T>;
1636 arr_funcs.copyswapn = NPyCustomFloat_CopySwapN<T>;
1637 arr_funcs.copyswap = NPyCustomFloat_CopySwap<T>;
1638 arr_funcs.nonzero = NPyCustomFloat_NonZero<T>;
1639 arr_funcs.fill = NPyCustomFloat_Fill<T>;
1640 arr_funcs.dotfunc = NPyCustomFloat_DotFunc<T>;
1641 arr_funcs.compare = NPyCustomFloat_CompareFunc<T>;
1642 arr_funcs.argmax = NPyCustomFloat_ArgMaxFunc<T>;
1643 arr_funcs.argmin = NPyCustomFloat_ArgMinFunc<T>;
1644
1645 #if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_TYPE)
1646 Py_TYPE(&CustomFloatTypeDescriptor<T>::npy_descr) = &PyArrayDescr_Type;
1647 #else
1648 Py_SET_TYPE(&CustomFloatTypeDescriptor<T>::npy_descr, &PyArrayDescr_Type);
1649 #endif
1650 TypeDescriptor<T>::npy_type =
1651 PyArray_RegisterDataType(&CustomFloatTypeDescriptor<T>::npy_descr);
1652 TypeDescriptor<T>::type_ptr = &TypeDescriptor<T>::type;
1653 if (TypeDescriptor<T>::Dtype() < 0) {
1654 return false;
1655 }
1656
1657 Safe_PyObjectPtr typeDict_obj =
1658 make_safe(PyObject_GetAttrString(numpy, "sctypeDict"));
1659 if (!typeDict_obj) return false;
1660 // Add the type object to `numpy.typeDict`: that makes
1661 // `numpy.dtype(type_name)` work.
1662 if (PyDict_SetItemString(
1663 typeDict_obj.get(), TypeDescriptor<T>::kTypeName,
1664 reinterpret_cast<PyObject*>(&TypeDescriptor<T>::type)) < 0) {
1665 return false;
1666 }
1667
1668 // Support dtype(type_name)
1669 if (PyDict_SetItemString(TypeDescriptor<T>::type.tp_dict, "dtype",
1670 reinterpret_cast<PyObject*>(
1671 &CustomFloatTypeDescriptor<T>::npy_descr)) < 0) {
1672 return false;
1673 }
1674
1675 return RegisterCasts<T>() && ufuncs::RegisterUFuncs<T>(numpy);
1676 }
1677
1678 namespace ufuncs {
1679
1680 template <>
1681 struct CopySign<bfloat16> {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::CopySign1682 bfloat16 operator()(bfloat16 a, bfloat16 b) {
1683 // LLVM is smart enough to turn this into (a & 0x7fff) | (b & 0x8000).
1684 bfloat16 abs_a = Eigen::numext::abs(a);
1685 return std::signbit(static_cast<float>(b)) ? -abs_a : abs_a;
1686 }
1687 };
1688
1689 template <>
1690 struct NextAfter<bfloat16> {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::NextAfter1691 bfloat16 operator()(bfloat16 from, bfloat16 to) {
1692 uint16_t from_as_int, to_as_int;
1693 const uint16_t sign_mask = 1 << 15;
1694 float from_as_float(from), to_as_float(to);
1695 memcpy(&from_as_int, &from, sizeof(bfloat16));
1696 memcpy(&to_as_int, &to, sizeof(bfloat16));
1697 if (Eigen::numext::isnan(from_as_float) ||
1698 Eigen::numext::isnan(to_as_float)) {
1699 return bfloat16(std::numeric_limits<float>::quiet_NaN());
1700 }
1701 if (from_as_int == to_as_int) {
1702 return to;
1703 }
1704 if (from_as_float == 0) {
1705 if (to_as_float == 0) {
1706 return to;
1707 } else {
1708 // Smallest subnormal signed like `to`.
1709 uint16_t out_int = (to_as_int & sign_mask) | 1;
1710 bfloat16 out;
1711 memcpy(&out, &out_int, sizeof(bfloat16));
1712 return out;
1713 }
1714 }
1715 uint16_t from_sign = from_as_int & sign_mask;
1716 uint16_t to_sign = to_as_int & sign_mask;
1717 uint16_t from_abs = from_as_int & ~sign_mask;
1718 uint16_t to_abs = to_as_int & ~sign_mask;
1719 uint16_t magnitude_adjustment =
1720 (from_abs > to_abs || from_sign != to_sign) ? 0xFFFF : 0x0001;
1721 uint16_t out_int = from_as_int + magnitude_adjustment;
1722 bfloat16 out;
1723 memcpy(&out, &out_int, sizeof(bfloat16));
1724 return out;
1725 }
1726 };
1727
1728 } // namespace ufuncs
1729
1730 using bfloat16 = Eigen::bfloat16;
1731
1732 template <>
1733 struct TypeDescriptor<bfloat16> : CustomFloatTypeDescriptor<bfloat16> {
1734 typedef bfloat16 T;
1735 static constexpr const char* kTypeName = "bfloat16";
1736 static constexpr const char* kTpDoc = "bfloat16 floating-point values";
1737 // We must register bfloat16 with a kind other than "f", because numpy
1738 // considers two types with the same kind and size to be equal, but
1739 // float16 != bfloat16.
1740 // The downside of this is that NumPy scalar promotion does not work with
1741 // bfloat16 values.
1742 static constexpr char kNpyDescrKind = 'V';
1743 // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
1744 // character is unique.
1745 static constexpr char kNpyDescrType = 'E';
1746 static constexpr char kNpyDescrByteorder = '=';
1747 };
1748
1749 template <>
1750 struct TypeDescriptor<float8_e4m3b11>
1751 : CustomFloatTypeDescriptor<float8_e4m3b11> {
1752 typedef float8_e4m3b11 T;
1753 static constexpr const char* kTypeName = "float8_e4m3b11";
1754 static constexpr const char* kTpDoc = "float8_e4m3b11 floating-point values";
1755 // We must register float8_e4m3b11 with a kind other than "f", because numpy
1756 // considers two types with the same kind and size to be equal, and we
1757 // expect multiple 1 byte floating point types.
1758 // The downside of this is that NumPy scalar promotion does not work with
1759 // float8_e4m3b11 values.
1760 static constexpr char kNpyDescrKind = 'V';
1761 // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
1762 // character is unique.
1763 static constexpr char kNpyDescrType = 'L';
1764 static constexpr char kNpyDescrByteorder = '=';
1765 };
1766
1767 namespace ufuncs {
1768
1769 template <>
1770 struct CopySign<float8_e4m3b11> {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::CopySign1771 float8_e4m3b11 operator()(float8_e4m3b11 a, float8_e4m3b11 b) {
1772 return float8_e4m3b11::FromRep((a.rep() & 0x7f) | (b.rep() & 0x80));
1773 }
1774 };
1775
1776 template <>
1777 struct NextAfter<float8_e4m3b11> {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::NextAfter1778 float8_e4m3b11 operator()(float8_e4m3b11 from, float8_e4m3b11 to) {
1779 uint8_t from_rep = from.rep();
1780 uint8_t to_rep = to.rep();
1781 if (from_rep == 0x80 || to_rep == 0x80) {
1782 return float8_e4m3b11::FromRep(0x80);
1783 }
1784 if (from_rep == to_rep) {
1785 return to;
1786 }
1787 if (from_rep == 0) {
1788 return float8_e4m3b11::FromRep(0x01 | (to_rep & 0x80));
1789 }
1790 const uint16_t sign_mask = 0x80;
1791 uint8_t from_sign = from_rep & sign_mask;
1792 uint8_t to_sign = to_rep & sign_mask;
1793 uint8_t from_abs = from_rep & ~sign_mask;
1794 uint8_t to_abs = to_rep & ~sign_mask;
1795 uint8_t magnitude_adjustment =
1796 (from_abs > to_abs || from_sign != to_sign) ? 0xFF : 0x0001;
1797 uint8_t out_int = from_rep + magnitude_adjustment;
1798 if (out_int == 0x80) {
1799 out_int = 0x0;
1800 }
1801 return float8_e4m3b11::FromRep(out_int);
1802 }
1803 };
1804
1805 } // namespace ufuncs
1806
1807 } // namespace
1808
1809 // Initializes the module.
Initialize()1810 bool Initialize() {
1811 ImportNumpy();
1812 import_umath1(false);
1813
1814 Safe_PyObjectPtr numpy_str = make_safe(PyUnicode_FromString("numpy"));
1815 if (!numpy_str) {
1816 return false;
1817 }
1818 Safe_PyObjectPtr numpy = make_safe(PyImport_Import(numpy_str.get()));
1819 if (!numpy) {
1820 return false;
1821 }
1822
1823 if (!RegisterNumpyDtype<bfloat16>(numpy.get())) {
1824 return false;
1825 }
1826 if (!RegisterNumpyDtype<float8_e4m3b11>(numpy.get())) {
1827 return false;
1828 }
1829 // TODO(parkers): Enable CanCast to-from fp8 and bf16 and f16.
1830 return true;
1831 }
1832
RegisterNumpyBfloat16()1833 bool RegisterNumpyBfloat16() {
1834 if (TypeDescriptor<bfloat16>::Dtype() != NPY_NOTYPE) {
1835 // Already initialized.
1836 return true;
1837 }
1838 if (!Initialize()) {
1839 if (!PyErr_Occurred()) {
1840 PyErr_SetString(PyExc_RuntimeError, "cannot load bfloat16 module.");
1841 }
1842 PyErr_Print();
1843 return false;
1844 }
1845 return true;
1846 }
1847
Bfloat16Dtype()1848 PyObject* Bfloat16Dtype() {
1849 return reinterpret_cast<PyObject*>(TypeDescriptor<bfloat16>::type_ptr);
1850 }
1851
Bfloat16NumpyType()1852 int Bfloat16NumpyType() { return TypeDescriptor<bfloat16>::Dtype(); }
1853
Float8_E4M3B11Dtype()1854 PyObject* Float8_E4M3B11Dtype() {
1855 return reinterpret_cast<PyObject*>(TypeDescriptor<float8_e4m3b11>::type_ptr);
1856 }
1857
1858 } // namespace tensorflow
1859