xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/utils/tf_utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TensorFlow-related utilities."""
16
17import collections
18import copy
19import numpy as np
20
21from tensorflow.python.data.experimental.ops import cardinality
22from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib
23from tensorflow.python.eager import context
24from tensorflow.python.framework import composite_tensor
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import sparse_tensor
27from tensorflow.python.framework import tensor_shape
28from tensorflow.python.framework import tensor_spec
29from tensorflow.python.framework import tensor_util
30from tensorflow.python.framework import type_spec
31from tensorflow.python.keras import backend as K
32from tensorflow.python.keras.engine import keras_tensor
33from tensorflow.python.keras.utils import object_identity
34from tensorflow.python.keras.utils import tf_contextlib
35from tensorflow.python.ops import math_ops
36from tensorflow.python.ops import variables
37from tensorflow.python.ops.ragged import ragged_tensor
38from tensorflow.python.ops.ragged import ragged_tensor_value
39from tensorflow.python.util import nest
40from tensorflow.python.util.tf_export import keras_export
41
42
43def is_tensor_or_tensor_list(v):
44  v = nest.flatten(v)
45  if v and isinstance(v[0], ops.Tensor):
46    return True
47  else:
48    return False
49
50
51def get_reachable_from_inputs(inputs, targets=None):
52  """Returns the set of tensors/ops reachable from `inputs`.
53
54  Stops if all targets have been found (target is optional).
55
56  Only valid in Symbolic mode, not Eager mode.
57
58  Args:
59    inputs: List of tensors.
60    targets: List of tensors.
61
62  Returns:
63    A set of tensors reachable from the inputs (includes the inputs themselves).
64  """
65  inputs = nest.flatten(inputs, expand_composites=True)
66  reachable = object_identity.ObjectIdentitySet(inputs)
67  if targets:
68    remaining_targets = object_identity.ObjectIdentitySet(nest.flatten(targets))
69  queue = collections.deque(inputs)
70
71  while queue:
72    x = queue.pop()
73    if isinstance(x, tuple(_user_convertible_tensor_types)):
74      # Can't find consumers of user-specific types.
75      continue
76
77    if isinstance(x, ops.Operation):
78      outputs = x.outputs[:] or []
79      outputs += x._control_outputs  # pylint: disable=protected-access
80    elif isinstance(x, variables.Variable):
81      try:
82        outputs = [x.op]
83      except AttributeError:
84        # Variables can be created in an Eager context.
85        outputs = []
86    elif tensor_util.is_tf_type(x):
87      outputs = x.consumers()
88    else:
89      raise TypeError('Expected Operation, Variable, or Tensor, got ' + str(x))
90
91    for y in outputs:
92      if y not in reachable:
93        reachable.add(y)
94        if targets:
95          remaining_targets.discard(y)
96        queue.appendleft(y)
97
98    if targets and not remaining_targets:
99      return reachable
100
101  return reachable
102
103
104# This function needs access to private functions of `nest`.
105#  pylint: disable=protected-access
106def map_structure_with_atomic(is_atomic_fn, map_fn, nested):
107  """Maps the atomic elements of a nested structure.
108
109  Args:
110    is_atomic_fn: A function that determines if an element of `nested` is
111      atomic.
112    map_fn: The function to apply to atomic elements of `nested`.
113    nested: A nested structure.
114
115  Returns:
116    The nested structure, with atomic elements mapped according to `map_fn`.
117
118  Raises:
119    ValueError: If an element that is neither atomic nor a sequence is
120      encountered.
121  """
122  if is_atomic_fn(nested):
123    return map_fn(nested)
124
125  # Recursively convert.
126  if not nest.is_nested(nested):
127    raise ValueError(
128        'Received non-atomic and non-sequence element: {}'.format(nested))
129  if nest.is_mapping(nested):
130    values = [nested[k] for k in sorted(nested.keys())]
131  elif nest.is_attrs(nested):
132    values = _astuple(nested)
133  else:
134    values = nested
135  mapped_values = [
136      map_structure_with_atomic(is_atomic_fn, map_fn, ele) for ele in values
137  ]
138  return nest._sequence_like(nested, mapped_values)
139
140
141def get_shapes(tensors):
142  """Gets shapes from tensors."""
143  return nest.map_structure(lambda x: x.shape, tensors)
144
145
146#  pylint: enable=protected-access
147
148
149def convert_shapes(input_shape, to_tuples=True):
150  """Converts nested shape representations to desired format.
151
152  Performs:
153
154  TensorShapes -> tuples if `to_tuples=True`.
155  tuples of int or None -> TensorShapes if `to_tuples=False`.
156
157  Valid objects to be converted are:
158  - TensorShapes
159  - tuples with elements of type int or None.
160  - ints
161  - None
162
163  Args:
164    input_shape: A nested structure of objects to be converted to TensorShapes.
165    to_tuples: If `True`, converts all TensorShape to tuples. Otherwise converts
166      all tuples representing shapes to TensorShapes.
167
168  Returns:
169    Nested structure of shapes in desired format.
170
171  Raises:
172    ValueError: when the input tensor shape can't be converted to tuples, eg
173      unknown tensor shape.
174  """
175
176  def _is_shape_component(value):
177    return value is None or isinstance(value, (int, tensor_shape.Dimension))
178
179  def _is_atomic_shape(input_shape):
180    # Ex: TensorShape or (None, 10, 32) or 5 or `None`
181    if _is_shape_component(input_shape):
182      return True
183    if isinstance(input_shape, tensor_shape.TensorShape):
184      return True
185    if (isinstance(input_shape, (tuple, list)) and
186        all(_is_shape_component(ele) for ele in input_shape)):
187      return True
188    return False
189
190  def _convert_shape(input_shape):
191    input_shape = tensor_shape.TensorShape(input_shape)
192    if to_tuples:
193      input_shape = tuple(input_shape.as_list())
194    return input_shape
195
196  return map_structure_with_atomic(_is_atomic_shape, _convert_shape,
197                                   input_shape)
198
199
200class ListWrapper(object):
201  """A wrapper for lists to be treated as elements for `nest`."""
202
203  def __init__(self, list_to_wrap):
204    self._list = list_to_wrap
205
206  def as_list(self):
207    return self._list
208
209
210def convert_inner_node_data(nested, wrap=False):
211  """Either wraps or unwraps innermost node data lists in `ListWrapper` objects.
212
213  Args:
214    nested: A nested data structure.
215    wrap: If `True`, wrap innermost lists in `ListWrapper` objects. If `False`,
216      unwraps `ListWrapper` objects into lists.
217
218  Returns:
219    Structure of same type as nested, with lists wrapped/unwrapped.
220  """
221
222  def _is_serialized_node_data(nested):
223    # Node data can be of form `[layer_name, node_id, tensor_id]` or
224    # `[layer_name, node_id, tensor_id, kwargs]`.
225    if (isinstance(nested, list) and (len(nested) in [3, 4]) and
226        isinstance(nested[0], str)):
227      return True
228    return False
229
230  def _is_atomic_nested(nested):
231    """Returns `True` if `nested` is a list representing node data."""
232    if isinstance(nested, ListWrapper):
233      return True
234    if _is_serialized_node_data(nested):
235      return True
236    return not nest.is_nested(nested)
237
238  def _convert_object_or_list(nested):
239    """Convert b/t `ListWrapper` object and list representations."""
240    if wrap:
241      if isinstance(nested, ListWrapper):
242        return nested
243      if _is_serialized_node_data(nested):
244        return ListWrapper(nested)
245      return nested
246    else:
247      if isinstance(nested, ListWrapper):
248        return nested.as_list()
249      return nested
250
251  return map_structure_with_atomic(_is_atomic_nested, _convert_object_or_list,
252                                   nested)
253
254
255def shape_type_conversion(fn):
256  """Decorator that handles tuple/TensorShape conversion.
257
258  Used in `compute_output_shape` and `build`.
259
260  Args:
261    fn: function to wrap.
262
263  Returns:
264    Wrapped function.
265  """
266
267  def wrapper(instance, input_shape):
268    # Pass shapes as tuples to `fn`
269    # This preserves compatibility with external Keras.
270    if input_shape is not None:
271      input_shape = convert_shapes(input_shape, to_tuples=True)
272    output_shape = fn(instance, input_shape)
273    # Return shapes from `fn` as TensorShapes.
274    if output_shape is not None:
275      output_shape = convert_shapes(output_shape, to_tuples=False)
276    return output_shape
277
278  return wrapper
279
280
281def are_all_symbolic_tensors(tensors):
282  return all(map(is_symbolic_tensor, tensors))
283
284
285_user_convertible_tensor_types = set()
286
287
288def is_extension_type(tensor):
289  """Returns whether a tensor is of an ExtensionType.
290
291  github.com/tensorflow/community/pull/269
292  Currently it works by checking if `tensor` is a `CompositeTensor` instance,
293  but this will be changed to use an appropriate extensiontype protocol
294  check once ExtensionType is made public.
295
296  Args:
297    tensor: An object to test
298
299  Returns:
300    True if the tensor is an extension type object, false if not.
301  """
302  return isinstance(tensor, composite_tensor.CompositeTensor)
303
304
305def is_symbolic_tensor(tensor):
306  """Returns whether a tensor is symbolic (from a TF graph) or an eager tensor.
307
308  A Variable can be seen as either: it is considered symbolic
309  when we are in a graph scope, and eager when we are in an eager scope.
310
311  Args:
312    tensor: A tensor instance to test.
313
314  Returns:
315    True for symbolic tensors, False for eager tensors.
316  """
317  if isinstance(tensor, ops.Tensor):
318    return hasattr(tensor, 'graph')
319  elif is_extension_type(tensor):
320    component_tensors = nest.flatten(tensor, expand_composites=True)
321    return any(hasattr(t, 'graph') for t in component_tensors)
322  elif isinstance(tensor, variables.Variable):
323    # Variables that are output of a Keras Layer in Functional API mode
324    # should be considered symbolic.
325    # TODO(omalleyt): We need a better way to check this in order to
326    # enable `run_eagerly=True` for Models containing Layers that
327    # return Variables as outputs.
328    return (getattr(tensor, '_keras_history', False) or
329            not context.executing_eagerly())
330  elif isinstance(tensor, tuple(_user_convertible_tensor_types)):
331    tensor = ops.convert_to_tensor_or_composite(tensor)
332    return is_symbolic_tensor(tensor)
333  else:
334    return False
335
336
337@keras_export('keras.__internal__.utils.register_symbolic_tensor_type', v1=[])
338def register_symbolic_tensor_type(cls):
339  """Allows users to specify types regarded as symbolic `Tensor`s.
340
341  Used in conjunction with `tf.register_tensor_conversion_function`, calling
342  `tf.keras.__internal__.utils.register_symbolic_tensor_type(cls)`
343  allows non-`Tensor` objects to be plumbed through Keras layers.
344
345  Example:
346
347  ```python
348  # One-time setup.
349  class Foo(object):
350    def __init__(self, input_):
351      self._input = input_
352    def value(self):
353      return tf.constant(42.)
354
355  tf.register_tensor_conversion_function(
356      Foo, lambda x, *args, **kwargs: x.value())
357
358  tf.keras.__internal__.utils.register_symbolic_tensor_type(Foo)
359
360  # User-land.
361  layer = tf.keras.layers.Lambda(lambda input_: Foo(input_))
362  ```
363
364  Args:
365    cls: A `class` type which shall be regarded as a symbolic `Tensor`.
366  """
367  global _user_convertible_tensor_types
368  if cls not in _user_convertible_tensor_types:
369    keras_tensor.register_keras_tensor_specialization(
370        cls, keras_tensor.UserRegisteredTypeKerasTensor)
371  _user_convertible_tensor_types.add(cls)
372
373
374def type_spec_from_value(value):
375  """Grab type_spec without converting array-likes to tensors."""
376  if is_extension_type(value):
377    return value._type_spec  # pylint: disable=protected-access
378  # Get a TensorSpec for array-like data without
379  # converting the data to a Tensor
380  if hasattr(value, 'shape') and hasattr(value, 'dtype'):
381    return tensor_spec.TensorSpec(value.shape, value.dtype)
382  else:
383    return type_spec.type_spec_from_value(value)
384
385
386def is_ragged(tensor):
387  """Returns true if `tensor` is a ragged tensor or ragged tensor value."""
388  return isinstance(
389      tensor,
390      (ragged_tensor.RaggedTensor, ragged_tensor_value.RaggedTensorValue))
391
392
393def is_sparse(tensor):
394  """Returns true if `tensor` is a sparse tensor or sparse tensor value."""
395  return isinstance(
396      tensor,
397      (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue))
398
399
400def is_tensor_or_variable(x):
401  return tensor_util.is_tf_type(x) or isinstance(x, variables.Variable)
402
403
404def assert_no_legacy_layers(layers):
405  """Prevent tf.layers.Layers from being used with Keras.
406
407  Certain legacy layers inherit from their keras analogs; however they are
408  not supported with keras and can lead to subtle and hard to diagnose bugs.
409
410  Args:
411    layers: A list of layers to check
412
413  Raises:
414    TypeError: If any elements of layers are tf.layers.Layers
415  """
416
417  # isinstance check for tf.layers.Layer introduces a circular dependency.
418  legacy_layers = [l for l in layers if getattr(l, '_is_legacy_layer', None)]
419  if legacy_layers:
420    layer_str = '\n'.join('  ' + str(l) for l in legacy_layers)
421    raise TypeError(
422        'The following are legacy tf.layers.Layers:\n{}\nTo use keras as a '
423        'framework (for instance using the Network, Model, or Sequential '
424        'classes), please use the tf.keras.layers implementation instead. '
425        '(Or, if writing custom layers, subclass from tf.keras.layers rather '
426        'than tf.layers)'.format(layer_str))
427
428
429@tf_contextlib.contextmanager
430def maybe_init_scope(layer):
431  """Open an `init_scope` if in V2 mode and using the keras graph.
432
433  Args:
434    layer: The Layer/Model that is currently active.
435
436  Yields:
437    None
438  """
439  # Don't open an init_scope in V1 mode or when using legacy tf.layers.
440  if (ops.executing_eagerly_outside_functions() and
441      getattr(layer, '_keras_style', True)):
442    with ops.init_scope():
443      yield
444  else:
445    yield
446
447
448@tf_contextlib.contextmanager
449def graph_context_for_symbolic_tensors(*args, **kwargs):
450  """Returns graph context manager if any of the inputs is a symbolic tensor."""
451  if any(is_symbolic_tensor(v) for v in list(args) + list(kwargs.values())):
452    with K.get_graph().as_default():
453      yield
454  else:
455    yield
456
457
458def dataset_is_infinite(dataset):
459  """True if the passed dataset is infinite."""
460  if ops.executing_eagerly_outside_functions():
461    return math_ops.equal(
462        cardinality.cardinality(dataset), cardinality.INFINITE)
463  else:
464    dataset_size = K.get_session().run(cardinality.cardinality(dataset))
465    return dataset_size == cardinality.INFINITE
466
467
468def get_tensor_spec(t, dynamic_batch=False, name=None):
469  """Returns a `TensorSpec` given a single `Tensor` or `TensorSpec`."""
470  # pylint: disable=protected-access
471  if isinstance(t, type_spec.TypeSpec):
472    spec = t
473  elif is_extension_type(t):
474    # TODO(b/148821952): Should these specs have a name attr?
475    spec = t._type_spec
476  elif (hasattr(t, '_keras_history') and
477        hasattr(t._keras_history[0], '_type_spec')):
478    return t._keras_history[0]._type_spec
479  elif hasattr(t, 'shape') and hasattr(t, 'dtype'):
480    spec = tensor_spec.TensorSpec(shape=t.shape, dtype=t.dtype, name=name)
481  else:
482    return None  # Allow non-Tensors to pass through.
483
484  if not dynamic_batch:
485    return spec
486
487  dynamic_batch_spec = copy.deepcopy(spec)
488  # RaggedTensorSpec only has a private _shape.
489  shape = dynamic_batch_spec._shape
490  if shape.rank is not None and shape.rank > 0:
491    shape_list = shape.as_list()
492    shape_list[0] = None
493    dynamic_batch_spec._shape = tensor_shape.TensorShape(shape_list)
494  return dynamic_batch_spec
495  # pylint: enable=protected-access
496
497
498def sync_to_numpy_or_python_type(tensors):
499  """Syncs and converts a structure of `Tensor`s to `NumPy` arrays or Python scalar types.
500
501  For each tensor, it calls `tensor.numpy()`. If the result is a scalar value,
502  it converts it to a Python type, such as a float or int, by calling
503  `result.item()`.
504
505  Numpy scalars are converted, as Python types are often more convenient to deal
506  with. This is especially useful for bfloat16 Numpy scalars, which don't
507  support as many operations as other Numpy values.
508
509  Async strategies (such as `TPUStrategy` and `ParameterServerStrategy`) are
510  forced to
511  sync during this process.
512
513  Args:
514    tensors: A structure of tensors.
515
516  Returns:
517    `tensors`, but scalar tensors are converted to Python types and non-scalar
518    tensors are converted to Numpy arrays.
519  """
520  if isinstance(tensors, coordinator_lib.RemoteValue):
521    return tensors.fetch()
522
523  def _to_single_numpy_or_python_type(t):
524    if isinstance(t, ops.Tensor):
525      x = t.numpy()
526      return x.item() if np.ndim(x) == 0 else x
527    return t  # Don't turn ragged or sparse tensors to NumPy.
528
529  return nest.map_structure(_to_single_numpy_or_python_type, tensors)
530
531
532def _astuple(attrs):
533  """Converts the given attrs to tuple non-recursively."""
534  cls = type(attrs)
535  fields = getattr(cls, '__attrs_attrs__', None)
536  if fields is None:
537    raise ValueError('%r is not an attrs-decorated class.' % cls)
538  values = []
539  for field in fields:
540    values.append(getattr(attrs, field.name))
541  return tuple(values)
542