xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/engine/base_layer_v1.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Contains the base Layer class, from which all layers inherit."""
17
18import collections
19import functools
20import itertools
21import threading
22import warnings
23
24import numpy as np
25
26from tensorflow.python.autograph.core import ag_ctx
27from tensorflow.python.autograph.impl import api as autograph
28from tensorflow.python.distribute import distribution_strategy_context as ds_context
29from tensorflow.python.eager import context
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import errors
32from tensorflow.python.framework import func_graph
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import sparse_tensor
35from tensorflow.python.framework import tensor_spec
36from tensorflow.python.framework import tensor_util
37from tensorflow.python.keras import backend
38from tensorflow.python.keras import constraints
39from tensorflow.python.keras import initializers
40from tensorflow.python.keras import regularizers
41from tensorflow.python.keras.engine import base_layer
42from tensorflow.python.keras.engine import base_layer_utils
43from tensorflow.python.keras.engine import input_spec
44from tensorflow.python.keras.mixed_precision import autocast_variable
45from tensorflow.python.keras.mixed_precision import loss_scale_optimizer
46from tensorflow.python.keras.mixed_precision import policy
47from tensorflow.python.keras.saving.saved_model import layer_serialization
48from tensorflow.python.keras.utils import generic_utils
49from tensorflow.python.keras.utils import layer_utils
50from tensorflow.python.keras.utils import object_identity
51from tensorflow.python.keras.utils import tf_inspect
52from tensorflow.python.keras.utils import tf_utils
53# A module that only depends on `keras.layers` import these from here.
54from tensorflow.python.keras.utils.generic_utils import to_snake_case  # pylint: disable=unused-import
55from tensorflow.python.keras.utils.tf_utils import is_tensor_or_tensor_list  # pylint: disable=unused-import
56from tensorflow.python.module import module
57from tensorflow.python.ops import array_ops
58from tensorflow.python.ops import math_ops
59from tensorflow.python.ops import variables as tf_variables
60from tensorflow.python.ops.ragged import ragged_tensor
61from tensorflow.python.platform import tf_logging
62from tensorflow.python.trackable import autotrackable
63from tensorflow.python.trackable import base as trackable
64from tensorflow.python.trackable import data_structures
65from tensorflow.python.util import nest
66from tensorflow.tools.docs import doc_controls
67
68
69# pylint: disable=g-classes-have-attributes
70class Layer(base_layer.Layer):
71  """Base layer class.
72
73  This is the class from which all layers inherit.
74
75  A layer is a class implementing common neural networks operations, such
76  as convolution, batch norm, etc. These operations require managing weights,
77  losses, updates, and inter-layer connectivity.
78
79  Users will just instantiate a layer and then treat it as a callable.
80
81  We recommend that descendants of `Layer` implement the following methods:
82
83  * `__init__()`: Save configuration in member variables
84  * `build()`: Called once from `__call__`, when we know the shapes of inputs
85    and `dtype`. Should have the calls to `add_weight()`, and then
86    call the super's `build()` (which sets `self.built = True`, which is
87    nice in case the user wants to call `build()` manually before the
88    first `__call__`).
89  * `call()`: Called in `__call__` after making sure `build()` has been called
90    once. Should actually perform the logic of applying the layer to the
91    input tensors (which should be passed in as the first argument).
92
93  Args:
94    trainable: Boolean, whether the layer's variables should be trainable.
95    name: String name of the layer.
96    dtype: The dtype of the layer's computations and weights (default of
97      `None` means use `tf.keras.backend.floatx` in TensorFlow 2, or the type
98      of the first input in TensorFlow 1).
99    dynamic: Set this to `True` if your layer should only be run eagerly, and
100      should not be used to generate a static computation graph.
101      This would be the case for a Tree-RNN or a recursive network,
102      for example, or generally for any layer that manipulates tensors
103      using Python control flow. If `False`, we assume that the layer can
104      safely be used to generate a static computation graph.
105
106  Attributes:
107    name: The name of the layer (string).
108    dtype: The dtype of the layer's computations and weights. If mixed
109      precision is used with a `tf.keras.mixed_precision.Policy`, this is
110      instead just the dtype of the layer's weights, as the computations are
111      done in a different dtype.
112    updates: List of update ops of this layer.
113    losses: List of losses added by this layer.
114    trainable_weights: List of variables to be included in backprop.
115    non_trainable_weights: List of variables that should not be
116      included in backprop.
117    weights: The concatenation of the lists trainable_weights and
118      non_trainable_weights (in this order).
119    trainable: Whether the layer should be trained (boolean).
120    input_spec: Optional (list of) `InputSpec` object(s) specifying the
121      constraints on inputs that can be accepted by the layer.
122
123  Each layer has a dtype, which is typically the dtype of the layer's
124  computations and variables. A layer's dtype can be queried via the
125  `Layer.dtype` property. The dtype is specified with the `dtype` constructor
126  argument. In TensorFlow 2, the dtype defaults to `tf.keras.backend.floatx()`
127  if no dtype is passed. `floatx()` itself defaults to "float32". Additionally,
128  layers will cast their inputs to the layer's dtype in TensorFlow 2. When mixed
129  precision is used, layers may have different computation and variable dtypes.
130  See `tf.keras.mixed_precision.Policy` for details on layer dtypes.
131  """
132
133  # See tf.Module for the usage of this property.
134  # The key for _obj_reference_counts_dict is a Trackable, which could be a
135  # variable or layer etc. tf.Module._flatten will fail to flatten the key
136  # since it is trying to convert Trackable to a string. This attribute can be
137  # ignored even after the fix of nest lib, since the trackable object should
138  # already been available as individual attributes. _obj_reference_counts_dict
139  # just contains a copy of them.
140  _TF_MODULE_IGNORED_PROPERTIES = frozenset(itertools.chain(
141      ('_obj_reference_counts_dict',),
142      module.Module._TF_MODULE_IGNORED_PROPERTIES
143  ))
144
145  @trackable.no_automatic_dependency_tracking
146  def __init__(self, trainable=True, name=None, dtype=None, dynamic=False,
147               **kwargs):
148    self._instrument_layer_creation()
149
150    # These properties should be set by the user via keyword arguments.
151    # note that 'dtype', 'input_shape' and 'batch_input_shape'
152    # are only applicable to input layers: do not pass these keywords
153    # to non-input layers.
154    allowed_kwargs = {
155        'input_dim', 'input_shape', 'batch_input_shape', 'batch_size',
156        'weights', 'activity_regularizer', 'autocast', 'implementation'
157    }
158    # Validate optional keyword arguments.
159    generic_utils.validate_kwargs(kwargs, allowed_kwargs)
160
161    # Mutable properties
162    # Indicates whether the layer's weights are updated during training
163    # and whether the layer's updates are run during training.
164    self._trainable = trainable
165    # A stateful layer is a layer whose updates are run during inference too,
166    # for instance stateful RNNs.
167    self._stateful = False
168    # Indicates whether `build` needs to be called upon layer call, to create
169    # the layer's weights.
170    self.built = False
171    self._build_input_shape = None
172    # Provides information about which inputs are compatible with the layer.
173    self._input_spec = None
174    self.supports_masking = False
175
176    self._init_set_name(name)
177    self._activity_regularizer = regularizers.get(
178        kwargs.pop('activity_regularizer', None))
179    self._maybe_create_attribute('_trainable_weights', [])
180    self._maybe_create_attribute('_non_trainable_weights', [])
181    self._updates = []
182    # Object to store all thread local layer properties.
183    self._thread_local = threading.local()
184    # A list of zero-argument lambdas which return Tensors, used for variable
185    # regularizers.
186    self._callable_losses = []
187    # A list of symbolic Tensors containing activity regularizers and losses
188    # manually added through `add_loss` in graph-building mode.
189    self._losses = []
190    # A list of metric instances corresponding to the symbolic metric tensors
191    # added using the `add_metric` API.
192    self._metrics = []
193
194    # Both graph and subclassed networks have a dtype policy. For graph
195    # networks, the policy's compute and variable dtypes are ignored. Such
196    # networks only use the policy if it is a PolicyV1, in which case it uses
197    # the PolicyV1's loss_scale (Policy does not have a loss_scale). For
198    # subclassed networks, the compute and variable dtypes are used as like any
199    # ordinary layer.
200    self._set_dtype_policy(dtype)
201    # Boolean indicating whether the layer automatically casts its inputs to the
202    # layer's compute_dtype.
203    self._autocast = kwargs.get('autocast',
204                                base_layer_utils.v2_dtype_behavior_enabled())
205
206    # Dependencies tracked via attribute assignment.
207    # All layers in order of horizontal graph traversal.
208    # Entries are unique. For models includes input and output layers.
209    self._maybe_create_attribute('_self_tracked_trackables', [])
210
211    # These lists will be filled via successive calls
212    # to self._add_inbound_node().
213    # Used in symbolic mode only, only in conjunction with graph-networks
214    self._inbound_nodes_value = []
215    self._outbound_nodes_value = []
216
217    self._init_call_fn_args()
218
219    # Whether the `call` method can be used to build a TF graph without issues.
220    # This attribute has no effect if the model is created using the Functional
221    # API. Instead, `model.dynamic` is determined based on the internal layers.
222    self._dynamic = dynamic
223
224    # Manage input shape information if passed.
225    if 'input_dim' in kwargs and 'input_shape' not in kwargs:
226      # Backwards compatibility: alias 'input_dim' to 'input_shape'.
227      kwargs['input_shape'] = (kwargs['input_dim'],)
228    if 'input_shape' in kwargs or 'batch_input_shape' in kwargs:
229      # In this case we will later create an input layer
230      # to insert before the current layer
231      if 'batch_input_shape' in kwargs:
232        batch_input_shape = tuple(kwargs['batch_input_shape'])
233      elif 'input_shape' in kwargs:
234        if 'batch_size' in kwargs:
235          batch_size = kwargs['batch_size']
236        else:
237          batch_size = None
238        batch_input_shape = (batch_size,) + tuple(kwargs['input_shape'])
239      self._batch_input_shape = batch_input_shape
240
241    # Manage initial weight values if passed.
242    self._initial_weights = kwargs.get('weights', None)
243
244    # Whether the layer will track any layers that is set as attribute on itself
245    # as sub-layers, the weights from the sub-layers will be included in the
246    # parent layer's variables() as well.
247    # Default to True, which means auto tracking is turned on. Certain subclass
248    # might want to turn it off, like Sequential model.
249    self._auto_track_sub_layers = True
250
251    # Mark this layer as having been originally built as a tf1 layer/model
252    self._originally_built_as_v1 = True
253
254    # For backwards compat reasons, most built-in layers do not guarantee
255    # That they will 100% preserve the structure of input args when saving
256    # / loading configs. E.g. they may un-nest an arg that is
257    # a list with one element.
258    self._preserve_input_structure_in_config = False
259
260  @trackable.no_automatic_dependency_tracking
261  @generic_utils.default
262  def build(self, input_shape):
263    """Creates the variables of the layer (optional, for subclass implementers).
264
265    This is a method that implementers of subclasses of `Layer` or `Model`
266    can override if they need a state-creation step in-between
267    layer instantiation and layer call.
268
269    This is typically used to create the weights of `Layer` subclasses.
270
271    Args:
272      input_shape: Instance of `TensorShape`, or list of instances of
273        `TensorShape` if the layer expects a list of inputs
274        (one instance per input).
275    """
276    if not hasattr(self.build, '_is_default'):
277      self._build_input_shape = input_shape
278    self.built = True
279
280  @doc_controls.for_subclass_implementers
281  def call(self, inputs, **kwargs):  # pylint: disable=unused-argument
282    """This is where the layer's logic lives.
283
284    Args:
285        inputs: Input tensor, or list/tuple of input tensors.
286        **kwargs: Additional keyword arguments.
287
288    Returns:
289        A tensor or list/tuple of tensors.
290    """
291    return inputs
292
293  @doc_controls.for_subclass_implementers
294  def _add_trackable(self, trackable_object, trainable):
295    """Adds a Trackable object to this layer's state.
296
297    Args:
298      trackable_object: The tf.tracking.Trackable object to add.
299      trainable: Boolean, whether the variable should be part of the layer's
300        "trainable_variables" (e.g. variables, biases) or
301        "non_trainable_variables" (e.g. BatchNorm mean and variance).
302
303    Returns:
304      The TrackableWeightHandler used to track this object.
305    """
306    if isinstance(trackable_object, base_layer_utils.TrackableWeightHandler):
307      handler = trackable_object
308    else:
309      handler = base_layer_utils.TrackableWeightHandler(trackable_object)
310    if trainable:
311      self._trainable_weights.append(handler)
312    else:
313      self._non_trainable_weights.append(handler)
314    return handler
315
316  @doc_controls.for_subclass_implementers
317  def add_weight(self,
318                 name=None,
319                 shape=None,
320                 dtype=None,
321                 initializer=None,
322                 regularizer=None,
323                 trainable=None,
324                 constraint=None,
325                 partitioner=None,
326                 use_resource=None,
327                 synchronization=tf_variables.VariableSynchronization.AUTO,
328                 aggregation=tf_variables.VariableAggregation.NONE,
329                 **kwargs):
330    """Adds a new variable to the layer.
331
332    Args:
333      name: Variable name.
334      shape: Variable shape. Defaults to scalar if unspecified.
335      dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
336      initializer: Initializer instance (callable).
337      regularizer: Regularizer instance (callable).
338      trainable: Boolean, whether the variable should be part of the layer's
339        "trainable_variables" (e.g. variables, biases)
340        or "non_trainable_variables" (e.g. BatchNorm mean and variance).
341        Note that `trainable` cannot be `True` if `synchronization`
342        is set to `ON_READ`.
343      constraint: Constraint instance (callable).
344      partitioner: Partitioner to be passed to the `Trackable` API.
345      use_resource: Whether to use `ResourceVariable`.
346      synchronization: Indicates when a distributed a variable will be
347        aggregated. Accepted values are constants defined in the class
348        `tf.VariableSynchronization`. By default the synchronization is set to
349        `AUTO` and the current `DistributionStrategy` chooses
350        when to synchronize. If `synchronization` is set to `ON_READ`,
351        `trainable` must not be set to `True`.
352      aggregation: Indicates how a distributed variable will be aggregated.
353        Accepted values are constants defined in the class
354        `tf.VariableAggregation`.
355      **kwargs: Additional keyword arguments. Accepted values are `getter`,
356        `collections`, `experimental_autocast` and `caching_device`.
357
358    Returns:
359      The created variable. Usually either a `Variable` or `ResourceVariable`
360      instance. If `partitioner` is not `None`, a `PartitionedVariable`
361      instance is returned.
362
363    Raises:
364      RuntimeError: If called with partitioned variable regularization and
365        eager execution is enabled.
366      ValueError: When giving unsupported dtype and no initializer or when
367        trainable has been set to True with synchronization set as `ON_READ`.
368    """
369    if shape is None:
370      shape = ()
371    # Validate optional keyword arguments.
372    for kwarg in kwargs:
373      if kwarg not in ['getter', 'collections', 'experimental_autocast',
374                       'caching_device']:
375        raise TypeError('Unknown keyword argument:', kwarg)
376    has_custom_getter = 'getter' in kwargs
377    getter = kwargs.pop('getter', base_layer_utils.make_variable)
378    collections_arg = kwargs.pop('collections', None)
379    # 'experimental_autocast' can be set to False by the caller to indicate an
380    # AutoCastVariable should never be created.
381    autocast = kwargs.pop('experimental_autocast', True)
382    # See the docstring for tf.Variable about the details for caching_device.
383    caching_device = kwargs.pop('caching_device', None)
384
385    if dtype is None:
386      dtype = self.dtype or backend.floatx()
387    dtype = dtypes.as_dtype(dtype)
388    if self._dtype_policy.variable_dtype is None:
389      # The policy is "_infer", so we infer the policy from the variable dtype.
390      self._set_dtype_policy(policy.Policy(dtype.base_dtype.name))
391    initializer = initializers.get(initializer)
392    regularizer = regularizers.get(regularizer)
393    constraint = constraints.get(constraint)
394
395    if synchronization == tf_variables.VariableSynchronization.ON_READ:
396      if trainable:
397        raise ValueError(
398            'Synchronization value can be set to '
399            'VariableSynchronization.ON_READ only for non-trainable variables. '
400            'You have specified trainable=True and '
401            'synchronization=VariableSynchronization.ON_READ.')
402      else:
403        # Set trainable to be false when variable is to be synced on read.
404        trainable = False
405    elif trainable is None:
406      trainable = True
407
408    # Initialize variable when no initializer provided
409    if initializer is None:
410      # If dtype is DT_FLOAT, provide a uniform unit scaling initializer
411      if dtype.is_floating:
412        initializer = initializers.get('glorot_uniform')
413      # If dtype is DT_INT/DT_UINT, provide a default value `zero`
414      # If dtype is DT_BOOL, provide a default value `FALSE`
415      elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool:
416        initializer = initializers.zeros()
417      # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here?
418      elif not has_custom_getter:
419        # When `getter` is specified, it's possibly fine for `initializer` to be
420        # None since it's up to the custom `getter` to raise error in case it
421        # indeed needs `initializer`.
422        raise ValueError('An initializer for variable %s of type %s is required'
423                         ' for layer %s' % (name, dtype.base_dtype, self.name))
424
425    if (autocast and
426        self._dtype_policy.compute_dtype != self._dtype_policy.variable_dtype
427        and dtype.is_floating):
428      # Wrap 'getter' with a version that returns an AutoCastVariable.
429      old_getter = getter
430      def getter(*args, **kwargs):  # pylint: disable=function-redefined
431        variable = old_getter(*args, **kwargs)
432        return autocast_variable.create_autocast_variable(variable)
433      # Also the caching_device does not work with the mixed precision API,
434      # disable it if it is specified.
435      # TODO(b/142020079): Reenable it once the bug is fixed.
436      if caching_device is not None:
437        tf_logging.warning(
438            '`caching_device` does not work with mixed precision API. Ignoring '
439            'user specified `caching_device`.')
440        caching_device = None
441
442    variable = self._add_variable_with_custom_getter(
443        name=name,
444        shape=shape,
445        # TODO(allenl): a `make_variable` equivalent should be added as a
446        # `Trackable` method.
447        getter=getter,
448        # Manage errors in Layer rather than Trackable.
449        overwrite=True,
450        initializer=initializer,
451        dtype=dtype,
452        constraint=constraint,
453        trainable=trainable,
454        partitioner=partitioner,
455        use_resource=use_resource,
456        collections=collections_arg,
457        synchronization=synchronization,
458        aggregation=aggregation,
459        caching_device=caching_device)
460    if regularizer is not None:
461      # TODO(fchollet): in the future, this should be handled at the
462      # level of variable creation, and weight regularization losses
463      # should be variable attributes.
464      name_in_scope = variable.name[:variable.name.find(':')]
465      self._handle_weight_regularization(name_in_scope,
466                                         variable,
467                                         regularizer)
468    if base_layer_utils.is_split_variable(variable):
469      for v in variable:
470        backend.track_variable(v)
471        if trainable:
472          self._trainable_weights.append(v)
473        else:
474          self._non_trainable_weights.append(v)
475    else:
476      backend.track_variable(variable)
477      if trainable:
478        self._trainable_weights.append(variable)
479      else:
480        self._non_trainable_weights.append(variable)
481    return variable
482
483  @generic_utils.default
484  def get_config(self):
485    """Returns the config of the layer.
486
487    A layer config is a Python dictionary (serializable)
488    containing the configuration of a layer.
489    The same layer can be reinstantiated later
490    (without its trained weights) from this configuration.
491
492    The config of a layer does not include connectivity
493    information, nor the layer class name. These are handled
494    by `Network` (one layer of abstraction above).
495
496    Returns:
497        Python dictionary.
498    """
499    all_args = tf_inspect.getfullargspec(self.__init__).args
500    config = {'name': self.name, 'trainable': self.trainable}
501    if hasattr(self, '_batch_input_shape'):
502      config['batch_input_shape'] = self._batch_input_shape
503    config['dtype'] = policy.serialize(self._dtype_policy)
504    if hasattr(self, 'dynamic'):
505      # Only include `dynamic` in the `config` if it is `True`
506      if self.dynamic:
507        config['dynamic'] = self.dynamic
508      elif 'dynamic' in all_args:
509        all_args.remove('dynamic')
510    expected_args = config.keys()
511    # Finds all arguments in the `__init__` that are not in the config:
512    extra_args = [arg for arg in all_args if arg not in expected_args]
513    # Check that either the only argument in the `__init__` is  `self`,
514    # or that `get_config` has been overridden:
515    if len(extra_args) > 1 and hasattr(self.get_config, '_is_default'):
516      raise NotImplementedError('Layers with arguments in `__init__` must '
517                                'override `get_config`.')
518    return config
519
520  @classmethod
521  def from_config(cls, config):
522    """Creates a layer from its config.
523
524    This method is the reverse of `get_config`,
525    capable of instantiating the same layer from the config
526    dictionary. It does not handle layer connectivity
527    (handled by Network), nor weights (handled by `set_weights`).
528
529    Args:
530        config: A Python dictionary, typically the
531            output of get_config.
532
533    Returns:
534        A layer instance.
535    """
536    return cls(**config)
537
538  def compute_output_shape(self, input_shape):
539    """Computes the output shape of the layer.
540
541    If the layer has not been built, this method will call `build` on the
542    layer. This assumes that the layer will later be used with inputs that
543    match the input shape provided here.
544
545    Args:
546        input_shape: Shape tuple (tuple of integers)
547            or list of shape tuples (one per output tensor of the layer).
548            Shape tuples can include None for free dimensions,
549            instead of an integer.
550
551    Returns:
552        An input shape tuple.
553    """
554    if context.executing_eagerly():
555      # In this case we build the model first in order to do shape inference.
556      # This is acceptable because the framework only calls
557      # `compute_output_shape` on shape values that the layer would later be
558      # built for. It would however cause issues in case a user attempts to
559      # use `compute_output_shape` manually with shapes that are incompatible
560      # with the shape the Layer will be called on (these users will have to
561      # implement `compute_output_shape` themselves).
562      self._maybe_build(input_shape)
563      with ops.get_default_graph().as_default():
564        graph = func_graph.FuncGraph('graph')
565        with graph.as_default():
566          input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
567          inputs = nest.map_structure(
568              base_layer_utils.generate_placeholders_from_shape, input_shape)
569          try:
570            outputs = self(inputs, training=False)
571          except TypeError as e:
572            raise NotImplementedError(
573                'We could not automatically infer the static shape of the '
574                'layer\'s output. Please implement the '
575                '`compute_output_shape` method on your layer (%s).' %
576                self.__class__.__name__) from e
577      return nest.map_structure(lambda t: t.shape, outputs)
578    raise NotImplementedError
579
580  @doc_controls.for_subclass_implementers
581  def compute_output_signature(self, input_signature):
582    """Compute the output tensor signature of the layer based on the inputs.
583
584    Unlike a TensorShape object, a TensorSpec object contains both shape
585    and dtype information for a tensor. This method allows layers to provide
586    output dtype information if it is different from the input dtype.
587    For any layer that doesn't implement this function,
588    the framework will fall back to use `compute_output_shape`, and will
589    assume that the output dtype matches the input dtype.
590
591    Args:
592      input_signature: Single TensorSpec or nested structure of TensorSpec
593        objects, describing a candidate input for the layer.
594
595    Returns:
596      Single TensorSpec or nested structure of TensorSpec objects, describing
597        how the layer would transform the provided input.
598
599    Raises:
600      TypeError: If input_signature contains a non-TensorSpec object.
601    """
602    def check_type_return_shape(s):
603      if not isinstance(s, tensor_spec.TensorSpec):
604        raise TypeError('Only TensorSpec signature types are supported, '
605                        'but saw signature entry: {}.'.format(s))
606      return s.shape
607    input_shape = nest.map_structure(check_type_return_shape, input_signature)
608    output_shape = self.compute_output_shape(input_shape)
609    dtype = self._compute_dtype
610    if dtype is None:
611      input_dtypes = [s.dtype for s in nest.flatten(input_signature)]
612      # Default behavior when self.dtype is None, is to use the first input's
613      # dtype.
614      dtype = input_dtypes[0]
615    return nest.map_structure(
616        lambda s: tensor_spec.TensorSpec(dtype=dtype, shape=s),
617        output_shape)
618
619  @generic_utils.default
620  def compute_mask(self, inputs, mask=None):  # pylint: disable=unused-argument
621    """Computes an output mask tensor.
622
623    Args:
624        inputs: Tensor or list of tensors.
625        mask: Tensor or list of tensors.
626
627    Returns:
628        None or a tensor (or list of tensors,
629            one per output tensor of the layer).
630    """
631    if not self.supports_masking:
632      if any(m is not None for m in nest.flatten(mask)):
633        raise TypeError('Layer ' + self.name + ' does not support masking, '
634                        'but was passed an input_mask: ' + str(mask))
635      # masking not explicitly supported: return None as mask.
636      return None
637    # if masking is explicitly supported, by default
638    # carry over the input mask
639    return mask
640
641  def __call__(self, *args, **kwargs):
642    """Wraps `call`, applying pre- and post-processing steps.
643
644    Args:
645      *args: Positional arguments to be passed to `self.call`.
646      **kwargs: Keyword arguments to be passed to `self.call`.
647
648    Returns:
649      Output tensor(s).
650
651    Note:
652      - The following optional keyword arguments are reserved for specific uses:
653        * `training`: Boolean scalar tensor of Python boolean indicating
654          whether the `call` is meant for training or inference.
655        * `mask`: Boolean input mask.
656      - If the layer's `call` method takes a `mask` argument (as some Keras
657        layers do), its default value will be set to the mask generated
658        for `inputs` by the previous layer (if `input` did come from
659        a layer that generated a corresponding mask, i.e. if it came from
660        a Keras layer with masking support.
661
662    Raises:
663      ValueError: if the layer's `call` method returns None (an invalid value).
664      RuntimeError: if `super().__init__()` was not called in the constructor.
665    """
666    self._assert_built_as_v1()
667
668    if not hasattr(self, '_thread_local'):
669      raise RuntimeError(
670          'You must call `super().__init__()` in the layer constructor.')
671
672    # Grab the first positional or keyword argument.
673    if args:
674      inputs = args[0]
675      args = args[1:]
676    elif self._call_fn_args[0] in kwargs:
677      inputs = kwargs.pop(self._call_fn_args[0])
678    else:
679      raise ValueError(
680          'The first argument to `Layer.call` must always be passed.')
681
682    call_context = base_layer_utils.call_context()
683    input_list = nest.flatten(inputs)
684
685    # We will attempt to build a TF graph if & only if all inputs are symbolic.
686    # This is always the case in graph mode. It can also be the case in eager
687    # mode when all inputs can be traced back to `keras.Input()` (when building
688    # models using the functional API).
689    build_graph = tf_utils.are_all_symbolic_tensors(input_list)
690
691    # Accept NumPy and scalar inputs by converting to Tensors.
692    if any(isinstance(x, (np.ndarray, float, int)) for x in input_list):
693      def _convert_non_tensor(x):
694        # Don't call `ops.convert_to_tensor` on all `inputs` because
695        # `SparseTensors` can't be converted to `Tensor`.
696        if isinstance(x, (np.ndarray, float, int)):
697          return ops.convert_to_tensor_v2_with_dispatch(x)
698        return x
699      inputs = nest.map_structure(_convert_non_tensor, inputs)
700      input_list = nest.flatten(inputs)
701
702    # Handle `mask` propagation from previous layer to current layer. Masks can
703    # be propagated explicitly via the `mask` argument, or implicitly via
704    # setting the `_keras_mask` attribute on the inputs to a Layer. Masks passed
705    # explicitly take priority.
706    mask_arg_passed_by_framework = False
707    input_masks = self._collect_input_masks(inputs, args, kwargs)
708    if (self._expects_mask_arg and input_masks is not None and
709        not self._call_arg_was_passed('mask', args, kwargs)):
710      mask_arg_passed_by_framework = True
711      kwargs['mask'] = input_masks
712
713    # If `training` argument is None or not explicitly passed,
714    # propagate `training` value from this layer's calling layer.
715    training_value = None
716    training_arg_passed_by_framework = False
717    # Priority 1: `training` was explicitly passed.
718    if self._call_arg_was_passed('training', args, kwargs):
719      training_value = self._get_call_arg_value('training', args, kwargs)
720      if not self._expects_training_arg:
721        kwargs.pop('training')
722
723    if training_value is None:
724      # Priority 2: `training` was passed to a parent layer.
725      if call_context.training is not None:
726        training_value = call_context.training
727      # Priority 3a: `learning_phase()` has been set.
728      elif backend.global_learning_phase_is_set():
729        training_value = backend.learning_phase()
730      # Priority 3b: Pass the `learning_phase()` if in the Keras FuncGraph.
731      elif build_graph:
732        with backend.get_graph().as_default():
733          if base_layer_utils.is_in_keras_graph():
734            training_value = backend.learning_phase()
735
736      if self._expects_training_arg and training_value is not None:
737        # Force the training_value to be bool type which matches to the contract
738        # for layer/model call args.
739        if tensor_util.is_tf_type(training_value):
740          training_value = math_ops.cast(training_value, dtypes.bool)
741        else:
742          training_value = bool(training_value)
743        args, kwargs = self._set_call_arg_value(
744            'training', training_value, args, kwargs)
745        training_arg_passed_by_framework = True
746
747    # Only create Keras history if at least one tensor originates from a
748    # `keras.Input`. Otherwise this Layer may be being used outside the Keras
749    # framework.
750    if build_graph and base_layer_utils.needs_keras_history(inputs):
751      base_layer_utils.create_keras_history(inputs)
752
753    with call_context.enter(self, inputs, build_graph, training_value):
754      # Check input assumptions set after layer building, e.g. input shape.
755      if build_graph:
756        # Symbolic execution on symbolic tensors. We will attempt to build
757        # the corresponding TF subgraph inside `backend.get_graph()`
758        input_spec.assert_input_compatibility(self.input_spec, inputs,
759                                              self.name)
760        graph = backend.get_graph()
761        with graph.as_default(), backend.name_scope(self._name_scope()):  # pylint: disable=not-callable
762          # Build layer if applicable (if the `build` method has been
763          # overridden).
764          self._maybe_build(inputs)
765          cast_inputs = self._maybe_cast_inputs(inputs)
766
767          # Wrapping `call` function in autograph to allow for dynamic control
768          # flow and control dependencies in call. We are limiting this to
769          # subclassed layers as autograph is strictly needed only for
770          # subclassed layers and models.
771          # tf_convert will respect the value of autograph setting in the
772          # enclosing tf.function, if any.
773          if (base_layer_utils.is_subclassed(self) and
774              not base_layer_utils.from_saved_model(self)):
775            call_fn = autograph.tf_convert(
776                self.call, ag_ctx.control_status_ctx())
777          else:
778            call_fn = self.call
779
780          if not self.dynamic:
781            try:
782              with autocast_variable.enable_auto_cast_variables(
783                  self._compute_dtype_object):
784                outputs = call_fn(cast_inputs, *args, **kwargs)
785
786            except errors.OperatorNotAllowedInGraphError as e:
787              raise TypeError('You are attempting to use Python control '
788                              'flow in a layer that was not declared to be '
789                              'dynamic. Pass `dynamic=True` to the class '
790                              'constructor.\nEncountered error:\n"""\n' +
791                              str(e) + '\n"""')
792          else:
793            # We will use static shape inference to return symbolic tensors
794            # matching the specifications of the layer outputs.
795            # Since `self.dynamic` is True, we will never attempt to
796            # run the underlying TF graph (which is disconnected).
797            # TODO(fchollet): consider py_func as an alternative, which
798            # would enable us to run the underlying graph if needed.
799            outputs = self._symbolic_call(inputs)
800
801          if outputs is None:
802            raise ValueError('A layer\'s `call` method should return a '
803                             'Tensor or a list of Tensors, not None '
804                             '(layer: ' + self.name + ').')
805          if base_layer_utils.have_all_keras_metadata(inputs):
806            if training_arg_passed_by_framework:
807              args, kwargs = self._set_call_arg_value(
808                  'training', None, args, kwargs, pop_kwarg_if_none=True)
809            if mask_arg_passed_by_framework:
810              kwargs.pop('mask')
811            outputs = self._set_connectivity_metadata((inputs,) + args, kwargs,
812                                                      outputs)
813          self._handle_activity_regularization(inputs, outputs)
814          self._set_mask_metadata(inputs, outputs, input_masks)
815          if hasattr(self, '_set_inputs') and not self.inputs:
816            # Subclassed network: explicitly set metadata normally set by
817            # a call to self._set_inputs().
818            # TODO(b/120997007): This should be done in Eager as well, but
819            # causes garbage collection issues because of the placeholders
820            # created on the default Keras graph.
821            self._set_inputs(inputs, outputs)
822      else:
823        # Eager execution on data tensors.
824        with backend.name_scope(self._name_scope()):  # pylint: disable=not-callable
825          self._maybe_build(inputs)
826          cast_inputs = self._maybe_cast_inputs(inputs)
827          with autocast_variable.enable_auto_cast_variables(
828              self._compute_dtype_object):
829            outputs = self.call(cast_inputs, *args, **kwargs)
830          self._handle_activity_regularization(inputs, outputs)
831          self._set_mask_metadata(inputs, outputs, input_masks)
832
833    return outputs
834
835  def _assert_built_as_v1(self):
836    if not hasattr(self, '_originally_built_as_v1'):
837      raise ValueError(
838          'Your Layer or Model is in an invalid state. '
839          'This can happen for the following cases:\n '
840          '1. You might be interleaving estimator/non-estimator models or '
841          'interleaving models/layers made in tf.compat.v1.Graph.as_default() '
842          'with models/layers created outside of it. '
843          'Converting a model to an estimator (via model_to_estimator) '
844          'invalidates all models/layers made before the conversion (even '
845          'if they were not the model converted to an estimator). '
846          'Similarly, making a layer or a model inside a '
847          'a tf.compat.v1.Graph invalidates all layers/models you previously '
848          'made outside of the graph.\n'
849          '2. You might be using a custom keras layer implementation with '
850          ' custom __init__ which didn\'t call super().__init__. '
851          ' Please check the implementation of %s and its bases.' %
852          (type(self),))
853
854  @property
855  def dtype(self):
856    return self._dtype_policy.variable_dtype
857
858  @property
859  def name(self):
860    return self._name
861
862  @property
863  def dynamic(self):
864    return any(layer._dynamic for layer in self._flatten_layers())
865
866  @property
867  @doc_controls.do_not_generate_docs
868  def stateful(self):
869    return any(layer._stateful for layer in self._flatten_layers())
870
871  @stateful.setter
872  def stateful(self, value):
873    self._stateful = value
874
875  @property
876  def trainable(self):
877    return self._trainable
878
879  @trainable.setter
880  def trainable(self, value):
881    self._trainable = value
882    for layer in getattr(self, '_self_tracked_trackables', []):
883      layer.trainable = value
884
885  @property
886  def activity_regularizer(self):
887    """Optional regularizer function for the output of this layer."""
888    return self._activity_regularizer
889
890  @activity_regularizer.setter
891  def activity_regularizer(self, regularizer):
892    """Optional regularizer function for the output of this layer."""
893    self._activity_regularizer = regularizer
894
895  @property
896  def input_spec(self):
897    return self._input_spec
898
899  @input_spec.setter
900  # Must be decorated to prevent tracking, since the input_spec can be nested
901  # InputSpec objects.
902  @trackable.no_automatic_dependency_tracking
903  def input_spec(self, value):
904    for v in nest.flatten(value):
905      if v is not None and not isinstance(v, base_layer.InputSpec):
906        raise TypeError('Layer input_spec must be an instance of InputSpec. '
907                        'Got: {}'.format(v))
908    self._input_spec = value
909
910  @property
911  def updates(self):
912    collected_updates = []
913    all_layers = self._flatten_layers()
914    with backend.get_graph().as_default():
915      for layer in all_layers:
916        if not layer.trainable and not layer.stateful:
917          continue
918        for u in layer._updates:
919          if callable(u):
920            try:
921              u = u()
922            except ValueError as e:
923              if 'InaccessibleTensorError' in type(e).__name__:
924                # For one specific case of error we try to raise
925                # a more meaningful error message about the graph if we can.
926                # This error is an internal TF symbol that is not
927                # publicly exposed, so we check the name directly rather
928                # than using a direct import.
929                base_layer_utils.check_graph_consistency(
930                    method='add_update', force_raise=True)
931              raise  # check_graph_consistency may not always raise.
932          base_layer_utils.check_graph_consistency(u, method='add_update')
933          collected_updates.append(u)
934    return collected_updates
935
936  @property
937  def losses(self):
938    """Losses which are associated with this `Layer`.
939
940    Variable regularization tensors are created when this property is accessed,
941    so it is eager safe: accessing `losses` under a `tf.GradientTape` will
942    propagate gradients back to the corresponding variables.
943
944    Returns:
945      A list of tensors.
946    """
947    collected_losses = []
948    all_layers = self._flatten_layers()
949    for layer in all_layers:
950      # If any eager losses are present, we assume the model to be part of an
951      # eager training loop (either a custom one or the one used when
952      # `run_eagerly=True`) and so we always return just the eager losses.
953      collected_losses.extend(layer._losses)
954      for regularizer in layer._callable_losses:
955        loss_tensor = regularizer()
956        if loss_tensor is not None:
957          collected_losses.append(loss_tensor)
958    return collected_losses
959
960  @doc_controls.for_subclass_implementers
961  def add_loss(self, losses, inputs=None):
962    """Add loss tensor(s), potentially dependent on layer inputs.
963
964    Some losses (for instance, activity regularization losses) may be dependent
965    on the inputs passed when calling a layer. Hence, when reusing the same
966    layer on different inputs `a` and `b`, some entries in `layer.losses` may
967    be dependent on `a` and some on `b`. This method automatically keeps track
968    of dependencies.
969
970    This method can be used inside a subclassed layer or model's `call`
971    function, in which case `losses` should be a Tensor or list of Tensors.
972
973    Example:
974
975    ```python
976    class MyLayer(tf.keras.layers.Layer):
977      def call(inputs, self):
978        self.add_loss(tf.abs(tf.reduce_mean(inputs)), inputs=True)
979        return inputs
980    ```
981
982    This method can also be called directly on a Functional Model during
983    construction. In this case, any loss Tensors passed to this Model must
984    be symbolic and be able to be traced back to the model's `Input`s. These
985    losses become part of the model's topology and are tracked in `get_config`.
986
987    Example:
988
989    ```python
990    inputs = tf.keras.Input(shape=(10,))
991    x = tf.keras.layers.Dense(10)(inputs)
992    outputs = tf.keras.layers.Dense(1)(x)
993    model = tf.keras.Model(inputs, outputs)
994    # Activity regularization.
995    model.add_loss(tf.abs(tf.reduce_mean(x)))
996    ```
997
998    If this is not the case for your loss (if, for example, your loss references
999    a `Variable` of one of the model's layers), you can wrap your loss in a
1000    zero-argument lambda. These losses are not tracked as part of the model's
1001    topology since they can't be serialized.
1002
1003    Example:
1004
1005    ```python
1006    inputs = tf.keras.Input(shape=(10,))
1007    x = tf.keras.layers.Dense(10)(inputs)
1008    outputs = tf.keras.layers.Dense(1)(x)
1009    model = tf.keras.Model(inputs, outputs)
1010    # Weight regularization.
1011    model.add_loss(lambda: tf.reduce_mean(x.kernel))
1012    ```
1013
1014    The `get_losses_for` method allows to retrieve the losses relevant to a
1015    specific set of inputs.
1016
1017    Args:
1018      losses: Loss tensor, or list/tuple of tensors. Rather than tensors, losses
1019        may also be zero-argument callables which create a loss tensor.
1020      inputs: Ignored when executing eagerly. If anything other than None is
1021        passed, it signals the losses are conditional on some of the layer's
1022        inputs, and thus they should only be run where these inputs are
1023        available. This is the case for activity regularization losses, for
1024        instance. If `None` is passed, the losses are assumed
1025        to be unconditional, and will apply across all dataflows of the layer
1026        (e.g. weight regularization losses).
1027    """
1028    def _tag_unconditional(loss):
1029      """Process the loss and tag it by setting loss._unconditional_loss."""
1030      if callable(loss):
1031        # We run the loss without autocasting, as regularizers are often
1032        # numerically unstable in float16.
1033        with autocast_variable.enable_auto_cast_variables(None):
1034          loss = loss()
1035      if loss is None:
1036        return None  # Will be filtered out when computing the .losses property
1037      if not tensor_util.is_tf_type(loss):
1038        loss = ops.convert_to_tensor_v2_with_dispatch(
1039            loss, dtype=backend.floatx())
1040      loss._unconditional_loss = (inputs is None)  # pylint: disable=protected-access
1041      return loss
1042
1043    losses = nest.flatten(losses)
1044
1045    callable_losses = []
1046    symbolic_losses = []
1047    for loss in losses:
1048      if callable(loss):
1049        callable_losses.append(functools.partial(_tag_unconditional, loss))
1050        continue
1051      if loss is None:
1052        continue
1053      if not tensor_util.is_tf_type(loss):
1054        loss = ops.convert_to_tensor_v2_with_dispatch(
1055            loss, dtype=backend.floatx())
1056      # TF Functions should take the eager path.
1057      if (tf_utils.is_symbolic_tensor(loss) and
1058          not base_layer_utils.is_in_tf_function()):
1059        symbolic_losses.append(_tag_unconditional(loss))
1060        base_layer_utils.check_graph_consistency(loss, method='add_loss')
1061
1062    self._callable_losses.extend(callable_losses)
1063
1064    in_call_context = base_layer_utils.call_context().in_call
1065
1066    if in_call_context:
1067      for symbolic_loss in symbolic_losses:
1068        self._losses.append(symbolic_loss)
1069    else:
1070      for symbolic_loss in symbolic_losses:
1071        if getattr(self, '_is_graph_network', False):
1072          self._graph_network_add_loss(symbolic_loss)
1073        else:
1074          # Possible a loss was added in a Layer's `build`.
1075          self._losses.append(symbolic_loss)
1076
1077  @property
1078  def metrics(self):
1079    collected_metrics = []
1080    for layer in self._flatten_layers():
1081      collected_metrics.extend(layer._metrics)
1082    return collected_metrics
1083
1084  @doc_controls.for_subclass_implementers
1085  def add_metric(self, value, aggregation=None, name=None):
1086    """Adds metric tensor to the layer.
1087
1088    Args:
1089      value: Metric tensor.
1090      aggregation: Sample-wise metric reduction function. If `aggregation=None`,
1091        it indicates that the metric tensor provided has been aggregated
1092        already. eg, `bin_acc = BinaryAccuracy(name='acc')` followed by
1093        `model.add_metric(bin_acc(y_true, y_pred))`. If aggregation='mean', the
1094        given metric tensor will be sample-wise reduced using `mean` function.
1095        eg, `model.add_metric(tf.reduce_sum(outputs), name='output_mean',
1096        aggregation='mean')`.
1097      name: String metric name.
1098
1099    Raises:
1100      ValueError: If `aggregation` is anything other than None or `mean`.
1101    """
1102    if aggregation is not None and aggregation != 'mean':
1103      raise ValueError(
1104          'We currently support only `mean` sample-wise metric aggregation. '
1105          'You provided aggregation=`%s`' % aggregation)
1106
1107    from_metric_obj = hasattr(value, '_metric_obj')
1108    is_symbolic = tf_utils.is_symbolic_tensor(value)
1109    in_call_context = base_layer_utils.call_context().in_call
1110
1111    if name is None and not from_metric_obj:
1112      # Eg. `self.add_metric(math_ops.reduce_sum(x), aggregation='mean')`
1113      # In eager mode, we use metric name to lookup a metric. Without a name,
1114      # a new Mean metric wrapper will be created on every model/layer call.
1115      # So, we raise an error when no name is provided.
1116      # We will do the same for symbolic mode for consistency although a name
1117      # will be generated if no name is provided.
1118
1119      # We will not raise this error in the foll use case for the sake of
1120      # consistency as name in provided in the metric constructor.
1121      # mean = metrics.Mean(name='my_metric')
1122      # model.add_metric(mean(outputs))
1123      raise ValueError('Please provide a name for your metric like '
1124                       '`self.add_metric(tf.reduce_sum(inputs), '
1125                       'name=\'mean_activation\', aggregation=\'mean\')`')
1126    elif from_metric_obj:
1127      name = value._metric_obj.name
1128
1129    if in_call_context:
1130      # TF Function path should take the eager path.
1131      self._symbolic_add_metric(value, aggregation, name)
1132    else:
1133      if not is_symbolic:
1134        raise ValueError('Expected a symbolic Tensor for the metric value, '
1135                         'received: ' + str(value))
1136
1137      # Possible a metric was added in a Layer's `build`.
1138      if not getattr(self, '_is_graph_network', False):
1139        with backend.get_graph().as_default():
1140          self._symbolic_add_metric(value, aggregation, name)
1141        return
1142
1143      if from_metric_obj:
1144        raise ValueError('Using the result of calling a `Metric` object '
1145                         'when calling `add_metric` on a Functional '
1146                         'Model is not supported. Please pass the '
1147                         'Tensor to monitor directly.')
1148
1149      # Insert layers into the Keras Graph Network.
1150      self._graph_network_add_metric(value, aggregation, name)
1151
1152  @doc_controls.for_subclass_implementers
1153  def add_update(self, updates, inputs=None):
1154    """Add update op(s), potentially dependent on layer inputs.
1155
1156    Weight updates (for instance, the updates of the moving mean and variance
1157    in a BatchNormalization layer) may be dependent on the inputs passed
1158    when calling a layer. Hence, when reusing the same layer on
1159    different inputs `a` and `b`, some entries in `layer.updates` may be
1160    dependent on `a` and some on `b`. This method automatically keeps track
1161    of dependencies.
1162
1163    The `get_updates_for` method allows to retrieve the updates relevant to a
1164    specific set of inputs.
1165
1166    This call is ignored when eager execution is enabled (in that case, variable
1167    updates are run on the fly and thus do not need to be tracked for later
1168    execution).
1169
1170    Args:
1171      updates: Update op, or list/tuple of update ops, or zero-arg callable
1172        that returns an update op. A zero-arg callable should be passed in
1173        order to disable running the updates by setting `trainable=False`
1174        on this Layer, when executing in Eager mode.
1175      inputs: Deprecated, will be automatically inferred.
1176    """
1177    if inputs is not None:
1178      tf_logging.warning(
1179          '`add_update` `inputs` kwarg has been deprecated. You no longer need '
1180          'to pass a value to `inputs` as it is being automatically inferred.')
1181    call_context = base_layer_utils.call_context()
1182
1183    if (ds_context.has_strategy() and
1184        ds_context.in_cross_replica_context() and
1185        # When saving the model, the distribution strategy context should be
1186        # ignored, following the default path for adding updates.
1187        not call_context.saving):
1188      # Updates don't need to be run in a cross-replica context.
1189      return
1190
1191    updates = generic_utils.to_list(updates)
1192
1193    if call_context.in_call:
1194      relevant_inputs = call_context.inputs
1195    else:
1196      inbound_nodes = getattr(self, '_inbound_nodes', [])
1197      relevant_inputs = [node.input_tensors for node in inbound_nodes]
1198
1199    def process_update(x):
1200      """Standardize update ops.
1201
1202      Args:
1203        x: Tensor, op, or callable.
1204
1205      Returns:
1206        An update op.
1207      """
1208      if callable(x):
1209        update = lambda: process_update(x())
1210        return update()
1211      elif isinstance(x, ops.Operation):
1212        update = x
1213      elif hasattr(x, 'op'):
1214        update = x.op
1215      else:
1216        update = ops.convert_to_tensor_v2_with_dispatch(x)
1217
1218      reachable = tf_utils.get_reachable_from_inputs(relevant_inputs, [update])
1219      update._unconditional_update = update not in reachable
1220      return update
1221
1222    updates = [process_update(x) for x in updates]
1223    self._updates.extend(updates)
1224
1225  def set_weights(self, weights):
1226    """Sets the weights of the layer, from Numpy arrays.
1227
1228    The weights of a layer represent the state of the layer. This function
1229    sets the weight values from numpy arrays. The weight values should be
1230    passed in the order they are created by the layer. Note that the layer's
1231    weights must be instantiated before calling this function by calling
1232    the layer.
1233
1234    For example, a Dense layer returns a list of two values-- per-output
1235    weights and the bias value. These can be used to set the weights of another
1236    Dense layer:
1237
1238    >>> a = tf.keras.layers.Dense(1,
1239    ...   kernel_initializer=tf.constant_initializer(1.))
1240    >>> a_out = a(tf.convert_to_tensor([[1., 2., 3.]]))
1241    >>> a.get_weights()
1242    [array([[1.],
1243           [1.],
1244           [1.]], dtype=float32), array([0.], dtype=float32)]
1245    >>> b = tf.keras.layers.Dense(1,
1246    ...   kernel_initializer=tf.constant_initializer(2.))
1247    >>> b_out = b(tf.convert_to_tensor([[10., 20., 30.]]))
1248    >>> b.get_weights()
1249    [array([[2.],
1250           [2.],
1251           [2.]], dtype=float32), array([0.], dtype=float32)]
1252    >>> b.set_weights(a.get_weights())
1253    >>> b.get_weights()
1254    [array([[1.],
1255           [1.],
1256           [1.]], dtype=float32), array([0.], dtype=float32)]
1257
1258    Args:
1259        weights: a list of Numpy arrays. The number
1260            of arrays and their shape must match
1261            number of the dimensions of the weights
1262            of the layer (i.e. it should match the
1263            output of `get_weights`).
1264
1265    Raises:
1266        ValueError: If the provided weights list does not match the
1267            layer's specifications.
1268    """
1269    params = self.weights
1270
1271    expected_num_weights = 0
1272    for param in params:
1273      if isinstance(param, base_layer_utils.TrackableWeightHandler):
1274        expected_num_weights += param.num_tensors
1275      else:
1276        expected_num_weights += 1
1277
1278    if expected_num_weights != len(weights):
1279      raise ValueError(
1280          'You called `set_weights(weights)` on layer "%s" '
1281          'with a weight list of length %s, but the layer was '
1282          'expecting %s weights. Provided weights: %s...' %
1283          (self.name, len(weights), expected_num_weights, str(weights)[:50]))
1284
1285    weight_index = 0
1286    weight_value_tuples = []
1287    for param in params:
1288      if isinstance(param, base_layer_utils.TrackableWeightHandler):
1289        num_tensors = param.num_tensors
1290        tensors = weights[weight_index:weight_index + num_tensors]
1291        param.set_weights(tensors)
1292        weight_index += num_tensors
1293      else:
1294        weight = weights[weight_index]
1295        weight_shape = weight.shape if hasattr(weight, 'shape') else ()
1296        ref_shape = param.shape
1297        if not ref_shape.is_compatible_with(weight_shape):
1298          raise ValueError(
1299              'Layer weight shape %s not compatible with provided weight '
1300              'shape %s' % (ref_shape, weight_shape))
1301        weight_value_tuples.append((param, weight))
1302        weight_index += 1
1303
1304    backend.batch_set_value(weight_value_tuples)
1305
1306  def get_weights(self):
1307    """Returns the current weights of the layer.
1308
1309    The weights of a layer represent the state of the layer. This function
1310    returns both trainable and non-trainable weight values associated with this
1311    layer as a list of Numpy arrays, which can in turn be used to load state
1312    into similarly parameterized layers.
1313
1314    For example, a Dense layer returns a list of two values-- per-output
1315    weights and the bias value. These can be used to set the weights of another
1316    Dense layer:
1317
1318    >>> a = tf.keras.layers.Dense(1,
1319    ...   kernel_initializer=tf.constant_initializer(1.))
1320    >>> a_out = a(tf.convert_to_tensor([[1., 2., 3.]]))
1321    >>> a.get_weights()
1322    [array([[1.],
1323           [1.],
1324           [1.]], dtype=float32), array([0.], dtype=float32)]
1325    >>> b = tf.keras.layers.Dense(1,
1326    ...   kernel_initializer=tf.constant_initializer(2.))
1327    >>> b_out = b(tf.convert_to_tensor([[10., 20., 30.]]))
1328    >>> b.get_weights()
1329    [array([[2.],
1330           [2.],
1331           [2.]], dtype=float32), array([0.], dtype=float32)]
1332    >>> b.set_weights(a.get_weights())
1333    >>> b.get_weights()
1334    [array([[1.],
1335           [1.],
1336           [1.]], dtype=float32), array([0.], dtype=float32)]
1337
1338    Returns:
1339        Weights values as a list of numpy arrays.
1340    """
1341    weights = self.weights
1342    output_weights = []
1343    for weight in weights:
1344      if isinstance(weight, base_layer_utils.TrackableWeightHandler):
1345        output_weights.extend(weight.get_tensors())
1346      else:
1347        output_weights.append(weight)
1348    return backend.batch_get_value(output_weights)
1349
1350  def get_updates_for(self, inputs):
1351    """Retrieves updates relevant to a specific set of inputs.
1352
1353    Args:
1354      inputs: Input tensor or list/tuple of input tensors.
1355
1356    Returns:
1357      List of update ops of the layer that depend on `inputs`.
1358    """
1359    if inputs is None:
1360      # Requesting unconditional updates.
1361      return [u for u in self.updates if u._unconditional_update]
1362
1363    # Requesting input-conditional updates.
1364    updates = [u for u in self.updates if not u._unconditional_update]
1365    inputs = nest.flatten(inputs)
1366    reachable = tf_utils.get_reachable_from_inputs(inputs, updates)
1367    return [u for u in updates if u in reachable]
1368
1369  def get_losses_for(self, inputs):
1370    """Retrieves losses relevant to a specific set of inputs.
1371
1372    Args:
1373      inputs: Input tensor or list/tuple of input tensors.
1374
1375    Returns:
1376      List of loss tensors of the layer that depend on `inputs`.
1377    """
1378    if inputs is None:
1379      # Requesting unconditional losses.
1380      return [l for l in self.losses if l._unconditional_loss]
1381
1382    # Requesting input-conditional losses.
1383    losses = [l for l in self.losses if not l._unconditional_loss]
1384    inputs = nest.flatten(inputs)
1385    reachable = tf_utils.get_reachable_from_inputs(inputs, losses)
1386    return [l for l in losses if l in reachable]
1387
1388  def get_input_mask_at(self, node_index):
1389    """Retrieves the input mask tensor(s) of a layer at a given node.
1390
1391    Args:
1392        node_index: Integer, index of the node
1393            from which to retrieve the attribute.
1394            E.g. `node_index=0` will correspond to the
1395            first time the layer was called.
1396
1397    Returns:
1398        A mask tensor
1399        (or list of tensors if the layer has multiple inputs).
1400    """
1401    inputs = self.get_input_at(node_index)
1402    if isinstance(inputs, list):
1403      return [getattr(x, '_keras_mask', None) for x in inputs]
1404    else:
1405      return getattr(inputs, '_keras_mask', None)
1406
1407  def get_output_mask_at(self, node_index):
1408    """Retrieves the output mask tensor(s) of a layer at a given node.
1409
1410    Args:
1411        node_index: Integer, index of the node
1412            from which to retrieve the attribute.
1413            E.g. `node_index=0` will correspond to the
1414            first time the layer was called.
1415
1416    Returns:
1417        A mask tensor
1418        (or list of tensors if the layer has multiple outputs).
1419    """
1420    output = self.get_output_at(node_index)
1421    if isinstance(output, list):
1422      return [getattr(x, '_keras_mask', None) for x in output]
1423    else:
1424      return getattr(output, '_keras_mask', None)
1425
1426  @property
1427  def input_mask(self):
1428    """Retrieves the input mask tensor(s) of a layer.
1429
1430    Only applicable if the layer has exactly one inbound node,
1431    i.e. if it is connected to one incoming layer.
1432
1433    Returns:
1434        Input mask tensor (potentially None) or list of input
1435        mask tensors.
1436
1437    Raises:
1438        AttributeError: if the layer is connected to
1439        more than one incoming layers.
1440    """
1441    inputs = self.input
1442    if isinstance(inputs, list):
1443      return [getattr(x, '_keras_mask', None) for x in inputs]
1444    else:
1445      return getattr(inputs, '_keras_mask', None)
1446
1447  @property
1448  def output_mask(self):
1449    """Retrieves the output mask tensor(s) of a layer.
1450
1451    Only applicable if the layer has exactly one inbound node,
1452    i.e. if it is connected to one incoming layer.
1453
1454    Returns:
1455        Output mask tensor (potentially None) or list of output
1456        mask tensors.
1457
1458    Raises:
1459        AttributeError: if the layer is connected to
1460        more than one incoming layers.
1461    """
1462    output = self.output
1463    if isinstance(output, list):
1464      return [getattr(x, '_keras_mask', None) for x in output]
1465    else:
1466      return getattr(output, '_keras_mask', None)
1467
1468  def get_input_shape_at(self, node_index):
1469    """Retrieves the input shape(s) of a layer at a given node.
1470
1471    Args:
1472        node_index: Integer, index of the node
1473            from which to retrieve the attribute.
1474            E.g. `node_index=0` will correspond to the
1475            first time the layer was called.
1476
1477    Returns:
1478        A shape tuple
1479        (or list of shape tuples if the layer has multiple inputs).
1480
1481    Raises:
1482      RuntimeError: If called in Eager mode.
1483    """
1484    return self._get_node_attribute_at_index(node_index, 'input_shapes',
1485                                             'input shape')
1486
1487  def get_output_shape_at(self, node_index):
1488    """Retrieves the output shape(s) of a layer at a given node.
1489
1490    Args:
1491        node_index: Integer, index of the node
1492            from which to retrieve the attribute.
1493            E.g. `node_index=0` will correspond to the
1494            first time the layer was called.
1495
1496    Returns:
1497        A shape tuple
1498        (or list of shape tuples if the layer has multiple outputs).
1499
1500    Raises:
1501      RuntimeError: If called in Eager mode.
1502    """
1503    return self._get_node_attribute_at_index(node_index, 'output_shapes',
1504                                             'output shape')
1505
1506  def get_input_at(self, node_index):
1507    """Retrieves the input tensor(s) of a layer at a given node.
1508
1509    Args:
1510        node_index: Integer, index of the node
1511            from which to retrieve the attribute.
1512            E.g. `node_index=0` will correspond to the
1513            first input node of the layer.
1514
1515    Returns:
1516        A tensor (or list of tensors if the layer has multiple inputs).
1517
1518    Raises:
1519      RuntimeError: If called in Eager mode.
1520    """
1521    return self._get_node_attribute_at_index(node_index, 'input_tensors',
1522                                             'input')
1523
1524  def get_output_at(self, node_index):
1525    """Retrieves the output tensor(s) of a layer at a given node.
1526
1527    Args:
1528        node_index: Integer, index of the node
1529            from which to retrieve the attribute.
1530            E.g. `node_index=0` will correspond to the
1531            first output node of the layer.
1532
1533    Returns:
1534        A tensor (or list of tensors if the layer has multiple outputs).
1535
1536    Raises:
1537      RuntimeError: If called in Eager mode.
1538    """
1539    return self._get_node_attribute_at_index(node_index, 'output_tensors',
1540                                             'output')
1541
1542  @property
1543  def input(self):
1544    """Retrieves the input tensor(s) of a layer.
1545
1546    Only applicable if the layer has exactly one input,
1547    i.e. if it is connected to one incoming layer.
1548
1549    Returns:
1550        Input tensor or list of input tensors.
1551
1552    Raises:
1553      RuntimeError: If called in Eager mode.
1554      AttributeError: If no inbound nodes are found.
1555    """
1556    if not self._inbound_nodes:
1557      raise AttributeError('Layer ' + self.name +
1558                           ' is not connected, no input to return.')
1559    return self._get_node_attribute_at_index(0, 'input_tensors', 'input')
1560
1561  @property
1562  def output(self):
1563    """Retrieves the output tensor(s) of a layer.
1564
1565    Only applicable if the layer has exactly one output,
1566    i.e. if it is connected to one incoming layer.
1567
1568    Returns:
1569      Output tensor or list of output tensors.
1570
1571    Raises:
1572      AttributeError: if the layer is connected to more than one incoming
1573        layers.
1574      RuntimeError: if called in Eager mode.
1575    """
1576    if not self._inbound_nodes:
1577      raise AttributeError('Layer ' + self.name + ' has no inbound nodes.')
1578    return self._get_node_attribute_at_index(0, 'output_tensors', 'output')
1579
1580  @property
1581  def input_shape(self):
1582    """Retrieves the input shape(s) of a layer.
1583
1584    Only applicable if the layer has exactly one input,
1585    i.e. if it is connected to one incoming layer, or if all inputs
1586    have the same shape.
1587
1588    Returns:
1589        Input shape, as an integer shape tuple
1590        (or list of shape tuples, one tuple per input tensor).
1591
1592    Raises:
1593        AttributeError: if the layer has no defined input_shape.
1594        RuntimeError: if called in Eager mode.
1595    """
1596    if not self._inbound_nodes:
1597      raise AttributeError('The layer has never been called '
1598                           'and thus has no defined input shape.')
1599    all_input_shapes = set(
1600        [str(node.input_shapes) for node in self._inbound_nodes])
1601    if len(all_input_shapes) == 1:
1602      return self._inbound_nodes[0].input_shapes
1603    else:
1604      raise AttributeError('The layer "' + str(self.name) +
1605                           ' has multiple inbound nodes, '
1606                           'with different input shapes. Hence '
1607                           'the notion of "input shape" is '
1608                           'ill-defined for the layer. '
1609                           'Use `get_input_shape_at(node_index)` '
1610                           'instead.')
1611
1612  def count_params(self):
1613    """Count the total number of scalars composing the weights.
1614
1615    Returns:
1616        An integer count.
1617
1618    Raises:
1619        ValueError: if the layer isn't yet built
1620          (in which case its weights aren't yet defined).
1621    """
1622    if not self.built:
1623      if getattr(self, '_is_graph_network', False):
1624        with tf_utils.maybe_init_scope(self):
1625          self._maybe_build(self.inputs)
1626      else:
1627        raise ValueError('You tried to call `count_params` on ' + self.name +
1628                         ', but the layer isn\'t built. '
1629                         'You can build it manually via: `' + self.name +
1630                         '.build(batch_input_shape)`.')
1631    return layer_utils.count_params(self.weights)
1632
1633  @property
1634  def output_shape(self):
1635    """Retrieves the output shape(s) of a layer.
1636
1637    Only applicable if the layer has one output,
1638    or if all outputs have the same shape.
1639
1640    Returns:
1641        Output shape, as an integer shape tuple
1642        (or list of shape tuples, one tuple per output tensor).
1643
1644    Raises:
1645        AttributeError: if the layer has no defined output shape.
1646        RuntimeError: if called in Eager mode.
1647    """
1648    if not self._inbound_nodes:
1649      raise AttributeError('The layer has never been called '
1650                           'and thus has no defined output shape.')
1651    all_output_shapes = set(
1652        [str(node.output_shapes) for node in self._inbound_nodes])
1653    if len(all_output_shapes) == 1:
1654      return self._inbound_nodes[0].output_shapes
1655    else:
1656      raise AttributeError('The layer "%s"'
1657                           ' has multiple inbound nodes, '
1658                           'with different output shapes. Hence '
1659                           'the notion of "output shape" is '
1660                           'ill-defined for the layer. '
1661                           'Use `get_output_shape_at(node_index)` '
1662                           'instead.' % self.name)
1663
1664  @property
1665  @doc_controls.do_not_doc_inheritable
1666  def inbound_nodes(self):
1667    """Deprecated, do NOT use! Only for compatibility with external Keras."""
1668    return self._inbound_nodes
1669
1670  @property
1671  @doc_controls.do_not_doc_inheritable
1672  def outbound_nodes(self):
1673    """Deprecated, do NOT use! Only for compatibility with external Keras."""
1674    return self._outbound_nodes
1675
1676  ##############################################################################
1677  # Methods & attributes below are public aliases of other methods.            #
1678  ##############################################################################
1679
1680  @doc_controls.do_not_doc_inheritable
1681  def apply(self, inputs, *args, **kwargs):
1682    """Deprecated, do NOT use!
1683
1684    This is an alias of `self.__call__`.
1685
1686    Args:
1687      inputs: Input tensor(s).
1688      *args: additional positional arguments to be passed to `self.call`.
1689      **kwargs: additional keyword arguments to be passed to `self.call`.
1690
1691    Returns:
1692      Output tensor(s).
1693    """
1694    warnings.warn('`layer.apply` is deprecated and '
1695                  'will be removed in a future version. '
1696                  'Please use `layer.__call__` method instead.')
1697    return self.__call__(inputs, *args, **kwargs)
1698
1699  @doc_controls.do_not_doc_inheritable
1700  def add_variable(self, *args, **kwargs):
1701    """Deprecated, do NOT use! Alias for `add_weight`."""
1702    warnings.warn('`layer.add_variable` is deprecated and '
1703                  'will be removed in a future version. '
1704                  'Please use `layer.add_weight` method instead.')
1705    return self.add_weight(*args, **kwargs)
1706
1707  @property
1708  def variables(self):
1709    """Returns the list of all layer variables/weights.
1710
1711    Alias of `self.weights`.
1712
1713    Returns:
1714      A list of variables.
1715    """
1716    return self.weights
1717
1718  @property
1719  def trainable_variables(self):
1720    return self.trainable_weights
1721
1722  @property
1723  def non_trainable_variables(self):
1724    return self.non_trainable_weights
1725
1726  ##############################################################################
1727  # Methods & attributes below are all private and only used by the framework. #
1728  ##############################################################################
1729
1730  @property
1731  def _inbound_nodes(self):
1732    return self._inbound_nodes_value
1733
1734  @_inbound_nodes.setter
1735  @trackable.no_automatic_dependency_tracking
1736  def _inbound_nodes(self, value):
1737    self._inbound_nodes_value = value
1738
1739  @property
1740  def _outbound_nodes(self):
1741    return self._outbound_nodes_value
1742
1743  @_outbound_nodes.setter
1744  @trackable.no_automatic_dependency_tracking
1745  def _outbound_nodes(self, value):
1746    self._outbound_nodes_value = value
1747
1748  def _set_dtype_policy(self, dtype):
1749    """Sets self._dtype_policy."""
1750    if isinstance(dtype, policy.Policy):
1751      self._dtype_policy = dtype
1752    elif isinstance(dtype, dict):
1753      self._dtype_policy = policy.deserialize(dtype)
1754    elif isinstance(dtype, str) and dtype in ('mixed_float16',
1755                                              'mixed_bfloat16'):
1756      # The isinstance check is required since np.dtype raises an error if
1757      # compared to a non-dtype string.
1758      self._dtype_policy = policy.Policy(dtype)
1759    elif dtype:
1760      self._dtype_policy = policy.Policy(dtypes.as_dtype(dtype).name)
1761    else:
1762      self._dtype_policy = policy.global_policy()
1763    if (self._dtype_policy.name == 'mixed_float16' and
1764        not loss_scale_optimizer.strategy_supports_loss_scaling()):
1765      # Although only loss scaling doesn't support certain strategies, to avoid
1766      # confusion, we disallow the 'mixed_float16' policy with unsupported
1767      # strategies. This is because 'mixed_float16' requires loss scaling for
1768      # numeric stability.
1769      strategy = ds_context.get_strategy()
1770      raise ValueError('Mixed precision is not supported with the '
1771                       'tf.distribute.Strategy: %s. Either stop using mixed '
1772                       'precision by removing the use of the "%s" policy or '
1773                       'use a different Strategy, e.g. a MirroredStrategy.' %
1774                       (strategy.__class__.__name__, self._dtype_policy.name))
1775
1776    # Performance optimization: cache the compute dtype as a Dtype object or
1777    # None, so that str to Dtype conversion doesn't happen in Layer.__call__.
1778    if self._dtype_policy.compute_dtype:
1779      self._compute_dtype_object = dtypes.as_dtype(
1780          self._dtype_policy.compute_dtype)
1781    else:
1782      self._compute_dtype_object = None
1783
1784  # TODO(reedwm): Expose this property?
1785  @property
1786  def _compute_dtype(self):
1787    """The layer's compute dtype.
1788
1789    Unless mixed-precision is used, this is the same as `Layer.dtype`.
1790
1791    If self._autocast is True, layer's will cast floating-point inputs to this.
1792
1793    Returns:
1794      The layer's compute dtype.
1795    """
1796    return self._dtype_policy.compute_dtype
1797
1798  def _maybe_cast_inputs(self, inputs):
1799    """Maybe casts the inputs to the compute dtype.
1800
1801    If self._compute_dtype is floating-point, and self_autocast is True,
1802    floating-point inputs are casted to self._compute_dtype.
1803
1804    Args:
1805      inputs: Input tensor, or structure of input tensors.
1806
1807    Returns:
1808      `inputs`, but tensors may have been casted to self._compute_dtype
1809    """
1810    compute_dtype = self._compute_dtype
1811    if (self._autocast and compute_dtype and
1812        dtypes.as_dtype(compute_dtype).is_floating):
1813      def f(x):
1814        """Cast a single Tensor or TensorSpec to the compute dtype."""
1815        cast_types = (ops.Tensor, sparse_tensor.SparseTensor,
1816                      ragged_tensor.RaggedTensor)
1817        if (isinstance(x, cast_types) and x.dtype.is_floating and
1818            x.dtype.base_dtype.name != compute_dtype):
1819          return math_ops.cast(x, compute_dtype)
1820        elif isinstance(x, tensor_spec.TensorSpec) and x.dtype.is_floating:
1821          # Inputs may be TensorSpecs when this function is called from
1822          # model._set_inputs.
1823          return tensor_spec.TensorSpec(x.shape, compute_dtype, x.name)
1824        else:
1825          return x
1826      return nest.map_structure(f, inputs)
1827    else:
1828      return inputs
1829
1830  # _dtype used to be an attribute set in the constructor. We still expose it
1831  # because some clients still use it.
1832  # TODO(reedwm): Deprecate, then remove the _dtype property.
1833  @property
1834  def _dtype(self):
1835    # This is equivalent to returning self.dtype . We do not return self.dtype
1836    # as it would cause infinite recursion in a few subclasses, which override
1837    # "dtype" to return self._dtype.
1838    return self._dtype_policy.variable_dtype
1839
1840  @_dtype.setter
1841  def _dtype(self, value):
1842    value = dtypes.as_dtype(value).name
1843    self._set_dtype_policy(policy.Policy(value))
1844
1845  def _name_scope(self):  # pylint: disable=method-hidden
1846    return self.name
1847
1848  def _init_set_name(self, name, zero_based=True):
1849    if not name:
1850      self._name = backend.unique_object_name(
1851          generic_utils.to_snake_case(self.__class__.__name__),
1852          zero_based=zero_based)
1853    else:
1854      self._name = name
1855
1856  def _get_existing_metric(self, name=None):
1857    match = [m for m in self._metrics if m.name == name]
1858    if not match:
1859      return
1860    if len(match) > 1:
1861      raise ValueError(
1862          'Please provide different names for the metrics you have added. '
1863          'We found {} metrics with the name: "{}"'.format(len(match), name))
1864    return match[0]
1865
1866  def _symbolic_add_metric(self, value, aggregation=None, name=None):
1867    base_layer_utils.check_graph_consistency(value, method='add_metric')
1868    match = self._get_existing_metric(name)
1869    if aggregation is None:
1870      # Iterate over the metrics and check if the given metric exists already.
1871      # This can happen when a metric instance is created in subclassed model
1872      # layer `__init__` and we have tracked that instance already in
1873      # model.__setattr__.
1874      if match:
1875        result_tensor = value
1876        metric_obj = match
1877      elif hasattr(value, '_metric_obj'):
1878        # We track the instance using the metadata on the result tensor.
1879        result_tensor = value
1880        metric_obj = result_tensor._metric_obj
1881        self._metrics.append(metric_obj)
1882      else:
1883        raise ValueError(
1884            'We do not support adding an aggregated metric result tensor that '
1885            'is not the output of a `tf.keras.metrics.Metric` metric instance. '
1886            'Without having access to the metric instance we cannot reset the '
1887            'state of a metric after every epoch during training. You can '
1888            'create a `tf.keras.metrics.Metric` instance and pass the result '
1889            'here or pass an un-aggregated result with `aggregation` parameter '
1890            'set as `mean`. For example: `self.add_metric(tf.reduce_sum(inputs)'
1891            ', name=\'mean_activation\', aggregation=\'mean\')`')
1892    else:
1893      # If a non-aggregated tensor is given as input (ie. `aggregation` is
1894      # explicitly set to `mean`), we wrap the tensor in `Mean` metric.
1895      if match:
1896        result_tensor = match(value)
1897        metric_obj = match
1898      else:
1899        metric_obj, result_tensor = base_layer_utils.create_mean_metric(
1900            value, name)
1901        self._metrics.append(metric_obj)
1902
1903  def _handle_weight_regularization(self, name, variable, regularizer):
1904    """Create lambdas which compute regularization losses."""
1905
1906    def _loss_for_variable(v):
1907      """Creates a regularization loss `Tensor` for variable `v`."""
1908      with backend.name_scope(name + '/Regularizer'):
1909        regularization = regularizer(v)
1910      return regularization
1911
1912    if base_layer_utils.is_split_variable(variable):
1913      for v in variable:
1914        self.add_loss(functools.partial(_loss_for_variable, v))
1915    else:
1916      self.add_loss(functools.partial(_loss_for_variable, variable))
1917
1918  def _handle_activity_regularization(self, inputs, outputs):
1919    # Apply activity regularization.
1920    # Note that it should be applied every time the layer creates a new
1921    # output, since it is output-specific.
1922    if self._activity_regularizer:
1923      output_list = nest.flatten(outputs)
1924      with backend.name_scope('ActivityRegularizer'):
1925        for output in output_list:
1926          activity_loss = self._activity_regularizer(output)
1927          batch_size = math_ops.cast(
1928              array_ops.shape(output)[0], activity_loss.dtype)
1929          # Make activity regularization strength batch-agnostic.
1930          mean_activity_loss = activity_loss / batch_size
1931          base_layer_utils.check_graph_consistency(
1932              mean_activity_loss, method='activity_regularizer')
1933          self.add_loss(mean_activity_loss, inputs=inputs)
1934
1935  def _set_mask_metadata(self, inputs, outputs, previous_mask):
1936    flat_outputs = nest.flatten(outputs)
1937
1938    mask_already_computed = (
1939        getattr(self, '_compute_output_and_mask_jointly', False) or
1940        all(getattr(x, '_keras_mask', None) is not None for x in flat_outputs))
1941
1942    # Only compute the mask if the Layer explicitly supports masking or has
1943    # overridden `compute_mask`.
1944    should_compute_mask = (
1945        hasattr(self, 'compute_mask') and
1946        (self.supports_masking or
1947         not getattr(self.compute_mask, '_is_default', False)))
1948
1949    if mask_already_computed:
1950      flat_masks = [getattr(x, '_keras_mask', None) for x in flat_outputs]
1951    elif not should_compute_mask:
1952      flat_masks = [None for _ in flat_outputs]
1953    else:
1954      output_masks = self.compute_mask(inputs, previous_mask)
1955      # `compute_mask` can return a single `None` even when a Layer
1956      # has multiple outputs.
1957      if output_masks is None:
1958        flat_masks = [None for _ in flat_outputs]
1959      else:
1960        flat_masks = nest.flatten(output_masks)
1961
1962    for output, mask in zip(flat_outputs, flat_masks):
1963      try:
1964        output._keras_mask = mask
1965      except AttributeError:
1966        # C Type such as np.ndarray.
1967        pass
1968
1969    if tf_utils.are_all_symbolic_tensors(flat_outputs):
1970      for output in flat_outputs:
1971        if getattr(output, '_keras_mask', None) is not None:
1972          # Do not track masks for `TensorFlowOpLayer` construction.
1973          output._keras_mask._keras_history_checked = True
1974
1975  def _collect_input_masks(self, inputs, args, kwargs):
1976    """Checks if `mask` argument was passed, else gathers mask from inputs."""
1977    if self._call_arg_was_passed('mask', args, kwargs):
1978      return self._get_call_arg_value('mask', args, kwargs)
1979
1980    if not self._should_compute_mask:
1981      return None
1982
1983    input_masks = nest.map_structure(lambda t: getattr(t, '_keras_mask', None),
1984                                     inputs)
1985    if generic_utils.is_all_none(input_masks):
1986      return None
1987    return input_masks
1988
1989  def _call_arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False):
1990    if arg_name in kwargs:
1991      return True
1992    call_fn_args = self._call_fn_args
1993    if not inputs_in_args:
1994      # Ignore `inputs` arg.
1995      call_fn_args = call_fn_args[1:]
1996    if arg_name in dict(zip(call_fn_args, args)):
1997      return True
1998    return False
1999
2000  def _get_call_arg_value(self, arg_name, args, kwargs, inputs_in_args=False):
2001    if arg_name in kwargs:
2002      return kwargs[arg_name]
2003    call_fn_args = self._call_fn_args
2004    if not inputs_in_args:
2005      # Ignore `inputs` arg.
2006      call_fn_args = call_fn_args[1:]
2007    args_dict = dict(zip(call_fn_args, args))
2008    return args_dict[arg_name]
2009
2010  def _set_call_arg_value(
2011      self, arg_name, new_value, args,
2012      kwargs, inputs_in_args=False, pop_kwarg_if_none=False):
2013    arg_pos = self._call_fn_arg_positions.get(arg_name, None)
2014    if arg_pos is not None:
2015      if not inputs_in_args:
2016        # Ignore `inputs` arg.
2017        arg_pos = arg_pos - 1
2018      if len(args) > arg_pos:
2019        args = list(args)
2020        args[arg_pos] = new_value
2021        return args, kwargs
2022    if new_value is None and pop_kwarg_if_none:
2023      kwargs.pop(arg_name, None)
2024    else:
2025      kwargs[arg_name] = new_value
2026    return args, kwargs
2027
2028  def _get_node_attribute_at_index(self, node_index, attr, attr_name):
2029    """Private utility to retrieves an attribute (e.g. inputs) from a node.
2030
2031    This is used to implement the methods:
2032        - get_input_shape_at
2033        - get_output_shape_at
2034        - get_input_at
2035        etc...
2036
2037    Args:
2038        node_index: Integer index of the node from which
2039            to retrieve the attribute.
2040        attr: Exact node attribute name.
2041        attr_name: Human-readable attribute name, for error messages.
2042
2043    Returns:
2044        The layer's attribute `attr` at the node of index `node_index`.
2045
2046    Raises:
2047        RuntimeError: If the layer has no inbound nodes, or if called in Eager
2048        mode.
2049        ValueError: If the index provided does not match any node.
2050    """
2051    if not self._inbound_nodes:
2052      raise RuntimeError('The layer has never been called '
2053                         'and thus has no defined ' + attr_name + '.')
2054    if not len(self._inbound_nodes) > node_index:
2055      raise ValueError('Asked to get ' + attr_name + ' at node ' +
2056                       str(node_index) + ', but the layer has only ' +
2057                       str(len(self._inbound_nodes)) + ' inbound nodes.')
2058    values = getattr(self._inbound_nodes[node_index], attr)
2059    if isinstance(values, list) and len(values) == 1:
2060      return values[0]
2061    else:
2062      return values
2063
2064  def _maybe_build(self, inputs):
2065    # Check input assumptions set before layer building, e.g. input rank.
2066    if not self.built:
2067      input_spec.assert_input_compatibility(
2068          self.input_spec, inputs, self.name)
2069      input_list = nest.flatten(inputs)
2070      if input_list and self._dtype_policy.compute_dtype is None:
2071        try:
2072          dtype = input_list[0].dtype.base_dtype.name
2073        except AttributeError:
2074          pass
2075        else:
2076          self._set_dtype_policy(policy.Policy(dtype))
2077      input_shapes = None
2078      if all(hasattr(x, 'shape') for x in input_list):
2079        input_shapes = nest.map_structure(lambda x: x.shape, inputs)
2080      # Only call `build` if the user has manually overridden the build method.
2081      if not hasattr(self.build, '_is_default'):
2082        # Any setup work performed only once should happen in an `init_scope`
2083        # to avoid creating symbolic Tensors that will later pollute any eager
2084        # operations.
2085        with tf_utils.maybe_init_scope(self):
2086          self.build(input_shapes)
2087      # We must set also ensure that the layer is marked as built, and the build
2088      # shape is stored since user defined build functions may not be calling
2089      # `super.build()`
2090      Layer.build(self, input_shapes)
2091
2092    # Optionally load weight values specified at layer instantiation.
2093    if self._initial_weights is not None:
2094      self.set_weights(self._initial_weights)
2095      self._initial_weights = None
2096
2097  def _symbolic_call(self, inputs):
2098    input_shapes = nest.map_structure(lambda x: x.shape, inputs)
2099    output_shapes = self.compute_output_shape(input_shapes)
2100
2101    def _make_placeholder_like(shape):
2102      ph = backend.placeholder(shape=shape, dtype=self.dtype)
2103      ph._keras_mask = None
2104      return ph
2105
2106    return nest.map_structure(_make_placeholder_like, output_shapes)
2107
2108  def _get_trainable_state(self):
2109    """Get the `trainable` state of each sublayer.
2110
2111    Returns:
2112      A dict mapping all sublayers to their `trainable` value.
2113    """
2114    layers = self._flatten_layers(include_self=False, recursive=False)
2115    trainable_state = {self: self.trainable}
2116    for l in layers:
2117      trainable_state.update(l._get_trainable_state())
2118    return trainable_state
2119
2120  def _set_trainable_state(self, trainable_state):
2121    """Set `trainable` state for each sublayer."""
2122    if self in trainable_state:
2123      self.trainable = trainable_state[self]
2124    layers = self._flatten_layers(include_self=False, recursive=False)
2125    for l in layers:
2126      if l in trainable_state:
2127        l._set_trainable_state(trainable_state)
2128
2129  @property
2130  def _obj_reference_counts(self):
2131    """A dictionary counting the number of attributes referencing an object."""
2132    self._maybe_create_attribute('_obj_reference_counts_dict',
2133                                 object_identity.ObjectIdentityDictionary())
2134    return self._obj_reference_counts_dict
2135
2136  @trackable.no_automatic_dependency_tracking
2137  def _maybe_create_attribute(self, name, default_value):
2138    """Create the attribute with the default value if it hasn't been created.
2139
2140    This is useful for fields that is used for tracking purpose,
2141    _trainable_weights, or _layers. Note that user could create a layer subclass
2142    and assign an internal field before invoking the Layer.__init__(), the
2143    __setattr__() need to create the tracking fields and __init__() need to not
2144    override them.
2145
2146    Args:
2147      name: String, the name of the attribute.
2148      default_value: Object, the default value of the attribute.
2149    """
2150    if not hasattr(self, name):
2151      self.__setattr__(name, default_value)
2152
2153  def __delattr__(self, name):
2154    # For any super.__delattr__() call, we will directly use the implementation
2155    # in Trackable and skip the behavior in AutoTrackable. The Layer was
2156    # originally use Trackable as base class, the change of using Module as base
2157    # class forced us to have AutoTrackable in the class hierarchy.
2158    #
2159    # TODO(b/180760306) Keeping the status quo of skipping _delattr__ and
2160    # __setattr__ in AutoTrackable may be unsustainable.
2161    existing_value = getattr(self, name, None)
2162
2163    # If this value is replacing an existing object assigned to an attribute, we
2164    # should clean it out to avoid leaking memory. First we check if there are
2165    # other attributes referencing it.
2166    reference_counts = self._obj_reference_counts
2167    if existing_value not in reference_counts:
2168      super(autotrackable.AutoTrackable, self).__delattr__(name)  # pylint: disable=bad-super-call
2169      return
2170
2171    reference_count = reference_counts[existing_value]
2172    if reference_count > 1:
2173      # There are other remaining references. We can't remove this object from
2174      # _layers etc.
2175      reference_counts[existing_value] = reference_count - 1
2176      super(autotrackable.AutoTrackable, self).__delattr__(name)  # pylint: disable=bad-super-call
2177      return
2178    else:
2179      # This is the last remaining reference.
2180      del reference_counts[existing_value]
2181
2182    super(autotrackable.AutoTrackable, self).__delattr__(name)  # pylint: disable=bad-super-call
2183
2184    if (isinstance(existing_value, Layer)
2185        or base_layer_utils.has_weights(existing_value)):
2186      super(autotrackable.AutoTrackable, self).__setattr__(  # pylint: disable=bad-super-call
2187          '_self_tracked_trackables',
2188          [l for l in self._self_tracked_trackables if l is not existing_value])
2189    if isinstance(existing_value, tf_variables.Variable):
2190      super(autotrackable.AutoTrackable, self).__setattr__(  # pylint: disable=bad-super-call
2191          '_trainable_weights',
2192          [w for w in self._trainable_weights if w is not existing_value])
2193      super(autotrackable.AutoTrackable, self).__setattr__(  # pylint: disable=bad-super-call
2194          '_non_trainable_weights',
2195          [w for w in self._non_trainable_weights if w is not existing_value])
2196
2197  def __setattr__(self, name, value):
2198    if (name == '_self_setattr_tracking' or
2199        not getattr(self, '_self_setattr_tracking', True) or
2200        # Exclude @property.setters from tracking
2201        hasattr(self.__class__, name)):
2202      try:
2203        super(autotrackable.AutoTrackable, self).__setattr__(name, value)  # pylint: disable=bad-super-call
2204      except AttributeError:
2205        raise AttributeError(
2206            ('Can\'t set the attribute "{}", likely because it conflicts with '
2207             'an existing read-only @property of the object. Please choose a '
2208             'different name.').format(name))
2209      return
2210
2211    # Keep track of trackable objects, for the needs of `Network.save_weights`.
2212    value = data_structures.sticky_attribute_assignment(
2213        trackable=self, value=value, name=name)
2214
2215    reference_counts = self._obj_reference_counts
2216    reference_counts[value] = reference_counts.get(value, 0) + 1
2217
2218    # Clean out the old attribute, which clears _layers and _trainable_weights
2219    # if necessary.
2220    try:
2221      self.__delattr__(name)
2222    except AttributeError:
2223      pass
2224
2225    # Keep track of metric instance created in subclassed layer.
2226    from tensorflow.python.keras import metrics as metrics_module  # pylint: disable=g-import-not-at-top
2227    for val in nest.flatten(value):
2228      if isinstance(val, metrics_module.Metric) and hasattr(self, '_metrics'):
2229        self._metrics.append(val)
2230
2231    # TODO(scottzhu): Need to track Module object as well for weight tracking.
2232    # Be careful about metric if it becomes a Module in future.
2233    # Append value to self._layers if relevant
2234    if (getattr(self, '_auto_track_sub_layers', True) and
2235        (isinstance(value, Layer) or base_layer_utils.has_weights(value))):
2236      self._maybe_create_attribute('_self_tracked_trackables', [])
2237      # We need to check object identity to avoid de-duplicating empty
2238      # container types which compare equal.
2239      if not any((layer is value for layer in self._self_tracked_trackables)):
2240        self._self_tracked_trackables.append(value)
2241        if hasattr(value, '_use_resource_variables'):
2242          # Legacy layers (V1 tf.layers) must always use
2243          # resource variables.
2244          value._use_resource_variables = True
2245
2246    # Append value to list of trainable / non-trainable weights if relevant
2247    # TODO(b/125122625): This won't pick up on any variables added to a
2248    # list/dict after creation.
2249    for val in nest.flatten(value):
2250      if not isinstance(val, tf_variables.Variable):
2251        continue
2252
2253      # Users may add extra weights/variables
2254      # simply by assigning them to attributes (invalid for graph networks)
2255      self._maybe_create_attribute('_trainable_weights', [])
2256      self._maybe_create_attribute('_non_trainable_weights', [])
2257      if val.trainable:
2258        if any(val is w for w in self._trainable_weights):
2259          continue
2260        self._trainable_weights.append(val)
2261      else:
2262        if any(val is w for w in self._non_trainable_weights):
2263          continue
2264        self._non_trainable_weights.append(val)
2265
2266      backend.track_variable(val)
2267
2268    # TODO(b/180760306) Skip the auto trackable from tf.Module to keep status
2269    # quo. See the comment at __delattr__.
2270    super(autotrackable.AutoTrackable, self).__setattr__(name, value)  # pylint: disable=bad-super-call
2271
2272  # This is a hack so that the is_layer (within
2273  # training/trackable/layer_utils.py) check doesn't get the weights attr.
2274  # TODO(b/110718070): Remove when fixed.
2275  def _is_layer(self):
2276    return True
2277
2278  def _init_call_fn_args(self, expects_training_arg=None):
2279    # Clear cached call function arguments.
2280    self.__class__._call_full_argspec.fget.cache.pop(self, None)
2281    self.__class__._call_fn_args.fget.cache.pop(self, None)
2282    self.__class__._call_accepts_kwargs.fget.cache.pop(self, None)
2283
2284    call_fn_args = self._call_fn_args
2285    if expects_training_arg is None:
2286      self._expects_training_arg = ('training' in call_fn_args or
2287                                    self._call_accepts_kwargs)
2288    else:
2289      # Use value encoded into the metadata when loading from the SavedModel.
2290      self._expects_training_arg = expects_training_arg
2291    self._expects_mask_arg = ('mask' in call_fn_args or
2292                              self._call_accepts_kwargs)
2293
2294  @property
2295  @layer_utils.cached_per_instance
2296  def _call_full_argspec(self):
2297    # Argspec inspection is expensive and the call spec is used often, so it
2298    # makes sense to cache the result.
2299    return tf_inspect.getfullargspec(self.call)
2300
2301  @property
2302  @layer_utils.cached_per_instance
2303  def _call_fn_args(self):
2304    all_args = self._call_full_argspec.args
2305    # Scrub `self` that appears if a decorator was applied.
2306    if all_args and all_args[0] == 'self':
2307      return all_args[1:]
2308    return all_args
2309
2310  @property
2311  @layer_utils.cached_per_instance
2312  def _call_fn_arg_positions(self):
2313    call_fn_arg_positions = dict()
2314    for pos, arg in enumerate(self._call_fn_args):
2315      call_fn_arg_positions[arg] = pos
2316    return call_fn_arg_positions
2317
2318  @property
2319  @layer_utils.cached_per_instance
2320  def _call_accepts_kwargs(self):
2321    return self._call_full_argspec.varkw is not None
2322
2323  @property
2324  @layer_utils.cached_per_instance
2325  def _should_compute_mask(self):
2326    return ('mask' in self._call_fn_args or
2327            getattr(self, 'compute_mask', None) is not None)
2328
2329  def _dedup_weights(self, weights):
2330    """Dedupe weights while maintaining order as much as possible."""
2331    output, seen_ids = [], set()
2332    for w in weights:
2333      if id(w) not in seen_ids:
2334        output.append(w)
2335        # Track the Variable's identity to avoid __eq__ issues.
2336        seen_ids.add(id(w))
2337
2338    return output
2339
2340  # SavedModel properties. Please see keras/saving/saved_model for details.
2341
2342  @property
2343  def _trackable_saved_model_saver(self):
2344    return layer_serialization.LayerSavedModelSaver(self)
2345
2346  @property
2347  def _object_identifier(self):
2348    return self._trackable_saved_model_saver.object_identifier
2349
2350  @property
2351  def _tracking_metadata(self):
2352    return self._trackable_saved_model_saver.tracking_metadata
2353
2354  def _trackable_children(self, save_type='checkpoint', **kwargs):
2355    if save_type == 'savedmodel':
2356      cache = kwargs['cache']
2357      # TODO(b/213628533): This must be called before super() to ensure
2358      # that any input shape changes are applied before getting the config of
2359      # the model.
2360      children = self._trackable_saved_model_saver.trackable_children(cache)
2361    else:
2362      children = {}
2363    children.update(super()._trackable_children(save_type, **kwargs))
2364    return children
2365
2366  def __getstate__(self):
2367    # Override to support `copy.deepcopy` and pickling.
2368    # Thread-local objects cannot be copied in Python 3, so pop these.
2369    # Thread-local objects are used to cache losses in MirroredStrategy, and
2370    # so shouldn't be copied.
2371    state = self.__dict__.copy()
2372    state.pop('_thread_local', None)
2373    return state
2374
2375  def __setstate__(self, state):
2376    state['_thread_local'] = threading.local()
2377    # Bypass Trackable logic as `__dict__` already contains this info.
2378    object.__setattr__(self, '__dict__', state)
2379
2380
2381class KerasHistory(
2382    collections.namedtuple('KerasHistory',
2383                           ['layer', 'node_index', 'tensor_index'])):
2384  """Tracks the Layer call that created a Tensor, for Keras Graph Networks.
2385
2386  During construction of Keras Graph Networks, this metadata is added to
2387  each Tensor produced as the output of a Layer, starting with an
2388  `InputLayer`. This allows Keras to track how each Tensor was produced, and
2389  this information is later retraced by the `keras.engine.Network` class to
2390  reconstruct the Keras Graph Network.
2391
2392  Attributes:
2393    layer: The Layer that produced the Tensor.
2394    node_index: The specific call to the Layer that produced this Tensor. Layers
2395      can be called multiple times in order to share weights. A new node is
2396      created every time a Tensor is called.
2397    tensor_index: The output index for this Tensor. Always zero if the Layer
2398      that produced this Tensor only has one output. Nested structures of
2399      Tensors are deterministically assigned an index via `nest.flatten`.
2400  """
2401  # Added to maintain memory and performance characteristics of `namedtuple`
2402  # while subclassing.
2403  __slots__ = ()
2404
2405
2406# Avoid breaking users who directly import this symbol from this file.
2407# TODO(fchollet): remove this.
2408InputSpec = input_spec.InputSpec  # pylint:disable=invalid-name
2409