1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities to create TensorProtos.""" 16import numpy as np 17 18from tensorflow.core.framework import tensor_pb2 19from tensorflow.core.framework import tensor_shape_pb2 20from tensorflow.python.client import pywrap_tf_session as c_api 21from tensorflow.python.eager import context 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import errors_impl 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import tensor_shape 26from tensorflow.python.types import core 27from tensorflow.python.types import internal 28from tensorflow.python.util import compat 29from tensorflow.python.util import nest 30from tensorflow.python.util.tf_export import tf_export 31 32# Fallback in case fast_tensor_util is not properly compiled. 33# pylint: disable=g-import-not-at-top 34try: 35 from tensorflow.python.framework import fast_tensor_util 36 _FAST_TENSOR_UTIL_AVAILABLE = True 37except ImportError: 38 _FAST_TENSOR_UTIL_AVAILABLE = False 39# pylint: enable=g-import-not-at-top 40 41 42def ExtractBitsFromFloat16(x): 43 return np.asarray(x, dtype=np.float16).view(np.uint16).item() 44 45 46def SlowAppendFloat16ArrayToTensorProto(tensor_proto, proto_values): 47 tensor_proto.half_val.extend( 48 [ExtractBitsFromFloat16(x) for x in proto_values]) 49 50 51def _MediumAppendFloat16ArrayToTensorProto(tensor_proto, proto_values): 52 # TODO: Remove the conversion if cython supports np.float16_t 53 fast_tensor_util.AppendFloat16ArrayToTensorProto( 54 tensor_proto, 55 np.asarray(proto_values, dtype=np.float16).view(np.uint16)) 56 57 58def ExtractBitsFromBFloat16(x): 59 return np.asarray( 60 x, dtype=dtypes.bfloat16.as_numpy_dtype).view(np.uint16).item() 61 62 63def SlowAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values): 64 tensor_proto.half_val.extend( 65 [ExtractBitsFromBFloat16(x) for x in proto_values]) 66 67 68def FastAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values): 69 fast_tensor_util.AppendBFloat16ArrayToTensorProto( 70 tensor_proto, np.asarray( 71 proto_values, dtype=dtypes.bfloat16.as_numpy_dtype).view(np.uint16)) 72 73 74if _FAST_TENSOR_UTIL_AVAILABLE: 75 _NP_TO_APPEND_FN = { 76 dtypes.bfloat16.as_numpy_dtype: 77 FastAppendBFloat16ArrayToTensorProto, 78 np.float16: 79 _MediumAppendFloat16ArrayToTensorProto, 80 np.float32: 81 fast_tensor_util.AppendFloat32ArrayToTensorProto, 82 np.float64: 83 fast_tensor_util.AppendFloat64ArrayToTensorProto, 84 np.int32: 85 fast_tensor_util.AppendInt32ArrayToTensorProto, 86 np.int64: 87 fast_tensor_util.AppendInt64ArrayToTensorProto, 88 np.uint8: 89 fast_tensor_util.AppendUInt8ArrayToTensorProto, 90 np.uint16: 91 fast_tensor_util.AppendUInt16ArrayToTensorProto, 92 np.uint32: 93 fast_tensor_util.AppendUInt32ArrayToTensorProto, 94 np.uint64: 95 fast_tensor_util.AppendUInt64ArrayToTensorProto, 96 np.int8: 97 fast_tensor_util.AppendInt8ArrayToTensorProto, 98 np.int16: 99 fast_tensor_util.AppendInt16ArrayToTensorProto, 100 np.complex64: 101 fast_tensor_util.AppendComplex64ArrayToTensorProto, 102 np.complex128: 103 fast_tensor_util.AppendComplex128ArrayToTensorProto, 104 np.object_: 105 fast_tensor_util.AppendObjectArrayToTensorProto, 106 np.bool_: 107 fast_tensor_util.AppendBoolArrayToTensorProto, 108 dtypes.qint8.as_numpy_dtype: 109 fast_tensor_util.AppendInt8ArrayToTensorProto, 110 dtypes.quint8.as_numpy_dtype: 111 fast_tensor_util.AppendUInt8ArrayToTensorProto, 112 dtypes.qint16.as_numpy_dtype: 113 fast_tensor_util.AppendInt16ArrayToTensorProto, 114 dtypes.quint16.as_numpy_dtype: 115 fast_tensor_util.AppendUInt16ArrayToTensorProto, 116 dtypes.qint32.as_numpy_dtype: 117 fast_tensor_util.AppendInt32ArrayToTensorProto, 118 # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16. 119 } 120else: 121 122 def SlowAppendFloat32ArrayToTensorProto(tensor_proto, proto_values): 123 tensor_proto.float_val.extend([x.item() for x in proto_values]) 124 125 def SlowAppendFloat64ArrayToTensorProto(tensor_proto, proto_values): 126 tensor_proto.double_val.extend([x.item() for x in proto_values]) 127 128 def SlowAppendIntArrayToTensorProto(tensor_proto, proto_values): 129 tensor_proto.int_val.extend([x.item() for x in proto_values]) 130 131 def SlowAppendInt64ArrayToTensorProto(tensor_proto, proto_values): 132 tensor_proto.int64_val.extend([x.item() for x in proto_values]) 133 134 def SlowAppendQIntArrayToTensorProto(tensor_proto, proto_values): 135 tensor_proto.int_val.extend([x.item()[0] for x in proto_values]) 136 137 def SlowAppendUInt32ArrayToTensorProto(tensor_proto, proto_values): 138 tensor_proto.uint32_val.extend([x.item() for x in proto_values]) 139 140 def SlowAppendUInt64ArrayToTensorProto(tensor_proto, proto_values): 141 tensor_proto.uint64_val.extend([x.item() for x in proto_values]) 142 143 def SlowAppendComplex64ArrayToTensorProto(tensor_proto, proto_values): 144 tensor_proto.scomplex_val.extend( 145 [v.item() for x in proto_values for v in [x.real, x.imag]]) 146 147 def SlowAppendComplex128ArrayToTensorProto(tensor_proto, proto_values): 148 tensor_proto.dcomplex_val.extend( 149 [v.item() for x in proto_values for v in [x.real, x.imag]]) 150 151 def SlowAppendObjectArrayToTensorProto(tensor_proto, proto_values): 152 tensor_proto.string_val.extend([compat.as_bytes(x) for x in proto_values]) 153 154 def SlowAppendBoolArrayToTensorProto(tensor_proto, proto_values): 155 tensor_proto.bool_val.extend([x.item() for x in proto_values]) 156 157 _NP_TO_APPEND_FN = { 158 dtypes.bfloat16.as_numpy_dtype: SlowAppendBFloat16ArrayToTensorProto, 159 np.float16: SlowAppendFloat16ArrayToTensorProto, 160 np.float32: SlowAppendFloat32ArrayToTensorProto, 161 np.float64: SlowAppendFloat64ArrayToTensorProto, 162 np.int32: SlowAppendIntArrayToTensorProto, 163 np.int64: SlowAppendInt64ArrayToTensorProto, 164 np.uint8: SlowAppendIntArrayToTensorProto, 165 np.uint16: SlowAppendIntArrayToTensorProto, 166 np.uint32: SlowAppendUInt32ArrayToTensorProto, 167 np.uint64: SlowAppendUInt64ArrayToTensorProto, 168 np.int8: SlowAppendIntArrayToTensorProto, 169 np.int16: SlowAppendIntArrayToTensorProto, 170 np.complex64: SlowAppendComplex64ArrayToTensorProto, 171 np.complex128: SlowAppendComplex128ArrayToTensorProto, 172 np.object_: SlowAppendObjectArrayToTensorProto, 173 np.bool_: SlowAppendBoolArrayToTensorProto, 174 dtypes.qint8.as_numpy_dtype: SlowAppendQIntArrayToTensorProto, 175 dtypes.quint8.as_numpy_dtype: SlowAppendQIntArrayToTensorProto, 176 dtypes.qint16.as_numpy_dtype: SlowAppendQIntArrayToTensorProto, 177 dtypes.quint16.as_numpy_dtype: SlowAppendQIntArrayToTensorProto, 178 dtypes.qint32.as_numpy_dtype: SlowAppendQIntArrayToTensorProto, 179 # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16. 180 } 181 182 183def GetFromNumpyDTypeDict(dtype_dict, dtype): 184 # NOTE: dtype_dict.get(dtype) always returns None. 185 for key, val in dtype_dict.items(): 186 if key == dtype: 187 return val 188 return None 189 190 191def GetNumpyAppendFn(dtype): 192 # numpy dtype for strings are variable length. We can not compare 193 # dtype with a single constant (np.string does not exist) to decide 194 # dtype is a "string" type. We need to compare the dtype.type to be 195 # sure it's a string type. 196 if dtype.type == np.bytes_ or dtype.type == np.str_: 197 if _FAST_TENSOR_UTIL_AVAILABLE: 198 return fast_tensor_util.AppendObjectArrayToTensorProto 199 else: 200 return SlowAppendObjectArrayToTensorProto 201 return GetFromNumpyDTypeDict(_NP_TO_APPEND_FN, dtype) 202 203 204def TensorShapeProtoToList(shape): 205 """Convert a TensorShape to a list. 206 207 Args: 208 shape: A TensorShapeProto. 209 210 Returns: 211 List of integers representing the dimensions of the tensor. 212 """ 213 return [dim.size for dim in shape.dim] 214 215 216def _GetDenseDimensions(list_of_lists): 217 """Returns the inferred dense dimensions of a list of lists.""" 218 if not isinstance(list_of_lists, (list, tuple)): 219 return [] 220 elif not list_of_lists: 221 return [0] 222 else: 223 return [len(list_of_lists)] + _GetDenseDimensions(list_of_lists[0]) 224 225 226def _FlattenToStrings(nested_strings): 227 if isinstance(nested_strings, (list, tuple)): 228 for inner in nested_strings: 229 for flattened_string in _FlattenToStrings(inner): 230 yield flattened_string 231 else: 232 yield nested_strings 233 234 235_TENSOR_CONTENT_TYPES = frozenset([ 236 dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32, dtypes.uint8, 237 dtypes.int16, dtypes.int8, dtypes.int64, dtypes.qint8, dtypes.quint8, 238 dtypes.qint16, dtypes.quint16, dtypes.qint32, dtypes.uint32, dtypes.uint64 239]) 240 241 242# pylint: disable=invalid-name 243def _check_failed(v): 244 # NB. none of the _check_* functions could raise a ValueError, so 245 # it is safe to use here. 246 raise ValueError(v) 247 248 249def _check_quantized(values): 250 # Cannot rely on `nest` because the leaves are tuples. 251 if not isinstance(values, (list, tuple)): 252 _check_failed(values) 253 if isinstance(values, tuple): 254 _ = [_check_int(v) for v in values] 255 else: 256 _ = [_check_quantized(v) for v in values] 257 258 259def _generate_isinstance_check(expected_types): 260 def inner(values): 261 for v in nest.flatten(values): 262 if not (isinstance(v, expected_types) or 263 (isinstance(v, np.ndarray) and 264 issubclass(v.dtype.type, expected_types))): 265 _check_failed(v) 266 267 return inner 268 269_check_int = _generate_isinstance_check( 270 (compat.integral_types, tensor_shape.Dimension)) 271_check_float = _generate_isinstance_check(compat.real_types) 272_check_complex = _generate_isinstance_check(compat.complex_types) 273_check_str = _generate_isinstance_check(compat.bytes_or_text_types) 274_check_bool = _generate_isinstance_check(bool) 275 276 277def _check_not_tensor(values): 278 _ = [_check_failed(v) for v in nest.flatten(values) 279 if isinstance(v, ops.Tensor)] 280# pylint: enable=invalid-name 281 282_TF_TO_IS_OK = { 283 dtypes.bool: _check_bool, 284 dtypes.complex128: _check_complex, 285 dtypes.complex64: _check_complex, 286 dtypes.float16: _check_float, 287 dtypes.float32: _check_float, 288 dtypes.float64: _check_float, 289 dtypes.int16: _check_int, 290 dtypes.int32: _check_int, 291 dtypes.int64: _check_int, 292 dtypes.int8: _check_int, 293 dtypes.qint16: _check_quantized, 294 dtypes.qint32: _check_quantized, 295 dtypes.qint8: _check_quantized, 296 dtypes.quint16: _check_quantized, 297 dtypes.quint8: _check_quantized, 298 dtypes.string: _check_str, 299 dtypes.uint16: _check_int, 300 dtypes.uint8: _check_int, 301 dtypes.uint32: _check_int, 302 dtypes.uint64: _check_int, 303} 304 305 306def _AssertCompatible(values, dtype): 307 if dtype is None: 308 fn = _check_not_tensor 309 else: 310 try: 311 fn = _TF_TO_IS_OK[dtype] 312 except KeyError: 313 # There isn't a specific fn, so we try to do the best possible. 314 if dtype.is_integer: 315 fn = _check_int 316 elif dtype.is_floating: 317 fn = _check_float 318 elif dtype.is_complex: 319 fn = _check_complex 320 elif dtype.is_quantized: 321 fn = _check_quantized 322 else: 323 fn = _check_not_tensor 324 325 try: 326 fn(values) 327 except ValueError as e: 328 [mismatch] = e.args 329 if dtype is None: 330 raise TypeError("Expected any non-tensor type, but got a tensor instead.") 331 else: 332 raise TypeError(f"Expected {dtype.name}, but got {mismatch} of type " 333 f"'{type(mismatch).__name__}'.") 334 335 336def _is_array_like(obj): # pylint: disable=invalid-name 337 """Check if a given object is array-like.""" 338 if isinstance(obj, ops.Tensor) and not isinstance(obj, ops._EagerTensorBase): # pylint: disable=protected-access 339 # Tensor implements __array__ only so it can inform the user that it is not 340 # a valid array. 341 return False 342 343 # TODO(slebedev): an object could also implement C-level array interface. 344 if (callable(getattr(obj, "__array__", None)) or 345 isinstance(getattr(obj, "__array_interface__", None), dict)): 346 return True 347 348 try: 349 memoryview(obj) 350 except TypeError: 351 return False 352 else: 353 return not isinstance(obj, bytes) 354 355 356# pylint: disable=invalid-name 357@tf_export("make_tensor_proto") 358def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False, 359 allow_broadcast=False): 360 """Create a TensorProto. 361 362 In TensorFlow 2.0, representing tensors as protos should no longer be a 363 common workflow. That said, this utility function is still useful for 364 generating TF Serving request protos: 365 366 ```python 367 request = tensorflow_serving.apis.predict_pb2.PredictRequest() 368 request.model_spec.name = "my_model" 369 request.model_spec.signature_name = "serving_default" 370 request.inputs["images"].CopyFrom(tf.make_tensor_proto(X_new)) 371 ``` 372 373 `make_tensor_proto` accepts "values" of a python scalar, a python list, a 374 numpy ndarray, or a numpy scalar. 375 376 If "values" is a python scalar or a python list, make_tensor_proto 377 first convert it to numpy ndarray. If dtype is None, the 378 conversion tries its best to infer the right numpy data 379 type. Otherwise, the resulting numpy array has a compatible data 380 type with the given dtype. 381 382 In either case above, the numpy ndarray (either the caller provided 383 or the auto-converted) must have the compatible type with dtype. 384 385 `make_tensor_proto` then converts the numpy array to a tensor proto. 386 387 If "shape" is None, the resulting tensor proto represents the numpy 388 array precisely. 389 390 Otherwise, "shape" specifies the tensor's shape and the numpy array 391 can not have more elements than what "shape" specifies. 392 393 Args: 394 values: Values to put in the TensorProto. 395 dtype: Optional tensor_pb2 DataType value. 396 shape: List of integers representing the dimensions of tensor. 397 verify_shape: Boolean that enables verification of a shape of values. 398 allow_broadcast: Boolean that enables allowing scalars and 1 length vector 399 broadcasting. Cannot be true when verify_shape is true. 400 401 Returns: 402 A `TensorProto`. Depending on the type, it may contain data in the 403 "tensor_content" attribute, which is not directly useful to Python programs. 404 To access the values you should convert the proto back to a numpy ndarray 405 with `tf.make_ndarray(proto)`. 406 407 If `values` is a `TensorProto`, it is immediately returned; `dtype` and 408 `shape` are ignored. 409 410 Raises: 411 TypeError: if unsupported types are provided. 412 ValueError: if arguments have inappropriate values or if verify_shape is 413 True and shape of values is not equals to a shape from the argument. 414 415 """ 416 if allow_broadcast and verify_shape: 417 raise ValueError("allow_broadcast and verify_shape are not both allowed.") 418 if isinstance(values, tensor_pb2.TensorProto): 419 return values 420 421 if dtype: 422 dtype = dtypes.as_dtype(dtype) 423 424 is_quantized = ( 425 dtype in [ 426 dtypes.qint8, dtypes.quint8, dtypes.qint16, dtypes.quint16, 427 dtypes.qint32 428 ]) 429 430 if _is_array_like(values): 431 values = np.asarray(values) 432 433 # We first convert value to a numpy array or scalar. 434 if isinstance(values, (np.ndarray, np.generic)): 435 if dtype and dtype.is_numpy_compatible: 436 nparray = values.astype(dtype.as_numpy_dtype) 437 else: 438 nparray = values 439 else: 440 if values is None: 441 raise ValueError("None values not supported.") 442 # if dtype is provided, forces numpy array to be the type 443 # provided if possible. 444 if dtype and dtype.is_numpy_compatible: 445 np_dt = dtype.as_numpy_dtype 446 else: 447 np_dt = None 448 # If shape is None, numpy.prod returns None when dtype is not set, but 449 # raises exception when dtype is set to np.int64 450 if shape is not None and np.prod(shape, dtype=np.int64) == 0: 451 nparray = np.empty(shape, dtype=np_dt) 452 else: 453 _AssertCompatible(values, dtype) 454 nparray = np.array(values, dtype=np_dt) 455 # check to them. 456 # We need to pass in quantized values as tuples, so don't apply the shape 457 if (list(nparray.shape) != _GetDenseDimensions(values) and 458 not is_quantized): 459 raise ValueError(f"Expected values {values} to be a dense tensor with " 460 f"shape {_GetDenseDimensions(values)}, but got shape " 461 f"{list(nparray.shape)}.") 462 463 # python/numpy default float type is float64. We prefer float32 instead. 464 if (nparray.dtype == np.float64) and dtype is None: 465 nparray = nparray.astype(np.float32) 466 # python/numpy default int type is int64. We prefer int32 instead. 467 elif (nparray.dtype == np.int64) and dtype is None: 468 downcasted_array = nparray.astype(np.int32) 469 # Do not down cast if it leads to precision loss. 470 if np.array_equal(downcasted_array, nparray): 471 nparray = downcasted_array 472 473 # if dtype is provided, it must be compatible with what numpy 474 # conversion says. 475 numpy_dtype = dtypes.as_dtype(nparray.dtype) 476 if numpy_dtype is None: 477 raise TypeError(f"Unrecognized data type: {nparray.dtype}.") 478 479 # If dtype was specified and is a quantized type, we convert 480 # numpy_dtype back into the quantized version. 481 if is_quantized: 482 numpy_dtype = dtype 483 484 if dtype is not None and (not hasattr(dtype, "base_dtype") or 485 dtype.base_dtype != numpy_dtype.base_dtype): 486 raise TypeError(f"`dtype` {dtype} is not compatible with {values} of " 487 f"dtype {nparray.dtype}.") 488 489 # If shape is not given, get the shape from the numpy array. 490 if shape is None: 491 shape = nparray.shape 492 is_same_size = True 493 shape_size = nparray.size 494 else: 495 shape = [int(dim) for dim in shape] 496 shape_size = np.prod(shape, dtype=np.int64) 497 is_same_size = shape_size == nparray.size 498 499 if allow_broadcast: 500 if nparray.shape == (1,) or nparray.shape == tuple(): 501 pass 502 elif nparray.size != shape_size: 503 raise TypeError(f"Expected Tensor's shape: {tuple(shape)}, but got " 504 f"{nparray.shape}.") 505 506 else: 507 if verify_shape and nparray.shape != tuple(shape): 508 raise TypeError(f"Expected Tensor's shape: {tuple(shape)}, but got " 509 f"{nparray.shape}.") 510 511 if nparray.size > shape_size: 512 raise ValueError("Too many elements provided. Takes at most " 513 f"{shape_size:d}, but got {nparray.size:d}.") 514 515 tensor_proto = tensor_pb2.TensorProto( 516 dtype=numpy_dtype.as_datatype_enum, 517 tensor_shape=tensor_shape.as_shape(shape).as_proto()) 518 519 if is_same_size and numpy_dtype in _TENSOR_CONTENT_TYPES and shape_size > 1: 520 if nparray.size * nparray.itemsize >= (1 << 31): 521 raise ValueError( 522 "Cannot create a tensor proto whose content is larger than 2GB.") 523 tensor_proto.tensor_content = nparray.tobytes() 524 return tensor_proto 525 526 # If we were not given values as a numpy array, compute the proto_values 527 # from the given values directly, to avoid numpy trimming nulls from the 528 # strings. Since values could be a list of strings, or a multi-dimensional 529 # list of lists that might or might not correspond to the given shape, 530 # we flatten it conservatively. 531 if numpy_dtype == dtypes.string and not isinstance(values, np.ndarray): 532 proto_values = _FlattenToStrings(values) 533 534 # At this point, values may be a list of objects that we could not 535 # identify a common type for (hence it was inferred as 536 # np.object_/dtypes.string). If we are unable to convert it to a 537 # string, we raise a more helpful error message. 538 # 539 # Ideally, we'd be able to convert the elements of the list to a 540 # common type, but this type inference requires some thinking and 541 # so we defer it for now. 542 try: 543 str_values = [compat.as_bytes(x) for x in proto_values] 544 except TypeError: 545 raise TypeError(f"Failed to convert elements of {values} to Tensor. " 546 "Consider casting elements to a supported type. See " 547 "https://www.tensorflow.org/api_docs/python/tf/dtypes " 548 "for supported TF dtypes.") 549 tensor_proto.string_val.extend(str_values) 550 return tensor_proto 551 552 # TensorFlow expects C order (a.k.a., eigen row major). 553 proto_values = nparray.ravel() 554 555 append_fn = GetNumpyAppendFn(proto_values.dtype) 556 if append_fn is None: 557 raise TypeError( 558 f"Element type not supported in TensorProto: {numpy_dtype.name}.") 559 append_fn(tensor_proto, proto_values) 560 561 return tensor_proto 562# pylint: enable=invalid-name 563 564 565@tf_export("make_ndarray") 566def MakeNdarray(tensor): 567 """Create a numpy ndarray from a tensor. 568 569 Create a numpy ndarray with the same shape and data as the tensor. 570 571 For example: 572 573 ```python 574 # Tensor a has shape (2,3) 575 a = tf.constant([[1,2,3],[4,5,6]]) 576 proto_tensor = tf.make_tensor_proto(a) # convert `tensor a` to a proto tensor 577 tf.make_ndarray(proto_tensor) # output: array([[1, 2, 3], 578 # [4, 5, 6]], dtype=int32) 579 # output has shape (2,3) 580 ``` 581 582 Args: 583 tensor: A TensorProto. 584 585 Returns: 586 A numpy array with the tensor contents. 587 588 Raises: 589 TypeError: if tensor has unsupported type. 590 591 """ 592 shape = [d.size for d in tensor.tensor_shape.dim] 593 num_elements = np.prod(shape, dtype=np.int64) 594 tensor_dtype = dtypes.as_dtype(tensor.dtype) 595 dtype = tensor_dtype.as_numpy_dtype 596 597 if tensor.tensor_content: 598 return (np.frombuffer(tensor.tensor_content, 599 dtype=dtype).copy().reshape(shape)) 600 601 if tensor_dtype == dtypes.string: 602 # np.pad throws on these arrays of type np.object_. 603 values = list(tensor.string_val) 604 padding = num_elements - len(values) 605 if padding > 0: 606 last = values[-1] if values else "" 607 values.extend([last] * padding) 608 return np.array(values, dtype=dtype).reshape(shape) 609 610 if tensor_dtype == dtypes.float16 or tensor_dtype == dtypes.bfloat16: 611 # the half_val field of the TensorProto stores the binary representation 612 # of the fp16: we need to reinterpret this as a proper float16 613 values = np.fromiter(tensor.half_val, dtype=np.uint16) 614 values.dtype = tensor_dtype.as_numpy_dtype 615 elif tensor_dtype == dtypes.float32: 616 values = np.fromiter(tensor.float_val, dtype=dtype) 617 elif tensor_dtype == dtypes.float64: 618 values = np.fromiter(tensor.double_val, dtype=dtype) 619 elif tensor_dtype in [ 620 dtypes.int32, dtypes.uint8, dtypes.uint16, dtypes.int16, dtypes.int8, 621 dtypes.qint32, dtypes.quint8, dtypes.qint8, dtypes.qint16, dtypes.quint16 622 ]: 623 values = np.fromiter(tensor.int_val, dtype=dtype) 624 elif tensor_dtype == dtypes.int64: 625 values = np.fromiter(tensor.int64_val, dtype=dtype) 626 elif tensor_dtype == dtypes.uint32: 627 values = np.fromiter(tensor.uint32_val, dtype=dtype) 628 elif tensor_dtype == dtypes.uint64: 629 values = np.fromiter(tensor.uint64_val, dtype=dtype) 630 elif tensor_dtype == dtypes.complex64: 631 it = iter(tensor.scomplex_val) 632 values = np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype) 633 elif tensor_dtype == dtypes.complex128: 634 it = iter(tensor.dcomplex_val) 635 values = np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype) 636 elif tensor_dtype == dtypes.bool: 637 values = np.fromiter(tensor.bool_val, dtype=dtype) 638 else: 639 raise TypeError(f"Unsupported tensor type: {tensor.dtype}. See " 640 "https://www.tensorflow.org/api_docs/python/tf/dtypes " 641 "for supported TF dtypes.") 642 643 if values.size == 0: 644 return np.zeros(shape, dtype) 645 646 if values.size != num_elements: 647 values = np.pad(values, (0, num_elements - values.size), "edge") 648 649 return values.reshape(shape) 650 651 652def ShapeEquals(tensor_proto, shape): 653 """Returns True if "tensor_proto" has the given "shape". 654 655 Args: 656 tensor_proto: A TensorProto. 657 shape: A tensor shape, expressed as a TensorShape, list, or tuple. 658 659 Returns: 660 True if "tensor_proto" has the given "shape", otherwise False. 661 662 Raises: 663 TypeError: If "tensor_proto" is not a TensorProto, or shape is not a 664 TensorShape, list, or tuple. 665 """ 666 if not isinstance(tensor_proto, tensor_pb2.TensorProto): 667 raise TypeError("`tensor_proto` must be a tensor_pb2.TensorProto object, " 668 f"but got type {type(tensor_proto)}.") 669 if isinstance(shape, tensor_shape_pb2.TensorShapeProto): 670 shape = [d.size for d in shape.dim] 671 elif not isinstance(shape, (list, tuple)): 672 raise TypeError("`shape` must be a list or tuple, but got type " 673 f"{type(shape)}.") 674 tensor_shape_list = [d.size for d in tensor_proto.tensor_shape.dim] 675 return all(x == y for x, y in zip(tensor_shape_list, shape)) 676 677 678def _ConstantValue(tensor, partial): 679 # TODO(touts): Support Variables? 680 if not isinstance(tensor, ops.Tensor): 681 raise TypeError(f"{tensor!r} must be a Tensor, but got {type(tensor)}.") 682 if tensor.op.type == "Const": 683 return MakeNdarray(tensor.op.get_attr("value")) 684 elif tensor.op.type == "Shape": 685 input_shape = tensor.op.inputs[0].get_shape() 686 if input_shape.is_fully_defined(): 687 return np.array( 688 [dim.value for dim in input_shape.dims], 689 dtype=tensor.dtype.as_numpy_dtype) 690 else: 691 return None 692 elif tensor.op.type == "Size": 693 input_shape = tensor.op.inputs[0].get_shape() 694 if input_shape.is_fully_defined(): 695 return np.prod([dim.value for dim in input_shape.dims], dtype=np.int32) 696 else: 697 return None 698 elif tensor.op.type == "Rank": 699 input_shape = tensor.op.inputs[0].get_shape() 700 if input_shape.ndims is not None: 701 return np.ndarray( 702 shape=(), 703 buffer=np.array([input_shape.ndims], dtype=np.int32), 704 dtype=np.int32) 705 else: 706 return None 707 elif tensor.op.type == "Range": 708 start = constant_value(tensor.op.inputs[0]) 709 if start is None: 710 return None 711 limit = constant_value(tensor.op.inputs[1]) 712 if limit is None: 713 return None 714 delta = constant_value(tensor.op.inputs[2]) 715 if delta is None: 716 return None 717 return np.arange(start, limit, delta, dtype=tensor.dtype.as_numpy_dtype) 718 elif tensor.op.type == "Cast": 719 pre_cast = constant_value(tensor.op.inputs[0]) 720 if pre_cast is None: 721 return None 722 cast_dtype = dtypes.as_dtype(tensor.op.get_attr("DstT")) 723 return pre_cast.astype(cast_dtype.as_numpy_dtype) 724 elif tensor.op.type == "Concat": 725 dim = constant_value(tensor.op.inputs[0]) 726 if dim is None: 727 return None 728 values = [] 729 for x in tensor.op.inputs[1:]: 730 value = constant_value(x) 731 if value is None: 732 return None 733 values.append(value) 734 return np.concatenate(values, axis=dim) 735 elif tensor.op.type == "ConcatV2": 736 dim = constant_value(tensor.op.inputs[-1]) 737 if dim is None: 738 return None 739 values = [] 740 for x in tensor.op.inputs[:-1]: 741 value = constant_value(x) 742 if value is None: 743 return None 744 values.append(value) 745 return np.concatenate(values, axis=dim) 746 elif tensor.op.type == "Pack": 747 values = [] 748 # Some imported GraphDefs have Pack ops with zero inputs. Those are invalid 749 # and shouldn't be produced, but to deal sensibly with them here we check 750 # and return None. 751 if not tensor.op.inputs: 752 return None 753 # We can't handle axis != 0 Packs at the moment. 754 if tensor.op.get_attr("axis") != 0: 755 return None 756 for x in tensor.op.inputs: 757 value = constant_value(x, partial) 758 if value is None and not partial: 759 return None 760 values.append(value) 761 return np.array(values) 762 elif tensor.op.type == "Unpack": 763 # We can't handle axis != 0 Unpacks at the moment. 764 if tensor.op.get_attr("axis") != 0: 765 return None 766 value = constant_value(tensor.op.inputs[0], partial) 767 if value is None: 768 return None 769 return value[tensor.value_index] 770 elif tensor.op.type == "Split": 771 dim = constant_value(tensor.op.inputs[0]) 772 value = constant_value(tensor.op.inputs[1], partial) 773 if value is None or dim is None: 774 return None 775 split = np.split(value, tensor.op.get_attr("num_split"), dim) 776 return split[tensor.value_index] 777 elif tensor.op.type == "Fill": 778 fill_shape = tensor.shape 779 fill_value = constant_value(tensor.op.inputs[1]) 780 if fill_shape.is_fully_defined() and fill_value is not None: 781 return np.full(fill_shape.as_list(), fill_value, dtype=fill_value.dtype) 782 else: 783 return None 784 elif tensor.op.type == "Equal": 785 value1 = constant_value(tensor.op.inputs[0]) 786 if value1 is None: 787 return None 788 value2 = constant_value(tensor.op.inputs[1]) 789 if value2 is None: 790 return None 791 return np.equal(value1, value2) 792 elif tensor.op.type == "NotEqual": 793 value1 = constant_value(tensor.op.inputs[0]) 794 if value1 is None: 795 return None 796 value2 = constant_value(tensor.op.inputs[1]) 797 if value2 is None: 798 return None 799 return np.not_equal(value1, value2) 800 elif tensor.op.type == "StopGradient": 801 return constant_value(tensor.op.inputs[0], partial) 802 elif tensor.op.type in ("CheckNumericsV2", "DebugIdentityV2", "Identity"): 803 return constant_value(tensor.op.inputs[0], partial) 804 else: 805 return None 806 807 808@tf_export("get_static_value") 809def constant_value(tensor, partial=False): # pylint: disable=invalid-name 810 """Returns the constant value of the given tensor, if efficiently calculable. 811 812 This function attempts to partially evaluate the given tensor, and 813 returns its value as a numpy ndarray if this succeeds. 814 815 Example usage: 816 817 >>> a = tf.constant(10) 818 >>> tf.get_static_value(a) 819 10 820 >>> b = tf.constant(20) 821 >>> tf.get_static_value(tf.add(a, b)) 822 30 823 824 >>> # `tf.Variable` is not supported. 825 >>> c = tf.Variable(30) 826 >>> print(tf.get_static_value(c)) 827 None 828 829 Using `partial` option is most relevant when calling `get_static_value` inside 830 a `tf.function`. Setting it to `True` will return the results but for the 831 values that cannot be evaluated will be `None`. For example: 832 833 ```python 834 class Foo: 835 def __init__(self): 836 self.a = tf.Variable(1) 837 self.b = tf.constant(2) 838 839 @tf.function 840 def bar(self, partial): 841 packed = tf.raw_ops.Pack(values=[self.a, self.b]) 842 static_val = tf.get_static_value(packed, partial=partial) 843 tf.print(static_val) 844 845 f = Foo() 846 f.bar(partial=True) # `array([None, array(2, dtype=int32)], dtype=object)` 847 f.bar(partial=False) # `None` 848 ``` 849 850 Compatibility(V1): If `constant_value(tensor)` returns a non-`None` result, it 851 will no longer be possible to feed a different value for `tensor`. This allows 852 the result of this function to influence the graph that is constructed, and 853 permits static shape optimizations. 854 855 Args: 856 tensor: The Tensor to be evaluated. 857 partial: If True, the returned numpy array is allowed to have partially 858 evaluated values. Values that can't be evaluated will be None. 859 860 Returns: 861 A numpy ndarray containing the constant value of the given `tensor`, 862 or None if it cannot be calculated. 863 864 Raises: 865 TypeError: if tensor is not an ops.Tensor. 866 """ 867 if isinstance(tensor, ops.EagerTensor): 868 try: 869 return tensor.numpy() 870 except errors_impl.UnimplementedError: 871 # Some EagerTensors may not implement .numpy/resolve, e.g. parallel 872 # tensors with multiple components on different devices. 873 return None 874 if not is_tensor(tensor): 875 return tensor 876 if not isinstance(tensor, ops.Tensor): 877 return None 878 ret = _ConstantValue(tensor, partial) 879 if ret is not None: 880 # The caller may now depend on the constant value of `tensor`, so we 881 # conservatively prevent it from being fed. 882 tensor.graph.prevent_feeding(tensor) 883 return ret 884 885 886def constant_value_as_shape(tensor): # pylint: disable=invalid-name 887 """A version of `constant_value()` that returns a `TensorShape`. 888 889 This version should be used when a constant tensor value is 890 interpreted as a (possibly partial) shape, e.g. in the shape 891 function for `tf.reshape()`. By explicitly requesting a 892 `TensorShape` as the return value, it is possible to represent 893 unknown dimensions; by contrast, `constant_value()` is 894 all-or-nothing. 895 896 Args: 897 tensor: The rank-0 or rank-1 Tensor to be evaluated. 898 899 Returns: 900 A `TensorShape` based on the constant value of the given `tensor`. 901 902 Raises: 903 ValueError: If the shape is rank-0 and is not statically known to be -1. 904 """ 905 if isinstance(tensor, ops.EagerTensor): 906 return tensor_shape.TensorShape( 907 [dim if dim != -1 else None for dim in tensor.numpy()]) 908 909 if tensor.get_shape().ndims == 0: 910 value = constant_value(tensor) 911 if value is None: 912 raise ValueError( 913 "Received a scalar with unknown value as shape; require a statically " 914 "known scalar with value '-1' to describe an unknown shape.") 915 if value != -1: 916 raise ValueError( 917 f"Received a scalar value '{value}' as shape; require a statically " 918 "known scalar with value '-1' to describe an unknown shape.") 919 return tensor_shape.unknown_shape() 920 921 shape = tensor.get_shape().with_rank(1) 922 if shape == [0]: 923 return tensor_shape.TensorShape([]) 924 elif tensor.op.type == "Cast": 925 pre_cast = constant_value_as_shape(tensor.op.inputs[0]) 926 if pre_cast.dims is None: 927 # the input to cast has a totally undefined shape; just return that. 928 return pre_cast 929 cast_dtype = dtypes.as_dtype(tensor.op.get_attr("DstT")) 930 if cast_dtype not in (dtypes.int32, dtypes.int64): 931 return tensor_shape.unknown_shape(shape.dims[0].value) 932 dest_dtype_shape_array = np.array( 933 [x if x is not None else -1 for x in pre_cast.as_list()]).astype( 934 cast_dtype.as_numpy_dtype) 935 return tensor_shape.TensorShape([ 936 x if x >= 0 else None 937 for x in dest_dtype_shape_array]) 938 elif tensor.op.type == "Shape": 939 return tensor.op.inputs[0].get_shape() 940 elif tensor.op.type == "Pack": 941 ret = tensor_shape.TensorShape([]) # Empty list. 942 # Since we expect rank 1 inputs, Pack's axis must be zero, otherwise it 943 # would not be rank 1. 944 assert tensor.op.get_attr("axis") == 0 945 for pack_input in tensor.op.inputs: 946 # `pack_input` must be a scalar. Attempt to evaluate it, and append it 947 # to `ret`. 948 pack_input_val = constant_value(pack_input) 949 if pack_input_val is None or pack_input_val < 0: 950 new_dim = tensor_shape.Dimension(None) 951 else: 952 new_dim = tensor_shape.Dimension(pack_input_val) 953 ret = ret.concatenate([new_dim]) 954 return ret 955 elif tensor.op.type == "Concat": 956 # We assume that `tensor.op.inputs[0]` evaluates to 0, as this is 957 # the only legal value when concatenating vectors, and it will 958 # have been checked by a previous shape function. 959 ret = tensor_shape.TensorShape([]) # Empty list. 960 for concat_input in tensor.op.inputs[1:]: 961 # `concat_input` must be a vector. Attempt to evaluate it as a shape, 962 # and concatenate it with `ret`. 963 ret = ret.concatenate(constant_value_as_shape(concat_input)) 964 return ret 965 elif tensor.op.type == "ConcatV2": 966 # We assume that `tensor.op.inputs[-1]` evaluates to 0, as this is 967 # the only legal value when concatenating vectors, and it will 968 # have been checked by a previous shape function. 969 ret = tensor_shape.TensorShape([]) # Empty list. 970 for concat_input in tensor.op.inputs[:-1]: 971 # `concat_input` must be a vector. Attempt to evaluate it as a shape, 972 # and concatenate it with `ret`. 973 ret = ret.concatenate(constant_value_as_shape(concat_input)) 974 return ret 975 elif tensor.op.type == "StridedSlice": 976 try: 977 begin = constant_value(tensor.op.inputs[1]) 978 end = constant_value(tensor.op.inputs[2]) 979 strides = constant_value(tensor.op.inputs[3]) 980 if begin is not None and end is not None and strides is not None: 981 begin = begin[0] 982 end = end[0] 983 strides = strides[0] 984 begin_mask = tensor.op.get_attr("begin_mask") 985 if begin_mask == 1: 986 begin = None 987 end_mask = tensor.op.get_attr("end_mask") 988 if end_mask == 1: 989 end = None 990 991 ellipsis_mask = tensor.op.get_attr("ellipsis_mask") 992 new_axis_mask = tensor.op.get_attr("new_axis_mask") 993 shrink_axis_mask = tensor.op.get_attr("shrink_axis_mask") 994 valid_attributes = (not ellipsis_mask and not new_axis_mask and 995 not shrink_axis_mask and (not begin_mask or 996 (begin_mask == 1)) and 997 (not end_mask or (end_mask == 1))) 998 if valid_attributes: # additional inputs not supported 999 prev = constant_value_as_shape(tensor.op.inputs[0]) 1000 prev = prev[begin:end:strides] 1001 ret = tensor_shape.TensorShape(prev) 1002 return ret 1003 1004 except ValueError: # Could come from get_attr or slicing prev. 1005 pass 1006 except TypeError: # Could come from slicing prev. 1007 pass 1008 elif (tensor.op.type == "Placeholder" and 1009 tensor.op.graph.building_function and 1010 hasattr(tensor.op.graph, "internal_captures")): 1011 # If we are inside a FuncGraph try to lookup the constant value of the 1012 # corresponding external capture. Note that we only look at captures and 1013 # not the fed inputs because those can be fed different values in different 1014 # instantiations of the function call or different iterations of a 1015 # tf.while_loop. 1016 for i, capture in enumerate(tensor.op.graph.internal_captures): 1017 if capture is tensor: 1018 external_capture = tensor.op.graph.external_captures[i] 1019 return constant_value_as_shape(external_capture) 1020 1021 ret = tensor_shape.unknown_shape(shape.dims[0].value) 1022 value = constant_value(tensor) 1023 if value is not None: 1024 ret = ret.merge_with( 1025 tensor_shape.TensorShape([d if d >= 0 else None for d in value])) 1026 return ret 1027 1028 1029# TODO(mdan): Deprecate in favor of more static-friendly types. 1030@tf_export("is_tensor") 1031def is_tf_type(x): # pylint: disable=invalid-name 1032 """Checks whether `x` is a TF-native type that can be passed to many TF ops. 1033 1034 Use `is_tensor` to differentiate types that can ingested by TensorFlow ops 1035 without any conversion (e.g., `tf.Tensor`, `tf.SparseTensor`, and 1036 `tf.RaggedTensor`) from types that need to be converted into tensors before 1037 they are ingested (e.g., numpy `ndarray` and Python scalars). 1038 1039 For example, in the following code block: 1040 1041 ```python 1042 if not tf.is_tensor(t): 1043 t = tf.convert_to_tensor(t) 1044 return t.shape, t.dtype 1045 ``` 1046 1047 we check to make sure that `t` is a tensor (and convert it if not) before 1048 accessing its `shape` and `dtype`. (But note that not all TensorFlow native 1049 types have shapes or dtypes; `tf.data.Dataset` is an example of a TensorFlow 1050 native type that has neither shape nor dtype.) 1051 1052 Args: 1053 x: A python object to check. 1054 1055 Returns: 1056 `True` if `x` is a TensorFlow-native type. 1057 """ 1058 return (isinstance(x, internal.NativeObject) or 1059 isinstance(x, core.Tensor) or 1060 getattr(x, "is_tensor_like", False)) 1061 1062 1063# Deprecated alias for tensor_util.is_tf_type. 1064is_tensor = is_tf_type 1065 1066 1067def shape_tensor(shape): # pylint: disable=invalid-name 1068 """Convert to an int32 or int64 tensor, defaulting to int32 if empty.""" 1069 dtype = None 1070 if isinstance(shape, (tuple, list)): 1071 if not shape: 1072 dtype = dtypes.int32 1073 else: 1074 # If there are Dimension objects in the shape, unwrap them. This can be a 1075 # problem if v1 and v2 TensorShape objects get mixed up in partial 1076 # conversions, leading to shapes such as (1, 2, Dimension(5)), which are 1077 # not convertible to Tensors because of mixed content. 1078 shape = tuple(map(tensor_shape.dimension_value, shape)) 1079 return ops.convert_to_tensor(shape, dtype=dtype, name="shape") 1080 1081 1082# DO NOT USE: For testing only. 1083_ENABLE_MAYBE_SET_STATIC_SHAPE = True 1084 1085 1086def maybe_set_static_shape(tensor, shape): # pylint: disable=invalid-name 1087 """Sets the shape of `tensor` to the `shape`'s constant value, if inferrable. 1088 1089 This is a temporary workaround to fix shape inference across functional op 1090 boundaries. E.g. 1091 1092 ```python 1093 shape = tf.constant([3]) 1094 @tf.function 1095 def f(): 1096 u = tf.random_uniform(shape) 1097 return u 1098 ``` 1099 1100 If we were to rely solely on C++ shape inference, the shape of `u` inside 1101 `f` would be unknown because C++ shape inference is not aware of the outer 1102 graph and all it sees is a Placeholder node when backtracing the captured 1103 tensor for `shape`. `maybe_set_static_shape` computes the static shape value 1104 of `shape` by traversing the `FuncGraph` boundaries and sets the correct 1105 shape. 1106 1107 A longer term solution would be to fix C++ shape inference. 1108 1109 Args: 1110 tensor: A tensor. 1111 shape: A shape tensor. 1112 """ 1113 if (_ENABLE_MAYBE_SET_STATIC_SHAPE and not context.executing_eagerly() and 1114 ops.get_default_graph().building_function and 1115 not tensor.shape.is_fully_defined() and is_tensor(shape)): 1116 shape = shape_tensor(shape) 1117 const_shape = constant_value_as_shape(shape) 1118 tensor.set_shape(const_shape) 1119 1120 1121def try_evaluate_constant(tensor): # pylint: disable=invalid-name 1122 """Evaluates a symbolic tensor as a constant. 1123 1124 Args: 1125 tensor: a symbolic Tensor. 1126 1127 Returns: 1128 ndarray if the evaluation succeeds, or None if it fails. 1129 """ 1130 # pylint: disable=protected-access 1131 with tensor.graph._c_graph.get() as c_graph: 1132 return c_api.TF_TryEvaluateConstant_wrapper(c_graph, tensor._as_tf_output()) 1133 # pylint: enable=protected-access 1134