1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Library of dtypes (Tensor element types).""" 16import abc 17import builtins 18from typing import Type, Sequence, Optional 19 20import numpy as np 21 22from tensorflow.core.framework import types_pb2 23# We need to import pywrap_tensorflow prior to the bfloat wrapper to avoid 24# protobuf errors where a file is defined twice on MacOS. 25# pylint: disable=invalid-import-order,g-bad-import-order 26from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import 27from tensorflow.python.framework import _dtypes 28from tensorflow.python.types import doc_typealias 29from tensorflow.python.lib.core import _pywrap_bfloat16 30from tensorflow.python.util.tf_export import tf_export 31from tensorflow.python.types import trace 32from tensorflow.core.function import trace_type 33 34_np_bfloat16 = _pywrap_bfloat16.TF_bfloat16_type() 35 36 37class DTypeMeta(type(_dtypes.DType), abc.ABCMeta): 38 pass 39 40 41@tf_export("dtypes.DType", "DType") 42class DType( 43 _dtypes.DType, 44 trace.TraceType, 45 trace_type.Serializable, 46 metaclass=DTypeMeta): 47 """Represents the type of the elements in a `Tensor`. 48 49 `DType`'s are used to specify the output data type for operations which 50 require it, or to inspect the data type of existing `Tensor`'s. 51 52 Examples: 53 54 >>> tf.constant(1, dtype=tf.int64) 55 <tf.Tensor: shape=(), dtype=int64, numpy=1> 56 >>> tf.constant(1.0).dtype 57 tf.float32 58 59 See `tf.dtypes` for a complete list of `DType`'s defined. 60 """ 61 __slots__ = () 62 63 @property 64 def _is_ref_dtype(self): 65 """Returns `True` if this `DType` represents a reference type.""" 66 return self._type_enum > 100 67 68 @property 69 def _as_ref(self): 70 """Returns a reference `DType` based on this `DType`.""" 71 if self._is_ref_dtype: 72 return self 73 else: 74 return _INTERN_TABLE[self._type_enum + 100] 75 76 @property 77 def base_dtype(self): 78 """Returns a non-reference `DType` based on this `DType`.""" 79 if self._is_ref_dtype: 80 return _INTERN_TABLE[self._type_enum - 100] 81 else: 82 return self 83 84 @property 85 def real_dtype(self): 86 """Returns the `DType` corresponding to this `DType`'s real part.""" 87 base = self.base_dtype 88 if base == complex64: 89 return float32 90 elif base == complex128: 91 return float64 92 else: 93 return self 94 95 @property 96 def as_numpy_dtype(self): 97 """Returns a Python `type` object based on this `DType`.""" 98 return _TF_TO_NP[self._type_enum] 99 100 @property 101 def min(self): 102 """Returns the minimum representable value in this data type. 103 104 Raises: 105 TypeError: if this is a non-numeric, unordered, or quantized type. 106 107 """ 108 if (self.is_quantized or 109 self.base_dtype in (bool, string, complex64, complex128)): 110 raise TypeError(f"Cannot find minimum value of {self} with " 111 f"{'quantized type' if self.is_quantized else 'type'} " 112 f"{self.base_dtype}.") 113 114 # there is no simple way to get the min value of a dtype, we have to check 115 # float and int types separately 116 try: 117 return np.finfo(self.as_numpy_dtype).min 118 except: # bare except as possible raises by finfo not documented 119 try: 120 return np.iinfo(self.as_numpy_dtype).min 121 except: 122 if self.base_dtype == bfloat16: 123 return _np_bfloat16(float.fromhex("-0x1.FEp127")) 124 raise TypeError(f"Cannot find minimum value of {self}.") 125 126 @property 127 def max(self): 128 """Returns the maximum representable value in this data type. 129 130 Raises: 131 TypeError: if this is a non-numeric, unordered, or quantized type. 132 133 """ 134 if (self.is_quantized or 135 self.base_dtype in (bool, string, complex64, complex128)): 136 raise TypeError(f"Cannot find maximum value of {self} with " 137 f"{'quantized type' if self.is_quantized else 'type'} " 138 f"{self.base_dtype}.") 139 140 # there is no simple way to get the max value of a dtype, we have to check 141 # float and int types separately 142 try: 143 return np.finfo(self.as_numpy_dtype).max 144 except: # bare except as possible raises by finfo not documented 145 try: 146 return np.iinfo(self.as_numpy_dtype).max 147 except: 148 if self.base_dtype == bfloat16: 149 return _np_bfloat16(float.fromhex("0x1.FEp127")) 150 raise TypeError(f"Cannot find maximum value of {self}.") 151 152 @property 153 def limits(self, clip_negative=True): 154 """Return intensity limits, i.e. 155 156 (min, max) tuple, of the dtype. 157 Args: 158 clip_negative : bool, optional If True, clip the negative range (i.e. 159 return 0 for min intensity) even if the image dtype allows negative 160 values. Returns 161 min, max : tuple Lower and upper intensity limits. 162 """ 163 if self.as_numpy_dtype in dtype_range: 164 min, max = dtype_range[self.as_numpy_dtype] # pylint: disable=redefined-builtin 165 else: 166 raise ValueError(str(self) + " does not have defined limits.") 167 168 if clip_negative: 169 min = 0 # pylint: disable=redefined-builtin 170 return min, max 171 172 def is_compatible_with(self, other): 173 """Returns True if the `other` DType will be converted to this DType. 174 175 The conversion rules are as follows: 176 177 ```python 178 DType(T) .is_compatible_with(DType(T)) == True 179 ``` 180 181 Args: 182 other: A `DType` (or object that may be converted to a `DType`). 183 184 Returns: 185 True if a Tensor of the `other` `DType` will be implicitly converted to 186 this `DType`. 187 """ 188 other = as_dtype(other) 189 return self._type_enum in (other.as_datatype_enum, 190 other.base_dtype.as_datatype_enum) 191 192 def is_subtype_of(self, other: trace.TraceType) -> bool: 193 """See tf.types.experimental.TraceType base class.""" 194 return self == other 195 196 def most_specific_common_supertype( 197 self, types: Sequence[trace.TraceType]) -> Optional["DType"]: 198 """See tf.types.experimental.TraceType base class.""" 199 return self if all(self == other for other in types) else None 200 201 @classmethod 202 def experimental_type_proto(cls) -> Type[types_pb2.SerializedDType]: 203 """Returns the type of proto associated with DType serialization.""" 204 return types_pb2.SerializedDType 205 206 @classmethod 207 def experimental_from_proto(cls, proto: types_pb2.SerializedDType) -> "DType": 208 """Returns a Dtype instance based on the serialized proto.""" 209 return DType(proto.datatype) 210 211 def experimental_as_proto(self) -> types_pb2.SerializedDType: 212 """Returns a proto representation of the Dtype instance.""" 213 return types_pb2.SerializedDType(datatype=self._type_enum) 214 215 def __eq__(self, other): 216 """Returns True iff this DType refers to the same type as `other`.""" 217 if other is None: 218 return False 219 220 if type(other) != DType: # pylint: disable=unidiomatic-typecheck 221 try: 222 other = as_dtype(other) 223 except TypeError: 224 return False 225 226 return self._type_enum == other._type_enum # pylint: disable=protected-access 227 228 def __ne__(self, other): 229 """Returns True iff self != other.""" 230 return not self.__eq__(other) 231 232 # "If a class that overrides __eq__() needs to retain the implementation 233 # of __hash__() from a parent class, the interpreter must be told this 234 # explicitly by setting __hash__ = <ParentClass>.__hash__." 235 # TODO(slebedev): Remove once __eq__ and __ne__ are implemented in C++. 236 __hash__ = _dtypes.DType.__hash__ 237 238 def __reduce__(self): 239 return as_dtype, (self.name,) 240 241trace_type.register_serializable(DType) 242 243# Define data type range of numpy dtype 244dtype_range = { 245 np.bool_: (False, True), 246 np.bool8: (False, True), 247 np.uint8: (0, 255), 248 np.uint16: (0, 65535), 249 np.int8: (-128, 127), 250 np.int16: (-32768, 32767), 251 np.int64: (-2**63, 2**63 - 1), 252 np.uint64: (0, 2**64 - 1), 253 np.int32: (-2**31, 2**31 - 1), 254 np.uint32: (0, 2**32 - 1), 255 np.float32: (-1, 1), 256 np.float64: (-1, 1) 257} 258 259# Define standard wrappers for the types_pb2.DataType enum. 260resource = DType(types_pb2.DT_RESOURCE) 261doc_typealias.document( 262 obj=resource, 263 doc="Handle to a mutable, dynamically allocated resource.") 264tf_export("dtypes.resource", "resource").export_constant(__name__, "resource") 265 266variant = DType(types_pb2.DT_VARIANT) 267doc_typealias.document( 268 obj=variant, 269 doc="Data of arbitrary type (known at runtime).") 270tf_export("dtypes.variant", "variant").export_constant(__name__, "variant") 271 272uint8 = DType(types_pb2.DT_UINT8) 273doc_typealias.document( 274 obj=uint8, 275 doc="Unsigned 8-bit (byte) integer.") 276tf_export("dtypes.uint8", "uint8").export_constant(__name__, "uint8") 277 278uint16 = DType(types_pb2.DT_UINT16) 279doc_typealias.document( 280 obj=uint16, 281 doc="Unsigned 16-bit (word) integer.") 282tf_export("dtypes.uint16", "uint16").export_constant(__name__, "uint16") 283 284uint32 = DType(types_pb2.DT_UINT32) 285doc_typealias.document( 286 obj=uint32, 287 doc="Unsigned 32-bit (dword) integer.") 288tf_export("dtypes.uint32", "uint32").export_constant(__name__, "uint32") 289 290uint64 = DType(types_pb2.DT_UINT64) 291doc_typealias.document( 292 obj=uint64, 293 doc="Unsigned 64-bit (qword) integer.") 294tf_export("dtypes.uint64", "uint64").export_constant(__name__, "uint64") 295 296int8 = DType(types_pb2.DT_INT8) 297doc_typealias.document( 298 obj=int8, 299 doc="Signed 8-bit integer.") 300tf_export("dtypes.int8", "int8").export_constant(__name__, "int8") 301 302int16 = DType(types_pb2.DT_INT16) 303doc_typealias.document( 304 obj=int16, 305 doc="Signed 16-bit integer.") 306tf_export("dtypes.int16", "int16").export_constant(__name__, "int16") 307 308int32 = DType(types_pb2.DT_INT32) 309doc_typealias.document( 310 obj=int32, 311 doc="Signed 32-bit integer.") 312tf_export("dtypes.int32", "int32").export_constant(__name__, "int32") 313 314int64 = DType(types_pb2.DT_INT64) 315doc_typealias.document( 316 obj=int64, 317 doc="Signed 64-bit integer.") 318tf_export("dtypes.int64", "int64").export_constant(__name__, "int64") 319 320float16 = DType(types_pb2.DT_HALF) 321half = float16 322doc_typealias.document( 323 obj=float16, 324 doc="16-bit (half precision) floating-point.") 325tf_export("dtypes.float16", "float16").export_constant(__name__, "float16") 326tf_export("dtypes.half", "half").export_constant(__name__, "half") 327 328float32 = DType(types_pb2.DT_FLOAT) 329doc_typealias.document( 330 obj=float32, 331 doc="32-bit (single precision) floating-point.") 332tf_export("dtypes.float32", "float32").export_constant(__name__, "float32") 333 334float64 = DType(types_pb2.DT_DOUBLE) 335doc_typealias.document( 336 obj=float64, 337 doc="64-bit (double precision) floating-point.") 338tf_export("dtypes.float64", "float64").export_constant(__name__, "float64") 339double = float64 340tf_export("dtypes.double", "double").export_constant(__name__, "double") 341 342complex64 = DType(types_pb2.DT_COMPLEX64) 343doc_typealias.document( 344 obj=complex64, 345 doc="64-bit complex.") 346tf_export("dtypes.complex64", 347 "complex64").export_constant(__name__, "complex64") 348 349complex128 = DType(types_pb2.DT_COMPLEX128) 350doc_typealias.document( 351 obj=complex128, 352 doc="128-bit complex.") 353tf_export("dtypes.complex128", 354 "complex128").export_constant(__name__, "complex128") 355 356string = DType(types_pb2.DT_STRING) 357doc_typealias.document( 358 obj=string, 359 doc="Variable-length string, represented as byte array.") 360tf_export("dtypes.string", "string").export_constant(__name__, "string") 361 362bool = DType(types_pb2.DT_BOOL) # pylint: disable=redefined-builtin 363doc_typealias.document( 364 obj=bool, 365 doc="Boolean.") 366tf_export("dtypes.bool", "bool").export_constant(__name__, "bool") 367 368qint8 = DType(types_pb2.DT_QINT8) 369doc_typealias.document( 370 obj=qint8, 371 doc="Signed quantized 8-bit integer.") 372tf_export("dtypes.qint8", "qint8").export_constant(__name__, "qint8") 373 374qint16 = DType(types_pb2.DT_QINT16) 375doc_typealias.document( 376 obj=qint16, 377 doc="Signed quantized 16-bit integer.") 378tf_export("dtypes.qint16", "qint16").export_constant(__name__, "qint16") 379 380qint32 = DType(types_pb2.DT_QINT32) 381doc_typealias.document( 382 obj=qint32, 383 doc="signed quantized 32-bit integer.") 384tf_export("dtypes.qint32", "qint32").export_constant(__name__, "qint32") 385 386quint8 = DType(types_pb2.DT_QUINT8) 387doc_typealias.document( 388 obj=quint8, 389 doc="Unsigned quantized 8-bit integer.") 390tf_export("dtypes.quint8", "quint8").export_constant(__name__, "quint8") 391 392quint16 = DType(types_pb2.DT_QUINT16) 393doc_typealias.document( 394 obj=quint16, 395 doc="Unsigned quantized 16-bit integer.") 396tf_export("dtypes.quint16", "quint16").export_constant(__name__, "quint16") 397 398bfloat16 = DType(types_pb2.DT_BFLOAT16) 399doc_typealias.document( 400 obj=bfloat16, 401 doc="16-bit bfloat (brain floating point).") 402tf_export("dtypes.bfloat16", "bfloat16").export_constant(__name__, "bfloat16") 403 404resource_ref = DType(types_pb2.DT_RESOURCE_REF) 405variant_ref = DType(types_pb2.DT_VARIANT_REF) 406float16_ref = DType(types_pb2.DT_HALF_REF) 407half_ref = float16_ref 408float32_ref = DType(types_pb2.DT_FLOAT_REF) 409float64_ref = DType(types_pb2.DT_DOUBLE_REF) 410double_ref = float64_ref 411int32_ref = DType(types_pb2.DT_INT32_REF) 412uint32_ref = DType(types_pb2.DT_UINT32_REF) 413uint8_ref = DType(types_pb2.DT_UINT8_REF) 414uint16_ref = DType(types_pb2.DT_UINT16_REF) 415int16_ref = DType(types_pb2.DT_INT16_REF) 416int8_ref = DType(types_pb2.DT_INT8_REF) 417string_ref = DType(types_pb2.DT_STRING_REF) 418complex64_ref = DType(types_pb2.DT_COMPLEX64_REF) 419complex128_ref = DType(types_pb2.DT_COMPLEX128_REF) 420int64_ref = DType(types_pb2.DT_INT64_REF) 421uint64_ref = DType(types_pb2.DT_UINT64_REF) 422bool_ref = DType(types_pb2.DT_BOOL_REF) 423qint8_ref = DType(types_pb2.DT_QINT8_REF) 424quint8_ref = DType(types_pb2.DT_QUINT8_REF) 425qint16_ref = DType(types_pb2.DT_QINT16_REF) 426quint16_ref = DType(types_pb2.DT_QUINT16_REF) 427qint32_ref = DType(types_pb2.DT_QINT32_REF) 428bfloat16_ref = DType(types_pb2.DT_BFLOAT16_REF) 429 430# Maintain an intern table so that we don't have to create a large 431# number of small objects. 432_INTERN_TABLE = { 433 types_pb2.DT_HALF: float16, 434 types_pb2.DT_FLOAT: float32, 435 types_pb2.DT_DOUBLE: float64, 436 types_pb2.DT_INT32: int32, 437 types_pb2.DT_UINT8: uint8, 438 types_pb2.DT_UINT16: uint16, 439 types_pb2.DT_UINT32: uint32, 440 types_pb2.DT_UINT64: uint64, 441 types_pb2.DT_INT16: int16, 442 types_pb2.DT_INT8: int8, 443 types_pb2.DT_STRING: string, 444 types_pb2.DT_COMPLEX64: complex64, 445 types_pb2.DT_COMPLEX128: complex128, 446 types_pb2.DT_INT64: int64, 447 types_pb2.DT_BOOL: bool, 448 types_pb2.DT_QINT8: qint8, 449 types_pb2.DT_QUINT8: quint8, 450 types_pb2.DT_QINT16: qint16, 451 types_pb2.DT_QUINT16: quint16, 452 types_pb2.DT_QINT32: qint32, 453 types_pb2.DT_BFLOAT16: bfloat16, 454 types_pb2.DT_RESOURCE: resource, 455 types_pb2.DT_VARIANT: variant, 456 types_pb2.DT_HALF_REF: float16_ref, 457 types_pb2.DT_FLOAT_REF: float32_ref, 458 types_pb2.DT_DOUBLE_REF: float64_ref, 459 types_pb2.DT_INT32_REF: int32_ref, 460 types_pb2.DT_UINT32_REF: uint32_ref, 461 types_pb2.DT_UINT8_REF: uint8_ref, 462 types_pb2.DT_UINT16_REF: uint16_ref, 463 types_pb2.DT_INT16_REF: int16_ref, 464 types_pb2.DT_INT8_REF: int8_ref, 465 types_pb2.DT_STRING_REF: string_ref, 466 types_pb2.DT_COMPLEX64_REF: complex64_ref, 467 types_pb2.DT_COMPLEX128_REF: complex128_ref, 468 types_pb2.DT_INT64_REF: int64_ref, 469 types_pb2.DT_UINT64_REF: uint64_ref, 470 types_pb2.DT_BOOL_REF: bool_ref, 471 types_pb2.DT_QINT8_REF: qint8_ref, 472 types_pb2.DT_QUINT8_REF: quint8_ref, 473 types_pb2.DT_QINT16_REF: qint16_ref, 474 types_pb2.DT_QUINT16_REF: quint16_ref, 475 types_pb2.DT_QINT32_REF: qint32_ref, 476 types_pb2.DT_BFLOAT16_REF: bfloat16_ref, 477 types_pb2.DT_RESOURCE_REF: resource_ref, 478 types_pb2.DT_VARIANT_REF: variant_ref, 479} 480 481# Standard mappings between types_pb2.DataType values and string names. 482_TYPE_TO_STRING = { 483 types_pb2.DT_HALF: "float16", 484 types_pb2.DT_FLOAT: "float32", 485 types_pb2.DT_DOUBLE: "float64", 486 types_pb2.DT_INT32: "int32", 487 types_pb2.DT_UINT8: "uint8", 488 types_pb2.DT_UINT16: "uint16", 489 types_pb2.DT_UINT32: "uint32", 490 types_pb2.DT_UINT64: "uint64", 491 types_pb2.DT_INT16: "int16", 492 types_pb2.DT_INT8: "int8", 493 types_pb2.DT_STRING: "string", 494 types_pb2.DT_COMPLEX64: "complex64", 495 types_pb2.DT_COMPLEX128: "complex128", 496 types_pb2.DT_INT64: "int64", 497 types_pb2.DT_BOOL: "bool", 498 types_pb2.DT_QINT8: "qint8", 499 types_pb2.DT_QUINT8: "quint8", 500 types_pb2.DT_QINT16: "qint16", 501 types_pb2.DT_QUINT16: "quint16", 502 types_pb2.DT_QINT32: "qint32", 503 types_pb2.DT_BFLOAT16: "bfloat16", 504 types_pb2.DT_RESOURCE: "resource", 505 types_pb2.DT_VARIANT: "variant", 506 types_pb2.DT_HALF_REF: "float16_ref", 507 types_pb2.DT_FLOAT_REF: "float32_ref", 508 types_pb2.DT_DOUBLE_REF: "float64_ref", 509 types_pb2.DT_INT32_REF: "int32_ref", 510 types_pb2.DT_UINT32_REF: "uint32_ref", 511 types_pb2.DT_UINT8_REF: "uint8_ref", 512 types_pb2.DT_UINT16_REF: "uint16_ref", 513 types_pb2.DT_INT16_REF: "int16_ref", 514 types_pb2.DT_INT8_REF: "int8_ref", 515 types_pb2.DT_STRING_REF: "string_ref", 516 types_pb2.DT_COMPLEX64_REF: "complex64_ref", 517 types_pb2.DT_COMPLEX128_REF: "complex128_ref", 518 types_pb2.DT_INT64_REF: "int64_ref", 519 types_pb2.DT_UINT64_REF: "uint64_ref", 520 types_pb2.DT_BOOL_REF: "bool_ref", 521 types_pb2.DT_QINT8_REF: "qint8_ref", 522 types_pb2.DT_QUINT8_REF: "quint8_ref", 523 types_pb2.DT_QINT16_REF: "qint16_ref", 524 types_pb2.DT_QUINT16_REF: "quint16_ref", 525 types_pb2.DT_QINT32_REF: "qint32_ref", 526 types_pb2.DT_BFLOAT16_REF: "bfloat16_ref", 527 types_pb2.DT_RESOURCE_REF: "resource_ref", 528 types_pb2.DT_VARIANT_REF: "variant_ref", 529} 530_STRING_TO_TF = { 531 value: _INTERN_TABLE[key] for key, value in _TYPE_TO_STRING.items() 532} 533# Add non-canonical aliases. 534_STRING_TO_TF["half"] = float16 535_STRING_TO_TF["half_ref"] = float16_ref 536_STRING_TO_TF["float"] = float32 537_STRING_TO_TF["float_ref"] = float32_ref 538_STRING_TO_TF["double"] = float64 539_STRING_TO_TF["double_ref"] = float64_ref 540 541# Numpy representation for quantized dtypes. 542# 543# These are magic strings that are used in the swig wrapper to identify 544# quantized types. 545# TODO(mrry,keveman): Investigate Numpy type registration to replace this 546# hard-coding of names. 547_np_qint8 = np.dtype([("qint8", np.int8)]) 548_np_quint8 = np.dtype([("quint8", np.uint8)]) 549_np_qint16 = np.dtype([("qint16", np.int16)]) 550_np_quint16 = np.dtype([("quint16", np.uint16)]) 551_np_qint32 = np.dtype([("qint32", np.int32)]) 552 553# _np_bfloat16 is defined by a module import. 554 555# Custom struct dtype for directly-fed ResourceHandles of supported type(s). 556np_resource = np.dtype([("resource", np.ubyte)]) 557 558# Standard mappings between types_pb2.DataType values and numpy.dtypes. 559_NP_TO_TF = { 560 np.float16: float16, 561 np.float32: float32, 562 np.float64: float64, 563 np.int32: int32, 564 np.int64: int64, 565 np.uint8: uint8, 566 np.uint16: uint16, 567 np.uint32: uint32, 568 np.uint64: uint64, 569 np.int16: int16, 570 np.int8: int8, 571 np.complex64: complex64, 572 np.complex128: complex128, 573 np.object_: string, 574 np.bytes_: string, 575 np.str_: string, 576 np.bool_: bool, 577 _np_qint8: qint8, 578 _np_quint8: quint8, 579 _np_qint16: qint16, 580 _np_quint16: quint16, 581 _np_qint32: qint32, 582 _np_bfloat16: bfloat16, 583} 584 585# Map (some) NumPy platform dtypes to TF ones using their fixed-width 586# synonyms. Note that platform dtypes are not always simples aliases, 587# i.e. reference equality is not guaranteed. See e.g. numpy/numpy#9799. 588for pdt in [ 589 np.intc, 590 np.uintc, 591 np.int_, 592 np.uint, 593 np.longlong, 594 np.ulonglong, 595]: 596 if pdt not in _NP_TO_TF: 597 _NP_TO_TF[pdt] = next( 598 _NP_TO_TF[dt] for dt in _NP_TO_TF if dt == pdt().dtype) # pylint: disable=no-value-for-parameter 599 600TF_VALUE_DTYPES = set(_NP_TO_TF.values()) 601 602_TF_TO_NP = { 603 types_pb2.DT_HALF: 604 np.float16, 605 types_pb2.DT_FLOAT: 606 np.float32, 607 types_pb2.DT_DOUBLE: 608 np.float64, 609 types_pb2.DT_INT32: 610 np.int32, 611 types_pb2.DT_UINT8: 612 np.uint8, 613 types_pb2.DT_UINT16: 614 np.uint16, 615 types_pb2.DT_UINT32: 616 np.uint32, 617 types_pb2.DT_UINT64: 618 np.uint64, 619 types_pb2.DT_INT16: 620 np.int16, 621 types_pb2.DT_INT8: 622 np.int8, 623 # NOTE(touts): For strings we use object as it supports variable length 624 # strings. 625 types_pb2.DT_STRING: 626 object, 627 types_pb2.DT_COMPLEX64: 628 np.complex64, 629 types_pb2.DT_COMPLEX128: 630 np.complex128, 631 types_pb2.DT_INT64: 632 np.int64, 633 types_pb2.DT_BOOL: 634 np.bool_, 635 types_pb2.DT_QINT8: 636 _np_qint8, 637 types_pb2.DT_QUINT8: 638 _np_quint8, 639 types_pb2.DT_QINT16: 640 _np_qint16, 641 types_pb2.DT_QUINT16: 642 _np_quint16, 643 types_pb2.DT_QINT32: 644 _np_qint32, 645 types_pb2.DT_BFLOAT16: 646 _np_bfloat16, 647 648 # Ref types 649 types_pb2.DT_HALF_REF: 650 np.float16, 651 types_pb2.DT_FLOAT_REF: 652 np.float32, 653 types_pb2.DT_DOUBLE_REF: 654 np.float64, 655 types_pb2.DT_INT32_REF: 656 np.int32, 657 types_pb2.DT_UINT32_REF: 658 np.uint32, 659 types_pb2.DT_UINT8_REF: 660 np.uint8, 661 types_pb2.DT_UINT16_REF: 662 np.uint16, 663 types_pb2.DT_INT16_REF: 664 np.int16, 665 types_pb2.DT_INT8_REF: 666 np.int8, 667 types_pb2.DT_STRING_REF: 668 np.object_, 669 types_pb2.DT_COMPLEX64_REF: 670 np.complex64, 671 types_pb2.DT_COMPLEX128_REF: 672 np.complex128, 673 types_pb2.DT_INT64_REF: 674 np.int64, 675 types_pb2.DT_UINT64_REF: 676 np.uint64, 677 types_pb2.DT_BOOL_REF: 678 np.bool_, 679 types_pb2.DT_QINT8_REF: 680 _np_qint8, 681 types_pb2.DT_QUINT8_REF: 682 _np_quint8, 683 types_pb2.DT_QINT16_REF: 684 _np_qint16, 685 types_pb2.DT_QUINT16_REF: 686 _np_quint16, 687 types_pb2.DT_QINT32_REF: 688 _np_qint32, 689 types_pb2.DT_BFLOAT16_REF: 690 _np_bfloat16, 691} 692 693_QUANTIZED_DTYPES_NO_REF = frozenset([qint8, quint8, qint16, quint16, qint32]) 694_QUANTIZED_DTYPES_REF = frozenset( 695 [qint8_ref, quint8_ref, qint16_ref, quint16_ref, qint32_ref]) 696QUANTIZED_DTYPES = _QUANTIZED_DTYPES_REF.union(_QUANTIZED_DTYPES_NO_REF) 697tf_export( 698 "dtypes.QUANTIZED_DTYPES", 699 v1=["dtypes.QUANTIZED_DTYPES", 700 "QUANTIZED_DTYPES"]).export_constant(__name__, "QUANTIZED_DTYPES") 701 702_PYTHON_TO_TF = { 703 builtins.float: float32, 704 builtins.bool: bool, 705 builtins.object: string 706} 707 708_ANY_TO_TF = {} 709_ANY_TO_TF.update(_INTERN_TABLE) 710_ANY_TO_TF.update(_STRING_TO_TF) 711_ANY_TO_TF.update(_PYTHON_TO_TF) 712_ANY_TO_TF.update(_NP_TO_TF) 713 714# Ensure no collisions. 715assert len(_ANY_TO_TF) == sum( 716 len(d) for d in [_INTERN_TABLE, _STRING_TO_TF, _PYTHON_TO_TF, _NP_TO_TF]) 717 718 719@tf_export("dtypes.as_dtype", "as_dtype") 720def as_dtype(type_value): 721 """Converts the given `type_value` to a `DType`. 722 723 Note: `DType` values are interned. When passed a new `DType` object, 724 `as_dtype` always returns the interned value. 725 726 Args: 727 type_value: A value that can be converted to a `tf.DType` object. This may 728 currently be a `tf.DType` object, a [`DataType` 729 enum](https://www.tensorflow.org/code/tensorflow/core/framework/types.proto), 730 a string type name, or a [`numpy.dtype`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html). 731 732 Returns: 733 A `DType` corresponding to `type_value`. 734 735 Raises: 736 TypeError: If `type_value` cannot be converted to a `DType`. 737 """ 738 if isinstance(type_value, DType): 739 return _INTERN_TABLE[type_value.as_datatype_enum] 740 741 if isinstance(type_value, np.dtype): 742 try: 743 return _NP_TO_TF[type_value.type] 744 except KeyError: 745 pass 746 747 try: 748 return _ANY_TO_TF[type_value] 749 except (KeyError, TypeError): 750 # TypeError indicates that type_value is not hashable. 751 pass 752 753 if hasattr(type_value, "dtype"): 754 try: 755 return _NP_TO_TF[np.dtype(type_value.dtype).type] 756 except (KeyError, TypeError): 757 pass 758 759 if isinstance(type_value, _dtypes.DType): 760 return _INTERN_TABLE[type_value.as_datatype_enum] 761 762 raise TypeError(f"Cannot convert the argument `type_value`: {type_value!r} " 763 "to a TensorFlow DType.") 764