xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/numpy_ops/np_dtypes.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Dtypes and dtype utilities."""
16
17import numpy as np
18
19from tensorflow.python.framework import dtypes
20from tensorflow.python.ops.numpy_ops import np_export
21
22
23# We use numpy's dtypes instead of TF's, because the user expects to use them
24# with numpy facilities such as `np.dtype(np.int64)` and
25# `if x.dtype.type is np.int64`.
26bool_ = np_export.np_export_constant(__name__, 'bool_', np.bool_)
27complex_ = np_export.np_export_constant(__name__, 'complex_', np.complex_)
28complex128 = np_export.np_export_constant(__name__, 'complex128', np.complex128)
29complex64 = np_export.np_export_constant(__name__, 'complex64', np.complex64)
30float_ = np_export.np_export_constant(__name__, 'float_', np.float_)
31float16 = np_export.np_export_constant(__name__, 'float16', np.float16)
32float32 = np_export.np_export_constant(__name__, 'float32', np.float32)
33float64 = np_export.np_export_constant(__name__, 'float64', np.float64)
34inexact = np_export.np_export_constant(__name__, 'inexact', np.inexact)
35int_ = np_export.np_export_constant(__name__, 'int_', np.int_)
36int16 = np_export.np_export_constant(__name__, 'int16', np.int16)
37int32 = np_export.np_export_constant(__name__, 'int32', np.int32)
38int64 = np_export.np_export_constant(__name__, 'int64', np.int64)
39int8 = np_export.np_export_constant(__name__, 'int8', np.int8)
40object_ = np_export.np_export_constant(__name__, 'object_', np.object_)
41string_ = np_export.np_export_constant(__name__, 'string_', np.string_)
42uint16 = np_export.np_export_constant(__name__, 'uint16', np.uint16)
43uint32 = np_export.np_export_constant(__name__, 'uint32', np.uint32)
44uint64 = np_export.np_export_constant(__name__, 'uint64', np.uint64)
45uint8 = np_export.np_export_constant(__name__, 'uint8', np.uint8)
46unicode_ = np_export.np_export_constant(__name__, 'unicode_', np.unicode_)
47
48
49iinfo = np_export.np_export_constant(__name__, 'iinfo', np.iinfo)
50
51
52issubdtype = np_export.np_export('issubdtype')(np.issubdtype)
53
54
55_to_float32 = {
56    np.dtype('float64'): np.dtype('float32'),
57    np.dtype('complex128'): np.dtype('complex64'),
58}
59
60
61_cached_np_dtypes = {}
62
63
64# Difference between is_prefer_float32 and is_allow_float64: is_prefer_float32
65# only decides which dtype to use for Python floats; is_allow_float64 decides
66# whether float64 dtypes can ever appear in programs. The latter is more
67# restrictive than the former.
68_prefer_float32 = False
69
70
71# TODO(b/178862061): Consider removing this knob
72_allow_float64 = True
73
74
75def is_prefer_float32():
76  return _prefer_float32
77
78
79def set_prefer_float32(b):
80  global _prefer_float32
81  _prefer_float32 = b
82
83
84def is_allow_float64():
85  return _allow_float64
86
87
88def set_allow_float64(b):
89  global _allow_float64
90  _allow_float64 = b
91
92
93def canonicalize_dtype(dtype):
94  if not _allow_float64:
95    try:
96      return _to_float32[dtype]
97    except KeyError:
98      pass
99  return dtype
100
101
102def _result_type(*arrays_and_dtypes):
103  """Returns the resulting type given a set of arrays."""
104  def preprocess_float(x):
105    if is_prefer_float32():
106      if isinstance(x, float):
107        return np.float32(x)
108      elif isinstance(x, complex):
109        return np.complex64(x)
110    return x
111  arrays_and_dtypes = [preprocess_float(x) for x in arrays_and_dtypes]
112  dtype = np.result_type(*arrays_and_dtypes)
113  return dtypes.as_dtype(canonicalize_dtype(dtype))
114
115
116def _get_cached_dtype(dtype):
117  """Returns an np.dtype for the TensorFlow DType."""
118  global _cached_np_dtypes
119  try:
120    return _cached_np_dtypes[dtype]
121  except KeyError:
122    pass
123  cached_dtype = np.dtype(dtype.as_numpy_dtype)
124  _cached_np_dtypes[dtype] = cached_dtype
125  return cached_dtype
126
127
128def default_float_type():
129  """Gets the default float type.
130
131  Returns:
132    If `is_prefer_float32()` is false and `is_allow_float64()` is true, returns
133    float64; otherwise returns float32.
134  """
135  if not is_prefer_float32() and is_allow_float64():
136    return float64
137  else:
138    return float32
139