xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/tensor_util.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities to create TensorProtos."""
16import numpy as np
17
18from tensorflow.core.framework import tensor_pb2
19from tensorflow.core.framework import tensor_shape_pb2
20from tensorflow.python.client import pywrap_tf_session as c_api
21from tensorflow.python.eager import context
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import errors_impl
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_shape
26from tensorflow.python.types import core
27from tensorflow.python.types import internal
28from tensorflow.python.util import compat
29from tensorflow.python.util import nest
30from tensorflow.python.util.tf_export import tf_export
31
32# Fallback in case fast_tensor_util is not properly compiled.
33# pylint: disable=g-import-not-at-top
34try:
35  from tensorflow.python.framework import fast_tensor_util
36  _FAST_TENSOR_UTIL_AVAILABLE = True
37except ImportError:
38  _FAST_TENSOR_UTIL_AVAILABLE = False
39# pylint: enable=g-import-not-at-top
40
41
42def ExtractBitsFromFloat16(x):
43  return np.asarray(x, dtype=np.float16).view(np.uint16).item()
44
45
46def SlowAppendFloat16ArrayToTensorProto(tensor_proto, proto_values):
47  tensor_proto.half_val.extend(
48      [ExtractBitsFromFloat16(x) for x in proto_values])
49
50
51def _MediumAppendFloat16ArrayToTensorProto(tensor_proto, proto_values):
52  # TODO: Remove the conversion if cython supports np.float16_t
53  fast_tensor_util.AppendFloat16ArrayToTensorProto(
54      tensor_proto,
55      np.asarray(proto_values, dtype=np.float16).view(np.uint16))
56
57
58def ExtractBitsFromBFloat16(x):
59  return np.asarray(
60      x, dtype=dtypes.bfloat16.as_numpy_dtype).view(np.uint16).item()
61
62
63def SlowAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values):
64  tensor_proto.half_val.extend(
65      [ExtractBitsFromBFloat16(x) for x in proto_values])
66
67
68def FastAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values):
69  fast_tensor_util.AppendBFloat16ArrayToTensorProto(
70      tensor_proto, np.asarray(
71          proto_values, dtype=dtypes.bfloat16.as_numpy_dtype).view(np.uint16))
72
73
74if _FAST_TENSOR_UTIL_AVAILABLE:
75  _NP_TO_APPEND_FN = {
76      dtypes.bfloat16.as_numpy_dtype:
77          FastAppendBFloat16ArrayToTensorProto,
78      np.float16:
79          _MediumAppendFloat16ArrayToTensorProto,
80      np.float32:
81          fast_tensor_util.AppendFloat32ArrayToTensorProto,
82      np.float64:
83          fast_tensor_util.AppendFloat64ArrayToTensorProto,
84      np.int32:
85          fast_tensor_util.AppendInt32ArrayToTensorProto,
86      np.int64:
87          fast_tensor_util.AppendInt64ArrayToTensorProto,
88      np.uint8:
89          fast_tensor_util.AppendUInt8ArrayToTensorProto,
90      np.uint16:
91          fast_tensor_util.AppendUInt16ArrayToTensorProto,
92      np.uint32:
93          fast_tensor_util.AppendUInt32ArrayToTensorProto,
94      np.uint64:
95          fast_tensor_util.AppendUInt64ArrayToTensorProto,
96      np.int8:
97          fast_tensor_util.AppendInt8ArrayToTensorProto,
98      np.int16:
99          fast_tensor_util.AppendInt16ArrayToTensorProto,
100      np.complex64:
101          fast_tensor_util.AppendComplex64ArrayToTensorProto,
102      np.complex128:
103          fast_tensor_util.AppendComplex128ArrayToTensorProto,
104      np.object_:
105          fast_tensor_util.AppendObjectArrayToTensorProto,
106      np.bool_:
107          fast_tensor_util.AppendBoolArrayToTensorProto,
108      dtypes.qint8.as_numpy_dtype:
109          fast_tensor_util.AppendInt8ArrayToTensorProto,
110      dtypes.quint8.as_numpy_dtype:
111          fast_tensor_util.AppendUInt8ArrayToTensorProto,
112      dtypes.qint16.as_numpy_dtype:
113          fast_tensor_util.AppendInt16ArrayToTensorProto,
114      dtypes.quint16.as_numpy_dtype:
115          fast_tensor_util.AppendUInt16ArrayToTensorProto,
116      dtypes.qint32.as_numpy_dtype:
117          fast_tensor_util.AppendInt32ArrayToTensorProto,
118      # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16.
119  }
120else:
121
122  def SlowAppendFloat32ArrayToTensorProto(tensor_proto, proto_values):
123    tensor_proto.float_val.extend([x.item() for x in proto_values])
124
125  def SlowAppendFloat64ArrayToTensorProto(tensor_proto, proto_values):
126    tensor_proto.double_val.extend([x.item() for x in proto_values])
127
128  def SlowAppendIntArrayToTensorProto(tensor_proto, proto_values):
129    tensor_proto.int_val.extend([x.item() for x in proto_values])
130
131  def SlowAppendInt64ArrayToTensorProto(tensor_proto, proto_values):
132    tensor_proto.int64_val.extend([x.item() for x in proto_values])
133
134  def SlowAppendQIntArrayToTensorProto(tensor_proto, proto_values):
135    tensor_proto.int_val.extend([x.item()[0] for x in proto_values])
136
137  def SlowAppendUInt32ArrayToTensorProto(tensor_proto, proto_values):
138    tensor_proto.uint32_val.extend([x.item() for x in proto_values])
139
140  def SlowAppendUInt64ArrayToTensorProto(tensor_proto, proto_values):
141    tensor_proto.uint64_val.extend([x.item() for x in proto_values])
142
143  def SlowAppendComplex64ArrayToTensorProto(tensor_proto, proto_values):
144    tensor_proto.scomplex_val.extend(
145        [v.item() for x in proto_values for v in [x.real, x.imag]])
146
147  def SlowAppendComplex128ArrayToTensorProto(tensor_proto, proto_values):
148    tensor_proto.dcomplex_val.extend(
149        [v.item() for x in proto_values for v in [x.real, x.imag]])
150
151  def SlowAppendObjectArrayToTensorProto(tensor_proto, proto_values):
152    tensor_proto.string_val.extend([compat.as_bytes(x) for x in proto_values])
153
154  def SlowAppendBoolArrayToTensorProto(tensor_proto, proto_values):
155    tensor_proto.bool_val.extend([x.item() for x in proto_values])
156
157  _NP_TO_APPEND_FN = {
158      dtypes.bfloat16.as_numpy_dtype: SlowAppendBFloat16ArrayToTensorProto,
159      np.float16: SlowAppendFloat16ArrayToTensorProto,
160      np.float32: SlowAppendFloat32ArrayToTensorProto,
161      np.float64: SlowAppendFloat64ArrayToTensorProto,
162      np.int32: SlowAppendIntArrayToTensorProto,
163      np.int64: SlowAppendInt64ArrayToTensorProto,
164      np.uint8: SlowAppendIntArrayToTensorProto,
165      np.uint16: SlowAppendIntArrayToTensorProto,
166      np.uint32: SlowAppendUInt32ArrayToTensorProto,
167      np.uint64: SlowAppendUInt64ArrayToTensorProto,
168      np.int8: SlowAppendIntArrayToTensorProto,
169      np.int16: SlowAppendIntArrayToTensorProto,
170      np.complex64: SlowAppendComplex64ArrayToTensorProto,
171      np.complex128: SlowAppendComplex128ArrayToTensorProto,
172      np.object_: SlowAppendObjectArrayToTensorProto,
173      np.bool_: SlowAppendBoolArrayToTensorProto,
174      dtypes.qint8.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
175      dtypes.quint8.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
176      dtypes.qint16.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
177      dtypes.quint16.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
178      dtypes.qint32.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
179      # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16.
180  }
181
182
183def GetFromNumpyDTypeDict(dtype_dict, dtype):
184  # NOTE: dtype_dict.get(dtype) always returns None.
185  for key, val in dtype_dict.items():
186    if key == dtype:
187      return val
188  return None
189
190
191def GetNumpyAppendFn(dtype):
192  # numpy dtype for strings are variable length. We can not compare
193  # dtype with a single constant (np.string does not exist) to decide
194  # dtype is a "string" type. We need to compare the dtype.type to be
195  # sure it's a string type.
196  if dtype.type == np.bytes_ or dtype.type == np.str_:
197    if _FAST_TENSOR_UTIL_AVAILABLE:
198      return fast_tensor_util.AppendObjectArrayToTensorProto
199    else:
200      return SlowAppendObjectArrayToTensorProto
201  return GetFromNumpyDTypeDict(_NP_TO_APPEND_FN, dtype)
202
203
204def TensorShapeProtoToList(shape):
205  """Convert a TensorShape to a list.
206
207  Args:
208    shape: A TensorShapeProto.
209
210  Returns:
211    List of integers representing the dimensions of the tensor.
212  """
213  return [dim.size for dim in shape.dim]
214
215
216def _GetDenseDimensions(list_of_lists):
217  """Returns the inferred dense dimensions of a list of lists."""
218  if not isinstance(list_of_lists, (list, tuple)):
219    return []
220  elif not list_of_lists:
221    return [0]
222  else:
223    return [len(list_of_lists)] + _GetDenseDimensions(list_of_lists[0])
224
225
226def _FlattenToStrings(nested_strings):
227  if isinstance(nested_strings, (list, tuple)):
228    for inner in nested_strings:
229      for flattened_string in _FlattenToStrings(inner):
230        yield flattened_string
231  else:
232    yield nested_strings
233
234
235_TENSOR_CONTENT_TYPES = frozenset([
236    dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32, dtypes.uint8,
237    dtypes.int16, dtypes.int8, dtypes.int64, dtypes.qint8, dtypes.quint8,
238    dtypes.qint16, dtypes.quint16, dtypes.qint32, dtypes.uint32, dtypes.uint64
239])
240
241
242# pylint: disable=invalid-name
243def _check_failed(v):
244  # NB. none of the _check_* functions could raise a ValueError, so
245  # it is safe to use here.
246  raise ValueError(v)
247
248
249def _check_quantized(values):
250  # Cannot rely on `nest` because the leaves are tuples.
251  if not isinstance(values, (list, tuple)):
252    _check_failed(values)
253  if isinstance(values, tuple):
254    _ = [_check_int(v) for v in values]
255  else:
256    _ = [_check_quantized(v) for v in values]
257
258
259def _generate_isinstance_check(expected_types):
260  def inner(values):
261    for v in nest.flatten(values):
262      if not (isinstance(v, expected_types) or
263              (isinstance(v, np.ndarray) and
264               issubclass(v.dtype.type, expected_types))):
265        _check_failed(v)
266
267  return inner
268
269_check_int = _generate_isinstance_check(
270    (compat.integral_types, tensor_shape.Dimension))
271_check_float = _generate_isinstance_check(compat.real_types)
272_check_complex = _generate_isinstance_check(compat.complex_types)
273_check_str = _generate_isinstance_check(compat.bytes_or_text_types)
274_check_bool = _generate_isinstance_check(bool)
275
276
277def _check_not_tensor(values):
278  _ = [_check_failed(v) for v in nest.flatten(values)
279       if isinstance(v, ops.Tensor)]
280# pylint: enable=invalid-name
281
282_TF_TO_IS_OK = {
283    dtypes.bool: _check_bool,
284    dtypes.complex128: _check_complex,
285    dtypes.complex64: _check_complex,
286    dtypes.float16: _check_float,
287    dtypes.float32: _check_float,
288    dtypes.float64: _check_float,
289    dtypes.int16: _check_int,
290    dtypes.int32: _check_int,
291    dtypes.int64: _check_int,
292    dtypes.int8: _check_int,
293    dtypes.qint16: _check_quantized,
294    dtypes.qint32: _check_quantized,
295    dtypes.qint8: _check_quantized,
296    dtypes.quint16: _check_quantized,
297    dtypes.quint8: _check_quantized,
298    dtypes.string: _check_str,
299    dtypes.uint16: _check_int,
300    dtypes.uint8: _check_int,
301    dtypes.uint32: _check_int,
302    dtypes.uint64: _check_int,
303}
304
305
306def _AssertCompatible(values, dtype):
307  if dtype is None:
308    fn = _check_not_tensor
309  else:
310    try:
311      fn = _TF_TO_IS_OK[dtype]
312    except KeyError:
313      # There isn't a specific fn, so we try to do the best possible.
314      if dtype.is_integer:
315        fn = _check_int
316      elif dtype.is_floating:
317        fn = _check_float
318      elif dtype.is_complex:
319        fn = _check_complex
320      elif dtype.is_quantized:
321        fn = _check_quantized
322      else:
323        fn = _check_not_tensor
324
325  try:
326    fn(values)
327  except ValueError as e:
328    [mismatch] = e.args
329    if dtype is None:
330      raise TypeError("Expected any non-tensor type, but got a tensor instead.")
331    else:
332      raise TypeError(f"Expected {dtype.name}, but got {mismatch} of type "
333                      f"'{type(mismatch).__name__}'.")
334
335
336def _is_array_like(obj):  # pylint: disable=invalid-name
337  """Check if a given object is array-like."""
338  if isinstance(obj, ops.Tensor) and not isinstance(obj, ops._EagerTensorBase):  # pylint: disable=protected-access
339    # Tensor implements __array__ only so it can inform the user that it is not
340    # a valid array.
341    return False
342
343  # TODO(slebedev): an object could also implement C-level array interface.
344  if (callable(getattr(obj, "__array__", None)) or
345      isinstance(getattr(obj, "__array_interface__", None), dict)):
346    return True
347
348  try:
349    memoryview(obj)
350  except TypeError:
351    return False
352  else:
353    return not isinstance(obj, bytes)
354
355
356# pylint: disable=invalid-name
357@tf_export("make_tensor_proto")
358def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False,
359                      allow_broadcast=False):
360  """Create a TensorProto.
361
362  In TensorFlow 2.0, representing tensors as protos should no longer be a
363  common workflow. That said, this utility function is still useful for
364  generating TF Serving request protos:
365
366  ```python
367    request = tensorflow_serving.apis.predict_pb2.PredictRequest()
368    request.model_spec.name = "my_model"
369    request.model_spec.signature_name = "serving_default"
370    request.inputs["images"].CopyFrom(tf.make_tensor_proto(X_new))
371  ```
372
373  `make_tensor_proto` accepts "values" of a python scalar, a python list, a
374  numpy ndarray, or a numpy scalar.
375
376  If "values" is a python scalar or a python list, make_tensor_proto
377  first convert it to numpy ndarray. If dtype is None, the
378  conversion tries its best to infer the right numpy data
379  type. Otherwise, the resulting numpy array has a compatible data
380  type with the given dtype.
381
382  In either case above, the numpy ndarray (either the caller provided
383  or the auto-converted) must have the compatible type with dtype.
384
385  `make_tensor_proto` then converts the numpy array to a tensor proto.
386
387  If "shape" is None, the resulting tensor proto represents the numpy
388  array precisely.
389
390  Otherwise, "shape" specifies the tensor's shape and the numpy array
391  can not have more elements than what "shape" specifies.
392
393  Args:
394    values:         Values to put in the TensorProto.
395    dtype:          Optional tensor_pb2 DataType value.
396    shape:          List of integers representing the dimensions of tensor.
397    verify_shape:   Boolean that enables verification of a shape of values.
398    allow_broadcast:  Boolean that enables allowing scalars and 1 length vector
399        broadcasting. Cannot be true when verify_shape is true.
400
401  Returns:
402    A `TensorProto`. Depending on the type, it may contain data in the
403    "tensor_content" attribute, which is not directly useful to Python programs.
404    To access the values you should convert the proto back to a numpy ndarray
405    with `tf.make_ndarray(proto)`.
406
407    If `values` is a `TensorProto`, it is immediately returned; `dtype` and
408    `shape` are ignored.
409
410  Raises:
411    TypeError:  if unsupported types are provided.
412    ValueError: if arguments have inappropriate values or if verify_shape is
413     True and shape of values is not equals to a shape from the argument.
414
415  """
416  if allow_broadcast and verify_shape:
417    raise ValueError("allow_broadcast and verify_shape are not both allowed.")
418  if isinstance(values, tensor_pb2.TensorProto):
419    return values
420
421  if dtype:
422    dtype = dtypes.as_dtype(dtype)
423
424  is_quantized = (
425      dtype in [
426          dtypes.qint8, dtypes.quint8, dtypes.qint16, dtypes.quint16,
427          dtypes.qint32
428      ])
429
430  if _is_array_like(values):
431    values = np.asarray(values)
432
433  # We first convert value to a numpy array or scalar.
434  if isinstance(values, (np.ndarray, np.generic)):
435    if dtype and dtype.is_numpy_compatible:
436      nparray = values.astype(dtype.as_numpy_dtype)
437    else:
438      nparray = values
439  else:
440    if values is None:
441      raise ValueError("None values not supported.")
442    # if dtype is provided, forces numpy array to be the type
443    # provided if possible.
444    if dtype and dtype.is_numpy_compatible:
445      np_dt = dtype.as_numpy_dtype
446    else:
447      np_dt = None
448    # If shape is None, numpy.prod returns None when dtype is not set, but
449    # raises exception when dtype is set to np.int64
450    if shape is not None and np.prod(shape, dtype=np.int64) == 0:
451      nparray = np.empty(shape, dtype=np_dt)
452    else:
453      _AssertCompatible(values, dtype)
454      nparray = np.array(values, dtype=np_dt)
455      # check to them.
456      # We need to pass in quantized values as tuples, so don't apply the shape
457      if (list(nparray.shape) != _GetDenseDimensions(values) and
458          not is_quantized):
459        raise ValueError(f"Expected values {values} to be a dense tensor with "
460                         f"shape {_GetDenseDimensions(values)}, but got shape "
461                         f"{list(nparray.shape)}.")
462
463    # python/numpy default float type is float64. We prefer float32 instead.
464    if (nparray.dtype == np.float64) and dtype is None:
465      nparray = nparray.astype(np.float32)
466    # python/numpy default int type is int64. We prefer int32 instead.
467    elif (nparray.dtype == np.int64) and dtype is None:
468      downcasted_array = nparray.astype(np.int32)
469      # Do not down cast if it leads to precision loss.
470      if np.array_equal(downcasted_array, nparray):
471        nparray = downcasted_array
472
473  # if dtype is provided, it must be compatible with what numpy
474  # conversion says.
475  numpy_dtype = dtypes.as_dtype(nparray.dtype)
476  if numpy_dtype is None:
477    raise TypeError(f"Unrecognized data type: {nparray.dtype}.")
478
479  # If dtype was specified and is a quantized type, we convert
480  # numpy_dtype back into the quantized version.
481  if is_quantized:
482    numpy_dtype = dtype
483
484  if dtype is not None and (not hasattr(dtype, "base_dtype") or
485                            dtype.base_dtype != numpy_dtype.base_dtype):
486    raise TypeError(f"`dtype` {dtype} is not compatible with {values} of "
487                    f"dtype {nparray.dtype}.")
488
489  # If shape is not given, get the shape from the numpy array.
490  if shape is None:
491    shape = nparray.shape
492    is_same_size = True
493    shape_size = nparray.size
494  else:
495    shape = [int(dim) for dim in shape]
496    shape_size = np.prod(shape, dtype=np.int64)
497    is_same_size = shape_size == nparray.size
498
499    if allow_broadcast:
500      if nparray.shape == (1,) or nparray.shape == tuple():
501        pass
502      elif nparray.size != shape_size:
503        raise TypeError(f"Expected Tensor's shape: {tuple(shape)}, but got "
504                        f"{nparray.shape}.")
505
506    else:
507      if verify_shape and nparray.shape != tuple(shape):
508        raise TypeError(f"Expected Tensor's shape: {tuple(shape)}, but got "
509                        f"{nparray.shape}.")
510
511      if nparray.size > shape_size:
512        raise ValueError("Too many elements provided. Takes at most "
513                         f"{shape_size:d}, but got {nparray.size:d}.")
514
515  tensor_proto = tensor_pb2.TensorProto(
516      dtype=numpy_dtype.as_datatype_enum,
517      tensor_shape=tensor_shape.as_shape(shape).as_proto())
518
519  if is_same_size and numpy_dtype in _TENSOR_CONTENT_TYPES and shape_size > 1:
520    if nparray.size * nparray.itemsize >= (1 << 31):
521      raise ValueError(
522          "Cannot create a tensor proto whose content is larger than 2GB.")
523    tensor_proto.tensor_content = nparray.tobytes()
524    return tensor_proto
525
526  # If we were not given values as a numpy array, compute the proto_values
527  # from the given values directly, to avoid numpy trimming nulls from the
528  # strings. Since values could be a list of strings, or a multi-dimensional
529  # list of lists that might or might not correspond to the given shape,
530  # we flatten it conservatively.
531  if numpy_dtype == dtypes.string and not isinstance(values, np.ndarray):
532    proto_values = _FlattenToStrings(values)
533
534    # At this point, values may be a list of objects that we could not
535    # identify a common type for (hence it was inferred as
536    # np.object_/dtypes.string).  If we are unable to convert it to a
537    # string, we raise a more helpful error message.
538    #
539    # Ideally, we'd be able to convert the elements of the list to a
540    # common type, but this type inference requires some thinking and
541    # so we defer it for now.
542    try:
543      str_values = [compat.as_bytes(x) for x in proto_values]
544    except TypeError:
545      raise TypeError(f"Failed to convert elements of {values} to Tensor. "
546                      "Consider casting elements to a supported type. See "
547                      "https://www.tensorflow.org/api_docs/python/tf/dtypes "
548                      "for supported TF dtypes.")
549    tensor_proto.string_val.extend(str_values)
550    return tensor_proto
551
552  # TensorFlow expects C order (a.k.a., eigen row major).
553  proto_values = nparray.ravel()
554
555  append_fn = GetNumpyAppendFn(proto_values.dtype)
556  if append_fn is None:
557    raise TypeError(
558        f"Element type not supported in TensorProto: {numpy_dtype.name}.")
559  append_fn(tensor_proto, proto_values)
560
561  return tensor_proto
562# pylint: enable=invalid-name
563
564
565@tf_export("make_ndarray")
566def MakeNdarray(tensor):
567  """Create a numpy ndarray from a tensor.
568
569  Create a numpy ndarray with the same shape and data as the tensor.
570
571  For example:
572
573  ```python
574  # Tensor a has shape (2,3)
575  a = tf.constant([[1,2,3],[4,5,6]])
576  proto_tensor = tf.make_tensor_proto(a)  # convert `tensor a` to a proto tensor
577  tf.make_ndarray(proto_tensor) # output: array([[1, 2, 3],
578  #                                              [4, 5, 6]], dtype=int32)
579  # output has shape (2,3)
580  ```
581
582  Args:
583    tensor: A TensorProto.
584
585  Returns:
586    A numpy array with the tensor contents.
587
588  Raises:
589    TypeError: if tensor has unsupported type.
590
591  """
592  shape = [d.size for d in tensor.tensor_shape.dim]
593  num_elements = np.prod(shape, dtype=np.int64)
594  tensor_dtype = dtypes.as_dtype(tensor.dtype)
595  dtype = tensor_dtype.as_numpy_dtype
596
597  if tensor.tensor_content:
598    return (np.frombuffer(tensor.tensor_content,
599                          dtype=dtype).copy().reshape(shape))
600
601  if tensor_dtype == dtypes.string:
602    # np.pad throws on these arrays of type np.object_.
603    values = list(tensor.string_val)
604    padding = num_elements - len(values)
605    if padding > 0:
606      last = values[-1] if values else ""
607      values.extend([last] * padding)
608    return np.array(values, dtype=dtype).reshape(shape)
609
610  if tensor_dtype == dtypes.float16 or tensor_dtype == dtypes.bfloat16:
611    # the half_val field of the TensorProto stores the binary representation
612    # of the fp16: we need to reinterpret this as a proper float16
613    values = np.fromiter(tensor.half_val, dtype=np.uint16)
614    values.dtype = tensor_dtype.as_numpy_dtype
615  elif tensor_dtype == dtypes.float32:
616    values = np.fromiter(tensor.float_val, dtype=dtype)
617  elif tensor_dtype == dtypes.float64:
618    values = np.fromiter(tensor.double_val, dtype=dtype)
619  elif tensor_dtype in [
620      dtypes.int32, dtypes.uint8, dtypes.uint16, dtypes.int16, dtypes.int8,
621      dtypes.qint32, dtypes.quint8, dtypes.qint8, dtypes.qint16, dtypes.quint16
622  ]:
623    values = np.fromiter(tensor.int_val, dtype=dtype)
624  elif tensor_dtype == dtypes.int64:
625    values = np.fromiter(tensor.int64_val, dtype=dtype)
626  elif tensor_dtype == dtypes.uint32:
627    values = np.fromiter(tensor.uint32_val, dtype=dtype)
628  elif tensor_dtype == dtypes.uint64:
629    values = np.fromiter(tensor.uint64_val, dtype=dtype)
630  elif tensor_dtype == dtypes.complex64:
631    it = iter(tensor.scomplex_val)
632    values = np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype)
633  elif tensor_dtype == dtypes.complex128:
634    it = iter(tensor.dcomplex_val)
635    values = np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype)
636  elif tensor_dtype == dtypes.bool:
637    values = np.fromiter(tensor.bool_val, dtype=dtype)
638  else:
639    raise TypeError(f"Unsupported tensor type: {tensor.dtype}. See "
640                    "https://www.tensorflow.org/api_docs/python/tf/dtypes "
641                    "for supported TF dtypes.")
642
643  if values.size == 0:
644    return np.zeros(shape, dtype)
645
646  if values.size != num_elements:
647    values = np.pad(values, (0, num_elements - values.size), "edge")
648
649  return values.reshape(shape)
650
651
652def ShapeEquals(tensor_proto, shape):
653  """Returns True if "tensor_proto" has the given "shape".
654
655  Args:
656    tensor_proto: A TensorProto.
657    shape: A tensor shape, expressed as a TensorShape, list, or tuple.
658
659  Returns:
660    True if "tensor_proto" has the given "shape", otherwise False.
661
662  Raises:
663    TypeError: If "tensor_proto" is not a TensorProto, or shape is not a
664      TensorShape, list, or tuple.
665  """
666  if not isinstance(tensor_proto, tensor_pb2.TensorProto):
667    raise TypeError("`tensor_proto` must be a tensor_pb2.TensorProto object, "
668                    f"but got type {type(tensor_proto)}.")
669  if isinstance(shape, tensor_shape_pb2.TensorShapeProto):
670    shape = [d.size for d in shape.dim]
671  elif not isinstance(shape, (list, tuple)):
672    raise TypeError("`shape` must be a list or tuple, but got type "
673                    f"{type(shape)}.")
674  tensor_shape_list = [d.size for d in tensor_proto.tensor_shape.dim]
675  return all(x == y for x, y in zip(tensor_shape_list, shape))
676
677
678def _ConstantValue(tensor, partial):
679  # TODO(touts): Support Variables?
680  if not isinstance(tensor, ops.Tensor):
681    raise TypeError(f"{tensor!r} must be a Tensor, but got {type(tensor)}.")
682  if tensor.op.type == "Const":
683    return MakeNdarray(tensor.op.get_attr("value"))
684  elif tensor.op.type == "Shape":
685    input_shape = tensor.op.inputs[0].get_shape()
686    if input_shape.is_fully_defined():
687      return np.array(
688          [dim.value for dim in input_shape.dims],
689          dtype=tensor.dtype.as_numpy_dtype)
690    else:
691      return None
692  elif tensor.op.type == "Size":
693    input_shape = tensor.op.inputs[0].get_shape()
694    if input_shape.is_fully_defined():
695      return np.prod([dim.value for dim in input_shape.dims], dtype=np.int32)
696    else:
697      return None
698  elif tensor.op.type == "Rank":
699    input_shape = tensor.op.inputs[0].get_shape()
700    if input_shape.ndims is not None:
701      return np.ndarray(
702          shape=(),
703          buffer=np.array([input_shape.ndims], dtype=np.int32),
704          dtype=np.int32)
705    else:
706      return None
707  elif tensor.op.type == "Range":
708    start = constant_value(tensor.op.inputs[0])
709    if start is None:
710      return None
711    limit = constant_value(tensor.op.inputs[1])
712    if limit is None:
713      return None
714    delta = constant_value(tensor.op.inputs[2])
715    if delta is None:
716      return None
717    return np.arange(start, limit, delta, dtype=tensor.dtype.as_numpy_dtype)
718  elif tensor.op.type == "Cast":
719    pre_cast = constant_value(tensor.op.inputs[0])
720    if pre_cast is None:
721      return None
722    cast_dtype = dtypes.as_dtype(tensor.op.get_attr("DstT"))
723    return pre_cast.astype(cast_dtype.as_numpy_dtype)
724  elif tensor.op.type == "Concat":
725    dim = constant_value(tensor.op.inputs[0])
726    if dim is None:
727      return None
728    values = []
729    for x in tensor.op.inputs[1:]:
730      value = constant_value(x)
731      if value is None:
732        return None
733      values.append(value)
734    return np.concatenate(values, axis=dim)
735  elif tensor.op.type == "ConcatV2":
736    dim = constant_value(tensor.op.inputs[-1])
737    if dim is None:
738      return None
739    values = []
740    for x in tensor.op.inputs[:-1]:
741      value = constant_value(x)
742      if value is None:
743        return None
744      values.append(value)
745    return np.concatenate(values, axis=dim)
746  elif tensor.op.type == "Pack":
747    values = []
748    # Some imported GraphDefs have Pack ops with zero inputs. Those are invalid
749    # and shouldn't be produced, but to deal sensibly with them here we check
750    # and return None.
751    if not tensor.op.inputs:
752      return None
753    # We can't handle axis != 0 Packs at the moment.
754    if tensor.op.get_attr("axis") != 0:
755      return None
756    for x in tensor.op.inputs:
757      value = constant_value(x, partial)
758      if value is None and not partial:
759        return None
760      values.append(value)
761    return np.array(values)
762  elif tensor.op.type == "Unpack":
763    # We can't handle axis != 0 Unpacks at the moment.
764    if tensor.op.get_attr("axis") != 0:
765      return None
766    value = constant_value(tensor.op.inputs[0], partial)
767    if value is None:
768      return None
769    return value[tensor.value_index]
770  elif tensor.op.type == "Split":
771    dim = constant_value(tensor.op.inputs[0])
772    value = constant_value(tensor.op.inputs[1], partial)
773    if value is None or dim is None:
774      return None
775    split = np.split(value, tensor.op.get_attr("num_split"), dim)
776    return split[tensor.value_index]
777  elif tensor.op.type == "Fill":
778    fill_shape = tensor.shape
779    fill_value = constant_value(tensor.op.inputs[1])
780    if fill_shape.is_fully_defined() and fill_value is not None:
781      return np.full(fill_shape.as_list(), fill_value, dtype=fill_value.dtype)
782    else:
783      return None
784  elif tensor.op.type == "Equal":
785    value1 = constant_value(tensor.op.inputs[0])
786    if value1 is None:
787      return None
788    value2 = constant_value(tensor.op.inputs[1])
789    if value2 is None:
790      return None
791    return np.equal(value1, value2)
792  elif tensor.op.type == "NotEqual":
793    value1 = constant_value(tensor.op.inputs[0])
794    if value1 is None:
795      return None
796    value2 = constant_value(tensor.op.inputs[1])
797    if value2 is None:
798      return None
799    return np.not_equal(value1, value2)
800  elif tensor.op.type == "StopGradient":
801    return constant_value(tensor.op.inputs[0], partial)
802  elif tensor.op.type in ("CheckNumericsV2", "DebugIdentityV2", "Identity"):
803    return constant_value(tensor.op.inputs[0], partial)
804  else:
805    return None
806
807
808@tf_export("get_static_value")
809def constant_value(tensor, partial=False):  # pylint: disable=invalid-name
810  """Returns the constant value of the given tensor, if efficiently calculable.
811
812  This function attempts to partially evaluate the given tensor, and
813  returns its value as a numpy ndarray if this succeeds.
814
815  Example usage:
816
817  >>> a = tf.constant(10)
818  >>> tf.get_static_value(a)
819  10
820  >>> b = tf.constant(20)
821  >>> tf.get_static_value(tf.add(a, b))
822  30
823
824  >>> # `tf.Variable` is not supported.
825  >>> c = tf.Variable(30)
826  >>> print(tf.get_static_value(c))
827  None
828
829  Using `partial` option is most relevant when calling `get_static_value` inside
830  a `tf.function`. Setting it to `True` will return the results but for the
831  values that cannot be evaluated will be `None`. For example:
832
833  ```python
834  class Foo:
835    def __init__(self):
836      self.a = tf.Variable(1)
837      self.b = tf.constant(2)
838
839    @tf.function
840    def bar(self, partial):
841      packed = tf.raw_ops.Pack(values=[self.a, self.b])
842      static_val = tf.get_static_value(packed, partial=partial)
843      tf.print(static_val)
844
845  f = Foo()
846  f.bar(partial=True)  # `array([None, array(2, dtype=int32)], dtype=object)`
847  f.bar(partial=False)  # `None`
848  ```
849
850  Compatibility(V1): If `constant_value(tensor)` returns a non-`None` result, it
851  will no longer be possible to feed a different value for `tensor`. This allows
852  the result of this function to influence the graph that is constructed, and
853  permits static shape optimizations.
854
855  Args:
856    tensor: The Tensor to be evaluated.
857    partial: If True, the returned numpy array is allowed to have partially
858      evaluated values. Values that can't be evaluated will be None.
859
860  Returns:
861    A numpy ndarray containing the constant value of the given `tensor`,
862    or None if it cannot be calculated.
863
864  Raises:
865    TypeError: if tensor is not an ops.Tensor.
866  """
867  if isinstance(tensor, ops.EagerTensor):
868    try:
869      return tensor.numpy()
870    except errors_impl.UnimplementedError:
871      # Some EagerTensors may not implement .numpy/resolve, e.g. parallel
872      # tensors with multiple components on different devices.
873      return None
874  if not is_tensor(tensor):
875    return tensor
876  if not isinstance(tensor, ops.Tensor):
877    return None
878  ret = _ConstantValue(tensor, partial)
879  if ret is not None:
880    # The caller may now depend on the constant value of `tensor`, so we
881    # conservatively prevent it from being fed.
882    tensor.graph.prevent_feeding(tensor)
883  return ret
884
885
886def constant_value_as_shape(tensor):  # pylint: disable=invalid-name
887  """A version of `constant_value()` that returns a `TensorShape`.
888
889  This version should be used when a constant tensor value is
890  interpreted as a (possibly partial) shape, e.g. in the shape
891  function for `tf.reshape()`. By explicitly requesting a
892  `TensorShape` as the return value, it is possible to represent
893  unknown dimensions; by contrast, `constant_value()` is
894  all-or-nothing.
895
896  Args:
897    tensor: The rank-0 or rank-1 Tensor to be evaluated.
898
899  Returns:
900    A `TensorShape` based on the constant value of the given `tensor`.
901
902  Raises:
903    ValueError: If the shape is rank-0 and is not statically known to be -1.
904  """
905  if isinstance(tensor, ops.EagerTensor):
906    return tensor_shape.TensorShape(
907        [dim if dim != -1 else None for dim in tensor.numpy()])
908
909  if tensor.get_shape().ndims == 0:
910    value = constant_value(tensor)
911    if value is None:
912      raise ValueError(
913          "Received a scalar with unknown value as shape; require a statically "
914          "known scalar with value '-1' to describe an unknown shape.")
915    if value != -1:
916      raise ValueError(
917          f"Received a scalar value '{value}' as shape; require a statically "
918          "known scalar with value '-1' to describe an unknown shape.")
919    return tensor_shape.unknown_shape()
920
921  shape = tensor.get_shape().with_rank(1)
922  if shape == [0]:
923    return tensor_shape.TensorShape([])
924  elif tensor.op.type == "Cast":
925    pre_cast = constant_value_as_shape(tensor.op.inputs[0])
926    if pre_cast.dims is None:
927      # the input to cast has a totally undefined shape; just return that.
928      return pre_cast
929    cast_dtype = dtypes.as_dtype(tensor.op.get_attr("DstT"))
930    if cast_dtype not in (dtypes.int32, dtypes.int64):
931      return tensor_shape.unknown_shape(shape.dims[0].value)
932    dest_dtype_shape_array = np.array(
933        [x if x is not None else -1 for x in pre_cast.as_list()]).astype(
934            cast_dtype.as_numpy_dtype)
935    return tensor_shape.TensorShape([
936        x if x >= 0 else None
937        for x in dest_dtype_shape_array])
938  elif tensor.op.type == "Shape":
939    return tensor.op.inputs[0].get_shape()
940  elif tensor.op.type == "Pack":
941    ret = tensor_shape.TensorShape([])  # Empty list.
942    # Since we expect rank 1 inputs, Pack's axis must be zero, otherwise it
943    # would not be rank 1.
944    assert tensor.op.get_attr("axis") == 0
945    for pack_input in tensor.op.inputs:
946      # `pack_input` must be a scalar. Attempt to evaluate it, and append it
947      # to `ret`.
948      pack_input_val = constant_value(pack_input)
949      if pack_input_val is None or pack_input_val < 0:
950        new_dim = tensor_shape.Dimension(None)
951      else:
952        new_dim = tensor_shape.Dimension(pack_input_val)
953      ret = ret.concatenate([new_dim])
954    return ret
955  elif tensor.op.type == "Concat":
956    # We assume that `tensor.op.inputs[0]` evaluates to 0, as this is
957    # the only legal value when concatenating vectors, and it will
958    # have been checked by a previous shape function.
959    ret = tensor_shape.TensorShape([])  # Empty list.
960    for concat_input in tensor.op.inputs[1:]:
961      # `concat_input` must be a vector. Attempt to evaluate it as a shape,
962      # and concatenate it with `ret`.
963      ret = ret.concatenate(constant_value_as_shape(concat_input))
964    return ret
965  elif tensor.op.type == "ConcatV2":
966    # We assume that `tensor.op.inputs[-1]` evaluates to 0, as this is
967    # the only legal value when concatenating vectors, and it will
968    # have been checked by a previous shape function.
969    ret = tensor_shape.TensorShape([])  # Empty list.
970    for concat_input in tensor.op.inputs[:-1]:
971      # `concat_input` must be a vector. Attempt to evaluate it as a shape,
972      # and concatenate it with `ret`.
973      ret = ret.concatenate(constant_value_as_shape(concat_input))
974    return ret
975  elif tensor.op.type == "StridedSlice":
976    try:
977      begin = constant_value(tensor.op.inputs[1])
978      end = constant_value(tensor.op.inputs[2])
979      strides = constant_value(tensor.op.inputs[3])
980      if begin is not None and end is not None and strides is not None:
981        begin = begin[0]
982        end = end[0]
983        strides = strides[0]
984        begin_mask = tensor.op.get_attr("begin_mask")
985        if begin_mask == 1:
986          begin = None
987        end_mask = tensor.op.get_attr("end_mask")
988        if end_mask == 1:
989          end = None
990
991        ellipsis_mask = tensor.op.get_attr("ellipsis_mask")
992        new_axis_mask = tensor.op.get_attr("new_axis_mask")
993        shrink_axis_mask = tensor.op.get_attr("shrink_axis_mask")
994        valid_attributes = (not ellipsis_mask and not new_axis_mask and
995                            not shrink_axis_mask and (not begin_mask or
996                                                      (begin_mask == 1)) and
997                            (not end_mask or (end_mask == 1)))
998        if valid_attributes:  # additional inputs not supported
999          prev = constant_value_as_shape(tensor.op.inputs[0])
1000          prev = prev[begin:end:strides]
1001          ret = tensor_shape.TensorShape(prev)
1002          return ret
1003
1004    except ValueError:  # Could come from get_attr or slicing prev.
1005      pass
1006    except TypeError:  # Could come from slicing prev.
1007      pass
1008  elif (tensor.op.type == "Placeholder" and
1009        tensor.op.graph.building_function and
1010        hasattr(tensor.op.graph, "internal_captures")):
1011    # If we are inside a FuncGraph try to lookup the constant value of the
1012    # corresponding external capture. Note that we only look at captures and
1013    # not the fed inputs because those can be fed different values in different
1014    # instantiations of the function call or different iterations of a
1015    # tf.while_loop.
1016    for i, capture in enumerate(tensor.op.graph.internal_captures):
1017      if capture is tensor:
1018        external_capture = tensor.op.graph.external_captures[i]
1019        return constant_value_as_shape(external_capture)
1020
1021  ret = tensor_shape.unknown_shape(shape.dims[0].value)
1022  value = constant_value(tensor)
1023  if value is not None:
1024    ret = ret.merge_with(
1025        tensor_shape.TensorShape([d if d >= 0 else None for d in value]))
1026  return ret
1027
1028
1029# TODO(mdan): Deprecate in favor of more static-friendly types.
1030@tf_export("is_tensor")
1031def is_tf_type(x):  # pylint: disable=invalid-name
1032  """Checks whether `x` is a TF-native type that can be passed to many TF ops.
1033
1034  Use `is_tensor` to differentiate types that can ingested by TensorFlow ops
1035  without any conversion (e.g., `tf.Tensor`, `tf.SparseTensor`, and
1036  `tf.RaggedTensor`) from types that need to be converted into tensors before
1037  they are ingested (e.g., numpy `ndarray` and Python scalars).
1038
1039  For example, in the following code block:
1040
1041  ```python
1042  if not tf.is_tensor(t):
1043    t = tf.convert_to_tensor(t)
1044  return t.shape, t.dtype
1045  ```
1046
1047  we check to make sure that `t` is a tensor (and convert it if not) before
1048  accessing its `shape` and `dtype`.  (But note that not all TensorFlow native
1049  types have shapes or dtypes; `tf.data.Dataset` is an example of a TensorFlow
1050  native type that has neither shape nor dtype.)
1051
1052  Args:
1053    x: A python object to check.
1054
1055  Returns:
1056    `True` if `x` is a TensorFlow-native type.
1057  """
1058  return (isinstance(x, internal.NativeObject) or
1059          isinstance(x, core.Tensor) or
1060          getattr(x, "is_tensor_like", False))
1061
1062
1063# Deprecated alias for tensor_util.is_tf_type.
1064is_tensor = is_tf_type
1065
1066
1067def shape_tensor(shape):  # pylint: disable=invalid-name
1068  """Convert to an int32 or int64 tensor, defaulting to int32 if empty."""
1069  dtype = None
1070  if isinstance(shape, (tuple, list)):
1071    if not shape:
1072      dtype = dtypes.int32
1073    else:
1074      # If there are Dimension objects in the shape, unwrap them. This can be a
1075      # problem if v1 and v2 TensorShape objects get mixed up in partial
1076      # conversions, leading to shapes such as (1, 2, Dimension(5)), which are
1077      # not convertible to Tensors because of mixed content.
1078      shape = tuple(map(tensor_shape.dimension_value, shape))
1079  return ops.convert_to_tensor(shape, dtype=dtype, name="shape")
1080
1081
1082# DO NOT USE: For testing only.
1083_ENABLE_MAYBE_SET_STATIC_SHAPE = True
1084
1085
1086def maybe_set_static_shape(tensor, shape):  # pylint: disable=invalid-name
1087  """Sets the shape of `tensor` to the `shape`'s constant value, if inferrable.
1088
1089  This is a temporary workaround to fix shape inference across functional op
1090  boundaries. E.g.
1091
1092  ```python
1093  shape = tf.constant([3])
1094  @tf.function
1095  def f():
1096    u = tf.random_uniform(shape)
1097    return u
1098  ```
1099
1100  If we were to rely solely on C++ shape inference, the shape of `u` inside
1101  `f` would be unknown because C++ shape inference is not aware of the outer
1102  graph and all it sees is a Placeholder node when backtracing the captured
1103  tensor for `shape`. `maybe_set_static_shape` computes the static shape value
1104  of `shape` by traversing the `FuncGraph` boundaries and sets the correct
1105  shape.
1106
1107  A longer term solution would be to fix C++ shape inference.
1108
1109  Args:
1110    tensor: A tensor.
1111    shape: A shape tensor.
1112  """
1113  if (_ENABLE_MAYBE_SET_STATIC_SHAPE and not context.executing_eagerly() and
1114      ops.get_default_graph().building_function and
1115      not tensor.shape.is_fully_defined() and is_tensor(shape)):
1116    shape = shape_tensor(shape)
1117    const_shape = constant_value_as_shape(shape)
1118    tensor.set_shape(const_shape)
1119
1120
1121def try_evaluate_constant(tensor):  # pylint: disable=invalid-name
1122  """Evaluates a symbolic tensor as a constant.
1123
1124  Args:
1125    tensor: a symbolic Tensor.
1126
1127  Returns:
1128    ndarray if the evaluation succeeds, or None if it fails.
1129  """
1130  # pylint: disable=protected-access
1131  with tensor.graph._c_graph.get() as c_graph:
1132    return c_api.TF_TryEvaluateConstant_wrapper(c_graph, tensor._as_tf_output())
1133  # pylint: enable=protected-access
1134