xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/engine/base_layer.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Contains the base Layer class, from which all layers inherit."""
17
18import collections
19import copy
20import functools
21import itertools
22import threading
23import warnings
24import weakref
25
26import numpy as np
27
28from google.protobuf import json_format
29from tensorflow.core.framework import node_def_pb2
30from tensorflow.python import tf2
31from tensorflow.python.autograph.core import ag_ctx
32from tensorflow.python.autograph.impl import api as autograph
33from tensorflow.python.distribute import distribution_strategy_context as ds_context
34from tensorflow.python.eager import backprop
35from tensorflow.python.eager import context
36from tensorflow.python.eager import def_function
37from tensorflow.python.framework import constant_op
38from tensorflow.python.framework import dtypes
39from tensorflow.python.framework import func_graph
40from tensorflow.python.framework import ops
41from tensorflow.python.framework import sparse_tensor
42from tensorflow.python.framework import tensor_spec
43from tensorflow.python.framework import tensor_util
44from tensorflow.python.keras import backend
45from tensorflow.python.keras import constraints
46from tensorflow.python.keras import initializers
47from tensorflow.python.keras import regularizers
48from tensorflow.python.keras.engine import base_layer_utils
49from tensorflow.python.keras.engine import input_spec
50from tensorflow.python.keras.engine import keras_tensor
51from tensorflow.python.keras.engine import node as node_module
52from tensorflow.python.keras.mixed_precision import autocast_variable
53from tensorflow.python.keras.mixed_precision import loss_scale_optimizer
54from tensorflow.python.keras.mixed_precision import policy
55from tensorflow.python.keras.saving.saved_model import layer_serialization
56from tensorflow.python.keras.utils import generic_utils
57from tensorflow.python.keras.utils import layer_utils
58from tensorflow.python.keras.utils import object_identity
59from tensorflow.python.keras.utils import tf_inspect
60from tensorflow.python.keras.utils import tf_utils
61from tensorflow.python.keras.utils import version_utils
62# A module that only depends on `keras.layers` import these from here.
63from tensorflow.python.keras.utils.generic_utils import to_snake_case  # pylint: disable=unused-import
64from tensorflow.python.keras.utils.tf_utils import is_tensor_or_tensor_list  # pylint: disable=unused-import
65
66from tensorflow.python.module import module
67from tensorflow.python.ops import array_ops
68from tensorflow.python.ops import math_ops
69from tensorflow.python.ops import variables as tf_variables
70from tensorflow.python.ops.numpy_ops import np_arrays
71from tensorflow.python.ops.ragged import ragged_tensor
72from tensorflow.python.platform import tf_logging
73from tensorflow.python.trackable import autotrackable
74from tensorflow.python.trackable import base as trackable
75from tensorflow.python.trackable import data_structures
76from tensorflow.python.util import compat
77from tensorflow.python.util import nest
78from tensorflow.python.util.tf_export import get_canonical_name_for_symbol
79from tensorflow.python.util.tf_export import keras_export
80from tensorflow.tools.docs import doc_controls
81
82# pylint: disable=g-inconsistent-quotes
83metrics_mod = generic_utils.LazyLoader(
84    "metrics_mod", globals(),
85    "tensorflow.python.keras.metrics")
86# pylint: enable=g-inconsistent-quotes
87
88# Prefix that is added to the TF op layer names.
89_TF_OP_LAYER_NAME_PREFIX = 'tf_op_layer_'
90
91# TODO(mdan): Should we have a single generic type for types that can be passed
92# to tf.cast?
93_AUTOCAST_TYPES = (ops.Tensor, sparse_tensor.SparseTensor,
94                   ragged_tensor.RaggedTensor)
95
96
97@keras_export('keras.layers.Layer')
98class Layer(module.Module, version_utils.LayerVersionSelector):
99  """This is the class from which all layers inherit.
100
101  A layer is a callable object that takes as input one or more tensors and
102  that outputs one or more tensors. It involves *computation*, defined
103  in the `call()` method, and a *state* (weight variables), defined
104  either in the constructor `__init__()` or in the `build()` method.
105
106  Users will just instantiate a layer and then treat it as a callable.
107
108  Args:
109    trainable: Boolean, whether the layer's variables should be trainable.
110    name: String name of the layer.
111    dtype: The dtype of the layer's computations and weights. Can also be a
112      `tf.keras.mixed_precision.Policy`, which allows the computation and weight
113      dtype to differ. Default of `None` means to use
114      `tf.keras.mixed_precision.global_policy()`, which is a float32 policy
115      unless set to different value.
116    dynamic: Set this to `True` if your layer should only be run eagerly, and
117      should not be used to generate a static computation graph.
118      This would be the case for a Tree-RNN or a recursive network,
119      for example, or generally for any layer that manipulates tensors
120      using Python control flow. If `False`, we assume that the layer can
121      safely be used to generate a static computation graph.
122
123  Attributes:
124    name: The name of the layer (string).
125    dtype: The dtype of the layer's weights.
126    variable_dtype: Alias of `dtype`.
127    compute_dtype: The dtype of the layer's computations. Layers automatically
128      cast inputs to this dtype which causes the computations and output to also
129      be in this dtype. When mixed precision is used with a
130      `tf.keras.mixed_precision.Policy`, this will be different than
131      `variable_dtype`.
132    dtype_policy: The layer's dtype policy. See the
133      `tf.keras.mixed_precision.Policy` documentation for details.
134    trainable_weights: List of variables to be included in backprop.
135    non_trainable_weights: List of variables that should not be
136      included in backprop.
137    weights: The concatenation of the lists trainable_weights and
138      non_trainable_weights (in this order).
139    trainable: Whether the layer should be trained (boolean), i.e. whether
140      its potentially-trainable weights should be returned as part of
141      `layer.trainable_weights`.
142    input_spec: Optional (list of) `InputSpec` object(s) specifying the
143      constraints on inputs that can be accepted by the layer.
144
145  We recommend that descendants of `Layer` implement the following methods:
146
147  * `__init__()`: Defines custom layer attributes, and creates layer state
148    variables that do not depend on input shapes, using `add_weight()`.
149  * `build(self, input_shape)`: This method can be used to create weights that
150    depend on the shape(s) of the input(s), using `add_weight()`. `__call__()`
151    will automatically build the layer (if it has not been built yet) by
152    calling `build()`.
153  * `call(self, inputs, *args, **kwargs)`: Called in `__call__` after making
154    sure `build()` has been called. `call()` performs the logic of applying the
155    layer to the input tensors (which should be passed in as argument).
156    Two reserved keyword arguments you can optionally use in `call()` are:
157      - `training` (boolean, whether the call is in inference mode or training
158        mode). See more details in [the layer/model subclassing guide](
159        https://www.tensorflow.org/guide/keras/custom_layers_and_models#privileged_training_argument_in_the_call_method)
160      - `mask` (boolean tensor encoding masked timesteps in the input, used
161        in RNN layers). See more details in [the layer/model subclassing guide](
162        https://www.tensorflow.org/guide/keras/custom_layers_and_models#privileged_mask_argument_in_the_call_method)
163    A typical signature for this method is `call(self, inputs)`, and user could
164    optionally add `training` and `mask` if the layer need them. `*args` and
165    `**kwargs` is only useful for future extension when more input parameters
166    are planned to be added.
167  * `get_config(self)`: Returns a dictionary containing the configuration used
168    to initialize this layer. If the keys differ from the arguments
169    in `__init__`, then override `from_config(self)` as well.
170    This method is used when saving
171    the layer or a model that contains this layer.
172
173  Examples:
174
175  Here's a basic example: a layer with two variables, `w` and `b`,
176  that returns `y = w . x + b`.
177  It shows how to implement `build()` and `call()`.
178  Variables set as attributes of a layer are tracked as weights
179  of the layers (in `layer.weights`).
180
181  ```python
182  class SimpleDense(Layer):
183
184    def __init__(self, units=32):
185        super(SimpleDense, self).__init__()
186        self.units = units
187
188    def build(self, input_shape):  # Create the state of the layer (weights)
189      w_init = tf.random_normal_initializer()
190      self.w = tf.Variable(
191          initial_value=w_init(shape=(input_shape[-1], self.units),
192                               dtype='float32'),
193          trainable=True)
194      b_init = tf.zeros_initializer()
195      self.b = tf.Variable(
196          initial_value=b_init(shape=(self.units,), dtype='float32'),
197          trainable=True)
198
199    def call(self, inputs):  # Defines the computation from inputs to outputs
200        return tf.matmul(inputs, self.w) + self.b
201
202  # Instantiates the layer.
203  linear_layer = SimpleDense(4)
204
205  # This will also call `build(input_shape)` and create the weights.
206  y = linear_layer(tf.ones((2, 2)))
207  assert len(linear_layer.weights) == 2
208
209  # These weights are trainable, so they're listed in `trainable_weights`:
210  assert len(linear_layer.trainable_weights) == 2
211  ```
212
213  Note that the method `add_weight()` offers a shortcut to create weights:
214
215  ```python
216  class SimpleDense(Layer):
217
218    def __init__(self, units=32):
219        super(SimpleDense, self).__init__()
220        self.units = units
221
222    def build(self, input_shape):
223        self.w = self.add_weight(shape=(input_shape[-1], self.units),
224                                 initializer='random_normal',
225                                 trainable=True)
226        self.b = self.add_weight(shape=(self.units,),
227                                 initializer='random_normal',
228                                 trainable=True)
229
230    def call(self, inputs):
231        return tf.matmul(inputs, self.w) + self.b
232  ```
233
234  Besides trainable weights, updated via backpropagation during training,
235  layers can also have non-trainable weights. These weights are meant to
236  be updated manually during `call()`. Here's a example layer that computes
237  the running sum of its inputs:
238
239  ```python
240  class ComputeSum(Layer):
241
242    def __init__(self, input_dim):
243        super(ComputeSum, self).__init__()
244        # Create a non-trainable weight.
245        self.total = tf.Variable(initial_value=tf.zeros((input_dim,)),
246                                 trainable=False)
247
248    def call(self, inputs):
249        self.total.assign_add(tf.reduce_sum(inputs, axis=0))
250        return self.total
251
252  my_sum = ComputeSum(2)
253  x = tf.ones((2, 2))
254
255  y = my_sum(x)
256  print(y.numpy())  # [2. 2.]
257
258  y = my_sum(x)
259  print(y.numpy())  # [4. 4.]
260
261  assert my_sum.weights == [my_sum.total]
262  assert my_sum.non_trainable_weights == [my_sum.total]
263  assert my_sum.trainable_weights == []
264  ```
265
266  For more information about creating layers, see the guide
267  [Making new Layers and Models via subclassing](
268    https://www.tensorflow.org/guide/keras/custom_layers_and_models)
269  """
270
271  # See tf.Module for the usage of this property.
272  # The key for _obj_reference_counts_dict is a Trackable, which could be a
273  # variable or layer etc. tf.Module._flatten will fail to flatten the key
274  # since it is trying to convert Trackable to a string. This attribute can be
275  # ignored even after the fix of nest lib, since the trackable object should
276  # already been available as individual attributes. _obj_reference_counts_dict
277  # just contains a copy of them.
278  _TF_MODULE_IGNORED_PROPERTIES = frozenset(itertools.chain(
279      ('_obj_reference_counts_dict',),
280      module.Module._TF_MODULE_IGNORED_PROPERTIES
281  ))
282
283  # When loading from a SavedModel, Layers typically can be revived into a
284  # generic Layer wrapper. Sometimes, however, layers may implement methods
285  # that go beyond this wrapper, as in the case of PreprocessingLayers'
286  # `adapt` method. When this is the case, layer implementers can override
287  # must_restore_from_config to return True; layers with this property must
288  # be restored into their actual objects (and will fail if the object is
289  # not available to the restoration code).
290  _must_restore_from_config = False
291
292  def _get_cell_name(self):
293    canonical_name = get_canonical_name_for_symbol(
294        self.__class__, api_name='keras', add_prefix_to_v1_names=True)
295    if canonical_name is not None:
296      return 'tf.{}'.format(canonical_name)
297    return self.__class__.__module__ + '.' + self.__class__.__name__
298
299  def _instrument_layer_creation(self):
300    self._instrumented_keras_api = False
301    self._instrumented_keras_layer_class = False
302    self._instrumented_keras_model_class = False
303    if not getattr(self, '_disable_keras_instrumentation', False):
304      self._instrumented_keras_api = True
305      if getattr(self, '_is_model_for_instrumentation', False):
306        self._instrumented_keras_model_class = True
307      else:
308        self._instrumented_keras_layer_class = True
309
310  @trackable.no_automatic_dependency_tracking
311  def __init__(self,
312               trainable=True,
313               name=None,
314               dtype=None,
315               dynamic=False,
316               **kwargs):
317    self._instrument_layer_creation()
318
319    # These properties should be set by the user via keyword arguments.
320    # note that 'dtype', 'input_shape' and 'batch_input_shape'
321    # are only applicable to input layers: do not pass these keywords
322    # to non-input layers.
323    allowed_kwargs = {
324        'input_dim',
325        'input_shape',
326        'batch_input_shape',
327        'batch_size',
328        'weights',
329        'activity_regularizer',
330        'autocast',
331        'implementation',
332    }
333    # Validate optional keyword arguments.
334    generic_utils.validate_kwargs(kwargs, allowed_kwargs)
335
336    # Mutable properties
337    # Indicates whether the layer's weights are updated during training
338    # and whether the layer's updates are run during training.
339    self._trainable = trainable
340    # A stateful layer is a layer whose updates are run during inference too,
341    # for instance stateful RNNs.
342    self._stateful = False
343    # Indicates whether `build` needs to be called upon layer call, to create
344    # the layer's weights.
345    self.built = False
346    # Provides information about which inputs are compatible with the layer.
347    self._input_spec = None
348
349    # SavedModel-related attributes.
350    # Record the build input shape for loading purposes.
351    # TODO(kathywu): Move this to Layer._set_save_spec once cl/290121460 is
352    # submitted.
353    self._build_input_shape = None
354    self._saved_model_inputs_spec = None
355
356    # `Layer.compute_mask` will be called at the end of `Layer.__call__` if
357    # `Layer.compute_mask` is overridden, or if the `Layer` subclass sets
358    # `self.supports_masking=True`.
359    self._supports_masking = not generic_utils.is_default(self.compute_mask)
360
361    self._init_set_name(name)
362    self._activity_regularizer = regularizers.get(
363        kwargs.pop('activity_regularizer', None))
364    self._maybe_create_attribute('_trainable_weights', [])
365    self._maybe_create_attribute('_non_trainable_weights', [])
366    self._updates = []
367    # Object to store all thread local layer properties.
368    self._thread_local = threading.local()
369    # A list of zero-argument lambdas which return Tensors, used for variable
370    # regularizers.
371    self._callable_losses = []
372    # A list of symbolic Tensors containing activity regularizers and losses
373    # manually added through `add_loss` in graph-building mode.
374    self._losses = []
375    # A list of metric instances corresponding to the symbolic metric tensors
376    # added using the `add_metric` API.
377    self._metrics = []
378    # Ensures the same metric is not added multiple times in `MirroredStrategy`.
379    self._metrics_lock = threading.Lock()
380
381    # Both graph and subclassed networks have a dtype policy. For graph
382    # networks, the policy's compute and variable dtypes are ignored. Such
383    # networks only use the policy if it is a PolicyV1, in which case it uses
384    # the PolicyV1's loss_scale (Policy does not have a loss_scale). For
385    # subclassed networks, the compute and variable dtypes are used as like any
386    # ordinary layer.
387    self._set_dtype_policy(dtype)
388    # Boolean indicating whether the layer automatically casts its inputs to the
389    # layer's compute_dtype.
390    self._autocast = kwargs.get('autocast',
391                                base_layer_utils.v2_dtype_behavior_enabled())
392
393    # Tracks `TrackableDataStructure`s, `Module`s, and `Layer`s.
394    # Ordered by when the object was assigned as an attr.
395    # Entries are unique.
396    self._maybe_create_attribute('_self_tracked_trackables', [])
397
398    # These lists will be filled via successive calls
399    # to self._add_inbound_node().
400    # Used in symbolic mode only, only in conjunction with graph-networks
401    self._inbound_nodes_value = []
402    self._outbound_nodes_value = []
403
404    self._init_call_fn_args()
405
406    # Whether the `call` method can be used to build a TF graph without issues.
407    # This attribute has no effect if the model is created using the Functional
408    # API. Instead, `model.dynamic` is determined based on the internal layers.
409    self._dynamic = dynamic
410
411    # Manage input shape information if passed.
412    if 'input_dim' in kwargs and 'input_shape' not in kwargs:
413      # Backwards compatibility: alias 'input_dim' to 'input_shape'.
414      kwargs['input_shape'] = (kwargs['input_dim'],)
415    if 'input_shape' in kwargs or 'batch_input_shape' in kwargs:
416      # In this case we will later create an input layer
417      # to insert before the current layer
418      if 'batch_input_shape' in kwargs:
419        batch_input_shape = tuple(kwargs['batch_input_shape'])
420      elif 'input_shape' in kwargs:
421        if 'batch_size' in kwargs:
422          batch_size = kwargs['batch_size']
423        else:
424          batch_size = None
425        batch_input_shape = (batch_size,) + tuple(kwargs['input_shape'])
426      self._batch_input_shape = batch_input_shape
427
428    # Manage initial weight values if passed.
429    self._initial_weights = kwargs.get('weights', None)
430
431    # Whether the layer will track any layers that is set as attribute on itself
432    # as sub-layers, the weights from the sub-layers will be included in the
433    # parent layer's variables() as well.
434    # Default to True, which means auto tracking is turned on. Certain subclass
435    # might want to turn it off, like Sequential model.
436    self._auto_track_sub_layers = True
437
438    # For backwards compat reasons, most built-in layers do not guarantee
439    # That they will 100% preserve the structure of input args when saving
440    # / loading configs. E.g. they may un-nest an arg that is
441    # a list with one element.
442    self._preserve_input_structure_in_config = False
443
444  @trackable.no_automatic_dependency_tracking
445  @generic_utils.default
446  def build(self, input_shape):
447    """Creates the variables of the layer (optional, for subclass implementers).
448
449    This is a method that implementers of subclasses of `Layer` or `Model`
450    can override if they need a state-creation step in-between
451    layer instantiation and layer call.
452
453    This is typically used to create the weights of `Layer` subclasses.
454
455    Args:
456      input_shape: Instance of `TensorShape`, or list of instances of
457        `TensorShape` if the layer expects a list of inputs
458        (one instance per input).
459    """
460    # Only record the build input shapes of overridden build methods.
461    if not hasattr(self.build, '_is_default'):
462      self._build_input_shape = input_shape
463    self.built = True
464
465  @doc_controls.for_subclass_implementers
466  def call(self, inputs, *args, **kwargs):  # pylint: disable=unused-argument
467    """This is where the layer's logic lives.
468
469    Note here that `call()` method in `tf.keras` is little bit different
470    from `keras` API. In `keras` API, you can pass support masking for
471    layers as additional arguments. Whereas `tf.keras` has `compute_mask()`
472    method to support masking.
473
474    Args:
475      inputs: Input tensor, or dict/list/tuple of input tensors.
476        The first positional `inputs` argument is subject to special rules:
477        - `inputs` must be explicitly passed. A layer cannot have zero
478          arguments, and `inputs` cannot be provided via the default value
479          of a keyword argument.
480        - NumPy array or Python scalar values in `inputs` get cast as tensors.
481        - Keras mask metadata is only collected from `inputs`.
482        - Layers are built (`build(input_shape)` method)
483          using shape info from `inputs` only.
484        - `input_spec` compatibility is only checked against `inputs`.
485        - Mixed precision input casting is only applied to `inputs`.
486          If a layer has tensor arguments in `*args` or `**kwargs`, their
487          casting behavior in mixed precision should be handled manually.
488        - The SavedModel input specification is generated using `inputs` only.
489        - Integration with various ecosystem packages like TFMOT, TFLite,
490          TF.js, etc is only supported for `inputs` and not for tensors in
491          positional and keyword arguments.
492      *args: Additional positional arguments. May contain tensors, although
493        this is not recommended, for the reasons above.
494      **kwargs: Additional keyword arguments. May contain tensors, although
495        this is not recommended, for the reasons above.
496        The following optional keyword arguments are reserved:
497        - `training`: Boolean scalar tensor of Python boolean indicating
498          whether the `call` is meant for training or inference.
499        - `mask`: Boolean input mask. If the layer's `call()` method takes a
500          `mask` argument, its default value will be set to the mask generated
501          for `inputs` by the previous layer (if `input` did come from a layer
502          that generated a corresponding mask, i.e. if it came from a Keras
503          layer with masking support).
504
505    Returns:
506      A tensor or list/tuple of tensors.
507    """
508    return inputs
509
510  @doc_controls.for_subclass_implementers
511  def _add_trackable(self, trackable_object, trainable):
512    """Adds a Trackable object to this layer's state.
513
514    Args:
515      trackable_object: The tf.tracking.Trackable object to add.
516      trainable: Boolean, whether the variable should be part of the layer's
517        "trainable_variables" (e.g. variables, biases) or
518        "non_trainable_variables" (e.g. BatchNorm mean and variance).
519
520    Returns:
521      The TrackableWeightHandler used to track this object.
522    """
523    if isinstance(trackable_object, base_layer_utils.TrackableWeightHandler):
524      handler = trackable_object
525    else:
526      handler = base_layer_utils.TrackableWeightHandler(trackable_object)
527    if trainable:
528      self._trainable_weights.append(handler)
529    else:
530      self._non_trainable_weights.append(handler)
531    return handler
532
533  @doc_controls.for_subclass_implementers
534  def add_weight(self,
535                 name=None,
536                 shape=None,
537                 dtype=None,
538                 initializer=None,
539                 regularizer=None,
540                 trainable=None,
541                 constraint=None,
542                 use_resource=None,
543                 synchronization=tf_variables.VariableSynchronization.AUTO,
544                 aggregation=tf_variables.VariableAggregation.NONE,
545                 **kwargs):
546    """Adds a new variable to the layer.
547
548    Args:
549      name: Variable name.
550      shape: Variable shape. Defaults to scalar if unspecified.
551      dtype: The type of the variable. Defaults to `self.dtype`.
552      initializer: Initializer instance (callable).
553      regularizer: Regularizer instance (callable).
554      trainable: Boolean, whether the variable should be part of the layer's
555        "trainable_variables" (e.g. variables, biases)
556        or "non_trainable_variables" (e.g. BatchNorm mean and variance).
557        Note that `trainable` cannot be `True` if `synchronization`
558        is set to `ON_READ`.
559      constraint: Constraint instance (callable).
560      use_resource: Whether to use `ResourceVariable`.
561      synchronization: Indicates when a distributed a variable will be
562        aggregated. Accepted values are constants defined in the class
563        `tf.VariableSynchronization`. By default the synchronization is set to
564        `AUTO` and the current `DistributionStrategy` chooses
565        when to synchronize. If `synchronization` is set to `ON_READ`,
566        `trainable` must not be set to `True`.
567      aggregation: Indicates how a distributed variable will be aggregated.
568        Accepted values are constants defined in the class
569        `tf.VariableAggregation`.
570      **kwargs: Additional keyword arguments. Accepted values are `getter`,
571        `collections`, `experimental_autocast` and `caching_device`.
572
573    Returns:
574      The variable created.
575
576    Raises:
577      ValueError: When giving unsupported dtype and no initializer or when
578        trainable has been set to True with synchronization set as `ON_READ`.
579    """
580    if shape is None:
581      shape = ()
582    kwargs.pop('partitioner', None)  # Ignored.
583    # Validate optional keyword arguments.
584    for kwarg in kwargs:
585      if kwarg not in ['collections', 'experimental_autocast',
586                       'caching_device', 'getter']:
587        raise TypeError('Unknown keyword argument:', kwarg)
588    collections_arg = kwargs.pop('collections', None)
589    # 'experimental_autocast' can be set to False by the caller to indicate an
590    # AutoCastVariable should never be created.
591    autocast = kwargs.pop('experimental_autocast', True)
592    # See the docstring for tf.Variable about the details for caching_device.
593    caching_device = kwargs.pop('caching_device', None)
594
595    if dtype is None:
596      dtype = self.dtype or backend.floatx()
597    dtype = dtypes.as_dtype(dtype)
598    if self._dtype_policy.variable_dtype is None:
599      # The policy is "_infer", so we infer the policy from the variable dtype.
600      self._set_dtype_policy(policy.Policy(dtype.base_dtype.name))
601    initializer = initializers.get(initializer)
602    regularizer = regularizers.get(regularizer)
603    constraint = constraints.get(constraint)
604
605    if synchronization == tf_variables.VariableSynchronization.ON_READ:
606      if trainable:
607        raise ValueError(
608            'Synchronization value can be set to '
609            'VariableSynchronization.ON_READ only for non-trainable variables. '
610            'You have specified trainable=True and '
611            'synchronization=VariableSynchronization.ON_READ.')
612      else:
613        # Set trainable to be false when variable is to be synced on read.
614        trainable = False
615    elif trainable is None:
616      trainable = True
617
618    # Initialize variable when no initializer provided
619    if initializer is None:
620      # If dtype is DT_FLOAT, provide a uniform unit scaling initializer
621      if dtype.is_floating:
622        initializer = initializers.get('glorot_uniform')
623      # If dtype is DT_INT/DT_UINT, provide a default value `zero`
624      # If dtype is DT_BOOL, provide a default value `FALSE`
625      elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool:
626        initializer = initializers.get('zeros')
627      # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here?
628      elif 'getter' not in kwargs:
629        # When `getter` is specified, it's possibly fine for `initializer` to be
630        # None since it's up to the custom `getter` to raise error in case it
631        # indeed needs `initializer`.
632        raise ValueError('An initializer for variable %s of type %s is required'
633                         ' for layer %s' % (name, dtype.base_dtype, self.name))
634
635    getter = kwargs.pop('getter', base_layer_utils.make_variable)
636    if (autocast and
637        self._dtype_policy.compute_dtype != self._dtype_policy.variable_dtype
638        and dtype.is_floating):
639      old_getter = getter
640      # Wrap variable constructor to return an AutoCastVariable.
641      def getter(*args, **kwargs):  # pylint: disable=function-redefined
642        variable = old_getter(*args, **kwargs)
643        return autocast_variable.create_autocast_variable(variable)
644      # Also the caching_device does not work with the mixed precision API,
645      # disable it if it is specified.
646      # TODO(b/142020079): Reenable it once the bug is fixed.
647      if caching_device is not None:
648        tf_logging.warning(
649            '`caching_device` does not work with mixed precision API. Ignoring '
650            'user specified `caching_device`.')
651        caching_device = None
652
653    variable = self._add_variable_with_custom_getter(
654        name=name,
655        shape=shape,
656        # TODO(allenl): a `make_variable` equivalent should be added as a
657        # `Trackable` method.
658        getter=getter,
659        # Manage errors in Layer rather than Trackable.
660        overwrite=True,
661        initializer=initializer,
662        dtype=dtype,
663        constraint=constraint,
664        trainable=trainable,
665        use_resource=use_resource,
666        collections=collections_arg,
667        synchronization=synchronization,
668        aggregation=aggregation,
669        caching_device=caching_device)
670    if regularizer is not None:
671      # TODO(fchollet): in the future, this should be handled at the
672      # level of variable creation, and weight regularization losses
673      # should be variable attributes.
674      name_in_scope = variable.name[:variable.name.find(':')]
675      self._handle_weight_regularization(name_in_scope,
676                                         variable,
677                                         regularizer)
678    if base_layer_utils.is_split_variable(variable):
679      for v in variable:
680        backend.track_variable(v)
681        if trainable:
682          self._trainable_weights.append(v)
683        else:
684          self._non_trainable_weights.append(v)
685    else:
686      backend.track_variable(variable)
687      if trainable:
688        self._trainable_weights.append(variable)
689      else:
690        self._non_trainable_weights.append(variable)
691    return variable
692
693  @generic_utils.default
694  def get_config(self):
695    """Returns the config of the layer.
696
697    A layer config is a Python dictionary (serializable)
698    containing the configuration of a layer.
699    The same layer can be reinstantiated later
700    (without its trained weights) from this configuration.
701
702    The config of a layer does not include connectivity
703    information, nor the layer class name. These are handled
704    by `Network` (one layer of abstraction above).
705
706    Note that `get_config()` does not guarantee to return a fresh copy of dict
707    every time it is called. The callers should make a copy of the returned dict
708    if they want to modify it.
709
710    Returns:
711        Python dictionary.
712    """
713    all_args = tf_inspect.getfullargspec(self.__init__).args
714    config = {
715        'name': self.name,
716        'trainable': self.trainable,
717    }
718    if hasattr(self, '_batch_input_shape'):
719      config['batch_input_shape'] = self._batch_input_shape
720    config['dtype'] = policy.serialize(self._dtype_policy)
721    if hasattr(self, 'dynamic'):
722      # Only include `dynamic` in the `config` if it is `True`
723      if self.dynamic:
724        config['dynamic'] = self.dynamic
725      elif 'dynamic' in all_args:
726        all_args.remove('dynamic')
727    expected_args = config.keys()
728    # Finds all arguments in the `__init__` that are not in the config:
729    extra_args = [arg for arg in all_args if arg not in expected_args]
730    # Check that either the only argument in the `__init__` is  `self`,
731    # or that `get_config` has been overridden:
732    if len(extra_args) > 1 and hasattr(self.get_config, '_is_default'):
733      raise NotImplementedError('Layer %s has arguments in `__init__` and '
734                                'therefore must override `get_config`.' %
735                                self.__class__.__name__)
736    return config
737
738  @classmethod
739  def from_config(cls, config):
740    """Creates a layer from its config.
741
742    This method is the reverse of `get_config`,
743    capable of instantiating the same layer from the config
744    dictionary. It does not handle layer connectivity
745    (handled by Network), nor weights (handled by `set_weights`).
746
747    Args:
748        config: A Python dictionary, typically the
749            output of get_config.
750
751    Returns:
752        A layer instance.
753    """
754    return cls(**config)
755
756  def compute_output_shape(self, input_shape):
757    """Computes the output shape of the layer.
758
759    If the layer has not been built, this method will call `build` on the
760    layer. This assumes that the layer will later be used with inputs that
761    match the input shape provided here.
762
763    Args:
764        input_shape: Shape tuple (tuple of integers)
765            or list of shape tuples (one per output tensor of the layer).
766            Shape tuples can include None for free dimensions,
767            instead of an integer.
768
769    Returns:
770        An input shape tuple.
771    """
772    if context.executing_eagerly():
773      # In this case we build the model first in order to do shape inference.
774      # This is acceptable because the framework only calls
775      # `compute_output_shape` on shape values that the layer would later be
776      # built for. It would however cause issues in case a user attempts to
777      # use `compute_output_shape` manually with shapes that are incompatible
778      # with the shape the Layer will be called on (these users will have to
779      # implement `compute_output_shape` themselves).
780      self._maybe_build(input_shape)
781      with func_graph.FuncGraph(str(self.name) + '_scratch_graph').as_default():
782        input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
783        def _make_placeholder_like(shape):
784          ph = backend.placeholder(shape=shape, dtype=self.dtype)
785          ph._keras_mask = None
786          return ph
787        inputs = nest.map_structure(_make_placeholder_like, input_shape)
788        try:
789          outputs = self(inputs, training=False)
790        except TypeError as e:
791          raise NotImplementedError(
792              'We could not automatically infer the static shape of the '
793              'layer\'s output. Please implement the '
794              '`compute_output_shape` method on your layer (%s).' %
795              self.__class__.__name__) from e
796      return nest.map_structure(lambda t: t.shape, outputs)
797    raise NotImplementedError(
798        'Please run in eager mode or implement the `compute_output_shape` '
799        'method on your layer (%s).' % self.__class__.__name__)
800
801  @doc_controls.for_subclass_implementers
802  def compute_output_signature(self, input_signature):
803    """Compute the output tensor signature of the layer based on the inputs.
804
805    Unlike a TensorShape object, a TensorSpec object contains both shape
806    and dtype information for a tensor. This method allows layers to provide
807    output dtype information if it is different from the input dtype.
808    For any layer that doesn't implement this function,
809    the framework will fall back to use `compute_output_shape`, and will
810    assume that the output dtype matches the input dtype.
811
812    Args:
813      input_signature: Single TensorSpec or nested structure of TensorSpec
814        objects, describing a candidate input for the layer.
815
816    Returns:
817      Single TensorSpec or nested structure of TensorSpec objects, describing
818        how the layer would transform the provided input.
819
820    Raises:
821      TypeError: If input_signature contains a non-TensorSpec object.
822    """
823    def check_type_return_shape(s):
824      if not isinstance(s, tensor_spec.TensorSpec):
825        raise TypeError('Only TensorSpec signature types are supported, '
826                        'but saw signature entry: {}.'.format(s))
827      return s.shape
828    input_shape = nest.map_structure(check_type_return_shape, input_signature)
829    output_shape = self.compute_output_shape(input_shape)
830    dtype = self._compute_dtype
831    if dtype is None:
832      input_dtypes = [s.dtype for s in nest.flatten(input_signature)]
833      # Default behavior when self.dtype is None, is to use the first input's
834      # dtype.
835      dtype = input_dtypes[0]
836    return nest.map_structure(
837        lambda s: tensor_spec.TensorSpec(dtype=dtype, shape=s),
838        output_shape)
839
840  def _keras_tensor_symbolic_call(self, inputs, input_masks, args, kwargs):
841    if self.dynamic:
842      # We will use static shape inference to return symbolic tensors
843      # matching the specifications of the layer outputs.
844      # Since `self.dynamic` is True, we will never attempt to
845      # run the underlying TF graph (which is disconnected).
846      # TODO(fchollet): consider py_func as an alternative, which
847      # would enable us to run the underlying graph if needed.
848      input_signature = nest.map_structure(
849          lambda x: tensor_spec.TensorSpec(shape=x.shape, dtype=x.dtype),
850          inputs)
851      output_signature = self.compute_output_signature(input_signature)
852      return nest.map_structure(keras_tensor.KerasTensor, output_signature)
853    else:
854      return self._infer_output_signature(inputs, args, kwargs, input_masks)
855
856  def _infer_output_signature(self, inputs, args, kwargs, input_masks):
857    """TODO(kaftan): Docstring."""
858
859    call_fn = self.call
860    # Wrapping `call` function in autograph to allow for dynamic control
861    # flow and control dependencies in call. We are limiting this to
862    # subclassed layers as autograph is strictly needed only for
863    # subclassed layers and models.
864    # tf_convert will respect the value of autograph setting in the
865    # enclosing tf.function, if any.
866    if (base_layer_utils.is_subclassed(self) and
867        not base_layer_utils.from_saved_model(self)):
868      call_fn = autograph.tf_convert(self.call, ag_ctx.control_status_ctx())
869
870    # We enter a scratch graph and build placeholder inputs inside of it that
871    # match the input args.
872    # We then call the layer inside of the scratch graph to identify the
873    # output signatures, then we build KerasTensors corresponding to those
874    # outputs.
875    scratch_graph = func_graph.FuncGraph(str(self.name) + '_scratch_graph')
876    with scratch_graph.as_default():
877      inputs = nest.map_structure(
878          keras_tensor.keras_tensor_to_placeholder, inputs)
879      args = nest.map_structure(
880          keras_tensor.keras_tensor_to_placeholder, args)
881      kwargs = nest.map_structure(
882          keras_tensor.keras_tensor_to_placeholder, kwargs)
883      input_masks = nest.map_structure(
884          keras_tensor.keras_tensor_to_placeholder, input_masks)
885
886      with backend.name_scope(self._name_scope()):  # pylint: disable=not-callable
887        with autocast_variable.enable_auto_cast_variables(
888            self._compute_dtype_object):
889          # Build layer if applicable (if the `build` method has been
890          # overridden).
891          # TODO(kaftan): do we maybe_build here, or have we already done it?
892          self._maybe_build(inputs)
893          inputs = self._maybe_cast_inputs(inputs)
894          outputs = call_fn(inputs, *args, **kwargs)
895
896        self._handle_activity_regularization(inputs, outputs)
897      self._set_mask_metadata(inputs, outputs, input_masks,
898                              build_graph=False)
899      outputs = nest.map_structure(
900          keras_tensor.keras_tensor_from_tensor, outputs)
901
902    if hasattr(self, '_set_inputs') and not self.inputs:
903      # TODO(kaftan): figure out if we need to do this at all
904      # Subclassed network: explicitly set metadata normally set by
905      # a call to self._set_inputs().
906      self._set_inputs(inputs, outputs)
907    del scratch_graph
908    return outputs
909
910  @generic_utils.default
911  def compute_mask(self, inputs, mask=None):  # pylint: disable=unused-argument
912    """Computes an output mask tensor.
913
914    Args:
915        inputs: Tensor or list of tensors.
916        mask: Tensor or list of tensors.
917
918    Returns:
919        None or a tensor (or list of tensors,
920            one per output tensor of the layer).
921    """
922    if not self._supports_masking:
923      if any(m is not None for m in nest.flatten(mask)):
924        raise TypeError('Layer ' + self.name + ' does not support masking, '
925                        'but was passed an input_mask: ' + str(mask))
926      # masking not explicitly supported: return None as mask.
927      return None
928    # if masking is explicitly supported, by default
929    # carry over the input mask
930    return mask
931
932  def __call__(self, *args, **kwargs):
933    """Wraps `call`, applying pre- and post-processing steps.
934
935    Args:
936      *args: Positional arguments to be passed to `self.call`.
937      **kwargs: Keyword arguments to be passed to `self.call`.
938
939    Returns:
940      Output tensor(s).
941
942    Note:
943      - The following optional keyword arguments are reserved for specific uses:
944        * `training`: Boolean scalar tensor of Python boolean indicating
945          whether the `call` is meant for training or inference.
946        * `mask`: Boolean input mask.
947      - If the layer's `call` method takes a `mask` argument (as some Keras
948        layers do), its default value will be set to the mask generated
949        for `inputs` by the previous layer (if `input` did come from
950        a layer that generated a corresponding mask, i.e. if it came from
951        a Keras layer with masking support.
952      - If the layer is not built, the method will call `build`.
953
954    Raises:
955      ValueError: if the layer's `call` method returns None (an invalid value).
956      RuntimeError: if `super().__init__()` was not called in the constructor.
957    """
958    if not hasattr(self, '_thread_local'):
959      raise RuntimeError(
960          'You must call `super().__init__()` in the layer constructor.')
961
962    # `inputs` (the first arg in the method spec) is special cased in
963    # layer call due to historical reasons.
964    # This special casing currently takes the form of:
965    # - 'inputs' must be explicitly passed. A layer cannot have zero arguments,
966    #   and inputs cannot have been provided via the default value of a kwarg.
967    # - numpy/scalar values in `inputs` get converted to tensors
968    # - implicit masks / mask metadata are only collected from 'inputs`
969    # - Layers are built using shape info from 'inputs' only
970    # - input_spec compatibility is only checked against `inputs`
971    # - mixed precision casting (autocast) is only applied to `inputs`,
972    #   not to any other argument.
973    # - setting the SavedModel saving spec.
974    inputs, args, kwargs = self._split_out_first_arg(args, kwargs)
975    input_list = nest.flatten(inputs)
976
977    # Functional Model construction mode is invoked when `Layer`s are called on
978    # symbolic `KerasTensor`s, i.e.:
979    # >> inputs = tf.keras.Input(10)
980    # >> outputs = MyLayer()(inputs)  # Functional construction mode.
981    # >> model = tf.keras.Model(inputs, outputs)
982    if _in_functional_construction_mode(self, inputs, args, kwargs, input_list):
983      return self._functional_construction_call(inputs, args, kwargs,
984                                                input_list)
985
986    # Maintains info about the `Layer.call` stack.
987    call_context = base_layer_utils.call_context()
988
989    # Accept NumPy and scalar inputs by converting to Tensors.
990    if any(isinstance(x, (
991        np_arrays.ndarray, np.ndarray, float, int)) for x in input_list):
992      inputs = nest.map_structure(_convert_numpy_or_python_types, inputs)
993      input_list = nest.flatten(inputs)
994
995    # Handle `mask` propagation from previous layer to current layer. Masks can
996    # be propagated explicitly via the `mask` argument, or implicitly via
997    # setting the `_keras_mask` attribute on the inputs to a Layer. Masks passed
998    # explicitly take priority.
999    input_masks, mask_is_implicit = self._get_input_masks(
1000        inputs, input_list, args, kwargs)
1001    if self._expects_mask_arg and mask_is_implicit:
1002      kwargs['mask'] = input_masks
1003
1004    # Training mode for `Layer.call` is set via (in order of priority):
1005    # (1) The `training` argument passed to this `Layer.call`, if it is not None
1006    # (2) The training mode of an outer `Layer.call`.
1007    # (3) The default mode set by `tf.keras.backend.set_learning_phase` (if set)
1008    # (4) Any non-None default value for `training` specified in the call
1009    #  signature
1010    # (5) False (treating the layer as if it's in inference)
1011    args, kwargs, training_mode = self._set_training_mode(
1012        args, kwargs, call_context)
1013
1014    # Losses are cleared for all sublayers on the outermost `Layer.call`.
1015    # Losses are not cleared on inner `Layer.call`s, because sublayers can be
1016    # called multiple times.
1017    if not call_context.in_call:
1018      self._clear_losses()
1019
1020    eager = context.executing_eagerly()
1021    with call_context.enter(
1022        layer=self,
1023        inputs=inputs,
1024        build_graph=not eager,
1025        training=training_mode):
1026
1027      input_spec.assert_input_compatibility(self.input_spec, inputs, self.name)
1028      if eager:
1029        call_fn = self.call
1030        name_scope = self._name
1031      else:
1032        name_scope = self._name_scope()  # Avoid autoincrementing.  # pylint: disable=not-callable
1033        call_fn = self._autographed_call()
1034
1035      with ops.name_scope_v2(name_scope):
1036        if not self.built:
1037          self._maybe_build(inputs)
1038
1039        if self._autocast:
1040          inputs = self._maybe_cast_inputs(inputs, input_list)
1041
1042        with autocast_variable.enable_auto_cast_variables(
1043            self._compute_dtype_object):
1044          outputs = call_fn(inputs, *args, **kwargs)
1045
1046        if self._activity_regularizer:
1047          self._handle_activity_regularization(inputs, outputs)
1048        if self._supports_masking:
1049          self._set_mask_metadata(inputs, outputs, input_masks, not eager)
1050        if self._saved_model_inputs_spec is None:
1051          self._set_save_spec(inputs)
1052
1053        return outputs
1054
1055  def _functional_construction_call(self, inputs, args, kwargs, input_list):
1056    call_context = base_layer_utils.call_context()
1057
1058    # Accept NumPy and scalar inputs by converting to Tensors.
1059    if any(isinstance(x, (
1060        np_arrays.ndarray, np.ndarray, float, int)) for x in input_list):
1061
1062      def _convert_non_tensor(x):
1063        # Don't call `ops.convert_to_tensor` on all `inputs` because
1064        # `SparseTensors` can't be converted to `Tensor`.
1065        if isinstance(x, (np_arrays.ndarray, np.ndarray, float, int)):
1066          return ops.convert_to_tensor_v2_with_dispatch(x)
1067        return x
1068
1069      inputs = nest.map_structure(_convert_non_tensor, inputs)
1070      input_list = nest.flatten(inputs)
1071
1072    # Handle `mask` propagation from previous layer to current layer. Masks can
1073    # be propagated explicitly via the `mask` argument, or implicitly via
1074    # setting the `_keras_mask` attribute on the inputs to a Layer. Masks passed
1075    # explicitly take priority.
1076    mask_arg_passed_by_framework = False
1077    input_masks, mask_is_implicit = self._get_input_masks(
1078        inputs, input_list, args, kwargs)
1079    if self._expects_mask_arg and mask_is_implicit:
1080      kwargs['mask'] = input_masks
1081      mask_arg_passed_by_framework = True
1082
1083    # If `training` argument is None or not explicitly passed,
1084    # propagate `training` value from this layer's calling layer.
1085    training_value = None
1086    training_arg_passed_by_framework = False
1087    # Priority 1: `training` was explicitly passed a non-None value.
1088    if self._call_arg_was_passed('training', args, kwargs):
1089      training_value = self._get_call_arg_value('training', args, kwargs)
1090      if not self._expects_training_arg:
1091        kwargs.pop('training')
1092
1093    if training_value is None:
1094      # Priority 2: `training` was passed to a parent layer.
1095      if call_context.training is not None:
1096        training_value = call_context.training
1097      # Priority 3: `learning_phase()` has been set.
1098      elif backend.global_learning_phase_is_set():
1099        training_value = backend.learning_phase()
1100        # Force the training_value to be bool type which matches to the contract
1101        # for layer/model call args.
1102        if tensor_util.is_tf_type(training_value):
1103          training_value = math_ops.cast(training_value, dtypes.bool)
1104        else:
1105          training_value = bool(training_value)
1106      # Priority 4: trace layer with the default training argument specified
1107      # in the `call` signature (or in inference mode if the `call` signature
1108      # specifies no non-None default).
1109      else:
1110        training_value = self._default_training_arg
1111      # In cases (2), (3), (4) the training argument is passed automatically
1112      # by the framework, and will not be hard-coded into the model.
1113      if self._expects_training_arg:
1114        args, kwargs = self._set_call_arg_value('training', training_value,
1115                                                args, kwargs)
1116        training_arg_passed_by_framework = True
1117
1118    with call_context.enter(
1119        layer=self, inputs=inputs, build_graph=True, training=training_value):
1120      # Check input assumptions set after layer building, e.g. input shape.
1121      outputs = self._keras_tensor_symbolic_call(
1122          inputs, input_masks, args, kwargs)
1123
1124      if outputs is None:
1125        raise ValueError('A layer\'s `call` method should return a '
1126                         'Tensor or a list of Tensors, not None '
1127                         '(layer: ' + self.name + ').')
1128      if training_arg_passed_by_framework:
1129        args, kwargs = self._set_call_arg_value(
1130            'training', None, args, kwargs, pop_kwarg_if_none=True)
1131      if mask_arg_passed_by_framework:
1132        kwargs.pop('mask')
1133      # Node connectivity does not special-case the first argument.
1134      outputs = self._set_connectivity_metadata((inputs,) + args, kwargs,
1135                                                outputs)
1136      return outputs
1137
1138  def _set_training_mode(self, args, kwargs, call_context):
1139    training_mode = None
1140    if self._expects_training_arg:
1141      # (1) `training` was passed to this `Layer.call`.
1142      if self._call_arg_was_passed('training', args, kwargs):
1143        training_mode = self._get_call_arg_value('training', args, kwargs)
1144      # If no `training` arg was passed, or `None` was explicitly passed,
1145      # the framework will make a decision about the training mode is.
1146      if training_mode is None:
1147        call_ctx_training = call_context.training
1148        # (2) `training` mode is inferred from an outer `Layer.call`.
1149        if call_ctx_training is not None:
1150          training_mode = call_ctx_training
1151        # (3) User set `tf.keras.backend.set_learning_phase`.
1152        elif backend.global_learning_phase_is_set():
1153          training_mode = backend.learning_phase()
1154          # Ensure value is a `bool` or `tf.bool`.
1155          if isinstance(training_mode, bool):
1156            pass
1157          elif tensor_util.is_tf_type(training_mode):
1158            training_mode = math_ops.cast(training_mode, dtypes.bool)
1159          else:
1160            training_mode = bool(training_mode)
1161        # (4) We default to using `call`'s default value for `training`,
1162        # or treating the layer as if it is in inference if no non-None default
1163        # is specified in the `call` signature.
1164        else:
1165          training_mode = self._default_training_arg
1166
1167        # For case (2), (3), (4) `training` arg is passed by framework.
1168        args, kwargs = self._set_call_arg_value('training', training_mode, args,
1169                                                kwargs)
1170    else:
1171      if 'training' in kwargs:
1172        # `training` was passed to this `Layer` but is not needed for
1173        # `Layer.call`. It will set the default mode for inner `Layer.call`s.
1174        training_mode = kwargs.pop('training')
1175      else:
1176        # Grab the current `training` mode from any outer `Layer.call`.
1177        training_mode = call_context.training
1178
1179    return args, kwargs, training_mode
1180
1181  def _autographed_call(self):
1182    # Wrapping `call` function in autograph to allow for dynamic control
1183    # flow and control dependencies in call. We are limiting this to
1184    # subclassed layers as autograph is strictly needed only for
1185    # subclassed layers and models.
1186    # tf_convert will respect the value of autograph setting in the
1187    # enclosing tf.function, if any.
1188    if (base_layer_utils.is_subclassed(self) and
1189        not base_layer_utils.from_saved_model(self)):
1190      return autograph.tf_convert(self.call, ag_ctx.control_status_ctx())
1191    else:
1192      return self.call
1193
1194  @property
1195  def dtype(self):
1196    """The dtype of the layer weights.
1197
1198    This is equivalent to `Layer.dtype_policy.variable_dtype`. Unless
1199    mixed precision is used, this is the same as `Layer.compute_dtype`, the
1200    dtype of the layer's computations.
1201    """
1202    return self._dtype_policy.variable_dtype
1203
1204  @property
1205  def name(self):
1206    """Name of the layer (string), set in the constructor."""
1207    return self._name
1208
1209  @property
1210  def supports_masking(self):
1211    """Whether this layer supports computing a mask using `compute_mask`."""
1212    return self._supports_masking
1213
1214  @supports_masking.setter
1215  def supports_masking(self, value):
1216    self._supports_masking = value
1217
1218  @property
1219  def dynamic(self):
1220    """Whether the layer is dynamic (eager-only); set in the constructor."""
1221    return any(layer._dynamic for layer in self._flatten_layers())
1222
1223  @property
1224  @doc_controls.do_not_doc_inheritable
1225  def stateful(self):
1226    return any(layer._stateful for layer in self._flatten_layers())
1227
1228  @stateful.setter
1229  def stateful(self, value):
1230    self._stateful = value
1231
1232  @property
1233  def trainable(self):
1234    return self._trainable
1235
1236  @trainable.setter
1237  def trainable(self, value):
1238    for layer in self._flatten_layers():
1239      layer._trainable = value
1240
1241  @property
1242  def activity_regularizer(self):
1243    """Optional regularizer function for the output of this layer."""
1244    return self._activity_regularizer
1245
1246  @activity_regularizer.setter
1247  def activity_regularizer(self, regularizer):
1248    """Optional regularizer function for the output of this layer."""
1249    self._activity_regularizer = regularizer
1250
1251  @property
1252  def input_spec(self):
1253    """`InputSpec` instance(s) describing the input format for this layer.
1254
1255    When you create a layer subclass, you can set `self.input_spec` to enable
1256    the layer to run input compatibility checks when it is called.
1257    Consider a `Conv2D` layer: it can only be called on a single input tensor
1258    of rank 4. As such, you can set, in `__init__()`:
1259
1260    ```python
1261    self.input_spec = tf.keras.layers.InputSpec(ndim=4)
1262    ```
1263
1264    Now, if you try to call the layer on an input that isn't rank 4
1265    (for instance, an input of shape `(2,)`, it will raise a nicely-formatted
1266    error:
1267
1268    ```
1269    ValueError: Input 0 of layer conv2d is incompatible with the layer:
1270    expected ndim=4, found ndim=1. Full shape received: [2]
1271    ```
1272
1273    Input checks that can be specified via `input_spec` include:
1274    - Structure (e.g. a single input, a list of 2 inputs, etc)
1275    - Shape
1276    - Rank (ndim)
1277    - Dtype
1278
1279    For more information, see `tf.keras.layers.InputSpec`.
1280
1281    Returns:
1282      A `tf.keras.layers.InputSpec` instance, or nested structure thereof.
1283    """
1284    return self._input_spec
1285
1286  @input_spec.setter
1287  # Must be decorated to prevent tracking, since the input_spec can be nested
1288  # InputSpec objects.
1289  @trackable.no_automatic_dependency_tracking
1290  def input_spec(self, value):
1291    for v in nest.flatten(value):
1292      if v is not None and not isinstance(v, InputSpec):
1293        raise TypeError('Layer input_spec must be an instance of InputSpec. '
1294                        'Got: {}'.format(v))
1295    self._input_spec = value
1296
1297  @property
1298  def trainable_weights(self):
1299    """List of all trainable weights tracked by this layer.
1300
1301    Trainable weights are updated via gradient descent during training.
1302
1303    Returns:
1304      A list of trainable variables.
1305    """
1306    if self.trainable:
1307      children_weights = self._gather_children_attribute('trainable_variables')
1308      return self._dedup_weights(self._trainable_weights + children_weights)
1309    else:
1310      return []
1311
1312  @property
1313  def non_trainable_weights(self):
1314    """List of all non-trainable weights tracked by this layer.
1315
1316    Non-trainable weights are *not* updated during training. They are expected
1317    to be updated manually in `call()`.
1318
1319    Returns:
1320      A list of non-trainable variables.
1321    """
1322    if self.trainable:
1323      children_weights = self._gather_children_attribute(
1324          'non_trainable_variables')
1325      non_trainable_weights = self._non_trainable_weights + children_weights
1326    else:
1327      children_weights = self._gather_children_attribute('variables')
1328      non_trainable_weights = (
1329          self._trainable_weights + self._non_trainable_weights +
1330          children_weights)
1331    return self._dedup_weights(non_trainable_weights)
1332
1333  @property
1334  def weights(self):
1335    """Returns the list of all layer variables/weights.
1336
1337    Returns:
1338      A list of variables.
1339    """
1340    return self.trainable_weights + self.non_trainable_weights
1341
1342  @property
1343  @doc_controls.do_not_generate_docs
1344  def updates(self):
1345    warnings.warn('`layer.updates` will be removed in a future version. '
1346                  'This property should not be used in TensorFlow 2.0, '
1347                  'as `updates` are applied automatically.')
1348    return []
1349
1350  @property
1351  def losses(self):
1352    """List of losses added using the `add_loss()` API.
1353
1354    Variable regularization tensors are created when this property is accessed,
1355    so it is eager safe: accessing `losses` under a `tf.GradientTape` will
1356    propagate gradients back to the corresponding variables.
1357
1358    Examples:
1359
1360    >>> class MyLayer(tf.keras.layers.Layer):
1361    ...   def call(self, inputs):
1362    ...     self.add_loss(tf.abs(tf.reduce_mean(inputs)))
1363    ...     return inputs
1364    >>> l = MyLayer()
1365    >>> l(np.ones((10, 1)))
1366    >>> l.losses
1367    [1.0]
1368
1369    >>> inputs = tf.keras.Input(shape=(10,))
1370    >>> x = tf.keras.layers.Dense(10)(inputs)
1371    >>> outputs = tf.keras.layers.Dense(1)(x)
1372    >>> model = tf.keras.Model(inputs, outputs)
1373    >>> # Activity regularization.
1374    >>> len(model.losses)
1375    0
1376    >>> model.add_loss(tf.abs(tf.reduce_mean(x)))
1377    >>> len(model.losses)
1378    1
1379
1380    >>> inputs = tf.keras.Input(shape=(10,))
1381    >>> d = tf.keras.layers.Dense(10, kernel_initializer='ones')
1382    >>> x = d(inputs)
1383    >>> outputs = tf.keras.layers.Dense(1)(x)
1384    >>> model = tf.keras.Model(inputs, outputs)
1385    >>> # Weight regularization.
1386    >>> model.add_loss(lambda: tf.reduce_mean(d.kernel))
1387    >>> model.losses
1388    [<tf.Tensor: shape=(), dtype=float32, numpy=1.0>]
1389
1390    Returns:
1391      A list of tensors.
1392    """
1393    collected_losses = []
1394    for layer in self._flatten_layers():
1395      # If any eager losses are present, we assume the model to be part of an
1396      # eager training loop (either a custom one or the one used when
1397      # `run_eagerly=True`) and so we always return just the eager losses.
1398      if layer._eager_losses:
1399        # Filter placeholder losses that may have been added by revived layers.
1400        # (see base_layer_utils for details).
1401        if (layer._eager_losses[0] is
1402            not base_layer_utils.REVIVED_LOSS_PLACEHOLDER):
1403          collected_losses.extend(layer._eager_losses)
1404      else:
1405        collected_losses.extend(layer._losses)
1406      for regularizer in layer._callable_losses:
1407        loss_tensor = regularizer()
1408        if loss_tensor is not None:
1409          collected_losses.append(loss_tensor)
1410    return collected_losses
1411
1412  def add_loss(self, losses, **kwargs):
1413    """Add loss tensor(s), potentially dependent on layer inputs.
1414
1415    Some losses (for instance, activity regularization losses) may be dependent
1416    on the inputs passed when calling a layer. Hence, when reusing the same
1417    layer on different inputs `a` and `b`, some entries in `layer.losses` may
1418    be dependent on `a` and some on `b`. This method automatically keeps track
1419    of dependencies.
1420
1421    This method can be used inside a subclassed layer or model's `call`
1422    function, in which case `losses` should be a Tensor or list of Tensors.
1423
1424    Example:
1425
1426    ```python
1427    class MyLayer(tf.keras.layers.Layer):
1428      def call(self, inputs):
1429        self.add_loss(tf.abs(tf.reduce_mean(inputs)))
1430        return inputs
1431    ```
1432
1433    This method can also be called directly on a Functional Model during
1434    construction. In this case, any loss Tensors passed to this Model must
1435    be symbolic and be able to be traced back to the model's `Input`s. These
1436    losses become part of the model's topology and are tracked in `get_config`.
1437
1438    Example:
1439
1440    ```python
1441    inputs = tf.keras.Input(shape=(10,))
1442    x = tf.keras.layers.Dense(10)(inputs)
1443    outputs = tf.keras.layers.Dense(1)(x)
1444    model = tf.keras.Model(inputs, outputs)
1445    # Activity regularization.
1446    model.add_loss(tf.abs(tf.reduce_mean(x)))
1447    ```
1448
1449    If this is not the case for your loss (if, for example, your loss references
1450    a `Variable` of one of the model's layers), you can wrap your loss in a
1451    zero-argument lambda. These losses are not tracked as part of the model's
1452    topology since they can't be serialized.
1453
1454    Example:
1455
1456    ```python
1457    inputs = tf.keras.Input(shape=(10,))
1458    d = tf.keras.layers.Dense(10)
1459    x = d(inputs)
1460    outputs = tf.keras.layers.Dense(1)(x)
1461    model = tf.keras.Model(inputs, outputs)
1462    # Weight regularization.
1463    model.add_loss(lambda: tf.reduce_mean(d.kernel))
1464    ```
1465
1466    Args:
1467      losses: Loss tensor, or list/tuple of tensors. Rather than tensors, losses
1468        may also be zero-argument callables which create a loss tensor.
1469      **kwargs: Additional keyword arguments for backward compatibility.
1470        Accepted values:
1471          inputs - Deprecated, will be automatically inferred.
1472    """
1473    kwargs.pop('inputs', None)
1474    if kwargs:
1475      raise TypeError('Unknown keyword arguments: %s' % (kwargs.keys(),))
1476
1477    def _tag_callable(loss):
1478      """Tags callable loss tensor as `_unconditional_loss`."""
1479      if callable(loss):
1480        # We run the loss without autocasting, as regularizers are often
1481        # numerically unstable in float16.
1482        with autocast_variable.enable_auto_cast_variables(None):
1483          loss = loss()
1484      if loss is None:
1485        return None  # Will be filtered out when computing the .losses property
1486      if not tensor_util.is_tf_type(loss):
1487        loss = ops.convert_to_tensor_v2_with_dispatch(
1488            loss, dtype=backend.floatx())
1489      loss._unconditional_loss = True  # pylint: disable=protected-access
1490      return loss
1491
1492    losses = nest.flatten(losses)
1493
1494    callable_losses = []
1495    eager_losses = []
1496    symbolic_losses = []
1497    for loss in losses:
1498      if callable(loss):
1499        callable_losses.append(functools.partial(_tag_callable, loss))
1500        continue
1501      if loss is None:
1502        continue
1503      if not tensor_util.is_tf_type(loss) and not isinstance(
1504          loss, keras_tensor.KerasTensor):
1505        loss = ops.convert_to_tensor_v2_with_dispatch(
1506            loss, dtype=backend.floatx())
1507      # TF Functions should take the eager path.
1508      if ((tf_utils.is_symbolic_tensor(loss) or
1509           isinstance(loss, keras_tensor.KerasTensor)) and
1510          not base_layer_utils.is_in_tf_function()):
1511        symbolic_losses.append(loss)
1512      elif tensor_util.is_tf_type(loss):
1513        eager_losses.append(loss)
1514
1515    self._callable_losses.extend(callable_losses)
1516
1517    in_call_context = base_layer_utils.call_context().in_call
1518    if eager_losses and not in_call_context:
1519      raise ValueError(
1520          'Expected a symbolic Tensors or a callable for the loss value. '
1521          'Please wrap your loss computation in a zero argument `lambda`.')
1522
1523    self._eager_losses.extend(eager_losses)
1524
1525    for symbolic_loss in symbolic_losses:
1526      if getattr(self, '_is_graph_network', False):
1527        self._graph_network_add_loss(symbolic_loss)
1528      else:
1529        # Possible a loss was added in a Layer's `build`.
1530        self._losses.append(symbolic_loss)
1531
1532  def _clear_losses(self):
1533    """Used every step in eager to reset losses."""
1534    # Set to thread local directly to avoid Layer.__setattr__ overhead.
1535    if not getattr(self, '_self_tracked_trackables',
1536                   None):  # Fast path for single Layer.
1537      self._thread_local._eager_losses = []
1538    else:
1539      for layer in self._flatten_layers():
1540        layer._thread_local._eager_losses = []
1541
1542  @property
1543  def metrics(self):
1544    """List of metrics added using the `add_metric()` API.
1545
1546    Example:
1547
1548    >>> input = tf.keras.layers.Input(shape=(3,))
1549    >>> d = tf.keras.layers.Dense(2)
1550    >>> output = d(input)
1551    >>> d.add_metric(tf.reduce_max(output), name='max')
1552    >>> d.add_metric(tf.reduce_min(output), name='min')
1553    >>> [m.name for m in d.metrics]
1554    ['max', 'min']
1555
1556    Returns:
1557      A list of `Metric` objects.
1558    """
1559    collected_metrics = []
1560    for layer in self._flatten_layers():
1561      with layer._metrics_lock:
1562        collected_metrics.extend(layer._metrics)
1563    return collected_metrics
1564
1565  def add_metric(self, value, name=None, **kwargs):
1566    """Adds metric tensor to the layer.
1567
1568    This method can be used inside the `call()` method of a subclassed layer
1569    or model.
1570
1571    ```python
1572    class MyMetricLayer(tf.keras.layers.Layer):
1573      def __init__(self):
1574        super(MyMetricLayer, self).__init__(name='my_metric_layer')
1575        self.mean = tf.keras.metrics.Mean(name='metric_1')
1576
1577      def call(self, inputs):
1578        self.add_metric(self.mean(inputs))
1579        self.add_metric(tf.reduce_sum(inputs), name='metric_2')
1580        return inputs
1581    ```
1582
1583    This method can also be called directly on a Functional Model during
1584    construction. In this case, any tensor passed to this Model must
1585    be symbolic and be able to be traced back to the model's `Input`s. These
1586    metrics become part of the model's topology and are tracked when you
1587    save the model via `save()`.
1588
1589    ```python
1590    inputs = tf.keras.Input(shape=(10,))
1591    x = tf.keras.layers.Dense(10)(inputs)
1592    outputs = tf.keras.layers.Dense(1)(x)
1593    model = tf.keras.Model(inputs, outputs)
1594    model.add_metric(math_ops.reduce_sum(x), name='metric_1')
1595    ```
1596
1597    Note: Calling `add_metric()` with the result of a metric object on a
1598    Functional Model, as shown in the example below, is not supported. This is
1599    because we cannot trace the metric result tensor back to the model's inputs.
1600
1601    ```python
1602    inputs = tf.keras.Input(shape=(10,))
1603    x = tf.keras.layers.Dense(10)(inputs)
1604    outputs = tf.keras.layers.Dense(1)(x)
1605    model = tf.keras.Model(inputs, outputs)
1606    model.add_metric(tf.keras.metrics.Mean()(x), name='metric_1')
1607    ```
1608
1609    Args:
1610      value: Metric tensor.
1611      name: String metric name.
1612      **kwargs: Additional keyword arguments for backward compatibility.
1613        Accepted values:
1614        `aggregation` - When the `value` tensor provided is not the result of
1615        calling a `keras.Metric` instance, it will be aggregated by default
1616        using a `keras.Metric.Mean`.
1617    """
1618    kwargs_keys = list(kwargs.keys())
1619    if (len(kwargs_keys) > 1 or
1620        (len(kwargs_keys) == 1 and kwargs_keys[0] != 'aggregation')):
1621      raise TypeError('Unknown keyword arguments: ', str(kwargs.keys()))
1622
1623    from_metric_obj = hasattr(value, '_metric_obj')
1624    is_symbolic = isinstance(value, keras_tensor.KerasTensor)
1625    in_call_context = base_layer_utils.call_context().in_call
1626
1627    if name is None and not from_metric_obj:
1628      # Eg. `self.add_metric(math_ops.reduce_sum(x))`
1629      # In eager mode, we use metric name to lookup a metric. Without a name,
1630      # a new Mean metric wrapper will be created on every model/layer call.
1631      # So, we raise an error when no name is provided.
1632      # We will do the same for symbolic mode for consistency although a name
1633      # will be generated if no name is provided.
1634
1635      # We will not raise this error in the foll use case for the sake of
1636      # consistency as name in provided in the metric constructor.
1637      # mean = metrics.Mean(name='my_metric')
1638      # model.add_metric(mean(outputs))
1639      raise ValueError('Please provide a name for your metric like '
1640                       '`self.add_metric(tf.reduce_sum(inputs), '
1641                       'name=\'mean_activation\')`')
1642    elif from_metric_obj:
1643      name = value._metric_obj.name
1644
1645    if not in_call_context and not is_symbolic:
1646      raise ValueError('Expected a symbolic Tensor for the metric value, '
1647                       'received: ' + str(value))
1648
1649    # If a metric was added in a Layer's `call` or `build`.
1650    if in_call_context or not getattr(self, '_is_graph_network', False):
1651      # TF Function path should take the eager path.
1652
1653      # If the given metric is available in `metrics` list we just update state
1654      # on it, otherwise we create a new metric instance and
1655      # add it to the `metrics` list.
1656      metric_obj = getattr(value, '_metric_obj', None)
1657      # Tensors that come from a Metric object already updated the Metric state.
1658      should_update_state = not metric_obj
1659      name = metric_obj.name if metric_obj else name
1660
1661      with self._metrics_lock:
1662        match = self._get_existing_metric(name)
1663        if match:
1664          metric_obj = match
1665        elif metric_obj:
1666          self._metrics.append(metric_obj)
1667        else:
1668          # Build the metric object with the value's dtype if it defines one
1669          metric_obj = metrics_mod.Mean(
1670              name=name, dtype=getattr(value, 'dtype', None))
1671          self._metrics.append(metric_obj)
1672
1673      if should_update_state:
1674        metric_obj(value)
1675    else:
1676      if from_metric_obj:
1677        raise ValueError('Using the result of calling a `Metric` object '
1678                         'when calling `add_metric` on a Functional '
1679                         'Model is not supported. Please pass the '
1680                         'Tensor to monitor directly.')
1681
1682      # Insert layers into the Keras Graph Network.
1683      aggregation = None if from_metric_obj else 'mean'
1684      self._graph_network_add_metric(value, aggregation, name)
1685
1686  @doc_controls.do_not_doc_inheritable
1687  def add_update(self, updates, inputs=None):
1688    """Add update op(s), potentially dependent on layer inputs.
1689
1690    Weight updates (for instance, the updates of the moving mean and variance
1691    in a BatchNormalization layer) may be dependent on the inputs passed
1692    when calling a layer. Hence, when reusing the same layer on
1693    different inputs `a` and `b`, some entries in `layer.updates` may be
1694    dependent on `a` and some on `b`. This method automatically keeps track
1695    of dependencies.
1696
1697    This call is ignored when eager execution is enabled (in that case, variable
1698    updates are run on the fly and thus do not need to be tracked for later
1699    execution).
1700
1701    Args:
1702      updates: Update op, or list/tuple of update ops, or zero-arg callable
1703        that returns an update op. A zero-arg callable should be passed in
1704        order to disable running the updates by setting `trainable=False`
1705        on this Layer, when executing in Eager mode.
1706      inputs: Deprecated, will be automatically inferred.
1707    """
1708    if inputs is not None:
1709      tf_logging.warning(
1710          '`add_update` `inputs` kwarg has been deprecated. You no longer need '
1711          'to pass a value to `inputs` as it is being automatically inferred.')
1712    call_context = base_layer_utils.call_context()
1713    # No need to run updates during Functional API construction.
1714    if call_context.in_keras_graph:
1715      return
1716
1717    # Callable updates are disabled by setting `trainable=False`.
1718    if not call_context.frozen:
1719      for update in nest.flatten(updates):
1720        if callable(update):
1721          update()  # pylint: disable=not-callable
1722
1723  def set_weights(self, weights):
1724    """Sets the weights of the layer, from NumPy arrays.
1725
1726    The weights of a layer represent the state of the layer. This function
1727    sets the weight values from numpy arrays. The weight values should be
1728    passed in the order they are created by the layer. Note that the layer's
1729    weights must be instantiated before calling this function, by calling
1730    the layer.
1731
1732    For example, a `Dense` layer returns a list of two values: the kernel matrix
1733    and the bias vector. These can be used to set the weights of another
1734    `Dense` layer:
1735
1736    >>> layer_a = tf.keras.layers.Dense(1,
1737    ...   kernel_initializer=tf.constant_initializer(1.))
1738    >>> a_out = layer_a(tf.convert_to_tensor([[1., 2., 3.]]))
1739    >>> layer_a.get_weights()
1740    [array([[1.],
1741           [1.],
1742           [1.]], dtype=float32), array([0.], dtype=float32)]
1743    >>> layer_b = tf.keras.layers.Dense(1,
1744    ...   kernel_initializer=tf.constant_initializer(2.))
1745    >>> b_out = layer_b(tf.convert_to_tensor([[10., 20., 30.]]))
1746    >>> layer_b.get_weights()
1747    [array([[2.],
1748           [2.],
1749           [2.]], dtype=float32), array([0.], dtype=float32)]
1750    >>> layer_b.set_weights(layer_a.get_weights())
1751    >>> layer_b.get_weights()
1752    [array([[1.],
1753           [1.],
1754           [1.]], dtype=float32), array([0.], dtype=float32)]
1755
1756    Args:
1757      weights: a list of NumPy arrays. The number
1758        of arrays and their shape must match
1759        number of the dimensions of the weights
1760        of the layer (i.e. it should match the
1761        output of `get_weights`).
1762
1763    Raises:
1764      ValueError: If the provided weights list does not match the
1765        layer's specifications.
1766    """
1767    params = self.weights
1768
1769    expected_num_weights = 0
1770    for param in params:
1771      if isinstance(param, base_layer_utils.TrackableWeightHandler):
1772        expected_num_weights += param.num_tensors
1773      else:
1774        expected_num_weights += 1
1775
1776    if expected_num_weights != len(weights):
1777      raise ValueError(
1778          'You called `set_weights(weights)` on layer "%s" '
1779          'with a weight list of length %s, but the layer was '
1780          'expecting %s weights. Provided weights: %s...' %
1781          (self.name, len(weights), expected_num_weights, str(weights)[:50]))
1782
1783    weight_index = 0
1784    weight_value_tuples = []
1785    for param in params:
1786      if isinstance(param, base_layer_utils.TrackableWeightHandler):
1787        num_tensors = param.num_tensors
1788        tensors = weights[weight_index:weight_index + num_tensors]
1789        param.set_weights(tensors)
1790        weight_index += num_tensors
1791      else:
1792        weight = weights[weight_index]
1793        weight_shape = weight.shape if hasattr(weight, 'shape') else ()
1794        ref_shape = param.shape
1795        if not ref_shape.is_compatible_with(weight_shape):
1796          raise ValueError(
1797              'Layer weight shape %s not compatible with provided weight '
1798              'shape %s' % (ref_shape, weight_shape))
1799        weight_value_tuples.append((param, weight))
1800        weight_index += 1
1801
1802    backend.batch_set_value(weight_value_tuples)
1803
1804    # Perform any layer defined finalization of the layer state.
1805    for layer in self._flatten_layers():
1806      layer.finalize_state()
1807
1808  def get_weights(self):
1809    """Returns the current weights of the layer, as NumPy arrays.
1810
1811    The weights of a layer represent the state of the layer. This function
1812    returns both trainable and non-trainable weight values associated with this
1813    layer as a list of NumPy arrays, which can in turn be used to load state
1814    into similarly parameterized layers.
1815
1816    For example, a `Dense` layer returns a list of two values: the kernel matrix
1817    and the bias vector. These can be used to set the weights of another
1818    `Dense` layer:
1819
1820    >>> layer_a = tf.keras.layers.Dense(1,
1821    ...   kernel_initializer=tf.constant_initializer(1.))
1822    >>> a_out = layer_a(tf.convert_to_tensor([[1., 2., 3.]]))
1823    >>> layer_a.get_weights()
1824    [array([[1.],
1825           [1.],
1826           [1.]], dtype=float32), array([0.], dtype=float32)]
1827    >>> layer_b = tf.keras.layers.Dense(1,
1828    ...   kernel_initializer=tf.constant_initializer(2.))
1829    >>> b_out = layer_b(tf.convert_to_tensor([[10., 20., 30.]]))
1830    >>> layer_b.get_weights()
1831    [array([[2.],
1832           [2.],
1833           [2.]], dtype=float32), array([0.], dtype=float32)]
1834    >>> layer_b.set_weights(layer_a.get_weights())
1835    >>> layer_b.get_weights()
1836    [array([[1.],
1837           [1.],
1838           [1.]], dtype=float32), array([0.], dtype=float32)]
1839
1840    Returns:
1841        Weights values as a list of NumPy arrays.
1842    """
1843    weights = self.weights
1844    output_weights = []
1845    for weight in weights:
1846      if isinstance(weight, base_layer_utils.TrackableWeightHandler):
1847        output_weights.extend(weight.get_tensors())
1848      else:
1849        output_weights.append(weight)
1850    return backend.batch_get_value(output_weights)
1851
1852  @doc_controls.do_not_generate_docs
1853  def finalize_state(self):
1854    """Finalizes the layers state after updating layer weights.
1855
1856    This function can be subclassed in a layer and will be called after updating
1857    a layer weights. It can be overridden to finalize any additional layer state
1858    after a weight update.
1859    """
1860    pass
1861
1862  @doc_controls.do_not_generate_docs
1863  def get_updates_for(self, inputs):
1864    """Deprecated, do NOT use!
1865
1866    Retrieves updates relevant to a specific set of inputs.
1867
1868    Args:
1869      inputs: Input tensor or list/tuple of input tensors.
1870
1871    Returns:
1872      List of update ops of the layer that depend on `inputs`.
1873    """
1874    warnings.warn('`layer.get_updates_for` is deprecated and '
1875                  'will be removed in a future version. '
1876                  'Please use `layer.updates` method instead.')
1877    return self.updates
1878
1879  @doc_controls.do_not_generate_docs
1880  def get_losses_for(self, inputs):
1881    """Deprecated, do NOT use!
1882
1883    Retrieves losses relevant to a specific set of inputs.
1884
1885    Args:
1886      inputs: Input tensor or list/tuple of input tensors.
1887
1888    Returns:
1889      List of loss tensors of the layer that depend on `inputs`.
1890    """
1891    warnings.warn('`layer.get_losses_for` is deprecated and '
1892                  'will be removed in a future version. '
1893                  'Please use `layer.losses` instead.')
1894    return self.losses
1895
1896  @doc_controls.do_not_doc_inheritable
1897  def get_input_mask_at(self, node_index):
1898    """Retrieves the input mask tensor(s) of a layer at a given node.
1899
1900    Args:
1901        node_index: Integer, index of the node
1902            from which to retrieve the attribute.
1903            E.g. `node_index=0` will correspond to the
1904            first time the layer was called.
1905
1906    Returns:
1907        A mask tensor
1908        (or list of tensors if the layer has multiple inputs).
1909    """
1910    inputs = self.get_input_at(node_index)
1911    if isinstance(inputs, list):
1912      return [getattr(x, '_keras_mask', None) for x in inputs]
1913    else:
1914      return getattr(inputs, '_keras_mask', None)
1915
1916  @doc_controls.do_not_doc_inheritable
1917  def get_output_mask_at(self, node_index):
1918    """Retrieves the output mask tensor(s) of a layer at a given node.
1919
1920    Args:
1921        node_index: Integer, index of the node
1922            from which to retrieve the attribute.
1923            E.g. `node_index=0` will correspond to the
1924            first time the layer was called.
1925
1926    Returns:
1927        A mask tensor
1928        (or list of tensors if the layer has multiple outputs).
1929    """
1930    output = self.get_output_at(node_index)
1931    if isinstance(output, list):
1932      return [getattr(x, '_keras_mask', None) for x in output]
1933    else:
1934      return getattr(output, '_keras_mask', None)
1935
1936  @property
1937  @doc_controls.do_not_doc_inheritable
1938  def input_mask(self):
1939    """Retrieves the input mask tensor(s) of a layer.
1940
1941    Only applicable if the layer has exactly one inbound node,
1942    i.e. if it is connected to one incoming layer.
1943
1944    Returns:
1945        Input mask tensor (potentially None) or list of input
1946        mask tensors.
1947
1948    Raises:
1949        AttributeError: if the layer is connected to
1950        more than one incoming layers.
1951    """
1952    inputs = self.input
1953    if isinstance(inputs, list):
1954      return [getattr(x, '_keras_mask', None) for x in inputs]
1955    else:
1956      return getattr(inputs, '_keras_mask', None)
1957
1958  @property
1959  @doc_controls.do_not_doc_inheritable
1960  def output_mask(self):
1961    """Retrieves the output mask tensor(s) of a layer.
1962
1963    Only applicable if the layer has exactly one inbound node,
1964    i.e. if it is connected to one incoming layer.
1965
1966    Returns:
1967        Output mask tensor (potentially None) or list of output
1968        mask tensors.
1969
1970    Raises:
1971        AttributeError: if the layer is connected to
1972        more than one incoming layers.
1973    """
1974    output = self.output
1975    if isinstance(output, list):
1976      return [getattr(x, '_keras_mask', None) for x in output]
1977    else:
1978      return getattr(output, '_keras_mask', None)
1979
1980  @doc_controls.do_not_doc_inheritable
1981  def get_input_shape_at(self, node_index):
1982    """Retrieves the input shape(s) of a layer at a given node.
1983
1984    Args:
1985        node_index: Integer, index of the node
1986            from which to retrieve the attribute.
1987            E.g. `node_index=0` will correspond to the
1988            first time the layer was called.
1989
1990    Returns:
1991        A shape tuple
1992        (or list of shape tuples if the layer has multiple inputs).
1993
1994    Raises:
1995      RuntimeError: If called in Eager mode.
1996    """
1997    return self._get_node_attribute_at_index(node_index, 'input_shapes',
1998                                             'input shape')
1999
2000  @doc_controls.do_not_doc_inheritable
2001  def get_output_shape_at(self, node_index):
2002    """Retrieves the output shape(s) of a layer at a given node.
2003
2004    Args:
2005        node_index: Integer, index of the node
2006            from which to retrieve the attribute.
2007            E.g. `node_index=0` will correspond to the
2008            first time the layer was called.
2009
2010    Returns:
2011        A shape tuple
2012        (or list of shape tuples if the layer has multiple outputs).
2013
2014    Raises:
2015      RuntimeError: If called in Eager mode.
2016    """
2017    return self._get_node_attribute_at_index(node_index, 'output_shapes',
2018                                             'output shape')
2019
2020  @doc_controls.do_not_doc_inheritable
2021  def get_input_at(self, node_index):
2022    """Retrieves the input tensor(s) of a layer at a given node.
2023
2024    Args:
2025        node_index: Integer, index of the node
2026            from which to retrieve the attribute.
2027            E.g. `node_index=0` will correspond to the
2028            first input node of the layer.
2029
2030    Returns:
2031        A tensor (or list of tensors if the layer has multiple inputs).
2032
2033    Raises:
2034      RuntimeError: If called in Eager mode.
2035    """
2036    return self._get_node_attribute_at_index(node_index, 'input_tensors',
2037                                             'input')
2038
2039  @doc_controls.do_not_doc_inheritable
2040  def get_output_at(self, node_index):
2041    """Retrieves the output tensor(s) of a layer at a given node.
2042
2043    Args:
2044        node_index: Integer, index of the node
2045            from which to retrieve the attribute.
2046            E.g. `node_index=0` will correspond to the
2047            first output node of the layer.
2048
2049    Returns:
2050        A tensor (or list of tensors if the layer has multiple outputs).
2051
2052    Raises:
2053      RuntimeError: If called in Eager mode.
2054    """
2055    return self._get_node_attribute_at_index(node_index, 'output_tensors',
2056                                             'output')
2057
2058  @property
2059  def input(self):
2060    """Retrieves the input tensor(s) of a layer.
2061
2062    Only applicable if the layer has exactly one input,
2063    i.e. if it is connected to one incoming layer.
2064
2065    Returns:
2066        Input tensor or list of input tensors.
2067
2068    Raises:
2069      RuntimeError: If called in Eager mode.
2070      AttributeError: If no inbound nodes are found.
2071    """
2072    if not self._inbound_nodes:
2073      raise AttributeError('Layer ' + self.name +
2074                           ' is not connected, no input to return.')
2075    return self._get_node_attribute_at_index(0, 'input_tensors', 'input')
2076
2077  @property
2078  def output(self):
2079    """Retrieves the output tensor(s) of a layer.
2080
2081    Only applicable if the layer has exactly one output,
2082    i.e. if it is connected to one incoming layer.
2083
2084    Returns:
2085      Output tensor or list of output tensors.
2086
2087    Raises:
2088      AttributeError: if the layer is connected to more than one incoming
2089        layers.
2090      RuntimeError: if called in Eager mode.
2091    """
2092    if not self._inbound_nodes:
2093      raise AttributeError('Layer ' + self.name + ' has no inbound nodes.')
2094    return self._get_node_attribute_at_index(0, 'output_tensors', 'output')
2095
2096  @property
2097  @doc_controls.do_not_doc_inheritable
2098  def input_shape(self):
2099    """Retrieves the input shape(s) of a layer.
2100
2101    Only applicable if the layer has exactly one input,
2102    i.e. if it is connected to one incoming layer, or if all inputs
2103    have the same shape.
2104
2105    Returns:
2106        Input shape, as an integer shape tuple
2107        (or list of shape tuples, one tuple per input tensor).
2108
2109    Raises:
2110        AttributeError: if the layer has no defined input_shape.
2111        RuntimeError: if called in Eager mode.
2112    """
2113    if not self._inbound_nodes:
2114      raise AttributeError('The layer has never been called '
2115                           'and thus has no defined input shape.')
2116    all_input_shapes = set(
2117        [str(node.input_shapes) for node in self._inbound_nodes])
2118    if len(all_input_shapes) == 1:
2119      return self._inbound_nodes[0].input_shapes
2120    else:
2121      raise AttributeError('The layer "' + str(self.name) +
2122                           ' has multiple inbound nodes, '
2123                           'with different input shapes. Hence '
2124                           'the notion of "input shape" is '
2125                           'ill-defined for the layer. '
2126                           'Use `get_input_shape_at(node_index)` '
2127                           'instead.')
2128
2129  def count_params(self):
2130    """Count the total number of scalars composing the weights.
2131
2132    Returns:
2133        An integer count.
2134
2135    Raises:
2136        ValueError: if the layer isn't yet built
2137          (in which case its weights aren't yet defined).
2138    """
2139    if not self.built:
2140      if getattr(self, '_is_graph_network', False):
2141        with tf_utils.maybe_init_scope(self):
2142          self._maybe_build(self.inputs)
2143      else:
2144        raise ValueError('You tried to call `count_params` on ' + self.name +
2145                         ', but the layer isn\'t built. '
2146                         'You can build it manually via: `' + self.name +
2147                         '.build(batch_input_shape)`.')
2148    return layer_utils.count_params(self.weights)
2149
2150  @property
2151  @doc_controls.do_not_doc_inheritable
2152  def output_shape(self):
2153    """Retrieves the output shape(s) of a layer.
2154
2155    Only applicable if the layer has one output,
2156    or if all outputs have the same shape.
2157
2158    Returns:
2159        Output shape, as an integer shape tuple
2160        (or list of shape tuples, one tuple per output tensor).
2161
2162    Raises:
2163        AttributeError: if the layer has no defined output shape.
2164        RuntimeError: if called in Eager mode.
2165    """
2166    if not self._inbound_nodes:
2167      raise AttributeError('The layer has never been called '
2168                           'and thus has no defined output shape.')
2169    all_output_shapes = set(
2170        [str(node.output_shapes) for node in self._inbound_nodes])
2171    if len(all_output_shapes) == 1:
2172      return self._inbound_nodes[0].output_shapes
2173    else:
2174      raise AttributeError('The layer "%s"'
2175                           ' has multiple inbound nodes, '
2176                           'with different output shapes. Hence '
2177                           'the notion of "output shape" is '
2178                           'ill-defined for the layer. '
2179                           'Use `get_output_shape_at(node_index)` '
2180                           'instead.' % self.name)
2181
2182  @property
2183  @doc_controls.do_not_doc_inheritable
2184  def inbound_nodes(self):
2185    """Deprecated, do NOT use! Only for compatibility with external Keras."""
2186    return self._inbound_nodes
2187
2188  @property
2189  @doc_controls.do_not_doc_inheritable
2190  def outbound_nodes(self):
2191    """Deprecated, do NOT use! Only for compatibility with external Keras."""
2192    return self._outbound_nodes
2193
2194  ##############################################################################
2195  # Methods & attributes below are public aliases of other methods.            #
2196  ##############################################################################
2197
2198  @doc_controls.do_not_doc_inheritable
2199  def apply(self, inputs, *args, **kwargs):
2200    """Deprecated, do NOT use!
2201
2202    This is an alias of `self.__call__`.
2203
2204    Args:
2205      inputs: Input tensor(s).
2206      *args: additional positional arguments to be passed to `self.call`.
2207      **kwargs: additional keyword arguments to be passed to `self.call`.
2208
2209    Returns:
2210      Output tensor(s).
2211    """
2212    warnings.warn('`layer.apply` is deprecated and '
2213                  'will be removed in a future version. '
2214                  'Please use `layer.__call__` method instead.')
2215    return self.__call__(inputs, *args, **kwargs)
2216
2217  @doc_controls.do_not_doc_inheritable
2218  def add_variable(self, *args, **kwargs):
2219    """Deprecated, do NOT use! Alias for `add_weight`."""
2220    warnings.warn('`layer.add_variable` is deprecated and '
2221                  'will be removed in a future version. '
2222                  'Please use `layer.add_weight` method instead.')
2223    return self.add_weight(*args, **kwargs)
2224
2225  @property
2226  @doc_controls.do_not_generate_docs
2227  def variables(self):
2228    """Returns the list of all layer variables/weights.
2229
2230    Alias of `self.weights`.
2231
2232    Note: This will not track the weights of nested `tf.Modules` that are not
2233    themselves Keras layers.
2234
2235    Returns:
2236      A list of variables.
2237    """
2238    return self.weights
2239
2240  @property
2241  @doc_controls.do_not_generate_docs
2242  def trainable_variables(self):
2243    return self.trainable_weights
2244
2245  @property
2246  @doc_controls.do_not_generate_docs
2247  def non_trainable_variables(self):
2248    return self.non_trainable_weights
2249
2250  ##############################################################################
2251  # Methods & attributes below are all private and only used by the framework. #
2252  ##############################################################################
2253
2254  @property
2255  def _inbound_nodes(self):
2256    return self._inbound_nodes_value
2257
2258  @_inbound_nodes.setter
2259  @trackable.no_automatic_dependency_tracking
2260  def _inbound_nodes(self, value):
2261    self._inbound_nodes_value = value
2262
2263  @property
2264  def _outbound_nodes(self):
2265    return self._outbound_nodes_value
2266
2267  @_outbound_nodes.setter
2268  @trackable.no_automatic_dependency_tracking
2269  def _outbound_nodes(self, value):
2270    self._outbound_nodes_value = value
2271
2272  def _set_dtype_policy(self, dtype):
2273    """Sets self._dtype_policy."""
2274    if isinstance(dtype, policy.Policy):
2275      self._dtype_policy = dtype
2276    elif isinstance(dtype, dict):
2277      self._dtype_policy = policy.deserialize(dtype)
2278    elif isinstance(dtype, str) and dtype in ('mixed_float16',
2279                                              'mixed_bfloat16'):
2280      # The isinstance check is required since np.dtype raises an error if
2281      # compared to a non-dtype string.
2282      self._dtype_policy = policy.Policy(dtype)
2283    elif dtype:
2284      self._dtype_policy = policy.Policy(dtypes.as_dtype(dtype).name)
2285    else:
2286      self._dtype_policy = policy.global_policy()
2287    if (self._dtype_policy.name == 'mixed_float16' and
2288        not loss_scale_optimizer.strategy_supports_loss_scaling()):
2289      # Although only loss scaling doesn't support certain strategies, to avoid
2290      # confusion, we disallow the 'mixed_float16' policy with unsupported
2291      # strategies. This is because 'mixed_float16' requires loss scaling for
2292      # numeric stability.
2293      strategy = ds_context.get_strategy()
2294      raise ValueError('Mixed precision is not supported with the '
2295                       'tf.distribute.Strategy: %s. Either stop using mixed '
2296                       'precision by removing the use of the "%s" policy or '
2297                       'use a different Strategy, e.g. a MirroredStrategy.' %
2298                       (strategy.__class__.__name__, self._dtype_policy.name))
2299
2300    # Performance optimization: cache the compute dtype as a Dtype object or
2301    # None, so that str to Dtype conversion doesn't happen in Layer.__call__.
2302    # TODO(b/157486353): Investigate returning DTypes in Policy.
2303    if self._dtype_policy.compute_dtype:
2304      self._compute_dtype_object = dtypes.as_dtype(
2305          self._dtype_policy.compute_dtype)
2306    else:
2307      self._compute_dtype_object = None
2308
2309  @property
2310  def dtype_policy(self):
2311    """The dtype policy associated with this layer.
2312
2313    This is an instance of a `tf.keras.mixed_precision.Policy`.
2314    """
2315    return self._dtype_policy
2316
2317  @property
2318  def compute_dtype(self):
2319    """The dtype of the layer's computations.
2320
2321    This is equivalent to `Layer.dtype_policy.compute_dtype`. Unless
2322    mixed precision is used, this is the same as `Layer.dtype`, the dtype of
2323    the weights.
2324
2325    Layers automatically cast their inputs to the compute dtype, which causes
2326    computations and the output to be in the compute dtype as well. This is done
2327    by the base Layer class in `Layer.__call__`, so you do not have to insert
2328    these casts if implementing your own layer.
2329
2330    Layers often perform certain internal computations in higher precision when
2331    `compute_dtype` is float16 or bfloat16 for numeric stability. The output
2332    will still typically be float16 or bfloat16 in such cases.
2333
2334    Returns:
2335      The layer's compute dtype.
2336    """
2337    return self._dtype_policy.compute_dtype
2338
2339  @property
2340  def _compute_dtype(self):
2341    """Deprecated alias of `compute_dtype`."""
2342    return self._dtype_policy.compute_dtype
2343
2344  @property
2345  def variable_dtype(self):
2346    """Alias of `Layer.dtype`, the dtype of the weights."""
2347    return self.dtype
2348
2349  def _maybe_cast_inputs(self, inputs, input_list=None):
2350    """Maybe casts the inputs to the compute dtype.
2351
2352    If self._compute_dtype is floating-point, and self_autocast is True,
2353    floating-point inputs are casted to self._compute_dtype.
2354
2355    Args:
2356      inputs: Input tensor, or structure of input tensors.
2357      input_list: Flat list of input tensors.
2358
2359    Returns:
2360      `inputs`, but tensors may have been casted to self._compute_dtype
2361    """
2362    if not input_list:
2363      input_list = nest.flatten(inputs)
2364
2365    compute_dtype_object = self._compute_dtype_object
2366    should_autocast = (
2367        self._autocast and compute_dtype_object and
2368        compute_dtype_object.is_floating)
2369
2370    if (should_autocast and
2371        any(map(self._should_cast_single_input, input_list))):
2372      # Only perform expensive `nest` operation when needed.
2373      return nest.map_structure(self._cast_single_input, inputs)
2374    else:
2375      return inputs
2376
2377  def _should_cast_single_input(self, x):
2378    if isinstance(x, _AUTOCAST_TYPES):
2379      return (self._compute_dtype_object and
2380              x.dtype != self._compute_dtype_object and x.dtype.is_floating)
2381    return False
2382
2383  def _cast_single_input(self, x):
2384    """Cast a single Tensor or TensorSpec to the compute dtype."""
2385    if self._should_cast_single_input(x):
2386      return math_ops.cast(x, self._compute_dtype_object)
2387    else:
2388      return x
2389
2390  # _dtype used to be an attribute set in the constructor. We still expose it
2391  # because some clients still use it.
2392  # TODO(reedwm): Deprecate, then remove the _dtype property.
2393  @property
2394  def _dtype(self):
2395    # This is equivalent to returning self.dtype . We do not return self.dtype
2396    # as it would cause infinite recursion in a few subclasses, which override
2397    # "dtype" to return self._dtype.
2398    return self._dtype_policy.variable_dtype
2399
2400  @_dtype.setter
2401  def _dtype(self, value):
2402    value = dtypes.as_dtype(value).name
2403    self._set_dtype_policy(policy.Policy(value))
2404
2405  def _name_scope(self):  # pylint: disable=method-hidden
2406    if not tf2.enabled():
2407      return self.name
2408    name_scope = self.name
2409    current_name_scope = ops.get_name_scope()
2410    if current_name_scope:
2411      name_scope = current_name_scope + '/' + name_scope
2412    if name_scope:
2413      # Note that the trailing `/` prevents autogenerated
2414      # numerical suffixes to get appended. It will also fully reset
2415      # nested name scope (i.e. the outer name scope has no effect).
2416      name_scope += '/'
2417    return name_scope
2418
2419  def _init_set_name(self, name, zero_based=True):
2420    if not name:
2421      self._name = backend.unique_object_name(
2422          generic_utils.to_snake_case(self.__class__.__name__),
2423          zero_based=zero_based)
2424    else:
2425      backend.observe_object_name(name)
2426      self._name = name
2427
2428  def _get_existing_metric(self, name=None):
2429    match = [m for m in self._metrics if m.name == name]
2430    if not match:
2431      return
2432    if len(match) > 1:
2433      raise ValueError(
2434          'Please provide different names for the metrics you have added. '
2435          'We found {} metrics with the name: "{}"'.format(len(match), name))
2436    return match[0]
2437
2438  def _handle_weight_regularization(self, name, variable, regularizer):
2439    """Create lambdas which compute regularization losses."""
2440
2441    def _loss_for_variable(v):
2442      """Creates a regularization loss `Tensor` for variable `v`."""
2443      with backend.name_scope(name + '/Regularizer'):
2444        regularization = regularizer(v)
2445      return regularization
2446
2447    if base_layer_utils.is_split_variable(variable):
2448      for v in variable:
2449        self.add_loss(functools.partial(_loss_for_variable, v))
2450    else:
2451      self.add_loss(functools.partial(_loss_for_variable, variable))
2452
2453  def _handle_activity_regularization(self, inputs, outputs):
2454    # Apply activity regularization.
2455    # Note that it should be applied every time the layer creates a new
2456    # output, since it is output-specific.
2457    if self._activity_regularizer:
2458      output_list = nest.flatten(outputs)
2459      with backend.name_scope('ActivityRegularizer'):
2460        for output in output_list:
2461          activity_loss = self._activity_regularizer(output)
2462          batch_size = math_ops.cast(
2463              array_ops.shape(output)[0], activity_loss.dtype)
2464          # Make activity regularization strength batch-agnostic.
2465          mean_activity_loss = activity_loss / batch_size
2466          self.add_loss(mean_activity_loss)
2467
2468  def _set_mask_metadata(self, inputs, outputs, previous_mask, build_graph):
2469    # Many `Layer`s don't need to call `compute_mask`.
2470    # This method is optimized to do as little work as needed for the common
2471    # case.
2472    if not self._supports_masking:
2473      return
2474
2475    flat_outputs = nest.flatten(outputs)
2476
2477    mask_already_computed = (
2478        getattr(self, '_compute_output_and_mask_jointly', False) or
2479        all(getattr(x, '_keras_mask', None) is not None for x in flat_outputs))
2480    if mask_already_computed:
2481      if build_graph:
2482        self._set_mask_keras_history_checked(flat_outputs)
2483      return
2484
2485    output_masks = self.compute_mask(inputs, previous_mask)
2486    if output_masks is None:
2487      return
2488
2489    flat_masks = nest.flatten(output_masks)
2490    for tensor, mask in zip(flat_outputs, flat_masks):
2491      try:
2492        tensor._keras_mask = mask
2493      except AttributeError:
2494        # C Type such as np.ndarray.
2495        pass
2496
2497    if build_graph:
2498      self._set_mask_keras_history_checked(flat_outputs)
2499
2500  def _set_mask_keras_history_checked(self, flat_outputs):
2501    for output in flat_outputs:
2502      if getattr(output, '_keras_mask', None) is not None:
2503        # Do not track masks for `TensorFlowOpLayer` construction.
2504        output._keras_mask._keras_history_checked = True
2505
2506  def _get_input_masks(self, inputs, input_list, args, kwargs):
2507    if not self._supports_masking and not self._expects_mask_arg:
2508      # Input masks only need to be retrieved if they are needed for `call`
2509      # or `compute_mask`.
2510      input_masks = None
2511      implicit_mask = False
2512    elif self._call_arg_was_passed('mask', args, kwargs):
2513      input_masks = self._get_call_arg_value('mask', args, kwargs)
2514      implicit_mask = False
2515    else:
2516      input_masks = [getattr(t, '_keras_mask', None) for t in input_list]
2517      if all(mask is None for mask in input_masks):
2518        input_masks = None
2519        implicit_mask = False
2520      else:
2521        # Only do expensive `nest` op when masking is actually being used.
2522        input_masks = nest.pack_sequence_as(inputs, input_masks)
2523        implicit_mask = True
2524    return input_masks, implicit_mask
2525
2526  def _call_arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False):
2527    # Performance optimization: do no work in most common case.
2528    if not args and not kwargs:
2529      return False
2530
2531    if arg_name in kwargs:
2532      return True
2533    call_fn_args = self._call_fn_args
2534    if not inputs_in_args:
2535      # Ignore `inputs` arg.
2536      call_fn_args = call_fn_args[1:]
2537    return arg_name in dict(zip(call_fn_args, args))
2538
2539  def _get_call_arg_value(self, arg_name, args, kwargs, inputs_in_args=False):
2540    if arg_name in kwargs:
2541      return kwargs[arg_name]
2542    call_fn_args = self._call_fn_args
2543    if not inputs_in_args:
2544      # Ignore `inputs` arg.
2545      call_fn_args = call_fn_args[1:]
2546    args_dict = dict(zip(call_fn_args, args))
2547    return args_dict[arg_name]
2548
2549  def _set_call_arg_value(
2550      self, arg_name, new_value, args,
2551      kwargs, inputs_in_args=False, pop_kwarg_if_none=False):
2552    arg_pos = self._call_fn_arg_positions.get(arg_name, None)
2553    if arg_pos is not None:
2554      if not inputs_in_args:
2555        # Ignore `inputs` arg.
2556        arg_pos = arg_pos - 1
2557      if len(args) > arg_pos:
2558        args = list(args)
2559        args[arg_pos] = new_value
2560        return tuple(args), kwargs
2561    if new_value is None and pop_kwarg_if_none:
2562      kwargs.pop(arg_name, None)
2563    else:
2564      kwargs[arg_name] = new_value
2565    return args, kwargs
2566
2567  def _set_connectivity_metadata(self, args, kwargs, outputs):
2568    # If the layer returns tensors from its inputs unmodified,
2569    # we copy them to avoid loss of KerasHistory metadata.
2570    flat_outputs = nest.flatten(outputs)
2571    flat_inputs = nest.flatten((args, kwargs))
2572    input_ids_set = {id(i) for i in flat_inputs}
2573    outputs_copy = []
2574    for x in flat_outputs:
2575      if id(x) in input_ids_set:
2576        with backend.name_scope(self.name):
2577          x = array_ops.identity(x)
2578      outputs_copy.append(x)
2579    outputs = nest.pack_sequence_as(outputs, outputs_copy)
2580
2581    # Create node, Node wires itself to inbound and outbound layers.
2582    # The Node constructor actually updates this layer's self._inbound_nodes,
2583    # sets _keras_history on the outputs, and adds itself to the
2584    # `_outbound_nodes` of the layers that produced the inputs to this
2585    # layer call.
2586    node_module.Node(self, call_args=args, call_kwargs=kwargs, outputs=outputs)
2587    return outputs
2588
2589  def _get_node_attribute_at_index(self, node_index, attr, attr_name):
2590    """Private utility to retrieves an attribute (e.g. inputs) from a node.
2591
2592    This is used to implement the methods:
2593        - get_input_shape_at
2594        - get_output_shape_at
2595        - get_input_at
2596        etc...
2597
2598    Args:
2599        node_index: Integer index of the node from which
2600            to retrieve the attribute.
2601        attr: Exact node attribute name.
2602        attr_name: Human-readable attribute name, for error messages.
2603
2604    Returns:
2605        The layer's attribute `attr` at the node of index `node_index`.
2606
2607    Raises:
2608        RuntimeError: If the layer has no inbound nodes, or if called in Eager
2609        mode.
2610        ValueError: If the index provided does not match any node.
2611    """
2612    if not self._inbound_nodes:
2613      raise RuntimeError('The layer has never been called '
2614                         'and thus has no defined ' + attr_name + '.')
2615    if not len(self._inbound_nodes) > node_index:
2616      raise ValueError('Asked to get ' + attr_name + ' at node ' +
2617                       str(node_index) + ', but the layer has only ' +
2618                       str(len(self._inbound_nodes)) + ' inbound nodes.')
2619    values = getattr(self._inbound_nodes[node_index], attr)
2620    if isinstance(values, list) and len(values) == 1:
2621      return values[0]
2622    else:
2623      return values
2624
2625  def _maybe_build(self, inputs):
2626    # Check input assumptions set before layer building, e.g. input rank.
2627    if not self.built:
2628      input_spec.assert_input_compatibility(
2629          self.input_spec, inputs, self.name)
2630      input_list = nest.flatten(inputs)
2631      if input_list and self._dtype_policy.compute_dtype is None:
2632        try:
2633          dtype = input_list[0].dtype.base_dtype.name
2634        except AttributeError:
2635          pass
2636        else:
2637          self._set_dtype_policy(policy.Policy(dtype))
2638      input_shapes = None
2639      # Converts Tensors / CompositeTensors to TensorShapes.
2640      if all(hasattr(x, 'shape') for x in input_list):
2641        input_shapes = tf_utils.get_shapes(inputs)
2642      else:
2643        # Converts input shape to TensorShapes.
2644        try:
2645          input_shapes = tf_utils.convert_shapes(inputs, to_tuples=False)
2646        except ValueError:
2647          pass
2648      # Only call `build` if the user has manually overridden the build method.
2649      if not hasattr(self.build, '_is_default'):
2650        # Any setup work performed only once should happen in an `init_scope`
2651        # to avoid creating symbolic Tensors that will later pollute any eager
2652        # operations.
2653        with tf_utils.maybe_init_scope(self):
2654          self.build(input_shapes)  # pylint:disable=not-callable
2655      # We must set also ensure that the layer is marked as built, and the build
2656      # shape is stored since user defined build functions may not be calling
2657      # `super.build()`
2658      Layer.build(self, input_shapes)
2659
2660    # Optionally load weight values specified at layer instantiation.
2661    if self._initial_weights is not None:
2662      with ops.init_scope():
2663        # Using `init_scope` since we want variable assignment in
2664        # `set_weights` to be treated like variable initialization.
2665        self.set_weights(self._initial_weights)
2666      self._initial_weights = None
2667
2668  def _symbolic_call(self, inputs):
2669    input_shapes = nest.map_structure(lambda x: x.shape, inputs)
2670    output_shapes = self.compute_output_shape(input_shapes)
2671    # Convert to TensorShape so that nest.map_structure will not map into
2672    # individual dim of the shape.
2673    output_shapes = tf_utils.convert_shapes(output_shapes, to_tuples=False)
2674
2675    def _make_placeholder_like(shape):
2676      ph = backend.placeholder(shape=shape, dtype=self.dtype)
2677      ph._keras_mask = None
2678      return ph
2679    return nest.map_structure(_make_placeholder_like, output_shapes)
2680
2681  def _get_trainable_state(self):
2682    """Get the `trainable` state of each sublayer.
2683
2684    Returns:
2685      A dict mapping all sublayers to their `trainable` value.
2686    """
2687    trainable_state = weakref.WeakKeyDictionary()
2688    for layer in self._flatten_layers():
2689      trainable_state[layer] = layer.trainable
2690    return trainable_state
2691
2692  def _set_trainable_state(self, trainable_state):
2693    """Set `trainable` state for each sublayer."""
2694    for layer in self._flatten_layers():
2695      if layer in trainable_state:
2696        layer.trainable = trainable_state[layer]
2697
2698  @property
2699  def _obj_reference_counts(self):
2700    """A dictionary counting the number of attributes referencing an object."""
2701    self._maybe_create_attribute('_obj_reference_counts_dict',
2702                                 object_identity.ObjectIdentityDictionary())
2703    return self._obj_reference_counts_dict
2704
2705  @trackable.no_automatic_dependency_tracking
2706  def _maybe_create_attribute(self, name, default_value):
2707    """Create the attribute with the default value if it hasn't been created.
2708
2709    This is useful for fields that is used for tracking purpose,
2710    _trainable_weights, or _layers. Note that user could create a layer subclass
2711    and assign an internal field before invoking the Layer.__init__(), the
2712    __setattr__() need to create the tracking fields and __init__() need to not
2713    override them.
2714
2715    Args:
2716      name: String, the name of the attribute.
2717      default_value: Object, the default value of the attribute.
2718    """
2719    if not hasattr(self, name):
2720      self.__setattr__(name, default_value)
2721
2722  def __delattr__(self, name):
2723    # For any super.__delattr__() call, we will directly use the implementation
2724    # in Trackable and skip the behavior in AutoTrackable. The Layer was
2725    # originally use Trackable as base class, the change of using Module as base
2726    # class forced us to have AutoTrackable in the class hierarchy.
2727    #
2728    # TODO(b/180760306) Keeping the status quo of skipping _delattr__ and
2729    # __setattr__ in AutoTrackable may be unsustainable.
2730    existing_value = getattr(self, name, None)
2731
2732    # If this value is replacing an existing object assigned to an attribute, we
2733    # should clean it out to avoid leaking memory. First we check if there are
2734    # other attributes referencing it.
2735    reference_counts = self._obj_reference_counts
2736    if existing_value not in reference_counts:
2737      super(autotrackable.AutoTrackable, self).__delattr__(name)  # pylint: disable=bad-super-call
2738      return
2739
2740    reference_count = reference_counts[existing_value]
2741    if reference_count > 1:
2742      # There are other remaining references. We can't remove this object from
2743      # _layers etc.
2744      reference_counts[existing_value] = reference_count - 1
2745      super(autotrackable.AutoTrackable, self).__delattr__(name)  # pylint: disable=bad-super-call
2746      return
2747    else:
2748      # This is the last remaining reference.
2749      del reference_counts[existing_value]
2750
2751    super(autotrackable.AutoTrackable, self).__delattr__(name)  # pylint: disable=bad-super-call
2752
2753    if (isinstance(existing_value, Layer)
2754        or base_layer_utils.has_weights(existing_value)):
2755      super(autotrackable.AutoTrackable, self).__setattr__(  # pylint: disable=bad-super-call
2756          '_self_tracked_trackables',
2757          [l for l in self._self_tracked_trackables if l is not existing_value])
2758    if isinstance(existing_value, tf_variables.Variable):
2759      super(autotrackable.AutoTrackable, self).__setattr__(  # pylint: disable=bad-super-call
2760          '_trainable_weights',
2761          [w for w in self._trainable_weights if w is not existing_value])
2762      super(autotrackable.AutoTrackable, self).__setattr__(  # pylint: disable=bad-super-call
2763          '_non_trainable_weights',
2764          [w for w in self._non_trainable_weights if w is not existing_value])
2765
2766  def __setattr__(self, name, value):
2767    if (name == '_self_setattr_tracking' or
2768        not getattr(self, '_self_setattr_tracking', True) or
2769        # Exclude @property.setters from tracking
2770        hasattr(self.__class__, name)):
2771      try:
2772        super(autotrackable.AutoTrackable, self).__setattr__(name, value)  # pylint: disable=bad-super-call
2773      except AttributeError:
2774        raise AttributeError(
2775            ('Can\'t set the attribute "{}", likely because it conflicts with '
2776             'an existing read-only @property of the object. Please choose a '
2777             'different name.').format(name))
2778      return
2779
2780    # Wraps data structures in `Trackable`, unwraps `NoDependency` objects.
2781    value = data_structures.sticky_attribute_assignment(
2782        trackable=self, value=value, name=name)
2783
2784    reference_counts = self._obj_reference_counts
2785    reference_counts[value] = reference_counts.get(value, 0) + 1
2786
2787    # Clean out the old attribute, which clears _layers and _trainable_weights
2788    # if necessary.
2789    try:
2790      self.__delattr__(name)
2791    except AttributeError:
2792      pass
2793
2794    # Keep track of metric instance created in subclassed layer.
2795    for val in nest.flatten(value):
2796      if isinstance(val, metrics_mod.Metric) and hasattr(self, '_metrics'):
2797        self._metrics.append(val)
2798
2799    # Append value to self._self_tracked_trackables if relevant
2800    if (getattr(self, '_auto_track_sub_layers', True) and
2801        (isinstance(value, module.Module) or
2802         base_layer_utils.has_weights(value))):
2803      self._maybe_create_attribute('_self_tracked_trackables', [])
2804      # We need to check object identity to avoid de-duplicating empty
2805      # container types which compare equal.
2806      if not any((layer is value for layer in self._self_tracked_trackables)):
2807        self._self_tracked_trackables.append(value)
2808        if hasattr(value, '_use_resource_variables'):
2809          # Legacy layers (V1 tf.layers) must always use
2810          # resource variables.
2811          value._use_resource_variables = True
2812
2813    # Append value to list of trainable / non-trainable weights if relevant
2814    # TODO(b/125122625): This won't pick up on any variables added to a
2815    # list/dict after creation.
2816    for val in nest.flatten(value, expand_composites=True):
2817      if not isinstance(val, tf_variables.Variable):
2818        continue
2819
2820      # Users may add extra weights/variables
2821      # simply by assigning them to attributes (invalid for graph networks)
2822      self._maybe_create_attribute('_trainable_weights', [])
2823      self._maybe_create_attribute('_non_trainable_weights', [])
2824      if val.trainable:
2825        if any(val is w for w in self._trainable_weights):
2826          continue
2827        self._trainable_weights.append(val)
2828      else:
2829        if any(val is w for w in self._non_trainable_weights):
2830          continue
2831        self._non_trainable_weights.append(val)
2832
2833      backend.track_variable(val)
2834
2835    # TODO(b/180760306) Skip the auto trackable from tf.Module to keep status
2836    # quo. See the comment at __delattr__.
2837    super(autotrackable.AutoTrackable, self).__setattr__(name, value)  # pylint: disable=bad-super-call
2838
2839  def _gather_children_attribute(self, attribute):
2840    assert attribute in {
2841        'variables', 'trainable_variables', 'non_trainable_variables'
2842    }
2843    if hasattr(self, '_self_tracked_trackables'):
2844      nested_layers = self._flatten_modules(include_self=False, recursive=False)
2845      return list(
2846          itertools.chain.from_iterable(
2847              getattr(layer, attribute) for layer in nested_layers))
2848    return []
2849
2850  def _flatten_layers(self, recursive=True, include_self=True):
2851    for m in self._flatten_modules(
2852        recursive=recursive, include_self=include_self):
2853      if isinstance(m, Layer):
2854        yield m
2855
2856  def _flatten_modules(self, recursive=True, include_self=True):
2857    """Flattens `tf.Module` instances (excluding `Metrics`).
2858
2859    Args:
2860      recursive: Whether to recursively flatten through submodules.
2861      include_self: Whether to include this `Layer` instance.
2862
2863    Yields:
2864      `tf.Module` instance tracked by this `Layer`.
2865    """
2866    if include_self:
2867      yield self
2868
2869    # Only instantiate set and deque if needed.
2870    trackables = getattr(self, '_self_tracked_trackables', None)
2871    if trackables:
2872      seen_object_ids = set()
2873      deque = collections.deque(trackables)
2874      while deque:
2875        trackable_obj = deque.popleft()
2876        trackable_id = id(trackable_obj)
2877        if trackable_id in seen_object_ids:
2878          continue
2879        seen_object_ids.add(trackable_id)
2880
2881        # Metrics are not considered part of the Layer's topology.
2882        if (isinstance(trackable_obj, module.Module) and
2883            not isinstance(trackable_obj, metrics_mod.Metric)):
2884          yield trackable_obj
2885          # Introspect recursively through sublayers.
2886          if recursive:
2887            subtrackables = getattr(trackable_obj, '_self_tracked_trackables',
2888                                    None)
2889            if subtrackables:
2890              deque.extendleft(reversed(subtrackables))
2891        elif isinstance(trackable_obj, data_structures.TrackableDataStructure):
2892          # Data structures are introspected even with `recursive=False`.
2893          tracked_values = trackable_obj._values
2894          if tracked_values:
2895            deque.extendleft(reversed(tracked_values))
2896
2897  # This is a hack so that the is_layer (within
2898  # training/trackable/layer_utils.py) check doesn't get the weights attr.
2899  # TODO(b/110718070): Remove when fixed.
2900  def _is_layer(self):
2901    return True
2902
2903  def _init_call_fn_args(self, expects_training_arg=None):
2904    # Clear cached call function arguments.
2905    self.__class__._call_full_argspec.fget.cache.pop(self, None)
2906    self.__class__._call_fn_args.fget.cache.pop(self, None)
2907    self.__class__._call_accepts_kwargs.fget.cache.pop(self, None)
2908
2909    call_fn_args = self._call_fn_args
2910    call_fn_args += self._call_full_argspec.kwonlyargs or []
2911    if expects_training_arg is None:
2912      self._expects_training_arg = ('training' in call_fn_args or
2913                                    self._call_accepts_kwargs)
2914    else:
2915      # Use value encoded into the metadata when loading from the SavedModel.
2916      self._expects_training_arg = expects_training_arg
2917    # The default training arg will be any (non-None) default specified in the
2918    # method signature, or None if no value is specified.
2919    call_fn_arg_defaults = self._call_fn_arg_defaults.copy()
2920    call_fn_arg_defaults.update(self._call_full_argspec.kwonlydefaults or {})
2921    self._default_training_arg = call_fn_arg_defaults.get('training')
2922
2923    self._expects_mask_arg = ('mask' in call_fn_args or
2924                              self._call_accepts_kwargs)
2925
2926  @property
2927  @layer_utils.cached_per_instance
2928  def _call_full_argspec(self):
2929    # Argspec inspection is expensive and the call spec is used often, so it
2930    # makes sense to cache the result.
2931    return tf_inspect.getfullargspec(self.call)
2932
2933  @property
2934  @layer_utils.cached_per_instance
2935  def _call_fn_args(self):
2936    all_args = self._call_full_argspec.args
2937    # Scrub `self` that appears if a decorator was applied.
2938    if all_args and all_args[0] == 'self':
2939      return all_args[1:]
2940    return all_args
2941
2942  @property
2943  @layer_utils.cached_per_instance
2944  def _call_fn_arg_defaults(self):
2945    call_fn_args = self._call_fn_args
2946    call_fn_defaults = self._call_full_argspec.defaults or []
2947    defaults = dict()
2948
2949    # The call arg defaults are an n-tuple of the last n elements of the args
2950    # list. (n = # of elements that have a default argument)
2951    for i in range(-1 * len(call_fn_defaults), 0):
2952      defaults[call_fn_args[i]] = call_fn_defaults[i]
2953    return defaults
2954
2955  @property
2956  @layer_utils.cached_per_instance
2957  def _call_fn_arg_positions(self):
2958    call_fn_arg_positions = dict()
2959    for pos, arg in enumerate(self._call_fn_args):
2960      call_fn_arg_positions[arg] = pos
2961    return call_fn_arg_positions
2962
2963  @property
2964  @layer_utils.cached_per_instance
2965  def _call_accepts_kwargs(self):
2966    return self._call_full_argspec.varkw is not None
2967
2968  @property
2969  def _eager_losses(self):
2970    # A list of loss values containing activity regularizers and losses
2971    # manually added through `add_loss` during eager execution. It is cleared
2972    # after every batch.
2973    # Because we plan on eventually allowing a same model instance to be trained
2974    # in eager mode or graph mode alternatively, we need to keep track of
2975    # eager losses and symbolic losses via separate attributes.
2976    if not hasattr(self._thread_local, '_eager_losses'):
2977      self._thread_local._eager_losses = []
2978    return self._thread_local._eager_losses
2979
2980  @_eager_losses.setter
2981  def _eager_losses(self, losses):
2982    self._thread_local._eager_losses = losses
2983
2984  def _dedup_weights(self, weights):
2985    """Dedupe weights while maintaining order as much as possible."""
2986    output, seen_ids = [], set()
2987    for w in weights:
2988      if id(w) not in seen_ids:
2989        output.append(w)
2990        # Track the Variable's identity to avoid __eq__ issues.
2991        seen_ids.add(id(w))
2992
2993    return output
2994
2995  def _split_out_first_arg(self, args, kwargs):
2996    # Grab the argument corresponding to the first argument in the
2997    # layer's `call` method spec. This will either be the first positional
2998    # argument, or it will be provided as a keyword argument.
2999    if args:
3000      inputs = args[0]
3001      args = args[1:]
3002    elif self._call_fn_args[0] in kwargs:
3003      kwargs = copy.copy(kwargs)
3004      inputs = kwargs.pop(self._call_fn_args[0])
3005    else:
3006      raise ValueError(
3007          'The first argument to `Layer.call` must always be passed.')
3008    return inputs, args, kwargs
3009
3010  # SavedModel properties. Please see keras/saving/saved_model for details.
3011
3012  @trackable.no_automatic_dependency_tracking
3013  def _set_save_spec(self, inputs):
3014    if self._saved_model_inputs_spec is not None:
3015      return  # Already set.
3016
3017    self._saved_model_inputs_spec = nest.map_structure(tf_utils.get_tensor_spec,
3018                                                       inputs)
3019
3020  def _get_save_spec(self, dynamic_batch=True):
3021    if self._saved_model_inputs_spec is None:
3022      return None
3023
3024    return nest.map_structure(
3025        lambda t: tf_utils.get_tensor_spec(t, dynamic_batch=dynamic_batch),
3026        self._saved_model_inputs_spec)
3027
3028  @property
3029  def _trackable_saved_model_saver(self):
3030    return layer_serialization.LayerSavedModelSaver(self)
3031
3032  @property
3033  def _object_identifier(self):
3034    return self._trackable_saved_model_saver.object_identifier
3035
3036  @property
3037  def _tracking_metadata(self):
3038    """Info about this layer to be saved into the SavedModel."""
3039    return self._trackable_saved_model_saver.tracking_metadata
3040
3041  def _trackable_children(self, save_type='checkpoint', **kwargs):
3042    if save_type == 'savedmodel':
3043      cache = kwargs['cache']
3044      # TODO(b/213628533): This must be called before super() to ensure
3045      # that any input shape changes are applied before getting the config of
3046      # the model.
3047      children = self._trackable_saved_model_saver.trackable_children(cache)
3048    else:
3049      children = {}
3050    children.update(super()._trackable_children(save_type, **kwargs))
3051    return children
3052
3053  @property
3054  def _use_input_spec_as_call_signature(self):
3055    # Whether input spec can be used as the call signature when tracing the
3056    # Layer for SavedModel. By default, this is set to `True` for layers
3057    # exported from the Keras library, because the layers more rigidly define
3058    # the `input_specs` property (many custom layers only set the `ndims`)
3059    return get_canonical_name_for_symbol(type(self),
3060                                         api_name='keras') is not None
3061
3062  def __getstate__(self):
3063    # Override to support `copy.deepcopy` and pickling.
3064    # Thread-local objects cannot be copied in Python 3, so pop these.
3065    # Thread-local objects are used to cache losses in MirroredStrategy, and
3066    # so shouldn't be copied.
3067    state = self.__dict__.copy()
3068    state.pop('_thread_local', None)
3069    state.pop('_metrics_lock', None)
3070    return state
3071
3072  def __setstate__(self, state):
3073    state['_thread_local'] = threading.local()
3074    state['_metrics_lock'] = threading.Lock()
3075    # Bypass Trackable logic as `__dict__` already contains this info.
3076    object.__setattr__(self, '__dict__', state)
3077
3078
3079class TensorFlowOpLayer(Layer):
3080  """Wraps a TensorFlow Operation in a Layer.
3081
3082  This class is used internally by the Functional API. When a user
3083  uses a raw TensorFlow Operation on symbolic tensors originating
3084  from an `Input` Layer, the resultant operation will be wrapped
3085  with this Layer object in order to make the operation compatible
3086  with the Keras API.
3087
3088  This Layer will create a new, identical operation (except for inputs
3089  and outputs) every time it is called. If `run_eagerly` is `True`,
3090  the op creation and calculation will happen inside an Eager function.
3091
3092  Instances of this Layer are created when `autolambda` is called, which
3093  is whenever a Layer's `__call__` encounters symbolic inputs that do
3094  not have Keras metadata, or when a Network's `__init__` encounters
3095  outputs that do not have Keras metadata.
3096
3097  Attributes:
3098    node_def: String, the serialized NodeDef of the Op this layer will wrap.
3099    name: String, the name of the Layer.
3100    constants: Dict of NumPy arrays, the values of any Tensors needed for this
3101      Operation that do not originate from a Keras `Input` Layer. Since all
3102      placeholders must come from Keras `Input` Layers, these Tensors must be
3103      treated as constant in the Functional API.
3104    trainable: Bool, whether this Layer is trainable. Currently Variables are
3105      not supported, and so this parameter has no effect.
3106    dtype: The default dtype of this Layer. Inherited from `Layer` and has no
3107      effect on this class, however is used in `get_config`.
3108  """
3109
3110  @trackable.no_automatic_dependency_tracking
3111  def __init__(self,
3112               node_def,
3113               name,
3114               constants=None,
3115               trainable=True,
3116               dtype=None):
3117    # Pass autocast=False, as if inputs are cast, input types might not match
3118    # Operation type.
3119    super(TensorFlowOpLayer, self).__init__(
3120        name=_TF_OP_LAYER_NAME_PREFIX + name, trainable=trainable, dtype=dtype,
3121        autocast=False)
3122    if isinstance(node_def, dict):
3123      self.node_def = json_format.ParseDict(node_def, node_def_pb2.NodeDef())
3124    else:
3125      if not isinstance(node_def, bytes):
3126        node_def = node_def.encode('utf-8')
3127      self.node_def = node_def_pb2.NodeDef.FromString(node_def)
3128    # JSON serialization stringifies keys which are integer input indices.
3129    self.constants = ({
3130        int(index): constant for index, constant in constants.items()
3131    } if constants is not None else {})
3132    # Layer uses original op unless it is called on new inputs.
3133    # This means `built` is not set in `__call__`.
3134    self.built = True
3135
3136    # Do not individually trace TensorflowOpLayers in the SavedModel.
3137    self._must_restore_from_config = True
3138
3139  def call(self, inputs):
3140    if context.executing_eagerly():
3141      return self._defun_call(inputs)
3142    return self._make_op(inputs)
3143
3144  def _make_node_def(self, graph):
3145    node_def = node_def_pb2.NodeDef()
3146    node_def.CopyFrom(self.node_def)
3147    # Used in TPUReplicateContext to indicate whether this node has been cloned
3148    # and to not add TPU attributes.
3149    node_def.attr['_cloned'].b = True
3150    node_def.name = graph.unique_name(node_def.name)
3151    return node_def
3152
3153  def _make_op(self, inputs):
3154    inputs = nest.flatten(inputs)
3155    graph = inputs[0].graph
3156    node_def = self._make_node_def(graph)
3157    with graph.as_default():
3158      for index, constant in self.constants.items():
3159        # Recreate constant in graph to add distribution context.
3160        value = tensor_util.constant_value(constant)
3161        if value is not None:
3162          constant = constant_op.constant(value, name=node_def.input[index])
3163        inputs.insert(index, constant)
3164      # TODO(b/183990973): We should drop or consolidate these private api calls
3165      # for adding an op to the graph and recording its gradient.
3166      c_op = ops._create_c_op(graph, node_def, inputs, control_inputs=[])
3167      op = graph._create_op_from_tf_operation(c_op)
3168      op._control_flow_post_processing()
3169
3170      # Record the gradient because custom-made ops don't go through the
3171      # code-gen'd eager call path
3172      op_type = compat.as_str(op.op_def.name)
3173      attr_names = [compat.as_str(attr.name) for attr in op.op_def.attr]
3174      attrs = []
3175      for attr_name in attr_names:
3176        attrs.append(attr_name)
3177        attrs.append(op.get_attr(attr_name))
3178      attrs = tuple(attrs)
3179      backprop.record_gradient(op_type, op.inputs, attrs, op.outputs)
3180
3181      if len(op.outputs) == 1:
3182        return op.outputs[0]
3183      return op.outputs
3184
3185  @def_function.function
3186  def _defun_call(self, inputs):
3187    """Wraps the op creation method in an Eager function for `run_eagerly`."""
3188    return self._make_op(inputs)
3189
3190  def get_config(self):
3191    config = super(TensorFlowOpLayer, self).get_config()
3192    config.update({
3193        # `__init__` prefixes the name. Revert to the constructor argument.
3194        'name': config['name'][len(_TF_OP_LAYER_NAME_PREFIX):],
3195        'node_def': json_format.MessageToDict(self.node_def),
3196        'constants': {
3197            i: backend.get_value(c) for i, c in self.constants.items()
3198        }
3199    })
3200    return config
3201
3202
3203class AddLoss(Layer):
3204  """Adds its inputs as a loss.
3205
3206  Attributes:
3207    unconditional: Whether or not the loss should be conditioned on the inputs.
3208  """
3209
3210  def __init__(self, unconditional, **kwargs):
3211    # Pass autocast=False, as there is no reason to cast loss to a different
3212    # dtype.
3213    kwargs['autocast'] = False
3214    super(AddLoss, self).__init__(**kwargs)
3215    self.unconditional = unconditional
3216
3217  def call(self, inputs):
3218    self.add_loss(inputs, inputs=(not self.unconditional))
3219    return inputs
3220
3221  def get_config(self):
3222    config = super(AddLoss, self).get_config()
3223    config.update({'unconditional': self.unconditional})
3224    return config
3225
3226
3227class AddMetric(Layer):
3228  """Adds its inputs as a metric.
3229
3230  Attributes:
3231    aggregation: 'mean' or None. How the inputs should be aggregated.
3232    metric_name: The name to use for this metric.
3233  """
3234
3235  def __init__(self, aggregation=None, metric_name=None, **kwargs):
3236    super(AddMetric, self).__init__(**kwargs)
3237    self.aggregation = aggregation
3238    self.metric_name = metric_name
3239
3240  def call(self, inputs):
3241    self.add_metric(inputs, aggregation=self.aggregation, name=self.metric_name)
3242    return inputs
3243
3244  def get_config(self):
3245    config = super(AddMetric, self).get_config()
3246    config.update({
3247        'aggregation': self.aggregation,
3248        'metric_name': self.metric_name
3249    })
3250    return config
3251
3252
3253def _in_functional_construction_mode(layer, inputs, args, kwargs, input_list):  # pylint: disable=unused-argument
3254  """Check the arguments to see if we are constructing a functional model."""
3255  # We are constructing a functional model if any of the inputs
3256  # are KerasTensors
3257  return any(
3258      isinstance(tensor, keras_tensor.KerasTensor)
3259      for tensor in nest.flatten([inputs, args, kwargs]))
3260
3261
3262def _convert_numpy_or_python_types(x):
3263  if isinstance(x, (np_arrays.ndarray, np.ndarray, float, int)):
3264    return ops.convert_to_tensor_v2_with_dispatch(x)
3265  return x
3266
3267
3268# Avoid breaking users who directly import this symbol from this file.
3269# TODO(fchollet): remove this.
3270InputSpec = input_spec.InputSpec  # pylint:disable=invalid-name
3271