xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/engine/base_layer_utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains private utilities used mainly by the base Layer class."""
16
17import functools
18import threading
19
20from tensorflow.python import tf2
21from tensorflow.python.distribute import distribution_strategy_context
22from tensorflow.python.eager import context
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import sparse_tensor
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.framework import tensor_util
28from tensorflow.python.keras import backend
29from tensorflow.python.keras.utils import control_flow_util
30from tensorflow.python.keras.utils import tf_inspect
31from tensorflow.python.keras.utils import tf_utils
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import variables as tf_variables
34from tensorflow.python.ops.ragged import ragged_tensor
35from tensorflow.python.trackable import base as tracking
36from tensorflow.python.util import keras_deps
37from tensorflow.python.util import nest
38from tensorflow.python.util.tf_export import keras_export
39
40_call_context = threading.local()
41
42
43def create_mean_metric(value, name=None):
44  # import keras will import base_layer and then this module, and metric relies
45  # on base_layer, which result into a cyclic dependency.
46  from tensorflow.python.keras import metrics as metrics_module  # pylint: disable=g-import-not-at-top
47  metric_obj = metrics_module.Mean(name=name, dtype=value.dtype)
48  return metric_obj, metric_obj(value)
49
50
51def make_variable(name,
52                  shape=None,
53                  dtype=dtypes.float32,
54                  initializer=None,
55                  trainable=None,
56                  caching_device=None,
57                  validate_shape=True,
58                  constraint=None,
59                  use_resource=None,
60                  collections=None,
61                  synchronization=tf_variables.VariableSynchronization.AUTO,
62                  aggregation=tf_variables.VariableAggregation.NONE,
63                  partitioner=None):  # pylint: disable=unused-argument
64  """Temporary util to create a variable (relies on `variable_scope.variable`).
65
66  Some reuse-related technicalities prevent us from using
67  `variable_scope.get_variable()` directly, so we use a subcomponent
68  that has fewer constraints (`variable_scope.variable()`).
69
70  In the longer term, it seems like a similar "default variable creator" method
71  should exist in `Trackable` instead. When this happens, we can get
72  rid of this temporary solution.
73
74  TODO(fchollet): remove this method when no longer needed.
75
76  Args:
77    name: Variable name.
78    shape: Variable shape.
79    dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
80    initializer: Initializer instance (callable).
81    trainable: Whether the variable should be part of the layer's
82      "trainable_variables" (e.g. variables, biases)
83      or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
84      Note, if the current variable scope is marked as non-trainable
85      then this parameter is ignored and any added variables are also
86      marked as non-trainable. `trainable` defaults to `True` unless
87      `synchronization` is set to `ON_READ`.
88    caching_device: Passed to `tf.Variable`.
89    validate_shape: Passed to `tf.Variable`.
90    constraint: Constraint instance (callable).
91    use_resource: Whether to use a `ResourceVariable`.
92    collections: List of graph collections keys. The new variable is added to
93      these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
94    synchronization: Indicates when a distributed a variable will be
95      aggregated. Accepted values are constants defined in the class
96      `tf.VariableSynchronization`. By default the synchronization is set to
97      `AUTO` and the current `DistributionStrategy` chooses
98      when to synchronize. If `synchronization` is set to `ON_READ`,
99      `trainable` must not be set to `True`.
100    aggregation: Indicates how a distributed variable will be aggregated.
101      Accepted values are constants defined in the class
102      `tf.VariableAggregation`.
103    partitioner: Not handled at this time.
104
105  Returns:
106    Variable instance.
107  """
108  initializing_from_value = False
109  if initializer is not None and not callable(initializer):
110    initializing_from_value = True
111
112  if initializing_from_value:
113    init_val = initializer
114    variable_dtype = None
115  else:
116    # Instantiate initializer if provided initializer is a type object.
117    if tf_inspect.isclass(initializer):
118      initializer = initializer()
119    init_val = functools.partial(initializer, shape, dtype=dtype)
120    variable_dtype = dtype.base_dtype
121  if use_resource is None:
122    use_resource = True
123
124  # TODO(apassos,rohanj) figure out how to remove collections from here so we
125  # can remove the V1.
126  variable_shape = tensor_shape.TensorShape(shape)
127  return tf_variables.VariableV1(
128      initial_value=init_val,
129      name=name,
130      trainable=trainable,
131      caching_device=caching_device,
132      dtype=variable_dtype,
133      validate_shape=validate_shape,
134      constraint=constraint,
135      use_resource=use_resource,
136      collections=collections,
137      synchronization=synchronization,
138      aggregation=aggregation,
139      shape=variable_shape if variable_shape else None)
140
141
142def collect_previous_mask(input_tensors):
143  """Retrieves the output mask(s) of the previous node.
144
145  Args:
146      input_tensors: An arbitrary structure of Tensors.
147
148  Returns:
149      A mask tensor or list of mask tensors.
150  """
151
152  def _collect_previous_mask(x):
153    return getattr(x, '_keras_mask', None)
154
155  return nest.map_structure(_collect_previous_mask, input_tensors)
156
157
158def have_all_keras_metadata(tensors):
159  return all(hasattr(x, '_keras_history') for x in nest.flatten(tensors))
160
161
162def generate_placeholders_from_shape(shape):
163  return array_ops.placeholder(shape=shape, dtype=backend.floatx())
164
165
166def create_keras_history(tensors):
167  """Wraps TensorFlow Operations for compatibility with the Functional API.
168
169  This method checks to see if a Tensor in `tensors` is missing Keras metadata
170  and has its origin in a Keras `Input` Layer. If so, this method will replace
171  the raw TensorFlow Operations that created this tensor with
172  `TensorFlowOpLayer` instances that create identical operations.
173
174  Any Tensors not originating from a Keras `Input` Layer will be treated as
175  constants when constructing `TensorFlowOpLayer` instances.
176
177  Args:
178    tensors: A structure of Tensors, some of which come from raw TensorFlow
179      operations and need to have Keras metadata assigned to them.
180
181  Returns:
182    created_layers: List. The `TensorFlowOpLayer` instances created to wrap
183      the raw Tensorflow operations.
184  """
185  _, created_layers = _create_keras_history_helper(tensors, set(), [])
186  return created_layers
187
188
189# Unsafe Internal attribute.
190# If True, Keras will not evaluate the constant-foldable inputs to tf op
191# layers in TF1 graphs. This *might* speed up model construction time in
192# certain settings, but it means
193# the models will not be serializable/deserializable via get_config
194# (Only via Savedmodels). It may also change the semantics of whether
195# generated random numbers are generated once and re-used, or recomputed
196# each time.
197# Note: This path triggers for TPUEstimators / xla compiled graphs regardless
198# of this setting.
199_UNSAFE_GRAPH_OP_LAYER_CREATION = False
200
201
202def _create_keras_history_helper(tensors, processed_ops, created_layers):
203  """Helper method for `create_keras_history`.
204
205  Args:
206    tensors: A structure of Tensors for which to create Keras metadata.
207    processed_ops: Set. TensorFlow operations that have already been wrapped in
208      `TensorFlowOpLayer` instances.
209    created_layers: List. The `TensorFlowOpLayer` instances created.
210
211  Returns:
212    Tuple. First element is the updated set of TensorFlow Operations that
213    have been wrapped in `TensorFlowOpLayer` instances. Second element is
214    a list of the `TensorFlowOpLayer` instances created.
215  """
216  if ops.executing_eagerly_outside_functions():
217    raise ValueError(
218        '`create_keras_history` should only be called if eager is disabled!')
219  # Import of `base_layer` needed in order to create `TensorFlowOpLayer`.
220  # Cannot be imported at top because of circular dependencies.
221  # TODO(omalleyt): Resolve circular dependency.
222  from tensorflow.python.keras.engine import base_layer  # pylint: disable=g-import-not-at-top
223  tensor_list = nest.flatten(tensors)
224  sparse_ops = []
225  ragged_tensors = []
226  for tensor in tensor_list:
227    if getattr(tensor, '_keras_history', None) is not None:
228      continue
229    if isinstance(
230        tensor, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
231      sparse_ops.append(tensor.op)
232      continue
233    if tf_utils.is_ragged(tensor):
234      # Ragged tensors don't have an op property
235      ragged_tensors.append(tensor)
236      continue
237    op = tensor.op  # The Op that created this Tensor.
238    if op not in processed_ops:
239      # Recursively set `_keras_history`.
240      op_inputs = list(op.inputs)
241      constants = {}
242      layer_inputs = []
243      for i, op_input in enumerate(op_inputs):
244        if uses_keras_history(op_input):
245          layer_inputs.append(op_input)
246        else:
247          # Treat any value not originating from a `keras.Input` as
248          # a constant. Variables cannot be supported.
249          ds_with_session = (
250              distribution_strategy_context.in_cross_replica_context() and
251              not ops.executing_eagerly_outside_functions())
252          using_xla = control_flow_util.GraphOrParentsInXlaContext(
253              ops.get_default_graph())
254          if ds_with_session or using_xla or _UNSAFE_GRAPH_OP_LAYER_CREATION:
255            # In Legacy Graph mode, evaluating here makes Session be
256            # configured improperly. The downside of this is that saving
257            # via `get_config` breaks, but SavedModel still works.
258            constants[i] = op_input
259          else:
260            with ops.init_scope():
261              constants[i] = backend.function([], op_input)([])
262      layer_inputs = unnest_if_single_tensor(layer_inputs)
263      processed_ops, created_layers = _create_keras_history_helper(
264          layer_inputs, processed_ops, created_layers)
265      name = op.name
266      node_def = op.node_def.SerializeToString()
267      op_layer = base_layer.TensorFlowOpLayer(
268          node_def, constants=constants, name=name)
269      created_layers.append(op_layer)
270      op_layer._set_connectivity_metadata(  # pylint: disable=protected-access
271          args=(layer_inputs,),
272          kwargs={},
273          outputs=op.outputs)
274      processed_ops.update([op])
275  if sparse_ops or ragged_tensors:
276    lambda_example = """
277    weights_mult = lambda x: tf.sparse.sparse_dense_matmul(x, weights)
278    output = tf.keras.layers.Lambda(weights_mult)(input)
279    """
280    raise ValueError(
281        'Tensorflow ops that generate ragged or sparse tensor '
282        'outputs are currently not supported by Keras automatic '
283        'op wrapping. Please wrap these ops in a Lambda layer: '
284        '\n\n```\n{example}\n```\n'
285        'Sparse ops encountered: {sparse_ops}\n'
286        'Ragged tensors encountered: {ragged_tensors}\n'.format(
287            example=lambda_example,
288            sparse_ops=str(sparse_ops),
289            ragged_tensors=str(ragged_tensors)))
290  return processed_ops, created_layers
291
292
293def unnest_if_single_tensor(input_tensors):
294  # Preserve compatibility with older configs
295  flat_input_tensors = nest.flatten(input_tensors)
296  # If this is a single element but not a dict, unwrap. If this is a dict,
297  # assume the first layer expects a dict (as is the case with a
298  # DenseFeatures layer); pass through.
299  if not isinstance(input_tensors, dict) and len(flat_input_tensors) == 1:
300    input_tensors = flat_input_tensors[0]
301  return input_tensors
302
303
304def needs_keras_history(tensors, ignore_call_context=False):
305  """Check if any Tensors need to be wrapped in TensorFlowOpLayers.
306
307  This will never return True inside a sublayer, because sublayers
308  do not need to create Keras History. Otherwise, this returns True
309  if one or more of `tensors` originates from a `keras.Input` and
310  does not have `_keras_history` set.
311
312  Args:
313    tensors: An arbitrary nested structure of Tensors.
314    ignore_call_context: Whether to ignore the check of if currently
315      outside of a `call` context. This is `True` when creating
316      KerasHistory inside `Node`, where we always know that Tensors
317      are being used with the Functional API.
318
319  Returns:
320    Bool, whether at least one Tensor needs to be wrapped.
321  """
322  input_tensors = nest.flatten(tensors)
323  if call_context().in_call and not ignore_call_context:
324    return False
325  if all(
326      getattr(tensor, '_keras_history', None) is not None
327      for tensor in input_tensors):
328    # KerasHistory already set.
329    return False
330  return uses_keras_history(tensors)
331
332
333def is_in_keras_graph():
334  """Returns if currently executing inside of a Keras graph."""
335  return call_context().in_keras_graph
336
337
338def is_in_eager_or_tf_function():
339  """Returns if in eager mode or inside of a tf.function."""
340  return context.executing_eagerly() or is_in_tf_function()
341
342
343def is_in_tf_function():
344  """Returns if inside of a tf.function."""
345  # Check if running in V1 graph mode.
346  if not ops.executing_eagerly_outside_functions():
347    return False
348  if not ops.inside_function():
349    return False
350  # Check if inside Keras FuncGraph.
351  if is_in_keras_graph():
352    return False
353  # Check for a v1 `wrap_function` FuncGraph.
354  graph = ops.get_default_graph()
355  if (getattr(graph, 'name', False) and
356      graph.name.startswith('wrapped_function')):
357    return False
358  return True
359
360
361def uses_keras_history(tensors):
362  """Check if at least one Tensor originates from a `keras.Input`.
363
364  This is `True` if at least one Tensor has its origin in a `keras.Input`.
365  Any Tensor that originates from a `keras.Input` will have a dependency
366  Tensor with a `_keras_history` attribute attached. Tensors that have
367  already been checked to not originate from a `keras.Input`
368  are marked as `_keras_history_checked`.
369
370  Args:
371    tensors: An arbitrary nested structure of Tensors.
372
373  Returns:
374    Bool, whether at least one Tensor originates from a `keras.Input`.
375  """
376  checked_tensors = set()
377  tensors_to_check = nest.flatten(tensors)
378
379  while tensors_to_check:
380    new_tensors_to_check = []
381    for tensor in tensors_to_check:
382      if id(tensor) in checked_tensors:
383        continue
384
385      checked_tensors.add(id(tensor))
386
387      if getattr(tensor, '_keras_history_checked', None) is not None:
388        continue
389      if getattr(tensor, '_keras_history', None) is not None:
390        return True
391
392      try:
393        new_tensors_to_check.extend(tensor.op.inputs)
394      except AttributeError:
395        # In case `tensor` is a Variable created in an Eager context.
396        pass
397
398    tensors_to_check = new_tensors_to_check
399
400  # Mark that these Tensors have been checked once for `_keras_history`,
401  # and should not be checked again for performance reasons.
402  mark_checked(tensors)
403  return False
404
405
406def mark_checked(tensors):
407  """Marks that these Tensors should not be tracked.
408
409  This prevents Layers from attempting to create TensorFlowOpLayers
410  for these Tensors.
411
412  Args:
413    tensors: An arbitrary structure of Tensors.
414  """
415
416  def _mark_checked(tensor):
417    tensor._keras_history_checked = True  # pylint: disable=protected-access
418
419  nest.map_structure(_mark_checked, tensors)
420
421
422def call_context():
423  """Returns currently active `CallContext`."""
424  call_ctx = getattr(_call_context, 'call_context', None)
425  if call_ctx is None:
426    call_ctx = CallContext()
427    _call_context.call_context = call_ctx
428  return call_ctx
429
430
431# Inject the call_context function to keras_deps to remove the dependency
432# from TFLite to Keras.
433keras_deps.register_call_context_function(call_context)
434
435
436class CallContext(object):
437  """Keeps track of properties currently inside a Layer/Model's `call`.
438
439  Attributes:
440    in_call: Whether currently inside the `call` of a Layer.
441    layer: The `Layer` whose `call` is currently active.
442    inputs: The inputs to the currently active `Layer`.
443    build_graph: Whether currently inside a Graph or FuncGraph.
444    training: Whether currently executing in training or inference mode.
445    saving: Whether currently saving to SavedModel.
446    frozen: Whether currently executing inside a `Layer` with `trainable` set to
447      `False`.
448    in_keras_graph: Whether executing inside the Keras Graph.
449  """
450
451  def __init__(self):
452    # Handle `in_call` separately as it is the most-read attr and reading it is
453    # on the hot path.
454    self.in_call = False
455    self._state = {
456        'layer': None,
457        'inputs': None,
458        'build_graph': False,
459        'training': None,
460        'saving': None
461    }
462    # TODO(b/150169018): This logic can be replaced after the Functional API
463    # refactor.
464    self._in_keras_graph = False
465
466  def enter(self, layer, inputs, build_graph, training, saving=None):
467    """Push a Layer and its inputs and state onto the current call context.
468
469    Args:
470      layer: The `Layer` whose `call` is currently active.
471      inputs: The inputs to the currently active `Layer`.
472      build_graph: Whether currently inside a Graph or FuncGraph.
473      training: Whether currently executing in training or inference mode.
474      saving: Whether currently saving to SavedModel.
475
476    Returns:
477      Context manager.
478    """
479    state = {
480        'layer': layer,
481        'inputs': inputs,
482        'build_graph': build_graph,
483        'training': training,
484        'saving': saving
485    }
486    return CallContextManager(self, state)
487
488  @property
489  def layer(self):
490    return self._state['layer']
491
492  @property
493  def inputs(self):
494    return self._state['inputs']
495
496  @property
497  def build_graph(self):
498    return self._state['build_graph']
499
500  @property
501  def training(self):
502    return self._state['training']
503
504  @property
505  def saving(self):
506    return self._state['saving']
507
508  @property
509  def frozen(self):
510    layer = self._state['layer']
511    if not layer:
512      return False
513    return not layer.trainable
514
515  @property
516  def in_keras_graph(self):
517    # Returns True even if in a subgraph of the Keras graph, such as those
518    # created by control flow ops.
519    if context.executing_eagerly():
520      return False
521    return (self._in_keras_graph or
522            getattr(backend.get_graph(), 'name', None) == 'keras_graph')
523
524
525class CallContextManager(object):
526  """Context manager for `CallContext`."""
527
528  def __init__(self, call_ctx, state):
529    self._call_ctx = call_ctx
530    self._state = state
531    self._build_graph = state['build_graph']
532
533  def __enter__(self):
534    call_ctx = self._call_ctx
535    self._prev_in_call = call_ctx.in_call
536    self._prev_state = call_ctx._state
537
538    call_ctx.in_call = True
539    call_ctx._state = self._state
540
541    # TODO(b/150169018): This logic can be removed after the Functional API
542    # refactor.
543    if self._build_graph:
544      self._prev_in_keras_graph = call_ctx._in_keras_graph
545      call_ctx._in_keras_graph = (
546          call_ctx._in_keras_graph or
547          getattr(backend.get_graph(), 'name', None) == 'keras_graph')
548
549  def __exit__(self, *exc_info):
550    call_ctx = self._call_ctx
551    call_ctx.in_call = self._prev_in_call
552    call_ctx._state = self._prev_state
553
554    if self._build_graph:
555      call_ctx._in_keras_graph = self._prev_in_keras_graph
556
557
558def training_arg_passed_to_call(argspec, args, kwargs):
559  """Returns whether a user passed the `training` argument in `__call__`."""
560  # `argspec.args` starts with ['self', 'inputs']
561  full_args = dict(zip(argspec.args[2:], args))
562  full_args.update(kwargs)
563  return 'training' in full_args and full_args['training'] is not None
564
565
566def is_subclassed(layer):
567  """Returns True if the object is a subclassed layer or subclassed model."""
568  return (layer.__module__.find('keras.engine') == -1 and
569          layer.__module__.find('keras.layers') == -1)
570
571
572def from_saved_model(layer):
573  """Returns whether the layer is loaded from a SavedModel."""
574  return layer.__module__.find('keras.saving.saved_model') != -1
575
576
577def check_graph_consistency(tensor=None, method='add_loss', force_raise=False):
578  """Checks that tensors passed to `add_*` method match the Keras graph.
579
580  When one of the `add_*` method is called inside a V2 conditional branch,
581  the underlying tensor gets created in a FuncGraph managed by control_flow_v2.
582  We need to raise clear error messages in such cases.
583
584  Args:
585    tensor: Tensor to check, or `False` if it is known that an error
586      should be raised.
587    method: Caller method, one of {'add_metric', 'add_loss', 'add_update'}.
588    force_raise: If an error should be raised regardless of `tensor`.
589
590  Raises:
591    RuntimeError: In case of an out-of-graph tensor.
592  """
593  if (force_raise or
594      (ops.executing_eagerly_outside_functions() and
595       hasattr(tensor, 'graph') and tensor.graph.is_control_flow_graph)):
596    if method == 'activity_regularizer':
597      bad_example = """
598      class TestModel(tf.keras.Model):
599
600        def __init__(self):
601          super(TestModel, self).__init__(name='test_model')
602          self.dense = tf.keras.layers.Dense(2, activity_regularizer='l2')
603
604        def call(self, x, training=None):
605          if training:
606            return self.dense(x)
607          else:
608            return self.dense(x)
609      """
610      correct_example = """
611      class TestModel(tf.keras.Model):
612
613        def __init__(self):
614          super(TestModel, self).__init__(name='test_model')
615          self.dense = tf.keras.layers.Dense(2, activity_regularizer='l2')
616
617        def call(self, x, training=None):
618          return self.dense(x)
619      """
620      raise RuntimeError(
621          'You are using a layer with `activity_regularizer` in a control flow '
622          'branch, e.g.:\n{bad_example}\nThis is currently not supported. '
623          'Please move your call to the layer with `activity_regularizer` out '
624          'of the control flow branch, e.g.:\n{correct_example}\n'
625          'You can also resolve this by marking your outer model/layer dynamic'
626          ' (eager-only) by passing `dynamic=True` to the layer constructor. '
627          'Any kind of control flow is supported with dynamic layers. '
628          'Note that using `dynamic=True` requires you to implement static '
629          'shape inference in the `compute_output_shape(input_shape)` '
630          'method.'.format(
631              bad_example=bad_example, correct_example=correct_example))
632
633    if method == 'add_metric':
634      bad_example = """
635      def call(self, inputs, training=None):
636        if training:
637          metric = compute_metric(inputs)
638          self.add_metric(metric, name='my_metric', aggregation='mean')
639        return inputs
640      """
641      correct_example = """
642      def call(self, inputs, training=None):
643        if training:
644          metric = compute_metric(inputs)
645        else:
646          metric = 0.
647        self.add_metric(metric, name='my_metric', aggregation='mean')
648        return inputs
649      """
650    elif method == 'add_loss':
651      bad_example = """
652      def call(self, inputs, training=None):
653        if training:
654          loss = compute_loss(inputs)
655          self.add_loss(loss)
656        return inputs
657      """
658      correct_example = """
659      def call(self, inputs, training=None):
660        if training:
661          loss = compute_loss(inputs)
662        else:
663          loss = 0.
664        self.add_loss(loss)
665        return inputs
666      """
667    else:
668      bad_example = """
669      def call(self, inputs, training=None):
670        if training:
671          self.add_update(self.w.assign_add(1))
672        return inputs
673      """
674      correct_example = """
675      def call(self, inputs, training=None):
676        if training:
677          increment = 1
678        else:
679          increment = 0
680        self.add_update(self.w.assign_add(increment))
681        return inputs
682      """
683    raise RuntimeError(
684        'You are using the method `{method}` in a control flow branch '
685        'in your layer, e.g.:\n{bad_example}\n'
686        'This is not currently supported. '
687        'Please move your call to {method} out of the control flow branch, '
688        'e.g.:\n{correct_example}\n'
689        'You can also resolve this by marking your layer '
690        'as dynamic (eager-only) by passing '
691        '`dynamic=True` to the layer constructor. '
692        'Any kind of control flow is supported with dynamic layers. '
693        'Note that using `dynamic=True` requires you '
694        'to implement static shape inference '
695        'in the `compute_output_shape(input_shape)` method.'.format(
696            method=method,
697            bad_example=bad_example,
698            correct_example=correct_example))
699
700
701def mark_as_return(outputs, acd):
702  """Marks `outputs` as the return values for automatic control deps."""
703
704  def _mark_as_return(tensor):
705    """Marks `tensor` as the return value for automatic control deps."""
706    if not tensor_util.is_tf_type(tensor):
707      return tensor
708
709    # pylint: disable=protected-access
710    return_tensor = acd.mark_as_return(tensor)
711    if getattr(tensor, '_keras_mask', None) is not None:
712      return_tensor._keras_mask = acd.mark_as_return(tensor._keras_mask)
713    else:
714      return_tensor._keras_mask = None
715
716    # Handle TensorFlow Probability attached metadata.
717    # TODO(b/132076537): Remove this once TFP uses `CompositeTensor`.
718    if getattr(tensor, '_tfp_distribution', None) is not None:
719      return_tensor._tfp_distribution = tensor._tfp_distribution
720
721    return return_tensor
722    # pylint: enable=protected-access
723
724  return nest.map_structure(_mark_as_return, outputs)
725
726
727V2_DTYPE_BEHAVIOR = None
728
729
730@keras_export(v1=['keras.layers.enable_v2_dtype_behavior'])
731def enable_v2_dtype_behavior():
732  """Enable the V2 dtype behavior for Keras layers.
733
734  By default, the V2 dtype behavior is enabled in TensorFlow 2, so this function
735  is only useful if `tf.compat.v1.disable_v2_behavior` has been called. Since
736  mixed precision requires V2 dtype behavior to be enabled, this function allows
737  you to use mixed precision in Keras layers if `disable_v2_behavior` has been
738  called.
739
740  When enabled, the dtype of Keras layers defaults to floatx (which is typically
741  float32) instead of None. In addition, layers will automatically cast
742  floating-point inputs to the layer's dtype.
743
744  >>> x = tf.ones((4, 4, 4, 4), dtype='float64')
745  >>> layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2)
746  >>> print(layer.dtype)  # float32 since V2 dtype behavior is enabled
747  float32
748  >>> y = layer(x)  # Layer casts inputs since V2 dtype behavior is enabled
749  >>> print(y.dtype.name)
750  float32
751
752  A layer author can opt-out their layer from the automatic input casting by
753  passing `autocast=False` to the base Layer's constructor. This disables the
754  autocasting part of the V2 behavior for that layer, but not the defaulting to
755  floatx part of the V2 behavior.
756
757  When a global `tf.keras.mixed_precision.Policy` is set, a Keras layer's dtype
758  will default to the global policy instead of floatx. Layers will automatically
759  cast inputs to the policy's compute_dtype.
760  """
761  global V2_DTYPE_BEHAVIOR
762  V2_DTYPE_BEHAVIOR = True
763
764
765@keras_export(v1=['keras.layers.disable_v2_dtype_behavior'])
766def disable_v2_dtype_behavior():
767  """Disables the V2 dtype behavior for Keras layers.
768
769  See `tf.compat.v1.keras.layers.enable_v2_dtype_behavior`.
770  """
771  global V2_DTYPE_BEHAVIOR
772  V2_DTYPE_BEHAVIOR = False
773
774
775def v2_dtype_behavior_enabled():
776  """Returns True if the V2 dtype behavior is enabled."""
777  if V2_DTYPE_BEHAVIOR is None:
778    return tf2.enabled()
779  return V2_DTYPE_BEHAVIOR
780
781
782class TrackableWeightHandler(object):
783  """Keras wrapper for handling tracking.Trackable object saving and restoring.
784
785  This class handles Trackables in both V1 and V2 modes, ensuring that they can
786  be saved and restored with the correct data and without adding additional ops
787  on every save.
788
789  Attributes:
790    trackable: The trackable to wrap.
791    num_tensors: The number of tensors that this trackable requires for saving.
792  """
793
794  def __init__(self, trackable):
795    if not isinstance(trackable, tracking.Trackable):
796      raise ValueError('%s is not a Trackable object.' % (trackable,))
797    self._trackable = trackable
798    self._distribute_strategy = distribution_strategy_context.get_strategy()
799
800    # TODO(b/141682913): Figure out why this is private and fix it.
801    saveables = trackable._gather_saveables_for_checkpoint().values()  # pylint: disable=protected-access
802    # 'Saveables' won't exist when we're passed a legacy TF1 table like
803    # a StaticHashTable.
804    if not saveables:
805      self._num_tensors = 0
806      self._setter = lambda weights: None
807      self._getter = lambda: []
808
809    elif len(saveables) == 1:
810      saveable = list(saveables)[0]
811
812      if ops.executing_eagerly_outside_functions():
813        # If we're in eager mode, we need to defer calling the Trackable's
814        # saveable() callable until data export time.
815        # However, it is safe to call the saveable as many times as we want, so
816        # we will call it now to figure out how many tensors this Trackable will
817        # produce.
818        self._saveable = saveable
819        self._num_tensors = len(self._saveable().specs)
820        self._setter = lambda weights: self._saveable().restore(weights, None)
821        self._getter = lambda: [spec.tensor for spec in self._saveable().specs]
822      else:
823        # If we're in Graph mode, we need to evaluate the Saveable only once and
824        # cache the resulting restore graph. Failing to do this will result in
825        # new assignment ops being added to the graph each time set_weights() is
826        # called.
827        self._placeholder_tensors = []
828        self._saveable = saveable()
829        self._num_tensors = len(self._saveable.specs)
830        for spec in self._saveable.specs:
831          tensor = spec.tensor
832          self._placeholder_tensors.append(
833              array_ops.placeholder(tensor.dtype, tensor.shape))
834        self._assign_op = self._saveable.restore(self._placeholder_tensors,
835                                                 None)
836        self._setter = self._set_weights_v1
837        self._getter = lambda: [spec.tensor for spec in self._saveable.specs]
838    else:
839      raise ValueError('Only Trackables with one Saveable are supported. '
840                       'The Trackable %s has %d Saveables.' %
841                       (trackable, len(saveables)))
842
843  @property
844  def num_tensors(self):
845    return self._num_tensors
846
847  def set_weights(self, weights):
848    if len(weights) != self._num_tensors:
849      raise ValueError(
850          ('Weight handler for trackable %s received the wrong number of ' +
851           'weights: expected %s, got %s.') %
852          (self._trackable, self._num_tensors, len(weights)))
853    self._setter(weights)
854
855  def get_tensors(self):
856    return self._getter()
857
858  def _set_weights_v1(self, weights):
859    feed_dict = {}
860    for idx, tensor in enumerate(weights):
861      feed_dict[self._placeholder_tensors[idx]] = tensor
862    backend.get_session().run(self._assign_op, feed_dict)
863
864
865class StaticTableHandler(TrackableWeightHandler):
866  """Wrapper for handling weight collection for static hash tables."""
867
868  def __init__(self, getter_lambda):  # pylint: disable=super-init-not-called
869    self._num_tensors = 2
870    self._getter = getter_lambda
871    self._distribute_strategy = distribution_strategy_context.get_strategy()
872
873    def raise_error(_):
874      raise RuntimeError('This layer contains a static lookup table, which '
875                         'cannot be changed via set_weights().')
876
877    self._setter = raise_error
878
879
880def no_ragged_support(inputs, layer_name):
881  input_list = nest.flatten(inputs)
882  if any(isinstance(x, ragged_tensor.RaggedTensor) for x in input_list):
883    raise ValueError('Layer %s does not support RaggedTensors as input. '
884                     'Inputs received: %s. You can try converting your '
885                     'input to an uniform tensor.' % (layer_name, inputs))
886
887
888def is_split_variable(v):
889  """Returns True if `v` is either a PartionedVariable or a ShardedVariable."""
890  return hasattr(v, '_variable_list') or hasattr(v, '_variables')
891
892
893def has_weights(obj):
894  obj_type = type(obj)
895  return (hasattr(obj_type, 'trainable_weights') and
896          hasattr(obj_type, 'non_trainable_weights') and
897          not isinstance(obj, type))
898
899
900# TODO(kathywu): This is a temporary hack. When a network of layers is revived
901# from SavedModel, only the top-level layer will have losses. This causes issues
902# in eager mode because the child layers may have graph losses
903# (thus model.losses returns a mix of Eager and graph tensors). To fix this,
904# whenever eager losses are added to one layer, add eager losses to all
905# child layers. This causes `.losses` to only return eager losses.
906REVIVED_LOSS_PLACEHOLDER = (
907    'This layer\'s losses have been added to the parent layer.')
908