xref: /aosp_15_r20/external/tensorflow/tensorflow/python/lib/core/bfloat16.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/python/lib/core/bfloat16.h"
17 
18 #include <array>
19 #include <cmath>
20 #include <limits>
21 #include <locale>
22 
23 #include "tensorflow/python/lib/core/float8_e4m3b11.h"
24 // Place `<locale>` before <Python.h> to avoid a build failure in macOS.
25 #include <Python.h>
26 
27 #include "absl/strings/str_cat.h"
28 #include "third_party/eigen3/Eigen/Core"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/python/lib/core/numpy.h"
31 
32 namespace tensorflow {
33 namespace {
34 
35 struct PyDecrefDeleter {
operator ()tensorflow::__anon7a1ed7ad0111::PyDecrefDeleter36   void operator()(PyObject* p) const { Py_DECREF(p); }
37 };
38 
39 // Safe container for an owned PyObject. On destruction, the reference count of
40 // the contained object will be decremented.
41 using Safe_PyObjectPtr = std::unique_ptr<PyObject, PyDecrefDeleter>;
make_safe(PyObject * object)42 Safe_PyObjectPtr make_safe(PyObject* object) {
43   return Safe_PyObjectPtr(object);
44 }
45 
PyLong_CheckNoOverflow(PyObject * object)46 bool PyLong_CheckNoOverflow(PyObject* object) {
47   if (!PyLong_Check(object)) {
48     return false;
49   }
50   int overflow = 0;
51   PyLong_AsLongAndOverflow(object, &overflow);
52   return (overflow == 0);
53 }
54 
55 template <typename T, typename Enable = void>
56 struct TypeDescriptor {
57   // typedef ... T;  // Representation type in memory for NumPy values of type
58   // static int Dtype() { return NPY_...; }  // Numpy type number for T.
59 };
60 
61 template <typename T>
62 struct CustomFloatTypeDescriptor {
Dtypetensorflow::__anon7a1ed7ad0111::CustomFloatTypeDescriptor63   static int Dtype() { return npy_type; }
64 
65   // Registered numpy type ID. Global variable populated by the registration
66   // code. Protected by the GIL.
67   static int npy_type;
68 
69   static PyTypeObject type;
70   // Pointer to the python type object we are using. This is either a pointer
71   // to type, if we choose to register it, or to the python type
72   // registered by another system into NumPy.
73   static PyTypeObject* type_ptr;
74 
75   static PyNumberMethods number_methods;
76 
77   static PyArray_ArrFuncs arr_funcs;
78 
79   static PyArray_Descr npy_descr;
80 };
81 template <typename T>
82 int CustomFloatTypeDescriptor<T>::npy_type = NPY_NOTYPE;
83 template <typename T>
84 PyTypeObject* CustomFloatTypeDescriptor<T>::type_ptr = nullptr;
85 
86 // Representation of a Python custom float object.
87 template <typename T>
88 struct PyCustomFloat {
89   PyObject_HEAD;  // Python object header
90   T value;
91 };
92 
93 // Returns true if 'object' is a PyCustomFloat.
94 template <typename T>
PyCustomFloat_Check(PyObject * object)95 bool PyCustomFloat_Check(PyObject* object) {
96   return PyObject_IsInstance(
97       object, reinterpret_cast<PyObject*>(&TypeDescriptor<T>::type));
98 }
99 
100 // Extracts the value of a PyCustomFloat object.
101 template <typename T>
PyCustomFloat_CustomFloat(PyObject * object)102 T PyCustomFloat_CustomFloat(PyObject* object) {
103   return reinterpret_cast<PyCustomFloat<T>*>(object)->value;
104 }
105 
106 // Constructs a PyCustomFloat object from PyCustomFloat<T>::T.
107 template <typename T>
PyCustomFloat_FromT(T x)108 Safe_PyObjectPtr PyCustomFloat_FromT(T x) {
109   Safe_PyObjectPtr ref =
110       make_safe(TypeDescriptor<T>::type.tp_alloc(&TypeDescriptor<T>::type, 0));
111   PyCustomFloat<T>* p = reinterpret_cast<PyCustomFloat<T>*>(ref.get());
112   if (p) {
113     p->value = x;
114   }
115   return ref;
116 }
117 
118 // Converts a Python object to a reduced float value. Returns true on success,
119 // returns false and reports a Python error on failure.
120 template <typename T>
CastToCustomFloat(PyObject * arg,T * output)121 bool CastToCustomFloat(PyObject* arg, T* output) {
122   if (PyCustomFloat_Check<T>(arg)) {
123     *output = PyCustomFloat_CustomFloat<T>(arg);
124     return true;
125   }
126   if (PyFloat_Check(arg)) {
127     double d = PyFloat_AsDouble(arg);
128     if (PyErr_Occurred()) {
129       return false;
130     }
131     // TODO(phawkins): check for overflow
132     *output = T(d);
133     return true;
134   }
135   if (PyLong_CheckNoOverflow(arg)) {
136     long l = PyLong_AsLong(arg);  // NOLINT
137     if (PyErr_Occurred()) {
138       return false;
139     }
140     // TODO(phawkins): check for overflow
141     *output = T(static_cast<float>(l));
142     return true;
143   }
144   if (PyArray_IsScalar(arg, Half)) {
145     Eigen::half f;
146     PyArray_ScalarAsCtype(arg, &f);
147     *output = T(f);
148     return true;
149   }
150   if (PyArray_IsScalar(arg, Float)) {
151     float f;
152     PyArray_ScalarAsCtype(arg, &f);
153     *output = T(f);
154     return true;
155   }
156   if (PyArray_IsScalar(arg, Double)) {
157     double f;
158     PyArray_ScalarAsCtype(arg, &f);
159     *output = T(f);
160     return true;
161   }
162   if (PyArray_IsScalar(arg, LongDouble)) {
163     long double f;
164     PyArray_ScalarAsCtype(arg, &f);
165     *output = T(f);
166     return true;
167   }
168   if (PyArray_IsZeroDim(arg)) {
169     Safe_PyObjectPtr ref;
170     PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(arg);
171     if (PyArray_TYPE(arr) != TypeDescriptor<T>::Dtype()) {
172       ref = make_safe(PyArray_Cast(arr, TypeDescriptor<T>::Dtype()));
173       if (PyErr_Occurred()) {
174         return false;
175       }
176       arg = ref.get();
177       arr = reinterpret_cast<PyArrayObject*>(arg);
178     }
179     *output = *reinterpret_cast<T*>(PyArray_DATA(arr));
180     return true;
181   }
182   return false;
183 }
184 
185 template <typename T>
SafeCastToCustomFloat(PyObject * arg,T * output)186 bool SafeCastToCustomFloat(PyObject* arg, T* output) {
187   if (PyCustomFloat_Check<T>(arg)) {
188     *output = PyCustomFloat_CustomFloat<T>(arg);
189     return true;
190   }
191   return false;
192 }
193 
194 // Converts a PyReduceFloat into a PyFloat.
195 template <typename T>
PyCustomFloat_Float(PyObject * self)196 PyObject* PyCustomFloat_Float(PyObject* self) {
197   T x = PyCustomFloat_CustomFloat<T>(self);
198   return PyFloat_FromDouble(static_cast<double>(static_cast<float>(x)));
199 }
200 
201 // Converts a PyReduceFloat into a PyInt.
202 template <typename T>
PyCustomFloat_Int(PyObject * self)203 PyObject* PyCustomFloat_Int(PyObject* self) {
204   T x = PyCustomFloat_CustomFloat<T>(self);
205   long y = static_cast<long>(static_cast<float>(x));  // NOLINT
206   return PyLong_FromLong(y);
207 }
208 
209 // Negates a PyCustomFloat.
210 template <typename T>
PyCustomFloat_Negative(PyObject * self)211 PyObject* PyCustomFloat_Negative(PyObject* self) {
212   T x = PyCustomFloat_CustomFloat<T>(self);
213   return PyCustomFloat_FromT<T>(-x).release();
214 }
215 
216 template <typename T>
PyCustomFloat_Add(PyObject * a,PyObject * b)217 PyObject* PyCustomFloat_Add(PyObject* a, PyObject* b) {
218   T x, y;
219   if (SafeCastToCustomFloat<T>(a, &x) && SafeCastToCustomFloat<T>(b, &y)) {
220     return PyCustomFloat_FromT<T>(x + y).release();
221   }
222   return PyArray_Type.tp_as_number->nb_add(a, b);
223 }
224 
225 template <typename T>
PyCustomFloat_Subtract(PyObject * a,PyObject * b)226 PyObject* PyCustomFloat_Subtract(PyObject* a, PyObject* b) {
227   T x, y;
228   if (SafeCastToCustomFloat<T>(a, &x) && SafeCastToCustomFloat<T>(b, &y)) {
229     return PyCustomFloat_FromT<T>(x - y).release();
230   }
231   return PyArray_Type.tp_as_number->nb_subtract(a, b);
232 }
233 
234 template <typename T>
PyCustomFloat_Multiply(PyObject * a,PyObject * b)235 PyObject* PyCustomFloat_Multiply(PyObject* a, PyObject* b) {
236   T x, y;
237   if (SafeCastToCustomFloat<T>(a, &x) && SafeCastToCustomFloat<T>(b, &y)) {
238     return PyCustomFloat_FromT<T>(x * y).release();
239   }
240   return PyArray_Type.tp_as_number->nb_multiply(a, b);
241 }
242 
243 template <typename T>
PyCustomFloat_TrueDivide(PyObject * a,PyObject * b)244 PyObject* PyCustomFloat_TrueDivide(PyObject* a, PyObject* b) {
245   T x, y;
246   if (SafeCastToCustomFloat<T>(a, &x) && SafeCastToCustomFloat<T>(b, &y)) {
247     return PyCustomFloat_FromT<T>(x / y).release();
248   }
249   return PyArray_Type.tp_as_number->nb_true_divide(a, b);
250 }
251 
252 // Python number methods for PyCustomFloat objects.
253 template <typename T>
254 PyNumberMethods CustomFloatTypeDescriptor<T>::number_methods = {
255     PyCustomFloat_Add<T>,       // nb_add
256     PyCustomFloat_Subtract<T>,  // nb_subtract
257     PyCustomFloat_Multiply<T>,  // nb_multiply
258     nullptr,                    // nb_remainder
259     nullptr,                    // nb_divmod
260     nullptr,                    // nb_power
261     PyCustomFloat_Negative<T>,  // nb_negative
262     nullptr,                    // nb_positive
263     nullptr,                    // nb_absolute
264     nullptr,                    // nb_nonzero
265     nullptr,                    // nb_invert
266     nullptr,                    // nb_lshift
267     nullptr,                    // nb_rshift
268     nullptr,                    // nb_and
269     nullptr,                    // nb_xor
270     nullptr,                    // nb_or
271     PyCustomFloat_Int<T>,       // nb_int
272     nullptr,                    // reserved
273     PyCustomFloat_Float<T>,     // nb_float
274 
275     nullptr,  // nb_inplace_add
276     nullptr,  // nb_inplace_subtract
277     nullptr,  // nb_inplace_multiply
278     nullptr,  // nb_inplace_remainder
279     nullptr,  // nb_inplace_power
280     nullptr,  // nb_inplace_lshift
281     nullptr,  // nb_inplace_rshift
282     nullptr,  // nb_inplace_and
283     nullptr,  // nb_inplace_xor
284     nullptr,  // nb_inplace_or
285 
286     nullptr,                      // nb_floor_divide
287     PyCustomFloat_TrueDivide<T>,  // nb_true_divide
288     nullptr,                      // nb_inplace_floor_divide
289     nullptr,                      // nb_inplace_true_divide
290     nullptr,                      // nb_index
291 };
292 
293 // Constructs a new PyCustomFloat.
294 template <typename T>
PyCustomFloat_New(PyTypeObject * type,PyObject * args,PyObject * kwds)295 PyObject* PyCustomFloat_New(PyTypeObject* type, PyObject* args,
296                             PyObject* kwds) {
297   if (kwds && PyDict_Size(kwds)) {
298     PyErr_SetString(PyExc_TypeError, "constructor takes no keyword arguments");
299     return nullptr;
300   }
301   Py_ssize_t size = PyTuple_Size(args);
302   if (size != 1) {
303     PyErr_Format(PyExc_TypeError,
304                  "expected number as argument to %s constructor",
305                  TypeDescriptor<T>::kTypeName);
306     return nullptr;
307   }
308   PyObject* arg = PyTuple_GetItem(args, 0);
309 
310   T value;
311   if (PyCustomFloat_Check<T>(arg)) {
312     Py_INCREF(arg);
313     return arg;
314   } else if (CastToCustomFloat<T>(arg, &value)) {
315     return PyCustomFloat_FromT<T>(value).release();
316   } else if (PyArray_Check(arg)) {
317     PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(arg);
318     if (PyArray_TYPE(arr) != TypeDescriptor<T>::Dtype()) {
319       return PyArray_Cast(arr, TypeDescriptor<T>::Dtype());
320     } else {
321       Py_INCREF(arg);
322       return arg;
323     }
324   }
325   PyErr_Format(PyExc_TypeError, "expected number, got %s",
326                Py_TYPE(arg)->tp_name);
327   return nullptr;
328 }
329 
330 // Comparisons on PyCustomFloats.
331 template <typename T>
PyCustomFloat_RichCompare(PyObject * a,PyObject * b,int op)332 PyObject* PyCustomFloat_RichCompare(PyObject* a, PyObject* b, int op) {
333   T x, y;
334   if (!SafeCastToCustomFloat<T>(a, &x) || !SafeCastToCustomFloat<T>(b, &y)) {
335     return PyGenericArrType_Type.tp_richcompare(a, b, op);
336   }
337   bool result;
338   switch (op) {
339     case Py_LT:
340       result = x < y;
341       break;
342     case Py_LE:
343       result = x <= y;
344       break;
345     case Py_EQ:
346       result = x == y;
347       break;
348     case Py_NE:
349       result = x != y;
350       break;
351     case Py_GT:
352       result = x > y;
353       break;
354     case Py_GE:
355       result = x >= y;
356       break;
357     default:
358       LOG(FATAL) << "Invalid op type " << op;
359   }
360   return PyBool_FromLong(result);
361 }
362 
363 // Implementation of repr() for PyCustomFloat.
364 template <typename T>
PyCustomFloat_Repr(PyObject * self)365 PyObject* PyCustomFloat_Repr(PyObject* self) {
366   T x = reinterpret_cast<PyCustomFloat<T>*>(self)->value;
367   std::string v = absl::StrCat(static_cast<float>(x));
368   return PyUnicode_FromString(v.c_str());
369 }
370 
371 // Implementation of str() for PyCustomFloat.
372 template <typename T>
PyCustomFloat_Str(PyObject * self)373 PyObject* PyCustomFloat_Str(PyObject* self) {
374   T x = reinterpret_cast<PyCustomFloat<T>*>(self)->value;
375   std::string v = absl::StrCat(static_cast<float>(x));
376   return PyUnicode_FromString(v.c_str());
377 }
378 
379 // _Py_HashDouble changed its prototype for Python 3.10 so we use an overload to
380 // handle the two possibilities.
381 // NOLINTNEXTLINE(clang-diagnostic-unused-function)
HashImpl(Py_hash_t (* hash_double)(PyObject *,double),PyObject * self,double value)382 Py_hash_t HashImpl(Py_hash_t (*hash_double)(PyObject*, double), PyObject* self,
383                    double value) {
384   return hash_double(self, value);
385 }
386 
387 // NOLINTNEXTLINE(clang-diagnostic-unused-function)
HashImpl(Py_hash_t (* hash_double)(double),PyObject * self,double value)388 Py_hash_t HashImpl(Py_hash_t (*hash_double)(double), PyObject* self,
389                    double value) {
390   return hash_double(value);
391 }
392 
393 // Hash function for PyCustomFloat.
394 template <typename T>
PyCustomFloat_Hash(PyObject * self)395 Py_hash_t PyCustomFloat_Hash(PyObject* self) {
396   T x = reinterpret_cast<PyCustomFloat<T>*>(self)->value;
397   return HashImpl(&_Py_HashDouble, self, static_cast<double>(x));
398 }
399 
400 // Python type for PyCustomFloat objects.
401 template <typename T>
402 PyTypeObject CustomFloatTypeDescriptor<T>::type = {
403     PyVarObject_HEAD_INIT(nullptr, 0) TypeDescriptor<T>::kTypeName,  // tp_name
404     sizeof(PyCustomFloat<T>),  // tp_basicsize
405     0,                         // tp_itemsize
406     nullptr,                   // tp_dealloc
407 #if PY_VERSION_HEX < 0x03080000
408     nullptr,  // tp_print
409 #else
410     0,  // tp_vectorcall_offset
411 #endif
412     nullptr,                                        // tp_getattr
413     nullptr,                                        // tp_setattr
414     nullptr,                                        // tp_compare / tp_reserved
415     PyCustomFloat_Repr<T>,                          // tp_repr
416     &CustomFloatTypeDescriptor<T>::number_methods,  // tp_as_number
417     nullptr,                                        // tp_as_sequence
418     nullptr,                                        // tp_as_mapping
419     PyCustomFloat_Hash<T>,                          // tp_hash
420     nullptr,                                        // tp_call
421     PyCustomFloat_Str<T>,                           // tp_str
422     nullptr,                                        // tp_getattro
423     nullptr,                                        // tp_setattro
424     nullptr,                                        // tp_as_buffer
425                                                     // tp_flags
426     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,
427     TypeDescriptor<T>::kTpDoc,     // tp_doc
428     nullptr,                       // tp_traverse
429     nullptr,                       // tp_clear
430     PyCustomFloat_RichCompare<T>,  // tp_richcompare
431     0,                             // tp_weaklistoffset
432     nullptr,                       // tp_iter
433     nullptr,                       // tp_iternext
434     nullptr,                       // tp_methods
435     nullptr,                       // tp_members
436     nullptr,                       // tp_getset
437     nullptr,                       // tp_base
438     nullptr,                       // tp_dict
439     nullptr,                       // tp_descr_get
440     nullptr,                       // tp_descr_set
441     0,                             // tp_dictoffset
442     nullptr,                       // tp_init
443     nullptr,                       // tp_alloc
444     PyCustomFloat_New<T>,          // tp_new
445     nullptr,                       // tp_free
446     nullptr,                       // tp_is_gc
447     nullptr,                       // tp_bases
448     nullptr,                       // tp_mro
449     nullptr,                       // tp_cache
450     nullptr,                       // tp_subclasses
451     nullptr,                       // tp_weaklist
452     nullptr,                       // tp_del
453     0,                             // tp_version_tag
454 };
455 
456 // Numpy support
457 template <typename T>
458 PyArray_ArrFuncs CustomFloatTypeDescriptor<T>::arr_funcs;
459 
460 template <typename T>
461 PyArray_Descr CustomFloatTypeDescriptor<T>::npy_descr = {
462     PyObject_HEAD_INIT(nullptr)  //
463                                  /*typeobj=*/
464     (&TypeDescriptor<T>::type),
465     /*kind=*/TypeDescriptor<T>::kNpyDescrKind,
466     /*type=*/TypeDescriptor<T>::kNpyDescrType,
467     /*byteorder=*/TypeDescriptor<T>::kNpyDescrByteorder,
468     /*flags=*/NPY_NEEDS_PYAPI | NPY_USE_SETITEM,
469     /*type_num=*/0,
470     /*elsize=*/sizeof(T),
471     /*alignment=*/alignof(T),
472     /*subarray=*/nullptr,
473     /*fields=*/nullptr,
474     /*names=*/nullptr,
475     /*f=*/&CustomFloatTypeDescriptor<T>::arr_funcs,
476     /*metadata=*/nullptr,
477     /*c_metadata=*/nullptr,
478     /*hash=*/-1,  // -1 means "not computed yet".
479 };
480 
481 // Implementations of NumPy array methods.
482 
483 template <typename T>
NPyCustomFloat_GetItem(void * data,void * arr)484 PyObject* NPyCustomFloat_GetItem(void* data, void* arr) {
485   T x;
486   memcpy(&x, data, sizeof(T));
487   return PyFloat_FromDouble(static_cast<float>(x));
488 }
489 
490 template <typename T>
NPyCustomFloat_SetItem(PyObject * item,void * data,void * arr)491 int NPyCustomFloat_SetItem(PyObject* item, void* data, void* arr) {
492   T x;
493   if (!CastToCustomFloat<T>(item, &x)) {
494     PyErr_Format(PyExc_TypeError, "expected number, got %s",
495                  Py_TYPE(item)->tp_name);
496     return -1;
497   }
498   memcpy(data, &x, sizeof(T));
499   return 0;
500 }
501 
ByteSwap16(void * value)502 void ByteSwap16(void* value) {
503   char* p = reinterpret_cast<char*>(value);
504   std::swap(p[0], p[1]);
505 }
506 
507 template <typename T>
NPyCustomFloat_Compare(const void * a,const void * b,void * arr)508 int NPyCustomFloat_Compare(const void* a, const void* b, void* arr) {
509   T x;
510   memcpy(&x, a, sizeof(T));
511 
512   T y;
513   memcpy(&y, b, sizeof(T));
514   float fy(y);
515   float fx(x);
516 
517   if (fx < fy) {
518     return -1;
519   }
520   if (fy < fx) {
521     return 1;
522   }
523   // NaNs sort to the end.
524   if (!Eigen::numext::isnan(fx) && Eigen::numext::isnan(fy)) {
525     return -1;
526   }
527   if (Eigen::numext::isnan(fx) && !Eigen::numext::isnan(fy)) {
528     return 1;
529   }
530   return 0;
531 }
532 
533 template <typename T>
NPyCustomFloat_CopySwapN(void * dstv,npy_intp dstride,void * srcv,npy_intp sstride,npy_intp n,int swap,void * arr)534 void NPyCustomFloat_CopySwapN(void* dstv, npy_intp dstride, void* srcv,
535                               npy_intp sstride, npy_intp n, int swap,
536                               void* arr) {
537   static_assert(sizeof(T) == sizeof(int16_t) || sizeof(T) == sizeof(int8_t),
538                 "Not supported");
539   char* dst = reinterpret_cast<char*>(dstv);
540   char* src = reinterpret_cast<char*>(srcv);
541   if (!src) {
542     return;
543   }
544   if (swap && sizeof(T) == sizeof(int16_t)) {
545     for (npy_intp i = 0; i < n; i++) {
546       char* r = dst + dstride * i;
547       memcpy(r, src + sstride * i, sizeof(T));
548       ByteSwap16(r);
549     }
550   } else if (dstride == sizeof(T) && sstride == sizeof(T)) {
551     memcpy(dst, src, n * sizeof(T));
552   } else {
553     for (npy_intp i = 0; i < n; i++) {
554       memcpy(dst + dstride * i, src + sstride * i, sizeof(T));
555     }
556   }
557 }
558 
559 template <typename T>
NPyCustomFloat_CopySwap(void * dst,void * src,int swap,void * arr)560 void NPyCustomFloat_CopySwap(void* dst, void* src, int swap, void* arr) {
561   if (!src) {
562     return;
563   }
564   memcpy(dst, src, sizeof(T));
565   static_assert(sizeof(T) == sizeof(int16_t) || sizeof(T) == sizeof(int8_t),
566                 "Not supported");
567   if (swap && sizeof(T) == sizeof(int16_t)) {
568     ByteSwap16(dst);
569   }
570 }
571 
572 template <typename T>
NPyCustomFloat_NonZero(void * data,void * arr)573 npy_bool NPyCustomFloat_NonZero(void* data, void* arr) {
574   T x;
575   memcpy(&x, data, sizeof(x));
576   return x != static_cast<T>(0);
577 }
578 
579 template <typename T>
NPyCustomFloat_Fill(void * buffer_raw,npy_intp length,void * ignored)580 int NPyCustomFloat_Fill(void* buffer_raw, npy_intp length, void* ignored) {
581   T* const buffer = reinterpret_cast<T*>(buffer_raw);
582   const float start(buffer[0]);
583   const float delta = static_cast<float>(buffer[1]) - start;
584   for (npy_intp i = 2; i < length; ++i) {
585     buffer[i] = static_cast<T>(start + i * delta);
586   }
587   return 0;
588 }
589 
590 template <typename T>
NPyCustomFloat_DotFunc(void * ip1,npy_intp is1,void * ip2,npy_intp is2,void * op,npy_intp n,void * arr)591 void NPyCustomFloat_DotFunc(void* ip1, npy_intp is1, void* ip2, npy_intp is2,
592                             void* op, npy_intp n, void* arr) {
593   char* c1 = reinterpret_cast<char*>(ip1);
594   char* c2 = reinterpret_cast<char*>(ip2);
595   float acc = 0.0f;
596   for (npy_intp i = 0; i < n; ++i) {
597     T* const b1 = reinterpret_cast<T*>(c1);
598     T* const b2 = reinterpret_cast<T*>(c2);
599     acc += static_cast<float>(*b1) * static_cast<float>(*b2);
600     c1 += is1;
601     c2 += is2;
602   }
603   T* out = reinterpret_cast<T*>(op);
604   *out = static_cast<T>(acc);
605 }
606 
607 template <typename T>
NPyCustomFloat_CompareFunc(const void * v1,const void * v2,void * arr)608 int NPyCustomFloat_CompareFunc(const void* v1, const void* v2, void* arr) {
609   T b1 = *reinterpret_cast<const T*>(v1);
610   T b2 = *reinterpret_cast<const T*>(v2);
611   if (b1 < b2) {
612     return -1;
613   }
614   if (b1 > b2) {
615     return 1;
616   }
617   return 0;
618 }
619 
620 template <typename T>
NPyCustomFloat_ArgMaxFunc(void * data,npy_intp n,npy_intp * max_ind,void * arr)621 int NPyCustomFloat_ArgMaxFunc(void* data, npy_intp n, npy_intp* max_ind,
622                               void* arr) {
623   const T* bdata = reinterpret_cast<const T*>(data);
624   // Start with a max_val of NaN, this results in the first iteration preferring
625   // bdata[0].
626   float max_val = std::numeric_limits<float>::quiet_NaN();
627   for (npy_intp i = 0; i < n; ++i) {
628     // This condition is chosen so that NaNs are always considered "max".
629     if (!(static_cast<float>(bdata[i]) <= max_val)) {
630       max_val = static_cast<float>(bdata[i]);
631       *max_ind = i;
632       // NumPy stops at the first NaN.
633       if (Eigen::numext::isnan(max_val)) {
634         break;
635       }
636     }
637   }
638   return 0;
639 }
640 
641 template <typename T>
NPyCustomFloat_ArgMinFunc(void * data,npy_intp n,npy_intp * min_ind,void * arr)642 int NPyCustomFloat_ArgMinFunc(void* data, npy_intp n, npy_intp* min_ind,
643                               void* arr) {
644   const T* bdata = reinterpret_cast<const T*>(data);
645   float min_val = std::numeric_limits<float>::quiet_NaN();
646   // Start with a min_val of NaN, this results in the first iteration preferring
647   // bdata[0].
648   for (npy_intp i = 0; i < n; ++i) {
649     // This condition is chosen so that NaNs are always considered "min".
650     if (!(static_cast<float>(bdata[i]) >= min_val)) {
651       min_val = static_cast<float>(bdata[i]);
652       *min_ind = i;
653       // NumPy stops at the first NaN.
654       if (Eigen::numext::isnan(min_val)) {
655         break;
656       }
657     }
658   }
659   return 0;
660 }
661 
662 template <>
663 struct TypeDescriptor<unsigned char> {
664   typedef unsigned char T;
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor665   static int Dtype() { return NPY_UBYTE; }
666 };
667 
668 template <>
669 struct TypeDescriptor<unsigned short> {  // NOLINT
670   typedef unsigned short T;              // NOLINT
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor671   static int Dtype() { return NPY_USHORT; }
672 };
673 
674 // We register "int", "long", and "long long" types for portability across
675 // Linux, where "int" and "long" are the same type, and Windows, where "long"
676 // and "longlong" are the same type.
677 template <>
678 struct TypeDescriptor<unsigned int> {
679   typedef unsigned int T;
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor680   static int Dtype() { return NPY_UINT; }
681 };
682 
683 template <>
684 struct TypeDescriptor<unsigned long> {  // NOLINT
685   typedef unsigned long T;              // NOLINT
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor686   static int Dtype() { return NPY_ULONG; }
687 };
688 
689 template <>
690 struct TypeDescriptor<unsigned long long> {  // NOLINT
691   typedef unsigned long long T;              // NOLINT
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor692   static int Dtype() { return NPY_ULONGLONG; }
693 };
694 
695 template <>
696 struct TypeDescriptor<signed char> {
697   typedef signed char T;
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor698   static int Dtype() { return NPY_BYTE; }
699 };
700 
701 template <>
702 struct TypeDescriptor<short> {  // NOLINT
703   typedef short T;              // NOLINT
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor704   static int Dtype() { return NPY_SHORT; }
705 };
706 
707 template <>
708 struct TypeDescriptor<int> {
709   typedef int T;
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor710   static int Dtype() { return NPY_INT; }
711 };
712 
713 template <>
714 struct TypeDescriptor<long> {  // NOLINT
715   typedef long T;              // NOLINT
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor716   static int Dtype() { return NPY_LONG; }
717 };
718 
719 template <>
720 struct TypeDescriptor<long long> {  // NOLINT
721   typedef long long T;              // NOLINT
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor722   static int Dtype() { return NPY_LONGLONG; }
723 };
724 
725 template <>
726 struct TypeDescriptor<bool> {
727   typedef unsigned char T;
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor728   static int Dtype() { return NPY_BOOL; }
729 };
730 
731 template <>
732 struct TypeDescriptor<Eigen::half> {
733   typedef Eigen::half T;
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor734   static int Dtype() { return NPY_HALF; }
735 };
736 
737 template <>
738 struct TypeDescriptor<float> {
739   typedef float T;
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor740   static int Dtype() { return NPY_FLOAT; }
741 };
742 
743 template <>
744 struct TypeDescriptor<double> {
745   typedef double T;
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor746   static int Dtype() { return NPY_DOUBLE; }
747 };
748 
749 template <>
750 struct TypeDescriptor<long double> {
751   typedef long double T;
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor752   static int Dtype() { return NPY_LONGDOUBLE; }
753 };
754 
755 template <>
756 struct TypeDescriptor<std::complex<float>> {
757   typedef std::complex<float> T;
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor758   static int Dtype() { return NPY_CFLOAT; }
759 };
760 
761 template <>
762 struct TypeDescriptor<std::complex<double>> {
763   typedef std::complex<double> T;
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor764   static int Dtype() { return NPY_CDOUBLE; }
765 };
766 
767 template <>
768 struct TypeDescriptor<std::complex<long double>> {
769   typedef std::complex<long double> T;
Dtypetensorflow::__anon7a1ed7ad0111::TypeDescriptor770   static int Dtype() { return NPY_CLONGDOUBLE; }
771 };
772 
773 template <typename T>
CastToFloat(T value)774 float CastToFloat(T value) {
775   return static_cast<float>(value);
776 }
777 
778 template <typename T>
CastToFloat(std::complex<T> value)779 float CastToFloat(std::complex<T> value) {
780   return CastToFloat(value.real());
781 }
782 
783 // Performs a NumPy array cast from type 'From' to 'To'.
784 template <typename From, typename To>
NPyCast(void * from_void,void * to_void,npy_intp n,void * fromarr,void * toarr)785 void NPyCast(void* from_void, void* to_void, npy_intp n, void* fromarr,
786              void* toarr) {
787   const auto* from =
788       reinterpret_cast<typename TypeDescriptor<From>::T*>(from_void);
789   auto* to = reinterpret_cast<typename TypeDescriptor<To>::T*>(to_void);
790   for (npy_intp i = 0; i < n; ++i) {
791     to[i] = static_cast<typename TypeDescriptor<To>::T>(
792         static_cast<To>(CastToFloat(from[i])));
793   }
794 }
795 
796 // Registers a cast between T (a reduced float) and type 'OtherT'. 'numpy_type'
797 // is the NumPy type corresponding to 'OtherT'.
798 template <typename T, typename OtherT>
RegisterCustomFloatCast(int numpy_type=TypeDescriptor<OtherT>::Dtype ())799 bool RegisterCustomFloatCast(int numpy_type = TypeDescriptor<OtherT>::Dtype()) {
800   PyArray_Descr* descr = PyArray_DescrFromType(numpy_type);
801   if (PyArray_RegisterCastFunc(descr, TypeDescriptor<T>::Dtype(),
802                                NPyCast<OtherT, T>) < 0) {
803     return false;
804   }
805   if (PyArray_RegisterCastFunc(&CustomFloatTypeDescriptor<T>::npy_descr,
806                                numpy_type, NPyCast<T, OtherT>) < 0) {
807     return false;
808   }
809   return true;
810 }
811 
812 template <typename T>
RegisterCasts()813 bool RegisterCasts() {
814   if (!RegisterCustomFloatCast<T, Eigen::half>(NPY_HALF)) {
815     return false;
816   }
817 
818   if (!RegisterCustomFloatCast<T, float>(NPY_FLOAT)) {
819     return false;
820   }
821   if (!RegisterCustomFloatCast<T, double>(NPY_DOUBLE)) {
822     return false;
823   }
824   if (!RegisterCustomFloatCast<T, long double>(NPY_LONGDOUBLE)) {
825     return false;
826   }
827   if (!RegisterCustomFloatCast<T, bool>(NPY_BOOL)) {
828     return false;
829   }
830   if (!RegisterCustomFloatCast<T, unsigned char>(NPY_UBYTE)) {
831     return false;
832   }
833   if (!RegisterCustomFloatCast<T, unsigned short>(NPY_USHORT)) {  // NOLINT
834     return false;
835   }
836   if (!RegisterCustomFloatCast<T, unsigned int>(NPY_UINT)) {
837     return false;
838   }
839   if (!RegisterCustomFloatCast<T, unsigned long>(NPY_ULONG)) {  // NOLINT
840     return false;
841   }
842   if (!RegisterCustomFloatCast<T, unsigned long long>(  // NOLINT
843           NPY_ULONGLONG)) {
844     return false;
845   }
846   if (!RegisterCustomFloatCast<T, signed char>(NPY_BYTE)) {
847     return false;
848   }
849   if (!RegisterCustomFloatCast<T, short>(NPY_SHORT)) {  // NOLINT
850     return false;
851   }
852   if (!RegisterCustomFloatCast<T, int>(NPY_INT)) {
853     return false;
854   }
855   if (!RegisterCustomFloatCast<T, long>(NPY_LONG)) {  // NOLINT
856     return false;
857   }
858   if (!RegisterCustomFloatCast<T, long long>(NPY_LONGLONG)) {  // NOLINT
859     return false;
860   }
861   // Following the numpy convention. imag part is dropped when converting to
862   // float.
863   if (!RegisterCustomFloatCast<T, std::complex<float>>(NPY_CFLOAT)) {
864     return false;
865   }
866   if (!RegisterCustomFloatCast<T, std::complex<double>>(NPY_CDOUBLE)) {
867     return false;
868   }
869   if (!RegisterCustomFloatCast<T, std::complex<long double>>(NPY_CLONGDOUBLE)) {
870     return false;
871   }
872 
873   // Safe casts from T to other types
874   if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_FLOAT,
875                               NPY_NOSCALAR) < 0) {
876     return false;
877   }
878   if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_DOUBLE,
879                               NPY_NOSCALAR) < 0) {
880     return false;
881   }
882   if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_LONGDOUBLE,
883                               NPY_NOSCALAR) < 0) {
884     return false;
885   }
886   if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_CFLOAT,
887                               NPY_NOSCALAR) < 0) {
888     return false;
889   }
890   if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_CDOUBLE,
891                               NPY_NOSCALAR) < 0) {
892     return false;
893   }
894   if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_CLONGDOUBLE,
895                               NPY_NOSCALAR) < 0) {
896     return false;
897   }
898 
899   // Safe casts to T from other types
900   if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_BOOL),
901                               TypeDescriptor<T>::Dtype(), NPY_NOSCALAR) < 0) {
902     return false;
903   }
904   if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_UBYTE),
905                               TypeDescriptor<T>::Dtype(), NPY_NOSCALAR) < 0) {
906     return false;
907   }
908   if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_BYTE),
909                               TypeDescriptor<T>::Dtype(), NPY_NOSCALAR) < 0) {
910     return false;
911   }
912 
913   return true;
914 }
915 
916 template <typename InType, typename OutType, typename Functor>
917 struct UnaryUFunc {
Typestensorflow::__anon7a1ed7ad0111::UnaryUFunc918   static std::vector<int> Types() {
919     return {TypeDescriptor<InType>::Dtype(), TypeDescriptor<OutType>::Dtype()};
920   }
Calltensorflow::__anon7a1ed7ad0111::UnaryUFunc921   static void Call(char** args, const npy_intp* dimensions,
922                    const npy_intp* steps, void* data) {
923     const char* i0 = args[0];
924     char* o = args[1];
925     for (npy_intp k = 0; k < *dimensions; k++) {
926       auto x = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i0);
927       *reinterpret_cast<typename TypeDescriptor<OutType>::T*>(o) = Functor()(x);
928       i0 += steps[0];
929       o += steps[1];
930     }
931   }
932 };
933 
934 template <typename InType, typename OutType, typename OutType2,
935           typename Functor>
936 struct UnaryUFunc2 {
Typestensorflow::__anon7a1ed7ad0111::UnaryUFunc2937   static std::vector<int> Types() {
938     return {TypeDescriptor<InType>::Dtype(), TypeDescriptor<OutType>::Dtype(),
939             TypeDescriptor<OutType2>::Dtype()};
940   }
Calltensorflow::__anon7a1ed7ad0111::UnaryUFunc2941   static void Call(char** args, const npy_intp* dimensions,
942                    const npy_intp* steps, void* data) {
943     const char* i0 = args[0];
944     char* o0 = args[1];
945     char* o1 = args[2];
946     for (npy_intp k = 0; k < *dimensions; k++) {
947       auto x = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i0);
948       std::tie(*reinterpret_cast<typename TypeDescriptor<OutType>::T*>(o0),
949                *reinterpret_cast<typename TypeDescriptor<OutType2>::T*>(o1)) =
950           Functor()(x);
951       i0 += steps[0];
952       o0 += steps[1];
953       o1 += steps[2];
954     }
955   }
956 };
957 
958 template <typename InType, typename OutType, typename Functor>
959 struct BinaryUFunc {
Typestensorflow::__anon7a1ed7ad0111::BinaryUFunc960   static std::vector<int> Types() {
961     return {TypeDescriptor<InType>::Dtype(), TypeDescriptor<InType>::Dtype(),
962             TypeDescriptor<OutType>::Dtype()};
963   }
Calltensorflow::__anon7a1ed7ad0111::BinaryUFunc964   static void Call(char** args, const npy_intp* dimensions,
965                    const npy_intp* steps, void* data) {
966     const char* i0 = args[0];
967     const char* i1 = args[1];
968     char* o = args[2];
969     for (npy_intp k = 0; k < *dimensions; k++) {
970       auto x = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i0);
971       auto y = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i1);
972       *reinterpret_cast<typename TypeDescriptor<OutType>::T*>(o) =
973           Functor()(x, y);
974       i0 += steps[0];
975       i1 += steps[1];
976       o += steps[2];
977     }
978   }
979 };
980 
981 template <typename InType, typename InType2, typename OutType, typename Functor>
982 struct BinaryUFunc2 {
Typestensorflow::__anon7a1ed7ad0111::BinaryUFunc2983   static std::vector<int> Types() {
984     return {TypeDescriptor<InType>::Dtype(), TypeDescriptor<InType2>::Dtype(),
985             TypeDescriptor<OutType>::Dtype()};
986   }
Calltensorflow::__anon7a1ed7ad0111::BinaryUFunc2987   static void Call(char** args, const npy_intp* dimensions,
988                    const npy_intp* steps, void* data) {
989     const char* i0 = args[0];
990     const char* i1 = args[1];
991     char* o = args[2];
992     for (npy_intp k = 0; k < *dimensions; k++) {
993       auto x = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i0);
994       auto y =
995           *reinterpret_cast<const typename TypeDescriptor<InType2>::T*>(i1);
996       *reinterpret_cast<typename TypeDescriptor<OutType>::T*>(o) =
997           Functor()(x, y);
998       i0 += steps[0];
999       i1 += steps[1];
1000       o += steps[2];
1001     }
1002   }
1003 };
1004 
1005 template <typename UFunc, typename CustomFloatT>
RegisterUFunc(PyObject * numpy,const char * name)1006 bool RegisterUFunc(PyObject* numpy, const char* name) {
1007   std::vector<int> types = UFunc::Types();
1008   PyUFuncGenericFunction fn =
1009       reinterpret_cast<PyUFuncGenericFunction>(UFunc::Call);
1010   Safe_PyObjectPtr ufunc_obj = make_safe(PyObject_GetAttrString(numpy, name));
1011   if (!ufunc_obj) {
1012     return false;
1013   }
1014   PyUFuncObject* ufunc = reinterpret_cast<PyUFuncObject*>(ufunc_obj.get());
1015   if (static_cast<int>(types.size()) != ufunc->nargs) {
1016     PyErr_Format(PyExc_AssertionError,
1017                  "ufunc %s takes %d arguments, loop takes %lu", name,
1018                  ufunc->nargs, types.size());
1019     return false;
1020   }
1021   if (PyUFunc_RegisterLoopForType(ufunc, TypeDescriptor<CustomFloatT>::Dtype(),
1022                                   fn, const_cast<int*>(types.data()),
1023                                   nullptr) < 0) {
1024     return false;
1025   }
1026   return true;
1027 }
1028 
1029 namespace ufuncs {
1030 
1031 template <typename T>
1032 struct Add {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Add1033   T operator()(T a, T b) { return a + b; }
1034 };
1035 template <typename T>
1036 struct Subtract {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Subtract1037   T operator()(T a, T b) { return a - b; }
1038 };
1039 template <typename T>
1040 struct Multiply {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Multiply1041   T operator()(T a, T b) { return a * b; }
1042 };
1043 template <typename T>
1044 struct TrueDivide {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::TrueDivide1045   T operator()(T a, T b) { return a / b; }
1046 };
1047 
divmod(float a,float b)1048 inline std::pair<float, float> divmod(float a, float b) {
1049   if (b == 0.0f) {
1050     float nan = std::numeric_limits<float>::quiet_NaN();
1051     return {nan, nan};
1052   }
1053   float mod = std::fmod(a, b);
1054   float div = (a - mod) / b;
1055   if (mod != 0.0f) {
1056     if ((b < 0.0f) != (mod < 0.0f)) {
1057       mod += b;
1058       div -= 1.0f;
1059     }
1060   } else {
1061     mod = std::copysign(0.0f, b);
1062   }
1063 
1064   float floordiv;
1065   if (div != 0.0f) {
1066     floordiv = std::floor(div);
1067     if (div - floordiv > 0.5f) {
1068       floordiv += 1.0f;
1069     }
1070   } else {
1071     floordiv = std::copysign(0.0f, a / b);
1072   }
1073   return {floordiv, mod};
1074 }
1075 
1076 template <typename T>
1077 struct FloorDivide {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::FloorDivide1078   T operator()(T a, T b) {
1079     return T(divmod(static_cast<float>(a), static_cast<float>(b)).first);
1080   }
1081 };
1082 template <typename T>
1083 struct Remainder {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Remainder1084   T operator()(T a, T b) {
1085     return T(divmod(static_cast<float>(a), static_cast<float>(b)).second);
1086   }
1087 };
1088 template <typename T>
1089 struct DivmodUFunc {
Typestensorflow::__anon7a1ed7ad0111::ufuncs::DivmodUFunc1090   static std::vector<int> Types() {
1091     return {TypeDescriptor<T>::Dtype(), TypeDescriptor<T>::Dtype(),
1092             TypeDescriptor<T>::Dtype(), TypeDescriptor<T>::Dtype()};
1093   }
Calltensorflow::__anon7a1ed7ad0111::ufuncs::DivmodUFunc1094   static void Call(char** args, npy_intp* dimensions, npy_intp* steps,
1095                    void* data) {
1096     const char* i0 = args[0];
1097     const char* i1 = args[1];
1098     char* o0 = args[2];
1099     char* o1 = args[3];
1100     for (npy_intp k = 0; k < *dimensions; k++) {
1101       T x = *reinterpret_cast<const T*>(i0);
1102       T y = *reinterpret_cast<const T*>(i1);
1103       float floordiv, mod;
1104       std::tie(floordiv, mod) =
1105           divmod(static_cast<float>(x), static_cast<float>(y));
1106       *reinterpret_cast<T*>(o0) = T(floordiv);
1107       *reinterpret_cast<T*>(o1) = T(mod);
1108       i0 += steps[0];
1109       i1 += steps[1];
1110       o0 += steps[2];
1111       o1 += steps[3];
1112     }
1113   }
1114 };
1115 template <typename T>
1116 struct Fmod {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Fmod1117   T operator()(T a, T b) {
1118     return T(std::fmod(static_cast<float>(a), static_cast<float>(b)));
1119   }
1120 };
1121 template <typename T>
1122 struct Negative {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Negative1123   T operator()(T a) { return -a; }
1124 };
1125 template <typename T>
1126 struct Positive {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Positive1127   T operator()(T a) { return a; }
1128 };
1129 template <typename T>
1130 struct Power {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Power1131   T operator()(T a, T b) {
1132     return T(std::pow(static_cast<float>(a), static_cast<float>(b)));
1133   }
1134 };
1135 template <typename T>
1136 struct Abs {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Abs1137   T operator()(T a) { return T(std::abs(static_cast<float>(a))); }
1138 };
1139 template <typename T>
1140 struct Cbrt {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Cbrt1141   T operator()(T a) { return T(std::cbrt(static_cast<float>(a))); }
1142 };
1143 template <typename T>
1144 struct Ceil {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Ceil1145   T operator()(T a) { return T(std::ceil(static_cast<float>(a))); }
1146 };
1147 template <typename T>
1148 struct CopySign;
1149 
1150 template <typename T>
1151 struct Exp {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Exp1152   T operator()(T a) { return T(std::exp(static_cast<float>(a))); }
1153 };
1154 template <typename T>
1155 struct Exp2 {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Exp21156   T operator()(T a) { return T(std::exp2(static_cast<float>(a))); }
1157 };
1158 template <typename T>
1159 struct Expm1 {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Expm11160   T operator()(T a) { return T(std::expm1(static_cast<float>(a))); }
1161 };
1162 template <typename T>
1163 struct Floor {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Floor1164   T operator()(T a) { return T(std::floor(static_cast<float>(a))); }
1165 };
1166 template <typename T>
1167 struct Frexp {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Frexp1168   std::pair<T, int> operator()(T a) {
1169     int exp;
1170     float f = std::frexp(static_cast<float>(a), &exp);
1171     return {T(f), exp};
1172   }
1173 };
1174 template <typename T>
1175 struct Heaviside {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Heaviside1176   T operator()(T bx, T h0) {
1177     float x = static_cast<float>(bx);
1178     if (Eigen::numext::isnan(x)) {
1179       return bx;
1180     }
1181     if (x < 0) {
1182       return T(0.0f);
1183     }
1184     if (x > 0) {
1185       return T(1.0f);
1186     }
1187     return h0;  // x == 0
1188   }
1189 };
1190 template <typename T>
1191 struct Conjugate {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Conjugate1192   T operator()(T a) { return a; }
1193 };
1194 template <typename T>
1195 struct IsFinite {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::IsFinite1196   bool operator()(T a) { return std::isfinite(static_cast<float>(a)); }
1197 };
1198 template <typename T>
1199 struct IsInf {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::IsInf1200   bool operator()(T a) { return std::isinf(static_cast<float>(a)); }
1201 };
1202 template <typename T>
1203 struct IsNan {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::IsNan1204   bool operator()(T a) { return Eigen::numext::isnan(static_cast<float>(a)); }
1205 };
1206 template <typename T>
1207 struct Ldexp {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Ldexp1208   T operator()(T a, int exp) {
1209     return T(std::ldexp(static_cast<float>(a), exp));
1210   }
1211 };
1212 template <typename T>
1213 struct Log {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Log1214   T operator()(T a) { return T(std::log(static_cast<float>(a))); }
1215 };
1216 template <typename T>
1217 struct Log2 {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Log21218   T operator()(T a) { return T(std::log2(static_cast<float>(a))); }
1219 };
1220 template <typename T>
1221 struct Log10 {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Log101222   T operator()(T a) { return T(std::log10(static_cast<float>(a))); }
1223 };
1224 template <typename T>
1225 struct Log1p {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Log1p1226   T operator()(T a) { return T(std::log1p(static_cast<float>(a))); }
1227 };
1228 template <typename T>
1229 struct LogAddExp {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::LogAddExp1230   T operator()(T bx, T by) {
1231     float x = static_cast<float>(bx);
1232     float y = static_cast<float>(by);
1233     if (x == y) {
1234       // Handles infinities of the same sign.
1235       return T(x + std::log(2.0f));
1236     }
1237     float out = std::numeric_limits<float>::quiet_NaN();
1238     if (x > y) {
1239       out = x + std::log1p(std::exp(y - x));
1240     } else if (x < y) {
1241       out = y + std::log1p(std::exp(x - y));
1242     }
1243     return T(out);
1244   }
1245 };
1246 template <typename T>
1247 struct LogAddExp2 {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::LogAddExp21248   T operator()(T bx, T by) {
1249     float x = static_cast<float>(bx);
1250     float y = static_cast<float>(by);
1251     if (x == y) {
1252       // Handles infinities of the same sign.
1253       return T(x + 1.0f);
1254     }
1255     float out = std::numeric_limits<float>::quiet_NaN();
1256     if (x > y) {
1257       out = x + std::log1p(std::exp2(y - x)) / std::log(2.0f);
1258     } else if (x < y) {
1259       out = y + std::log1p(std::exp2(x - y)) / std::log(2.0f);
1260     }
1261     return T(out);
1262   }
1263 };
1264 template <typename T>
1265 struct Modf {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Modf1266   std::pair<T, T> operator()(T a) {
1267     float integral;
1268     float f = std::modf(static_cast<float>(a), &integral);
1269     return {T(f), T(integral)};
1270   }
1271 };
1272 
1273 template <typename T>
1274 struct Reciprocal {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Reciprocal1275   T operator()(T a) { return T(1.f / static_cast<float>(a)); }
1276 };
1277 template <typename T>
1278 struct Rint {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Rint1279   T operator()(T a) { return T(std::rint(static_cast<float>(a))); }
1280 };
1281 template <typename T>
1282 struct Sign {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Sign1283   T operator()(T a) {
1284     float f(a);
1285     if (f < 0) {
1286       return T(-1);
1287     }
1288     if (f > 0) {
1289       return T(1);
1290     }
1291     return a;
1292   }
1293 };
1294 template <typename T>
1295 struct SignBit {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::SignBit1296   bool operator()(T a) { return std::signbit(static_cast<float>(a)); }
1297 };
1298 template <typename T>
1299 struct Sqrt {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Sqrt1300   T operator()(T a) { return T(std::sqrt(static_cast<float>(a))); }
1301 };
1302 template <typename T>
1303 struct Square {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Square1304   T operator()(T a) {
1305     float f(a);
1306     return T(f * f);
1307   }
1308 };
1309 template <typename T>
1310 struct Trunc {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Trunc1311   T operator()(T a) { return T(std::trunc(static_cast<float>(a))); }
1312 };
1313 
1314 // Trigonometric functions
1315 template <typename T>
1316 struct Sin {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Sin1317   T operator()(T a) { return T(std::sin(static_cast<float>(a))); }
1318 };
1319 template <typename T>
1320 struct Cos {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Cos1321   T operator()(T a) { return T(std::cos(static_cast<float>(a))); }
1322 };
1323 template <typename T>
1324 struct Tan {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Tan1325   T operator()(T a) { return T(std::tan(static_cast<float>(a))); }
1326 };
1327 template <typename T>
1328 struct Arcsin {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Arcsin1329   T operator()(T a) { return T(std::asin(static_cast<float>(a))); }
1330 };
1331 template <typename T>
1332 struct Arccos {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Arccos1333   T operator()(T a) { return T(std::acos(static_cast<float>(a))); }
1334 };
1335 template <typename T>
1336 struct Arctan {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Arctan1337   T operator()(T a) { return T(std::atan(static_cast<float>(a))); }
1338 };
1339 template <typename T>
1340 struct Arctan2 {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Arctan21341   T operator()(T a, T b) {
1342     return T(std::atan2(static_cast<float>(a), static_cast<float>(b)));
1343   }
1344 };
1345 template <typename T>
1346 struct Hypot {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Hypot1347   T operator()(T a, T b) {
1348     return T(std::hypot(static_cast<float>(a), static_cast<float>(b)));
1349   }
1350 };
1351 template <typename T>
1352 struct Sinh {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Sinh1353   T operator()(T a) { return T(std::sinh(static_cast<float>(a))); }
1354 };
1355 template <typename T>
1356 struct Cosh {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Cosh1357   T operator()(T a) { return T(std::cosh(static_cast<float>(a))); }
1358 };
1359 template <typename T>
1360 struct Tanh {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Tanh1361   T operator()(T a) { return T(std::tanh(static_cast<float>(a))); }
1362 };
1363 template <typename T>
1364 struct Arcsinh {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Arcsinh1365   T operator()(T a) { return T(std::asinh(static_cast<float>(a))); }
1366 };
1367 template <typename T>
1368 struct Arccosh {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Arccosh1369   T operator()(T a) { return T(std::acosh(static_cast<float>(a))); }
1370 };
1371 template <typename T>
1372 struct Arctanh {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Arctanh1373   T operator()(T a) { return T(std::atanh(static_cast<float>(a))); }
1374 };
1375 template <typename T>
1376 struct Deg2rad {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Deg2rad1377   T operator()(T a) {
1378     static constexpr float radians_per_degree = M_PI / 180.0f;
1379     return T(static_cast<float>(a) * radians_per_degree);
1380   }
1381 };
1382 template <typename T>
1383 struct Rad2deg {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Rad2deg1384   T operator()(T a) {
1385     static constexpr float degrees_per_radian = 180.0f / M_PI;
1386     return T(static_cast<float>(a) * degrees_per_radian);
1387   }
1388 };
1389 
1390 template <typename T>
1391 struct Eq {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Eq1392   npy_bool operator()(T a, T b) { return a == b; }
1393 };
1394 template <typename T>
1395 struct Ne {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Ne1396   npy_bool operator()(T a, T b) { return a != b; }
1397 };
1398 template <typename T>
1399 struct Lt {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Lt1400   npy_bool operator()(T a, T b) { return a < b; }
1401 };
1402 template <typename T>
1403 struct Gt {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Gt1404   npy_bool operator()(T a, T b) { return a > b; }
1405 };
1406 template <typename T>
1407 struct Le {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Le1408   npy_bool operator()(T a, T b) { return a <= b; }
1409 };
1410 template <typename T>
1411 struct Ge {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Ge1412   npy_bool operator()(T a, T b) { return a >= b; }
1413 };
1414 template <typename T>
1415 struct Maximum {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Maximum1416   T operator()(T a, T b) {
1417     float fa(a), fb(b);
1418     return Eigen::numext::isnan(fa) || fa > fb ? a : b;
1419   }
1420 };
1421 template <typename T>
1422 struct Minimum {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Minimum1423   T operator()(T a, T b) {
1424     float fa(a), fb(b);
1425     return Eigen::numext::isnan(fa) || fa < fb ? a : b;
1426   }
1427 };
1428 template <typename T>
1429 struct Fmax {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Fmax1430   T operator()(T a, T b) {
1431     float fa(a), fb(b);
1432     return Eigen::numext::isnan(fb) || fa > fb ? a : b;
1433   }
1434 };
1435 template <typename T>
1436 struct Fmin {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Fmin1437   T operator()(T a, T b) {
1438     float fa(a), fb(b);
1439     return Eigen::numext::isnan(fb) || fa < fb ? a : b;
1440   }
1441 };
1442 
1443 template <typename T>
1444 struct LogicalNot {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::LogicalNot1445   npy_bool operator()(T a) { return !a; }
1446 };
1447 template <typename T>
1448 struct LogicalAnd {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::LogicalAnd1449   npy_bool operator()(T a, T b) { return a && b; }
1450 };
1451 template <typename T>
1452 struct LogicalOr {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::LogicalOr1453   npy_bool operator()(T a, T b) { return a || b; }
1454 };
1455 template <typename T>
1456 struct LogicalXor {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::LogicalXor1457   npy_bool operator()(T a, T b) {
1458     return static_cast<bool>(a) ^ static_cast<bool>(b);
1459   }
1460 };
1461 
1462 template <typename T>
1463 struct NextAfter;
1464 
1465 template <typename T>
1466 struct Spacing {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::Spacing1467   T operator()(T x) {
1468     // Compute the distance between the input and the next number with greater
1469     // magnitude. The result should have the sign of the input.
1470     T away(std::copysign(std::numeric_limits<float>::infinity(),
1471                          static_cast<float>(x)));
1472     return NextAfter<T>()(x, away) - x;
1473   }
1474 };
1475 
1476 template <typename T>
RegisterUFuncs(PyObject * numpy)1477 bool RegisterUFuncs(PyObject* numpy) {
1478   bool ok =
1479       RegisterUFunc<BinaryUFunc<T, T, ufuncs::Add<T>>, T>(numpy, "add") &&
1480       RegisterUFunc<BinaryUFunc<T, T, ufuncs::Subtract<T>>, T>(numpy,
1481                                                                "subtract") &&
1482       RegisterUFunc<BinaryUFunc<T, T, ufuncs::Multiply<T>>, T>(numpy,
1483                                                                "multiply") &&
1484       RegisterUFunc<BinaryUFunc<T, T, ufuncs::TrueDivide<T>>, T>(numpy,
1485                                                                  "divide") &&
1486       RegisterUFunc<BinaryUFunc<T, T, ufuncs::LogAddExp<T>>, T>(numpy,
1487                                                                 "logaddexp") &&
1488       RegisterUFunc<BinaryUFunc<T, T, ufuncs::LogAddExp2<T>>, T>(
1489           numpy, "logaddexp2") &&
1490       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Negative<T>>, T>(numpy,
1491                                                               "negative") &&
1492       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Positive<T>>, T>(numpy,
1493                                                               "positive") &&
1494       RegisterUFunc<BinaryUFunc<T, T, ufuncs::TrueDivide<T>>, T>(
1495           numpy, "true_divide") &&
1496       RegisterUFunc<BinaryUFunc<T, T, ufuncs::FloorDivide<T>>, T>(
1497           numpy, "floor_divide") &&
1498       RegisterUFunc<BinaryUFunc<T, T, ufuncs::Power<T>>, T>(numpy, "power") &&
1499       RegisterUFunc<BinaryUFunc<T, T, ufuncs::Remainder<T>>, T>(numpy,
1500                                                                 "remainder") &&
1501       RegisterUFunc<BinaryUFunc<T, T, ufuncs::Remainder<T>>, T>(numpy, "mod") &&
1502       RegisterUFunc<BinaryUFunc<T, T, ufuncs::Fmod<T>>, T>(numpy, "fmod") &&
1503       RegisterUFunc<ufuncs::DivmodUFunc<T>, T>(numpy, "divmod") &&
1504       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Abs<T>>, T>(numpy, "absolute") &&
1505       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Abs<T>>, T>(numpy, "fabs") &&
1506       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Rint<T>>, T>(numpy, "rint") &&
1507       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Sign<T>>, T>(numpy, "sign") &&
1508       RegisterUFunc<BinaryUFunc<T, T, ufuncs::Heaviside<T>>, T>(numpy,
1509                                                                 "heaviside") &&
1510       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Conjugate<T>>, T>(numpy,
1511                                                                "conjugate") &&
1512       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Exp<T>>, T>(numpy, "exp") &&
1513       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Exp2<T>>, T>(numpy, "exp2") &&
1514       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Expm1<T>>, T>(numpy, "expm1") &&
1515       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Log<T>>, T>(numpy, "log") &&
1516       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Log2<T>>, T>(numpy, "log2") &&
1517       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Log10<T>>, T>(numpy, "log10") &&
1518       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Log1p<T>>, T>(numpy, "log1p") &&
1519       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Sqrt<T>>, T>(numpy, "sqrt") &&
1520       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Square<T>>, T>(numpy, "square") &&
1521       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Cbrt<T>>, T>(numpy, "cbrt") &&
1522       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Reciprocal<T>>, T>(numpy,
1523                                                                 "reciprocal") &&
1524 
1525       // Trigonometric functions
1526       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Sin<T>>, T>(numpy, "sin") &&
1527       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Cos<T>>, T>(numpy, "cos") &&
1528       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Tan<T>>, T>(numpy, "tan") &&
1529       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arcsin<T>>, T>(numpy, "arcsin") &&
1530       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arccos<T>>, T>(numpy, "arccos") &&
1531       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arctan<T>>, T>(numpy, "arctan") &&
1532       RegisterUFunc<BinaryUFunc<T, T, ufuncs::Arctan2<T>>, T>(numpy,
1533                                                               "arctan2") &&
1534       RegisterUFunc<BinaryUFunc<T, T, ufuncs::Hypot<T>>, T>(numpy, "hypot") &&
1535       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Sinh<T>>, T>(numpy, "sinh") &&
1536       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Cosh<T>>, T>(numpy, "cosh") &&
1537       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Tanh<T>>, T>(numpy, "tanh") &&
1538       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arcsinh<T>>, T>(numpy,
1539                                                              "arcsinh") &&
1540       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arccosh<T>>, T>(numpy,
1541                                                              "arccosh") &&
1542       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arctanh<T>>, T>(numpy,
1543                                                              "arctanh") &&
1544       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Deg2rad<T>>, T>(numpy,
1545                                                              "deg2rad") &&
1546       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Rad2deg<T>>, T>(numpy,
1547                                                              "rad2deg") &&
1548 
1549       // Comparison functions
1550       RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Eq<T>>, T>(numpy, "equal") &&
1551       RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Ne<T>>, T>(numpy,
1552                                                             "not_equal") &&
1553       RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Lt<T>>, T>(numpy, "less") &&
1554       RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Gt<T>>, T>(numpy, "greater") &&
1555       RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Le<T>>, T>(numpy,
1556                                                             "less_equal") &&
1557       RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Ge<T>>, T>(numpy,
1558                                                             "greater_equal") &&
1559       RegisterUFunc<BinaryUFunc<T, T, ufuncs::Maximum<T>>, T>(numpy,
1560                                                               "maximum") &&
1561       RegisterUFunc<BinaryUFunc<T, T, ufuncs::Minimum<T>>, T>(numpy,
1562                                                               "minimum") &&
1563       RegisterUFunc<BinaryUFunc<T, T, ufuncs::Fmax<T>>, T>(numpy, "fmax") &&
1564       RegisterUFunc<BinaryUFunc<T, T, ufuncs::Fmin<T>>, T>(numpy, "fmin") &&
1565       RegisterUFunc<BinaryUFunc<T, bool, ufuncs::LogicalAnd<T>>, T>(
1566           numpy, "logical_and") &&
1567       RegisterUFunc<BinaryUFunc<T, bool, ufuncs::LogicalOr<T>>, T>(
1568           numpy, "logical_or") &&
1569       RegisterUFunc<BinaryUFunc<T, bool, ufuncs::LogicalXor<T>>, T>(
1570           numpy, "logical_xor") &&
1571       RegisterUFunc<UnaryUFunc<T, bool, ufuncs::LogicalNot<T>>, T>(
1572           numpy, "logical_not") &&
1573 
1574       // Floating point functions
1575       RegisterUFunc<UnaryUFunc<T, bool, ufuncs::IsFinite<T>>, T>(numpy,
1576                                                                  "isfinite") &&
1577       RegisterUFunc<UnaryUFunc<T, bool, ufuncs::IsInf<T>>, T>(numpy, "isinf") &&
1578       RegisterUFunc<UnaryUFunc<T, bool, ufuncs::IsNan<T>>, T>(numpy, "isnan") &&
1579       RegisterUFunc<UnaryUFunc<T, bool, ufuncs::SignBit<T>>, T>(numpy,
1580                                                                 "signbit") &&
1581       RegisterUFunc<BinaryUFunc<T, T, ufuncs::CopySign<T>>, T>(numpy,
1582                                                                "copysign") &&
1583       RegisterUFunc<UnaryUFunc2<T, T, T, ufuncs::Modf<T>>, T>(numpy, "modf") &&
1584       RegisterUFunc<BinaryUFunc2<T, int, T, ufuncs::Ldexp<T>>, T>(numpy,
1585                                                                   "ldexp") &&
1586       RegisterUFunc<UnaryUFunc2<T, T, int, ufuncs::Frexp<T>>, T>(numpy,
1587                                                                  "frexp") &&
1588       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Floor<T>>, T>(numpy, "floor") &&
1589       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Ceil<T>>, T>(numpy, "ceil") &&
1590       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Trunc<T>>, T>(numpy, "trunc") &&
1591       RegisterUFunc<BinaryUFunc<T, T, ufuncs::NextAfter<T>>, T>(numpy,
1592                                                                 "nextafter") &&
1593       RegisterUFunc<UnaryUFunc<T, T, ufuncs::Spacing<T>>, T>(numpy, "spacing");
1594 
1595   return ok;
1596 }
1597 
1598 }  // namespace ufuncs
1599 
1600 template <typename T>
RegisterNumpyDtype(PyObject * numpy)1601 bool RegisterNumpyDtype(PyObject* numpy) {
1602   // If another module (presumably either TF or JAX) has registered a bfloat16
1603   // type, use it. We don't want two bfloat16 types if we can avoid it since it
1604   // leads to confusion if we have two different types with the same name. This
1605   // assumes that the other module has a sufficiently complete bfloat16
1606   // implementation. The only known NumPy bfloat16 extension at the time of
1607   // writing is this one (distributed in TF and JAX).
1608   // TODO(phawkins): distribute the bfloat16 extension as its own pip package,
1609   // so we can unambiguously refer to a single canonical definition of bfloat16.
1610   int typenum =
1611       PyArray_TypeNumFromName(const_cast<char*>(TypeDescriptor<T>::kTypeName));
1612   if (typenum != NPY_NOTYPE) {
1613     PyArray_Descr* descr = PyArray_DescrFromType(typenum);
1614     // The test for an argmax function here is to verify that the
1615     // bfloat16 implementation is sufficiently new, and, say, not from
1616     // an older version of TF or JAX.
1617     if (descr && descr->f && descr->f->argmax) {
1618       TypeDescriptor<T>::npy_type = typenum;
1619       TypeDescriptor<T>::type_ptr = descr->typeobj;
1620       return true;
1621     }
1622   }
1623 
1624   TypeDescriptor<T>::type.tp_base = &PyGenericArrType_Type;
1625 
1626   if (PyType_Ready(&TypeDescriptor<T>::type) < 0) {
1627     return false;
1628   }
1629 
1630   // Initializes the NumPy descriptor.
1631   PyArray_ArrFuncs& arr_funcs = CustomFloatTypeDescriptor<T>::arr_funcs;
1632   PyArray_InitArrFuncs(&arr_funcs);
1633   arr_funcs.getitem = NPyCustomFloat_GetItem<T>;
1634   arr_funcs.setitem = NPyCustomFloat_SetItem<T>;
1635   arr_funcs.compare = NPyCustomFloat_Compare<T>;
1636   arr_funcs.copyswapn = NPyCustomFloat_CopySwapN<T>;
1637   arr_funcs.copyswap = NPyCustomFloat_CopySwap<T>;
1638   arr_funcs.nonzero = NPyCustomFloat_NonZero<T>;
1639   arr_funcs.fill = NPyCustomFloat_Fill<T>;
1640   arr_funcs.dotfunc = NPyCustomFloat_DotFunc<T>;
1641   arr_funcs.compare = NPyCustomFloat_CompareFunc<T>;
1642   arr_funcs.argmax = NPyCustomFloat_ArgMaxFunc<T>;
1643   arr_funcs.argmin = NPyCustomFloat_ArgMinFunc<T>;
1644 
1645 #if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_TYPE)
1646   Py_TYPE(&CustomFloatTypeDescriptor<T>::npy_descr) = &PyArrayDescr_Type;
1647 #else
1648   Py_SET_TYPE(&CustomFloatTypeDescriptor<T>::npy_descr, &PyArrayDescr_Type);
1649 #endif
1650   TypeDescriptor<T>::npy_type =
1651       PyArray_RegisterDataType(&CustomFloatTypeDescriptor<T>::npy_descr);
1652   TypeDescriptor<T>::type_ptr = &TypeDescriptor<T>::type;
1653   if (TypeDescriptor<T>::Dtype() < 0) {
1654     return false;
1655   }
1656 
1657   Safe_PyObjectPtr typeDict_obj =
1658       make_safe(PyObject_GetAttrString(numpy, "sctypeDict"));
1659   if (!typeDict_obj) return false;
1660   // Add the type object to `numpy.typeDict`: that makes
1661   // `numpy.dtype(type_name)` work.
1662   if (PyDict_SetItemString(
1663           typeDict_obj.get(), TypeDescriptor<T>::kTypeName,
1664           reinterpret_cast<PyObject*>(&TypeDescriptor<T>::type)) < 0) {
1665     return false;
1666   }
1667 
1668   // Support dtype(type_name)
1669   if (PyDict_SetItemString(TypeDescriptor<T>::type.tp_dict, "dtype",
1670                            reinterpret_cast<PyObject*>(
1671                                &CustomFloatTypeDescriptor<T>::npy_descr)) < 0) {
1672     return false;
1673   }
1674 
1675   return RegisterCasts<T>() && ufuncs::RegisterUFuncs<T>(numpy);
1676 }
1677 
1678 namespace ufuncs {
1679 
1680 template <>
1681 struct CopySign<bfloat16> {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::CopySign1682   bfloat16 operator()(bfloat16 a, bfloat16 b) {
1683     // LLVM is smart enough to turn this into (a & 0x7fff) | (b & 0x8000).
1684     bfloat16 abs_a = Eigen::numext::abs(a);
1685     return std::signbit(static_cast<float>(b)) ? -abs_a : abs_a;
1686   }
1687 };
1688 
1689 template <>
1690 struct NextAfter<bfloat16> {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::NextAfter1691   bfloat16 operator()(bfloat16 from, bfloat16 to) {
1692     uint16_t from_as_int, to_as_int;
1693     const uint16_t sign_mask = 1 << 15;
1694     float from_as_float(from), to_as_float(to);
1695     memcpy(&from_as_int, &from, sizeof(bfloat16));
1696     memcpy(&to_as_int, &to, sizeof(bfloat16));
1697     if (Eigen::numext::isnan(from_as_float) ||
1698         Eigen::numext::isnan(to_as_float)) {
1699       return bfloat16(std::numeric_limits<float>::quiet_NaN());
1700     }
1701     if (from_as_int == to_as_int) {
1702       return to;
1703     }
1704     if (from_as_float == 0) {
1705       if (to_as_float == 0) {
1706         return to;
1707       } else {
1708         // Smallest subnormal signed like `to`.
1709         uint16_t out_int = (to_as_int & sign_mask) | 1;
1710         bfloat16 out;
1711         memcpy(&out, &out_int, sizeof(bfloat16));
1712         return out;
1713       }
1714     }
1715     uint16_t from_sign = from_as_int & sign_mask;
1716     uint16_t to_sign = to_as_int & sign_mask;
1717     uint16_t from_abs = from_as_int & ~sign_mask;
1718     uint16_t to_abs = to_as_int & ~sign_mask;
1719     uint16_t magnitude_adjustment =
1720         (from_abs > to_abs || from_sign != to_sign) ? 0xFFFF : 0x0001;
1721     uint16_t out_int = from_as_int + magnitude_adjustment;
1722     bfloat16 out;
1723     memcpy(&out, &out_int, sizeof(bfloat16));
1724     return out;
1725   }
1726 };
1727 
1728 }  // namespace ufuncs
1729 
1730 using bfloat16 = Eigen::bfloat16;
1731 
1732 template <>
1733 struct TypeDescriptor<bfloat16> : CustomFloatTypeDescriptor<bfloat16> {
1734   typedef bfloat16 T;
1735   static constexpr const char* kTypeName = "bfloat16";
1736   static constexpr const char* kTpDoc = "bfloat16 floating-point values";
1737   // We must register bfloat16 with a kind other than "f", because numpy
1738   // considers two types with the same kind and size to be equal, but
1739   // float16 != bfloat16.
1740   // The downside of this is that NumPy scalar promotion does not work with
1741   // bfloat16 values.
1742   static constexpr char kNpyDescrKind = 'V';
1743   // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
1744   // character is unique.
1745   static constexpr char kNpyDescrType = 'E';
1746   static constexpr char kNpyDescrByteorder = '=';
1747 };
1748 
1749 template <>
1750 struct TypeDescriptor<float8_e4m3b11>
1751     : CustomFloatTypeDescriptor<float8_e4m3b11> {
1752   typedef float8_e4m3b11 T;
1753   static constexpr const char* kTypeName = "float8_e4m3b11";
1754   static constexpr const char* kTpDoc = "float8_e4m3b11 floating-point values";
1755   // We must register float8_e4m3b11 with a kind other than "f", because numpy
1756   // considers two types with the same kind and size to be equal, and we
1757   // expect multiple 1 byte floating point types.
1758   // The downside of this is that NumPy scalar promotion does not work with
1759   // float8_e4m3b11 values.
1760   static constexpr char kNpyDescrKind = 'V';
1761   // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
1762   // character is unique.
1763   static constexpr char kNpyDescrType = 'L';
1764   static constexpr char kNpyDescrByteorder = '=';
1765 };
1766 
1767 namespace ufuncs {
1768 
1769 template <>
1770 struct CopySign<float8_e4m3b11> {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::CopySign1771   float8_e4m3b11 operator()(float8_e4m3b11 a, float8_e4m3b11 b) {
1772     return float8_e4m3b11::FromRep((a.rep() & 0x7f) | (b.rep() & 0x80));
1773   }
1774 };
1775 
1776 template <>
1777 struct NextAfter<float8_e4m3b11> {
operator ()tensorflow::__anon7a1ed7ad0111::ufuncs::NextAfter1778   float8_e4m3b11 operator()(float8_e4m3b11 from, float8_e4m3b11 to) {
1779     uint8_t from_rep = from.rep();
1780     uint8_t to_rep = to.rep();
1781     if (from_rep == 0x80 || to_rep == 0x80) {
1782       return float8_e4m3b11::FromRep(0x80);
1783     }
1784     if (from_rep == to_rep) {
1785       return to;
1786     }
1787     if (from_rep == 0) {
1788       return float8_e4m3b11::FromRep(0x01 | (to_rep & 0x80));
1789     }
1790     const uint16_t sign_mask = 0x80;
1791     uint8_t from_sign = from_rep & sign_mask;
1792     uint8_t to_sign = to_rep & sign_mask;
1793     uint8_t from_abs = from_rep & ~sign_mask;
1794     uint8_t to_abs = to_rep & ~sign_mask;
1795     uint8_t magnitude_adjustment =
1796         (from_abs > to_abs || from_sign != to_sign) ? 0xFF : 0x0001;
1797     uint8_t out_int = from_rep + magnitude_adjustment;
1798     if (out_int == 0x80) {
1799       out_int = 0x0;
1800     }
1801     return float8_e4m3b11::FromRep(out_int);
1802   }
1803 };
1804 
1805 }  // namespace ufuncs
1806 
1807 }  // namespace
1808 
1809 // Initializes the module.
Initialize()1810 bool Initialize() {
1811   ImportNumpy();
1812   import_umath1(false);
1813 
1814   Safe_PyObjectPtr numpy_str = make_safe(PyUnicode_FromString("numpy"));
1815   if (!numpy_str) {
1816     return false;
1817   }
1818   Safe_PyObjectPtr numpy = make_safe(PyImport_Import(numpy_str.get()));
1819   if (!numpy) {
1820     return false;
1821   }
1822 
1823   if (!RegisterNumpyDtype<bfloat16>(numpy.get())) {
1824     return false;
1825   }
1826   if (!RegisterNumpyDtype<float8_e4m3b11>(numpy.get())) {
1827     return false;
1828   }
1829   // TODO(parkers): Enable CanCast to-from fp8 and bf16 and f16.
1830   return true;
1831 }
1832 
RegisterNumpyBfloat16()1833 bool RegisterNumpyBfloat16() {
1834   if (TypeDescriptor<bfloat16>::Dtype() != NPY_NOTYPE) {
1835     // Already initialized.
1836     return true;
1837   }
1838   if (!Initialize()) {
1839     if (!PyErr_Occurred()) {
1840       PyErr_SetString(PyExc_RuntimeError, "cannot load bfloat16 module.");
1841     }
1842     PyErr_Print();
1843     return false;
1844   }
1845   return true;
1846 }
1847 
Bfloat16Dtype()1848 PyObject* Bfloat16Dtype() {
1849   return reinterpret_cast<PyObject*>(TypeDescriptor<bfloat16>::type_ptr);
1850 }
1851 
Bfloat16NumpyType()1852 int Bfloat16NumpyType() { return TypeDescriptor<bfloat16>::Dtype(); }
1853 
Float8_E4M3B11Dtype()1854 PyObject* Float8_E4M3B11Dtype() {
1855   return reinterpret_cast<PyObject*>(TypeDescriptor<float8_e4m3b11>::type_ptr);
1856 }
1857 
1858 }  // namespace tensorflow
1859