1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Dtypes and dtype utilities.""" 16 17import numpy as np 18 19from tensorflow.python.framework import dtypes 20from tensorflow.python.ops.numpy_ops import np_export 21 22 23# We use numpy's dtypes instead of TF's, because the user expects to use them 24# with numpy facilities such as `np.dtype(np.int64)` and 25# `if x.dtype.type is np.int64`. 26bool_ = np_export.np_export_constant(__name__, 'bool_', np.bool_) 27complex_ = np_export.np_export_constant(__name__, 'complex_', np.complex_) 28complex128 = np_export.np_export_constant(__name__, 'complex128', np.complex128) 29complex64 = np_export.np_export_constant(__name__, 'complex64', np.complex64) 30float_ = np_export.np_export_constant(__name__, 'float_', np.float_) 31float16 = np_export.np_export_constant(__name__, 'float16', np.float16) 32float32 = np_export.np_export_constant(__name__, 'float32', np.float32) 33float64 = np_export.np_export_constant(__name__, 'float64', np.float64) 34inexact = np_export.np_export_constant(__name__, 'inexact', np.inexact) 35int_ = np_export.np_export_constant(__name__, 'int_', np.int_) 36int16 = np_export.np_export_constant(__name__, 'int16', np.int16) 37int32 = np_export.np_export_constant(__name__, 'int32', np.int32) 38int64 = np_export.np_export_constant(__name__, 'int64', np.int64) 39int8 = np_export.np_export_constant(__name__, 'int8', np.int8) 40object_ = np_export.np_export_constant(__name__, 'object_', np.object_) 41string_ = np_export.np_export_constant(__name__, 'string_', np.string_) 42uint16 = np_export.np_export_constant(__name__, 'uint16', np.uint16) 43uint32 = np_export.np_export_constant(__name__, 'uint32', np.uint32) 44uint64 = np_export.np_export_constant(__name__, 'uint64', np.uint64) 45uint8 = np_export.np_export_constant(__name__, 'uint8', np.uint8) 46unicode_ = np_export.np_export_constant(__name__, 'unicode_', np.unicode_) 47 48 49iinfo = np_export.np_export_constant(__name__, 'iinfo', np.iinfo) 50 51 52issubdtype = np_export.np_export('issubdtype')(np.issubdtype) 53 54 55_to_float32 = { 56 np.dtype('float64'): np.dtype('float32'), 57 np.dtype('complex128'): np.dtype('complex64'), 58} 59 60 61_cached_np_dtypes = {} 62 63 64# Difference between is_prefer_float32 and is_allow_float64: is_prefer_float32 65# only decides which dtype to use for Python floats; is_allow_float64 decides 66# whether float64 dtypes can ever appear in programs. The latter is more 67# restrictive than the former. 68_prefer_float32 = False 69 70 71# TODO(b/178862061): Consider removing this knob 72_allow_float64 = True 73 74 75def is_prefer_float32(): 76 return _prefer_float32 77 78 79def set_prefer_float32(b): 80 global _prefer_float32 81 _prefer_float32 = b 82 83 84def is_allow_float64(): 85 return _allow_float64 86 87 88def set_allow_float64(b): 89 global _allow_float64 90 _allow_float64 = b 91 92 93def canonicalize_dtype(dtype): 94 if not _allow_float64: 95 try: 96 return _to_float32[dtype] 97 except KeyError: 98 pass 99 return dtype 100 101 102def _result_type(*arrays_and_dtypes): 103 """Returns the resulting type given a set of arrays.""" 104 def preprocess_float(x): 105 if is_prefer_float32(): 106 if isinstance(x, float): 107 return np.float32(x) 108 elif isinstance(x, complex): 109 return np.complex64(x) 110 return x 111 arrays_and_dtypes = [preprocess_float(x) for x in arrays_and_dtypes] 112 dtype = np.result_type(*arrays_and_dtypes) 113 return dtypes.as_dtype(canonicalize_dtype(dtype)) 114 115 116def _get_cached_dtype(dtype): 117 """Returns an np.dtype for the TensorFlow DType.""" 118 global _cached_np_dtypes 119 try: 120 return _cached_np_dtypes[dtype] 121 except KeyError: 122 pass 123 cached_dtype = np.dtype(dtype.as_numpy_dtype) 124 _cached_np_dtypes[dtype] = cached_dtype 125 return cached_dtype 126 127 128def default_float_type(): 129 """Gets the default float type. 130 131 Returns: 132 If `is_prefer_float32()` is false and `is_allow_float64()` is true, returns 133 float64; otherwise returns float32. 134 """ 135 if not is_prefer_float32() and is_allow_float64(): 136 return float64 137 else: 138 return float32 139