xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/dtypes.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Library of dtypes (Tensor element types)."""
16import abc
17import builtins
18from typing import Type, Sequence, Optional
19
20import numpy as np
21
22from tensorflow.core.framework import types_pb2
23# We need to import pywrap_tensorflow prior to the bfloat wrapper to avoid
24# protobuf errors where a file is defined twice on MacOS.
25# pylint: disable=invalid-import-order,g-bad-import-order
26from tensorflow.python import pywrap_tensorflow  # pylint: disable=unused-import
27from tensorflow.python.framework import _dtypes
28from tensorflow.python.types import doc_typealias
29from tensorflow.python.lib.core import _pywrap_bfloat16
30from tensorflow.python.util.tf_export import tf_export
31from tensorflow.python.types import trace
32from tensorflow.core.function import trace_type
33
34_np_bfloat16 = _pywrap_bfloat16.TF_bfloat16_type()
35
36
37class DTypeMeta(type(_dtypes.DType), abc.ABCMeta):
38  pass
39
40
41@tf_export("dtypes.DType", "DType")
42class DType(
43    _dtypes.DType,
44    trace.TraceType,
45    trace_type.Serializable,
46    metaclass=DTypeMeta):
47  """Represents the type of the elements in a `Tensor`.
48
49  `DType`'s are used to specify the output data type for operations which
50  require it, or to inspect the data type of existing `Tensor`'s.
51
52  Examples:
53
54  >>> tf.constant(1, dtype=tf.int64)
55  <tf.Tensor: shape=(), dtype=int64, numpy=1>
56  >>> tf.constant(1.0).dtype
57  tf.float32
58
59  See `tf.dtypes` for a complete list of `DType`'s defined.
60  """
61  __slots__ = ()
62
63  @property
64  def _is_ref_dtype(self):
65    """Returns `True` if this `DType` represents a reference type."""
66    return self._type_enum > 100
67
68  @property
69  def _as_ref(self):
70    """Returns a reference `DType` based on this `DType`."""
71    if self._is_ref_dtype:
72      return self
73    else:
74      return _INTERN_TABLE[self._type_enum + 100]
75
76  @property
77  def base_dtype(self):
78    """Returns a non-reference `DType` based on this `DType`."""
79    if self._is_ref_dtype:
80      return _INTERN_TABLE[self._type_enum - 100]
81    else:
82      return self
83
84  @property
85  def real_dtype(self):
86    """Returns the `DType` corresponding to this `DType`'s real part."""
87    base = self.base_dtype
88    if base == complex64:
89      return float32
90    elif base == complex128:
91      return float64
92    else:
93      return self
94
95  @property
96  def as_numpy_dtype(self):
97    """Returns a Python `type` object based on this `DType`."""
98    return _TF_TO_NP[self._type_enum]
99
100  @property
101  def min(self):
102    """Returns the minimum representable value in this data type.
103
104    Raises:
105      TypeError: if this is a non-numeric, unordered, or quantized type.
106
107    """
108    if (self.is_quantized or
109        self.base_dtype in (bool, string, complex64, complex128)):
110      raise TypeError(f"Cannot find minimum value of {self} with "
111                      f"{'quantized type' if self.is_quantized else 'type'} "
112                      f"{self.base_dtype}.")
113
114    # there is no simple way to get the min value of a dtype, we have to check
115    # float and int types separately
116    try:
117      return np.finfo(self.as_numpy_dtype).min
118    except:  # bare except as possible raises by finfo not documented
119      try:
120        return np.iinfo(self.as_numpy_dtype).min
121      except:
122        if self.base_dtype == bfloat16:
123          return _np_bfloat16(float.fromhex("-0x1.FEp127"))
124        raise TypeError(f"Cannot find minimum value of {self}.")
125
126  @property
127  def max(self):
128    """Returns the maximum representable value in this data type.
129
130    Raises:
131      TypeError: if this is a non-numeric, unordered, or quantized type.
132
133    """
134    if (self.is_quantized or
135        self.base_dtype in (bool, string, complex64, complex128)):
136      raise TypeError(f"Cannot find maximum value of {self} with "
137                      f"{'quantized type' if self.is_quantized else 'type'} "
138                      f"{self.base_dtype}.")
139
140    # there is no simple way to get the max value of a dtype, we have to check
141    # float and int types separately
142    try:
143      return np.finfo(self.as_numpy_dtype).max
144    except:  # bare except as possible raises by finfo not documented
145      try:
146        return np.iinfo(self.as_numpy_dtype).max
147      except:
148        if self.base_dtype == bfloat16:
149          return _np_bfloat16(float.fromhex("0x1.FEp127"))
150        raise TypeError(f"Cannot find maximum value of {self}.")
151
152  @property
153  def limits(self, clip_negative=True):
154    """Return intensity limits, i.e.
155
156    (min, max) tuple, of the dtype.
157    Args:
158      clip_negative : bool, optional If True, clip the negative range (i.e.
159        return 0 for min intensity) even if the image dtype allows negative
160        values. Returns
161      min, max : tuple Lower and upper intensity limits.
162    """
163    if self.as_numpy_dtype in dtype_range:
164      min, max = dtype_range[self.as_numpy_dtype]  # pylint: disable=redefined-builtin
165    else:
166      raise ValueError(str(self) + " does not have defined limits.")
167
168    if clip_negative:
169      min = 0  # pylint: disable=redefined-builtin
170    return min, max
171
172  def is_compatible_with(self, other):
173    """Returns True if the `other` DType will be converted to this DType.
174
175    The conversion rules are as follows:
176
177    ```python
178    DType(T)       .is_compatible_with(DType(T))        == True
179    ```
180
181    Args:
182      other: A `DType` (or object that may be converted to a `DType`).
183
184    Returns:
185      True if a Tensor of the `other` `DType` will be implicitly converted to
186      this `DType`.
187    """
188    other = as_dtype(other)
189    return self._type_enum in (other.as_datatype_enum,
190                               other.base_dtype.as_datatype_enum)
191
192  def is_subtype_of(self, other: trace.TraceType) -> bool:
193    """See tf.types.experimental.TraceType base class."""
194    return self == other
195
196  def most_specific_common_supertype(
197      self, types: Sequence[trace.TraceType]) -> Optional["DType"]:
198    """See tf.types.experimental.TraceType base class."""
199    return self if all(self == other for other in types) else None
200
201  @classmethod
202  def experimental_type_proto(cls) -> Type[types_pb2.SerializedDType]:
203    """Returns the type of proto associated with DType serialization."""
204    return types_pb2.SerializedDType
205
206  @classmethod
207  def experimental_from_proto(cls, proto: types_pb2.SerializedDType) -> "DType":
208    """Returns a Dtype instance based on the serialized proto."""
209    return DType(proto.datatype)
210
211  def experimental_as_proto(self) -> types_pb2.SerializedDType:
212    """Returns a proto representation of the Dtype instance."""
213    return types_pb2.SerializedDType(datatype=self._type_enum)
214
215  def __eq__(self, other):
216    """Returns True iff this DType refers to the same type as `other`."""
217    if other is None:
218      return False
219
220    if type(other) != DType:  # pylint: disable=unidiomatic-typecheck
221      try:
222        other = as_dtype(other)
223      except TypeError:
224        return False
225
226    return self._type_enum == other._type_enum  # pylint: disable=protected-access
227
228  def __ne__(self, other):
229    """Returns True iff self != other."""
230    return not self.__eq__(other)
231
232  # "If a class that overrides __eq__() needs to retain the implementation
233  #  of __hash__() from a parent class, the interpreter must be told this
234  #  explicitly by setting __hash__ = <ParentClass>.__hash__."
235  # TODO(slebedev): Remove once __eq__ and __ne__ are implemented in C++.
236  __hash__ = _dtypes.DType.__hash__
237
238  def __reduce__(self):
239    return as_dtype, (self.name,)
240
241trace_type.register_serializable(DType)
242
243# Define data type range of numpy dtype
244dtype_range = {
245    np.bool_: (False, True),
246    np.bool8: (False, True),
247    np.uint8: (0, 255),
248    np.uint16: (0, 65535),
249    np.int8: (-128, 127),
250    np.int16: (-32768, 32767),
251    np.int64: (-2**63, 2**63 - 1),
252    np.uint64: (0, 2**64 - 1),
253    np.int32: (-2**31, 2**31 - 1),
254    np.uint32: (0, 2**32 - 1),
255    np.float32: (-1, 1),
256    np.float64: (-1, 1)
257}
258
259# Define standard wrappers for the types_pb2.DataType enum.
260resource = DType(types_pb2.DT_RESOURCE)
261doc_typealias.document(
262    obj=resource,
263    doc="Handle to a mutable, dynamically allocated resource.")
264tf_export("dtypes.resource", "resource").export_constant(__name__, "resource")
265
266variant = DType(types_pb2.DT_VARIANT)
267doc_typealias.document(
268    obj=variant,
269    doc="Data of arbitrary type (known at runtime).")
270tf_export("dtypes.variant", "variant").export_constant(__name__, "variant")
271
272uint8 = DType(types_pb2.DT_UINT8)
273doc_typealias.document(
274    obj=uint8,
275    doc="Unsigned 8-bit (byte) integer.")
276tf_export("dtypes.uint8", "uint8").export_constant(__name__, "uint8")
277
278uint16 = DType(types_pb2.DT_UINT16)
279doc_typealias.document(
280    obj=uint16,
281    doc="Unsigned 16-bit (word) integer.")
282tf_export("dtypes.uint16", "uint16").export_constant(__name__, "uint16")
283
284uint32 = DType(types_pb2.DT_UINT32)
285doc_typealias.document(
286    obj=uint32,
287    doc="Unsigned 32-bit (dword) integer.")
288tf_export("dtypes.uint32", "uint32").export_constant(__name__, "uint32")
289
290uint64 = DType(types_pb2.DT_UINT64)
291doc_typealias.document(
292    obj=uint64,
293    doc="Unsigned 64-bit (qword) integer.")
294tf_export("dtypes.uint64", "uint64").export_constant(__name__, "uint64")
295
296int8 = DType(types_pb2.DT_INT8)
297doc_typealias.document(
298    obj=int8,
299    doc="Signed 8-bit integer.")
300tf_export("dtypes.int8", "int8").export_constant(__name__, "int8")
301
302int16 = DType(types_pb2.DT_INT16)
303doc_typealias.document(
304    obj=int16,
305    doc="Signed 16-bit integer.")
306tf_export("dtypes.int16", "int16").export_constant(__name__, "int16")
307
308int32 = DType(types_pb2.DT_INT32)
309doc_typealias.document(
310    obj=int32,
311    doc="Signed 32-bit integer.")
312tf_export("dtypes.int32", "int32").export_constant(__name__, "int32")
313
314int64 = DType(types_pb2.DT_INT64)
315doc_typealias.document(
316    obj=int64,
317    doc="Signed 64-bit integer.")
318tf_export("dtypes.int64", "int64").export_constant(__name__, "int64")
319
320float16 = DType(types_pb2.DT_HALF)
321half = float16
322doc_typealias.document(
323    obj=float16,
324    doc="16-bit (half precision) floating-point.")
325tf_export("dtypes.float16", "float16").export_constant(__name__, "float16")
326tf_export("dtypes.half", "half").export_constant(__name__, "half")
327
328float32 = DType(types_pb2.DT_FLOAT)
329doc_typealias.document(
330    obj=float32,
331    doc="32-bit (single precision) floating-point.")
332tf_export("dtypes.float32", "float32").export_constant(__name__, "float32")
333
334float64 = DType(types_pb2.DT_DOUBLE)
335doc_typealias.document(
336    obj=float64,
337    doc="64-bit (double precision) floating-point.")
338tf_export("dtypes.float64", "float64").export_constant(__name__, "float64")
339double = float64
340tf_export("dtypes.double", "double").export_constant(__name__, "double")
341
342complex64 = DType(types_pb2.DT_COMPLEX64)
343doc_typealias.document(
344    obj=complex64,
345    doc="64-bit complex.")
346tf_export("dtypes.complex64",
347          "complex64").export_constant(__name__, "complex64")
348
349complex128 = DType(types_pb2.DT_COMPLEX128)
350doc_typealias.document(
351    obj=complex128,
352    doc="128-bit complex.")
353tf_export("dtypes.complex128",
354          "complex128").export_constant(__name__, "complex128")
355
356string = DType(types_pb2.DT_STRING)
357doc_typealias.document(
358    obj=string,
359    doc="Variable-length string, represented as byte array.")
360tf_export("dtypes.string", "string").export_constant(__name__, "string")
361
362bool = DType(types_pb2.DT_BOOL)  # pylint: disable=redefined-builtin
363doc_typealias.document(
364    obj=bool,
365    doc="Boolean.")
366tf_export("dtypes.bool", "bool").export_constant(__name__, "bool")
367
368qint8 = DType(types_pb2.DT_QINT8)
369doc_typealias.document(
370    obj=qint8,
371    doc="Signed quantized 8-bit integer.")
372tf_export("dtypes.qint8", "qint8").export_constant(__name__, "qint8")
373
374qint16 = DType(types_pb2.DT_QINT16)
375doc_typealias.document(
376    obj=qint16,
377    doc="Signed quantized 16-bit integer.")
378tf_export("dtypes.qint16", "qint16").export_constant(__name__, "qint16")
379
380qint32 = DType(types_pb2.DT_QINT32)
381doc_typealias.document(
382    obj=qint32,
383    doc="signed quantized 32-bit integer.")
384tf_export("dtypes.qint32", "qint32").export_constant(__name__, "qint32")
385
386quint8 = DType(types_pb2.DT_QUINT8)
387doc_typealias.document(
388    obj=quint8,
389    doc="Unsigned quantized 8-bit integer.")
390tf_export("dtypes.quint8", "quint8").export_constant(__name__, "quint8")
391
392quint16 = DType(types_pb2.DT_QUINT16)
393doc_typealias.document(
394    obj=quint16,
395    doc="Unsigned quantized 16-bit integer.")
396tf_export("dtypes.quint16", "quint16").export_constant(__name__, "quint16")
397
398bfloat16 = DType(types_pb2.DT_BFLOAT16)
399doc_typealias.document(
400    obj=bfloat16,
401    doc="16-bit bfloat (brain floating point).")
402tf_export("dtypes.bfloat16", "bfloat16").export_constant(__name__, "bfloat16")
403
404resource_ref = DType(types_pb2.DT_RESOURCE_REF)
405variant_ref = DType(types_pb2.DT_VARIANT_REF)
406float16_ref = DType(types_pb2.DT_HALF_REF)
407half_ref = float16_ref
408float32_ref = DType(types_pb2.DT_FLOAT_REF)
409float64_ref = DType(types_pb2.DT_DOUBLE_REF)
410double_ref = float64_ref
411int32_ref = DType(types_pb2.DT_INT32_REF)
412uint32_ref = DType(types_pb2.DT_UINT32_REF)
413uint8_ref = DType(types_pb2.DT_UINT8_REF)
414uint16_ref = DType(types_pb2.DT_UINT16_REF)
415int16_ref = DType(types_pb2.DT_INT16_REF)
416int8_ref = DType(types_pb2.DT_INT8_REF)
417string_ref = DType(types_pb2.DT_STRING_REF)
418complex64_ref = DType(types_pb2.DT_COMPLEX64_REF)
419complex128_ref = DType(types_pb2.DT_COMPLEX128_REF)
420int64_ref = DType(types_pb2.DT_INT64_REF)
421uint64_ref = DType(types_pb2.DT_UINT64_REF)
422bool_ref = DType(types_pb2.DT_BOOL_REF)
423qint8_ref = DType(types_pb2.DT_QINT8_REF)
424quint8_ref = DType(types_pb2.DT_QUINT8_REF)
425qint16_ref = DType(types_pb2.DT_QINT16_REF)
426quint16_ref = DType(types_pb2.DT_QUINT16_REF)
427qint32_ref = DType(types_pb2.DT_QINT32_REF)
428bfloat16_ref = DType(types_pb2.DT_BFLOAT16_REF)
429
430# Maintain an intern table so that we don't have to create a large
431# number of small objects.
432_INTERN_TABLE = {
433    types_pb2.DT_HALF: float16,
434    types_pb2.DT_FLOAT: float32,
435    types_pb2.DT_DOUBLE: float64,
436    types_pb2.DT_INT32: int32,
437    types_pb2.DT_UINT8: uint8,
438    types_pb2.DT_UINT16: uint16,
439    types_pb2.DT_UINT32: uint32,
440    types_pb2.DT_UINT64: uint64,
441    types_pb2.DT_INT16: int16,
442    types_pb2.DT_INT8: int8,
443    types_pb2.DT_STRING: string,
444    types_pb2.DT_COMPLEX64: complex64,
445    types_pb2.DT_COMPLEX128: complex128,
446    types_pb2.DT_INT64: int64,
447    types_pb2.DT_BOOL: bool,
448    types_pb2.DT_QINT8: qint8,
449    types_pb2.DT_QUINT8: quint8,
450    types_pb2.DT_QINT16: qint16,
451    types_pb2.DT_QUINT16: quint16,
452    types_pb2.DT_QINT32: qint32,
453    types_pb2.DT_BFLOAT16: bfloat16,
454    types_pb2.DT_RESOURCE: resource,
455    types_pb2.DT_VARIANT: variant,
456    types_pb2.DT_HALF_REF: float16_ref,
457    types_pb2.DT_FLOAT_REF: float32_ref,
458    types_pb2.DT_DOUBLE_REF: float64_ref,
459    types_pb2.DT_INT32_REF: int32_ref,
460    types_pb2.DT_UINT32_REF: uint32_ref,
461    types_pb2.DT_UINT8_REF: uint8_ref,
462    types_pb2.DT_UINT16_REF: uint16_ref,
463    types_pb2.DT_INT16_REF: int16_ref,
464    types_pb2.DT_INT8_REF: int8_ref,
465    types_pb2.DT_STRING_REF: string_ref,
466    types_pb2.DT_COMPLEX64_REF: complex64_ref,
467    types_pb2.DT_COMPLEX128_REF: complex128_ref,
468    types_pb2.DT_INT64_REF: int64_ref,
469    types_pb2.DT_UINT64_REF: uint64_ref,
470    types_pb2.DT_BOOL_REF: bool_ref,
471    types_pb2.DT_QINT8_REF: qint8_ref,
472    types_pb2.DT_QUINT8_REF: quint8_ref,
473    types_pb2.DT_QINT16_REF: qint16_ref,
474    types_pb2.DT_QUINT16_REF: quint16_ref,
475    types_pb2.DT_QINT32_REF: qint32_ref,
476    types_pb2.DT_BFLOAT16_REF: bfloat16_ref,
477    types_pb2.DT_RESOURCE_REF: resource_ref,
478    types_pb2.DT_VARIANT_REF: variant_ref,
479}
480
481# Standard mappings between types_pb2.DataType values and string names.
482_TYPE_TO_STRING = {
483    types_pb2.DT_HALF: "float16",
484    types_pb2.DT_FLOAT: "float32",
485    types_pb2.DT_DOUBLE: "float64",
486    types_pb2.DT_INT32: "int32",
487    types_pb2.DT_UINT8: "uint8",
488    types_pb2.DT_UINT16: "uint16",
489    types_pb2.DT_UINT32: "uint32",
490    types_pb2.DT_UINT64: "uint64",
491    types_pb2.DT_INT16: "int16",
492    types_pb2.DT_INT8: "int8",
493    types_pb2.DT_STRING: "string",
494    types_pb2.DT_COMPLEX64: "complex64",
495    types_pb2.DT_COMPLEX128: "complex128",
496    types_pb2.DT_INT64: "int64",
497    types_pb2.DT_BOOL: "bool",
498    types_pb2.DT_QINT8: "qint8",
499    types_pb2.DT_QUINT8: "quint8",
500    types_pb2.DT_QINT16: "qint16",
501    types_pb2.DT_QUINT16: "quint16",
502    types_pb2.DT_QINT32: "qint32",
503    types_pb2.DT_BFLOAT16: "bfloat16",
504    types_pb2.DT_RESOURCE: "resource",
505    types_pb2.DT_VARIANT: "variant",
506    types_pb2.DT_HALF_REF: "float16_ref",
507    types_pb2.DT_FLOAT_REF: "float32_ref",
508    types_pb2.DT_DOUBLE_REF: "float64_ref",
509    types_pb2.DT_INT32_REF: "int32_ref",
510    types_pb2.DT_UINT32_REF: "uint32_ref",
511    types_pb2.DT_UINT8_REF: "uint8_ref",
512    types_pb2.DT_UINT16_REF: "uint16_ref",
513    types_pb2.DT_INT16_REF: "int16_ref",
514    types_pb2.DT_INT8_REF: "int8_ref",
515    types_pb2.DT_STRING_REF: "string_ref",
516    types_pb2.DT_COMPLEX64_REF: "complex64_ref",
517    types_pb2.DT_COMPLEX128_REF: "complex128_ref",
518    types_pb2.DT_INT64_REF: "int64_ref",
519    types_pb2.DT_UINT64_REF: "uint64_ref",
520    types_pb2.DT_BOOL_REF: "bool_ref",
521    types_pb2.DT_QINT8_REF: "qint8_ref",
522    types_pb2.DT_QUINT8_REF: "quint8_ref",
523    types_pb2.DT_QINT16_REF: "qint16_ref",
524    types_pb2.DT_QUINT16_REF: "quint16_ref",
525    types_pb2.DT_QINT32_REF: "qint32_ref",
526    types_pb2.DT_BFLOAT16_REF: "bfloat16_ref",
527    types_pb2.DT_RESOURCE_REF: "resource_ref",
528    types_pb2.DT_VARIANT_REF: "variant_ref",
529}
530_STRING_TO_TF = {
531    value: _INTERN_TABLE[key] for key, value in _TYPE_TO_STRING.items()
532}
533# Add non-canonical aliases.
534_STRING_TO_TF["half"] = float16
535_STRING_TO_TF["half_ref"] = float16_ref
536_STRING_TO_TF["float"] = float32
537_STRING_TO_TF["float_ref"] = float32_ref
538_STRING_TO_TF["double"] = float64
539_STRING_TO_TF["double_ref"] = float64_ref
540
541# Numpy representation for quantized dtypes.
542#
543# These are magic strings that are used in the swig wrapper to identify
544# quantized types.
545# TODO(mrry,keveman): Investigate Numpy type registration to replace this
546# hard-coding of names.
547_np_qint8 = np.dtype([("qint8", np.int8)])
548_np_quint8 = np.dtype([("quint8", np.uint8)])
549_np_qint16 = np.dtype([("qint16", np.int16)])
550_np_quint16 = np.dtype([("quint16", np.uint16)])
551_np_qint32 = np.dtype([("qint32", np.int32)])
552
553# _np_bfloat16 is defined by a module import.
554
555# Custom struct dtype for directly-fed ResourceHandles of supported type(s).
556np_resource = np.dtype([("resource", np.ubyte)])
557
558# Standard mappings between types_pb2.DataType values and numpy.dtypes.
559_NP_TO_TF = {
560    np.float16: float16,
561    np.float32: float32,
562    np.float64: float64,
563    np.int32: int32,
564    np.int64: int64,
565    np.uint8: uint8,
566    np.uint16: uint16,
567    np.uint32: uint32,
568    np.uint64: uint64,
569    np.int16: int16,
570    np.int8: int8,
571    np.complex64: complex64,
572    np.complex128: complex128,
573    np.object_: string,
574    np.bytes_: string,
575    np.str_: string,
576    np.bool_: bool,
577    _np_qint8: qint8,
578    _np_quint8: quint8,
579    _np_qint16: qint16,
580    _np_quint16: quint16,
581    _np_qint32: qint32,
582    _np_bfloat16: bfloat16,
583}
584
585# Map (some) NumPy platform dtypes to TF ones using their fixed-width
586# synonyms. Note that platform dtypes are not always simples aliases,
587# i.e. reference equality is not guaranteed. See e.g. numpy/numpy#9799.
588for pdt in [
589    np.intc,
590    np.uintc,
591    np.int_,
592    np.uint,
593    np.longlong,
594    np.ulonglong,
595]:
596  if pdt not in _NP_TO_TF:
597    _NP_TO_TF[pdt] = next(
598        _NP_TO_TF[dt] for dt in _NP_TO_TF if dt == pdt().dtype)  # pylint: disable=no-value-for-parameter
599
600TF_VALUE_DTYPES = set(_NP_TO_TF.values())
601
602_TF_TO_NP = {
603    types_pb2.DT_HALF:
604        np.float16,
605    types_pb2.DT_FLOAT:
606        np.float32,
607    types_pb2.DT_DOUBLE:
608        np.float64,
609    types_pb2.DT_INT32:
610        np.int32,
611    types_pb2.DT_UINT8:
612        np.uint8,
613    types_pb2.DT_UINT16:
614        np.uint16,
615    types_pb2.DT_UINT32:
616        np.uint32,
617    types_pb2.DT_UINT64:
618        np.uint64,
619    types_pb2.DT_INT16:
620        np.int16,
621    types_pb2.DT_INT8:
622        np.int8,
623    # NOTE(touts): For strings we use object as it supports variable length
624    # strings.
625    types_pb2.DT_STRING:
626        object,
627    types_pb2.DT_COMPLEX64:
628        np.complex64,
629    types_pb2.DT_COMPLEX128:
630        np.complex128,
631    types_pb2.DT_INT64:
632        np.int64,
633    types_pb2.DT_BOOL:
634        np.bool_,
635    types_pb2.DT_QINT8:
636        _np_qint8,
637    types_pb2.DT_QUINT8:
638        _np_quint8,
639    types_pb2.DT_QINT16:
640        _np_qint16,
641    types_pb2.DT_QUINT16:
642        _np_quint16,
643    types_pb2.DT_QINT32:
644        _np_qint32,
645    types_pb2.DT_BFLOAT16:
646        _np_bfloat16,
647
648    # Ref types
649    types_pb2.DT_HALF_REF:
650        np.float16,
651    types_pb2.DT_FLOAT_REF:
652        np.float32,
653    types_pb2.DT_DOUBLE_REF:
654        np.float64,
655    types_pb2.DT_INT32_REF:
656        np.int32,
657    types_pb2.DT_UINT32_REF:
658        np.uint32,
659    types_pb2.DT_UINT8_REF:
660        np.uint8,
661    types_pb2.DT_UINT16_REF:
662        np.uint16,
663    types_pb2.DT_INT16_REF:
664        np.int16,
665    types_pb2.DT_INT8_REF:
666        np.int8,
667    types_pb2.DT_STRING_REF:
668        np.object_,
669    types_pb2.DT_COMPLEX64_REF:
670        np.complex64,
671    types_pb2.DT_COMPLEX128_REF:
672        np.complex128,
673    types_pb2.DT_INT64_REF:
674        np.int64,
675    types_pb2.DT_UINT64_REF:
676        np.uint64,
677    types_pb2.DT_BOOL_REF:
678        np.bool_,
679    types_pb2.DT_QINT8_REF:
680        _np_qint8,
681    types_pb2.DT_QUINT8_REF:
682        _np_quint8,
683    types_pb2.DT_QINT16_REF:
684        _np_qint16,
685    types_pb2.DT_QUINT16_REF:
686        _np_quint16,
687    types_pb2.DT_QINT32_REF:
688        _np_qint32,
689    types_pb2.DT_BFLOAT16_REF:
690        _np_bfloat16,
691}
692
693_QUANTIZED_DTYPES_NO_REF = frozenset([qint8, quint8, qint16, quint16, qint32])
694_QUANTIZED_DTYPES_REF = frozenset(
695    [qint8_ref, quint8_ref, qint16_ref, quint16_ref, qint32_ref])
696QUANTIZED_DTYPES = _QUANTIZED_DTYPES_REF.union(_QUANTIZED_DTYPES_NO_REF)
697tf_export(
698    "dtypes.QUANTIZED_DTYPES",
699    v1=["dtypes.QUANTIZED_DTYPES",
700        "QUANTIZED_DTYPES"]).export_constant(__name__, "QUANTIZED_DTYPES")
701
702_PYTHON_TO_TF = {
703    builtins.float: float32,
704    builtins.bool: bool,
705    builtins.object: string
706}
707
708_ANY_TO_TF = {}
709_ANY_TO_TF.update(_INTERN_TABLE)
710_ANY_TO_TF.update(_STRING_TO_TF)
711_ANY_TO_TF.update(_PYTHON_TO_TF)
712_ANY_TO_TF.update(_NP_TO_TF)
713
714# Ensure no collisions.
715assert len(_ANY_TO_TF) == sum(
716    len(d) for d in [_INTERN_TABLE, _STRING_TO_TF, _PYTHON_TO_TF, _NP_TO_TF])
717
718
719@tf_export("dtypes.as_dtype", "as_dtype")
720def as_dtype(type_value):
721  """Converts the given `type_value` to a `DType`.
722
723  Note: `DType` values are interned. When passed a new `DType` object,
724  `as_dtype` always returns the interned value.
725
726  Args:
727    type_value: A value that can be converted to a `tf.DType` object. This may
728      currently be a `tf.DType` object, a [`DataType`
729      enum](https://www.tensorflow.org/code/tensorflow/core/framework/types.proto),
730        a string type name, or a [`numpy.dtype`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html).
731
732  Returns:
733    A `DType` corresponding to `type_value`.
734
735  Raises:
736    TypeError: If `type_value` cannot be converted to a `DType`.
737  """
738  if isinstance(type_value, DType):
739    return _INTERN_TABLE[type_value.as_datatype_enum]
740
741  if isinstance(type_value, np.dtype):
742    try:
743      return _NP_TO_TF[type_value.type]
744    except KeyError:
745      pass
746
747  try:
748    return _ANY_TO_TF[type_value]
749  except (KeyError, TypeError):
750    # TypeError indicates that type_value is not hashable.
751    pass
752
753  if hasattr(type_value, "dtype"):
754    try:
755      return _NP_TO_TF[np.dtype(type_value.dtype).type]
756    except (KeyError, TypeError):
757      pass
758
759  if isinstance(type_value, _dtypes.DType):
760    return _INTERN_TABLE[type_value.as_datatype_enum]
761
762  raise TypeError(f"Cannot convert the argument `type_value`: {type_value!r} "
763                  "to a TensorFlow DType.")
764