1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Contains the base Layer class, from which all layers inherit.""" 17 18import collections 19import copy 20import functools 21import itertools 22import threading 23import warnings 24import weakref 25 26import numpy as np 27 28from google.protobuf import json_format 29from tensorflow.core.framework import node_def_pb2 30from tensorflow.python import tf2 31from tensorflow.python.autograph.core import ag_ctx 32from tensorflow.python.autograph.impl import api as autograph 33from tensorflow.python.distribute import distribution_strategy_context as ds_context 34from tensorflow.python.eager import backprop 35from tensorflow.python.eager import context 36from tensorflow.python.eager import def_function 37from tensorflow.python.framework import constant_op 38from tensorflow.python.framework import dtypes 39from tensorflow.python.framework import func_graph 40from tensorflow.python.framework import ops 41from tensorflow.python.framework import sparse_tensor 42from tensorflow.python.framework import tensor_spec 43from tensorflow.python.framework import tensor_util 44from tensorflow.python.keras import backend 45from tensorflow.python.keras import constraints 46from tensorflow.python.keras import initializers 47from tensorflow.python.keras import regularizers 48from tensorflow.python.keras.engine import base_layer_utils 49from tensorflow.python.keras.engine import input_spec 50from tensorflow.python.keras.engine import keras_tensor 51from tensorflow.python.keras.engine import node as node_module 52from tensorflow.python.keras.mixed_precision import autocast_variable 53from tensorflow.python.keras.mixed_precision import loss_scale_optimizer 54from tensorflow.python.keras.mixed_precision import policy 55from tensorflow.python.keras.saving.saved_model import layer_serialization 56from tensorflow.python.keras.utils import generic_utils 57from tensorflow.python.keras.utils import layer_utils 58from tensorflow.python.keras.utils import object_identity 59from tensorflow.python.keras.utils import tf_inspect 60from tensorflow.python.keras.utils import tf_utils 61from tensorflow.python.keras.utils import version_utils 62# A module that only depends on `keras.layers` import these from here. 63from tensorflow.python.keras.utils.generic_utils import to_snake_case # pylint: disable=unused-import 64from tensorflow.python.keras.utils.tf_utils import is_tensor_or_tensor_list # pylint: disable=unused-import 65 66from tensorflow.python.module import module 67from tensorflow.python.ops import array_ops 68from tensorflow.python.ops import math_ops 69from tensorflow.python.ops import variables as tf_variables 70from tensorflow.python.ops.numpy_ops import np_arrays 71from tensorflow.python.ops.ragged import ragged_tensor 72from tensorflow.python.platform import tf_logging 73from tensorflow.python.trackable import autotrackable 74from tensorflow.python.trackable import base as trackable 75from tensorflow.python.trackable import data_structures 76from tensorflow.python.util import compat 77from tensorflow.python.util import nest 78from tensorflow.python.util.tf_export import get_canonical_name_for_symbol 79from tensorflow.python.util.tf_export import keras_export 80from tensorflow.tools.docs import doc_controls 81 82# pylint: disable=g-inconsistent-quotes 83metrics_mod = generic_utils.LazyLoader( 84 "metrics_mod", globals(), 85 "tensorflow.python.keras.metrics") 86# pylint: enable=g-inconsistent-quotes 87 88# Prefix that is added to the TF op layer names. 89_TF_OP_LAYER_NAME_PREFIX = 'tf_op_layer_' 90 91# TODO(mdan): Should we have a single generic type for types that can be passed 92# to tf.cast? 93_AUTOCAST_TYPES = (ops.Tensor, sparse_tensor.SparseTensor, 94 ragged_tensor.RaggedTensor) 95 96 97@keras_export('keras.layers.Layer') 98class Layer(module.Module, version_utils.LayerVersionSelector): 99 """This is the class from which all layers inherit. 100 101 A layer is a callable object that takes as input one or more tensors and 102 that outputs one or more tensors. It involves *computation*, defined 103 in the `call()` method, and a *state* (weight variables), defined 104 either in the constructor `__init__()` or in the `build()` method. 105 106 Users will just instantiate a layer and then treat it as a callable. 107 108 Args: 109 trainable: Boolean, whether the layer's variables should be trainable. 110 name: String name of the layer. 111 dtype: The dtype of the layer's computations and weights. Can also be a 112 `tf.keras.mixed_precision.Policy`, which allows the computation and weight 113 dtype to differ. Default of `None` means to use 114 `tf.keras.mixed_precision.global_policy()`, which is a float32 policy 115 unless set to different value. 116 dynamic: Set this to `True` if your layer should only be run eagerly, and 117 should not be used to generate a static computation graph. 118 This would be the case for a Tree-RNN or a recursive network, 119 for example, or generally for any layer that manipulates tensors 120 using Python control flow. If `False`, we assume that the layer can 121 safely be used to generate a static computation graph. 122 123 Attributes: 124 name: The name of the layer (string). 125 dtype: The dtype of the layer's weights. 126 variable_dtype: Alias of `dtype`. 127 compute_dtype: The dtype of the layer's computations. Layers automatically 128 cast inputs to this dtype which causes the computations and output to also 129 be in this dtype. When mixed precision is used with a 130 `tf.keras.mixed_precision.Policy`, this will be different than 131 `variable_dtype`. 132 dtype_policy: The layer's dtype policy. See the 133 `tf.keras.mixed_precision.Policy` documentation for details. 134 trainable_weights: List of variables to be included in backprop. 135 non_trainable_weights: List of variables that should not be 136 included in backprop. 137 weights: The concatenation of the lists trainable_weights and 138 non_trainable_weights (in this order). 139 trainable: Whether the layer should be trained (boolean), i.e. whether 140 its potentially-trainable weights should be returned as part of 141 `layer.trainable_weights`. 142 input_spec: Optional (list of) `InputSpec` object(s) specifying the 143 constraints on inputs that can be accepted by the layer. 144 145 We recommend that descendants of `Layer` implement the following methods: 146 147 * `__init__()`: Defines custom layer attributes, and creates layer state 148 variables that do not depend on input shapes, using `add_weight()`. 149 * `build(self, input_shape)`: This method can be used to create weights that 150 depend on the shape(s) of the input(s), using `add_weight()`. `__call__()` 151 will automatically build the layer (if it has not been built yet) by 152 calling `build()`. 153 * `call(self, inputs, *args, **kwargs)`: Called in `__call__` after making 154 sure `build()` has been called. `call()` performs the logic of applying the 155 layer to the input tensors (which should be passed in as argument). 156 Two reserved keyword arguments you can optionally use in `call()` are: 157 - `training` (boolean, whether the call is in inference mode or training 158 mode). See more details in [the layer/model subclassing guide]( 159 https://www.tensorflow.org/guide/keras/custom_layers_and_models#privileged_training_argument_in_the_call_method) 160 - `mask` (boolean tensor encoding masked timesteps in the input, used 161 in RNN layers). See more details in [the layer/model subclassing guide]( 162 https://www.tensorflow.org/guide/keras/custom_layers_and_models#privileged_mask_argument_in_the_call_method) 163 A typical signature for this method is `call(self, inputs)`, and user could 164 optionally add `training` and `mask` if the layer need them. `*args` and 165 `**kwargs` is only useful for future extension when more input parameters 166 are planned to be added. 167 * `get_config(self)`: Returns a dictionary containing the configuration used 168 to initialize this layer. If the keys differ from the arguments 169 in `__init__`, then override `from_config(self)` as well. 170 This method is used when saving 171 the layer or a model that contains this layer. 172 173 Examples: 174 175 Here's a basic example: a layer with two variables, `w` and `b`, 176 that returns `y = w . x + b`. 177 It shows how to implement `build()` and `call()`. 178 Variables set as attributes of a layer are tracked as weights 179 of the layers (in `layer.weights`). 180 181 ```python 182 class SimpleDense(Layer): 183 184 def __init__(self, units=32): 185 super(SimpleDense, self).__init__() 186 self.units = units 187 188 def build(self, input_shape): # Create the state of the layer (weights) 189 w_init = tf.random_normal_initializer() 190 self.w = tf.Variable( 191 initial_value=w_init(shape=(input_shape[-1], self.units), 192 dtype='float32'), 193 trainable=True) 194 b_init = tf.zeros_initializer() 195 self.b = tf.Variable( 196 initial_value=b_init(shape=(self.units,), dtype='float32'), 197 trainable=True) 198 199 def call(self, inputs): # Defines the computation from inputs to outputs 200 return tf.matmul(inputs, self.w) + self.b 201 202 # Instantiates the layer. 203 linear_layer = SimpleDense(4) 204 205 # This will also call `build(input_shape)` and create the weights. 206 y = linear_layer(tf.ones((2, 2))) 207 assert len(linear_layer.weights) == 2 208 209 # These weights are trainable, so they're listed in `trainable_weights`: 210 assert len(linear_layer.trainable_weights) == 2 211 ``` 212 213 Note that the method `add_weight()` offers a shortcut to create weights: 214 215 ```python 216 class SimpleDense(Layer): 217 218 def __init__(self, units=32): 219 super(SimpleDense, self).__init__() 220 self.units = units 221 222 def build(self, input_shape): 223 self.w = self.add_weight(shape=(input_shape[-1], self.units), 224 initializer='random_normal', 225 trainable=True) 226 self.b = self.add_weight(shape=(self.units,), 227 initializer='random_normal', 228 trainable=True) 229 230 def call(self, inputs): 231 return tf.matmul(inputs, self.w) + self.b 232 ``` 233 234 Besides trainable weights, updated via backpropagation during training, 235 layers can also have non-trainable weights. These weights are meant to 236 be updated manually during `call()`. Here's a example layer that computes 237 the running sum of its inputs: 238 239 ```python 240 class ComputeSum(Layer): 241 242 def __init__(self, input_dim): 243 super(ComputeSum, self).__init__() 244 # Create a non-trainable weight. 245 self.total = tf.Variable(initial_value=tf.zeros((input_dim,)), 246 trainable=False) 247 248 def call(self, inputs): 249 self.total.assign_add(tf.reduce_sum(inputs, axis=0)) 250 return self.total 251 252 my_sum = ComputeSum(2) 253 x = tf.ones((2, 2)) 254 255 y = my_sum(x) 256 print(y.numpy()) # [2. 2.] 257 258 y = my_sum(x) 259 print(y.numpy()) # [4. 4.] 260 261 assert my_sum.weights == [my_sum.total] 262 assert my_sum.non_trainable_weights == [my_sum.total] 263 assert my_sum.trainable_weights == [] 264 ``` 265 266 For more information about creating layers, see the guide 267 [Making new Layers and Models via subclassing]( 268 https://www.tensorflow.org/guide/keras/custom_layers_and_models) 269 """ 270 271 # See tf.Module for the usage of this property. 272 # The key for _obj_reference_counts_dict is a Trackable, which could be a 273 # variable or layer etc. tf.Module._flatten will fail to flatten the key 274 # since it is trying to convert Trackable to a string. This attribute can be 275 # ignored even after the fix of nest lib, since the trackable object should 276 # already been available as individual attributes. _obj_reference_counts_dict 277 # just contains a copy of them. 278 _TF_MODULE_IGNORED_PROPERTIES = frozenset(itertools.chain( 279 ('_obj_reference_counts_dict',), 280 module.Module._TF_MODULE_IGNORED_PROPERTIES 281 )) 282 283 # When loading from a SavedModel, Layers typically can be revived into a 284 # generic Layer wrapper. Sometimes, however, layers may implement methods 285 # that go beyond this wrapper, as in the case of PreprocessingLayers' 286 # `adapt` method. When this is the case, layer implementers can override 287 # must_restore_from_config to return True; layers with this property must 288 # be restored into their actual objects (and will fail if the object is 289 # not available to the restoration code). 290 _must_restore_from_config = False 291 292 def _get_cell_name(self): 293 canonical_name = get_canonical_name_for_symbol( 294 self.__class__, api_name='keras', add_prefix_to_v1_names=True) 295 if canonical_name is not None: 296 return 'tf.{}'.format(canonical_name) 297 return self.__class__.__module__ + '.' + self.__class__.__name__ 298 299 def _instrument_layer_creation(self): 300 self._instrumented_keras_api = False 301 self._instrumented_keras_layer_class = False 302 self._instrumented_keras_model_class = False 303 if not getattr(self, '_disable_keras_instrumentation', False): 304 self._instrumented_keras_api = True 305 if getattr(self, '_is_model_for_instrumentation', False): 306 self._instrumented_keras_model_class = True 307 else: 308 self._instrumented_keras_layer_class = True 309 310 @trackable.no_automatic_dependency_tracking 311 def __init__(self, 312 trainable=True, 313 name=None, 314 dtype=None, 315 dynamic=False, 316 **kwargs): 317 self._instrument_layer_creation() 318 319 # These properties should be set by the user via keyword arguments. 320 # note that 'dtype', 'input_shape' and 'batch_input_shape' 321 # are only applicable to input layers: do not pass these keywords 322 # to non-input layers. 323 allowed_kwargs = { 324 'input_dim', 325 'input_shape', 326 'batch_input_shape', 327 'batch_size', 328 'weights', 329 'activity_regularizer', 330 'autocast', 331 'implementation', 332 } 333 # Validate optional keyword arguments. 334 generic_utils.validate_kwargs(kwargs, allowed_kwargs) 335 336 # Mutable properties 337 # Indicates whether the layer's weights are updated during training 338 # and whether the layer's updates are run during training. 339 self._trainable = trainable 340 # A stateful layer is a layer whose updates are run during inference too, 341 # for instance stateful RNNs. 342 self._stateful = False 343 # Indicates whether `build` needs to be called upon layer call, to create 344 # the layer's weights. 345 self.built = False 346 # Provides information about which inputs are compatible with the layer. 347 self._input_spec = None 348 349 # SavedModel-related attributes. 350 # Record the build input shape for loading purposes. 351 # TODO(kathywu): Move this to Layer._set_save_spec once cl/290121460 is 352 # submitted. 353 self._build_input_shape = None 354 self._saved_model_inputs_spec = None 355 356 # `Layer.compute_mask` will be called at the end of `Layer.__call__` if 357 # `Layer.compute_mask` is overridden, or if the `Layer` subclass sets 358 # `self.supports_masking=True`. 359 self._supports_masking = not generic_utils.is_default(self.compute_mask) 360 361 self._init_set_name(name) 362 self._activity_regularizer = regularizers.get( 363 kwargs.pop('activity_regularizer', None)) 364 self._maybe_create_attribute('_trainable_weights', []) 365 self._maybe_create_attribute('_non_trainable_weights', []) 366 self._updates = [] 367 # Object to store all thread local layer properties. 368 self._thread_local = threading.local() 369 # A list of zero-argument lambdas which return Tensors, used for variable 370 # regularizers. 371 self._callable_losses = [] 372 # A list of symbolic Tensors containing activity regularizers and losses 373 # manually added through `add_loss` in graph-building mode. 374 self._losses = [] 375 # A list of metric instances corresponding to the symbolic metric tensors 376 # added using the `add_metric` API. 377 self._metrics = [] 378 # Ensures the same metric is not added multiple times in `MirroredStrategy`. 379 self._metrics_lock = threading.Lock() 380 381 # Both graph and subclassed networks have a dtype policy. For graph 382 # networks, the policy's compute and variable dtypes are ignored. Such 383 # networks only use the policy if it is a PolicyV1, in which case it uses 384 # the PolicyV1's loss_scale (Policy does not have a loss_scale). For 385 # subclassed networks, the compute and variable dtypes are used as like any 386 # ordinary layer. 387 self._set_dtype_policy(dtype) 388 # Boolean indicating whether the layer automatically casts its inputs to the 389 # layer's compute_dtype. 390 self._autocast = kwargs.get('autocast', 391 base_layer_utils.v2_dtype_behavior_enabled()) 392 393 # Tracks `TrackableDataStructure`s, `Module`s, and `Layer`s. 394 # Ordered by when the object was assigned as an attr. 395 # Entries are unique. 396 self._maybe_create_attribute('_self_tracked_trackables', []) 397 398 # These lists will be filled via successive calls 399 # to self._add_inbound_node(). 400 # Used in symbolic mode only, only in conjunction with graph-networks 401 self._inbound_nodes_value = [] 402 self._outbound_nodes_value = [] 403 404 self._init_call_fn_args() 405 406 # Whether the `call` method can be used to build a TF graph without issues. 407 # This attribute has no effect if the model is created using the Functional 408 # API. Instead, `model.dynamic` is determined based on the internal layers. 409 self._dynamic = dynamic 410 411 # Manage input shape information if passed. 412 if 'input_dim' in kwargs and 'input_shape' not in kwargs: 413 # Backwards compatibility: alias 'input_dim' to 'input_shape'. 414 kwargs['input_shape'] = (kwargs['input_dim'],) 415 if 'input_shape' in kwargs or 'batch_input_shape' in kwargs: 416 # In this case we will later create an input layer 417 # to insert before the current layer 418 if 'batch_input_shape' in kwargs: 419 batch_input_shape = tuple(kwargs['batch_input_shape']) 420 elif 'input_shape' in kwargs: 421 if 'batch_size' in kwargs: 422 batch_size = kwargs['batch_size'] 423 else: 424 batch_size = None 425 batch_input_shape = (batch_size,) + tuple(kwargs['input_shape']) 426 self._batch_input_shape = batch_input_shape 427 428 # Manage initial weight values if passed. 429 self._initial_weights = kwargs.get('weights', None) 430 431 # Whether the layer will track any layers that is set as attribute on itself 432 # as sub-layers, the weights from the sub-layers will be included in the 433 # parent layer's variables() as well. 434 # Default to True, which means auto tracking is turned on. Certain subclass 435 # might want to turn it off, like Sequential model. 436 self._auto_track_sub_layers = True 437 438 # For backwards compat reasons, most built-in layers do not guarantee 439 # That they will 100% preserve the structure of input args when saving 440 # / loading configs. E.g. they may un-nest an arg that is 441 # a list with one element. 442 self._preserve_input_structure_in_config = False 443 444 @trackable.no_automatic_dependency_tracking 445 @generic_utils.default 446 def build(self, input_shape): 447 """Creates the variables of the layer (optional, for subclass implementers). 448 449 This is a method that implementers of subclasses of `Layer` or `Model` 450 can override if they need a state-creation step in-between 451 layer instantiation and layer call. 452 453 This is typically used to create the weights of `Layer` subclasses. 454 455 Args: 456 input_shape: Instance of `TensorShape`, or list of instances of 457 `TensorShape` if the layer expects a list of inputs 458 (one instance per input). 459 """ 460 # Only record the build input shapes of overridden build methods. 461 if not hasattr(self.build, '_is_default'): 462 self._build_input_shape = input_shape 463 self.built = True 464 465 @doc_controls.for_subclass_implementers 466 def call(self, inputs, *args, **kwargs): # pylint: disable=unused-argument 467 """This is where the layer's logic lives. 468 469 Note here that `call()` method in `tf.keras` is little bit different 470 from `keras` API. In `keras` API, you can pass support masking for 471 layers as additional arguments. Whereas `tf.keras` has `compute_mask()` 472 method to support masking. 473 474 Args: 475 inputs: Input tensor, or dict/list/tuple of input tensors. 476 The first positional `inputs` argument is subject to special rules: 477 - `inputs` must be explicitly passed. A layer cannot have zero 478 arguments, and `inputs` cannot be provided via the default value 479 of a keyword argument. 480 - NumPy array or Python scalar values in `inputs` get cast as tensors. 481 - Keras mask metadata is only collected from `inputs`. 482 - Layers are built (`build(input_shape)` method) 483 using shape info from `inputs` only. 484 - `input_spec` compatibility is only checked against `inputs`. 485 - Mixed precision input casting is only applied to `inputs`. 486 If a layer has tensor arguments in `*args` or `**kwargs`, their 487 casting behavior in mixed precision should be handled manually. 488 - The SavedModel input specification is generated using `inputs` only. 489 - Integration with various ecosystem packages like TFMOT, TFLite, 490 TF.js, etc is only supported for `inputs` and not for tensors in 491 positional and keyword arguments. 492 *args: Additional positional arguments. May contain tensors, although 493 this is not recommended, for the reasons above. 494 **kwargs: Additional keyword arguments. May contain tensors, although 495 this is not recommended, for the reasons above. 496 The following optional keyword arguments are reserved: 497 - `training`: Boolean scalar tensor of Python boolean indicating 498 whether the `call` is meant for training or inference. 499 - `mask`: Boolean input mask. If the layer's `call()` method takes a 500 `mask` argument, its default value will be set to the mask generated 501 for `inputs` by the previous layer (if `input` did come from a layer 502 that generated a corresponding mask, i.e. if it came from a Keras 503 layer with masking support). 504 505 Returns: 506 A tensor or list/tuple of tensors. 507 """ 508 return inputs 509 510 @doc_controls.for_subclass_implementers 511 def _add_trackable(self, trackable_object, trainable): 512 """Adds a Trackable object to this layer's state. 513 514 Args: 515 trackable_object: The tf.tracking.Trackable object to add. 516 trainable: Boolean, whether the variable should be part of the layer's 517 "trainable_variables" (e.g. variables, biases) or 518 "non_trainable_variables" (e.g. BatchNorm mean and variance). 519 520 Returns: 521 The TrackableWeightHandler used to track this object. 522 """ 523 if isinstance(trackable_object, base_layer_utils.TrackableWeightHandler): 524 handler = trackable_object 525 else: 526 handler = base_layer_utils.TrackableWeightHandler(trackable_object) 527 if trainable: 528 self._trainable_weights.append(handler) 529 else: 530 self._non_trainable_weights.append(handler) 531 return handler 532 533 @doc_controls.for_subclass_implementers 534 def add_weight(self, 535 name=None, 536 shape=None, 537 dtype=None, 538 initializer=None, 539 regularizer=None, 540 trainable=None, 541 constraint=None, 542 use_resource=None, 543 synchronization=tf_variables.VariableSynchronization.AUTO, 544 aggregation=tf_variables.VariableAggregation.NONE, 545 **kwargs): 546 """Adds a new variable to the layer. 547 548 Args: 549 name: Variable name. 550 shape: Variable shape. Defaults to scalar if unspecified. 551 dtype: The type of the variable. Defaults to `self.dtype`. 552 initializer: Initializer instance (callable). 553 regularizer: Regularizer instance (callable). 554 trainable: Boolean, whether the variable should be part of the layer's 555 "trainable_variables" (e.g. variables, biases) 556 or "non_trainable_variables" (e.g. BatchNorm mean and variance). 557 Note that `trainable` cannot be `True` if `synchronization` 558 is set to `ON_READ`. 559 constraint: Constraint instance (callable). 560 use_resource: Whether to use `ResourceVariable`. 561 synchronization: Indicates when a distributed a variable will be 562 aggregated. Accepted values are constants defined in the class 563 `tf.VariableSynchronization`. By default the synchronization is set to 564 `AUTO` and the current `DistributionStrategy` chooses 565 when to synchronize. If `synchronization` is set to `ON_READ`, 566 `trainable` must not be set to `True`. 567 aggregation: Indicates how a distributed variable will be aggregated. 568 Accepted values are constants defined in the class 569 `tf.VariableAggregation`. 570 **kwargs: Additional keyword arguments. Accepted values are `getter`, 571 `collections`, `experimental_autocast` and `caching_device`. 572 573 Returns: 574 The variable created. 575 576 Raises: 577 ValueError: When giving unsupported dtype and no initializer or when 578 trainable has been set to True with synchronization set as `ON_READ`. 579 """ 580 if shape is None: 581 shape = () 582 kwargs.pop('partitioner', None) # Ignored. 583 # Validate optional keyword arguments. 584 for kwarg in kwargs: 585 if kwarg not in ['collections', 'experimental_autocast', 586 'caching_device', 'getter']: 587 raise TypeError('Unknown keyword argument:', kwarg) 588 collections_arg = kwargs.pop('collections', None) 589 # 'experimental_autocast' can be set to False by the caller to indicate an 590 # AutoCastVariable should never be created. 591 autocast = kwargs.pop('experimental_autocast', True) 592 # See the docstring for tf.Variable about the details for caching_device. 593 caching_device = kwargs.pop('caching_device', None) 594 595 if dtype is None: 596 dtype = self.dtype or backend.floatx() 597 dtype = dtypes.as_dtype(dtype) 598 if self._dtype_policy.variable_dtype is None: 599 # The policy is "_infer", so we infer the policy from the variable dtype. 600 self._set_dtype_policy(policy.Policy(dtype.base_dtype.name)) 601 initializer = initializers.get(initializer) 602 regularizer = regularizers.get(regularizer) 603 constraint = constraints.get(constraint) 604 605 if synchronization == tf_variables.VariableSynchronization.ON_READ: 606 if trainable: 607 raise ValueError( 608 'Synchronization value can be set to ' 609 'VariableSynchronization.ON_READ only for non-trainable variables. ' 610 'You have specified trainable=True and ' 611 'synchronization=VariableSynchronization.ON_READ.') 612 else: 613 # Set trainable to be false when variable is to be synced on read. 614 trainable = False 615 elif trainable is None: 616 trainable = True 617 618 # Initialize variable when no initializer provided 619 if initializer is None: 620 # If dtype is DT_FLOAT, provide a uniform unit scaling initializer 621 if dtype.is_floating: 622 initializer = initializers.get('glorot_uniform') 623 # If dtype is DT_INT/DT_UINT, provide a default value `zero` 624 # If dtype is DT_BOOL, provide a default value `FALSE` 625 elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool: 626 initializer = initializers.get('zeros') 627 # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here? 628 elif 'getter' not in kwargs: 629 # When `getter` is specified, it's possibly fine for `initializer` to be 630 # None since it's up to the custom `getter` to raise error in case it 631 # indeed needs `initializer`. 632 raise ValueError('An initializer for variable %s of type %s is required' 633 ' for layer %s' % (name, dtype.base_dtype, self.name)) 634 635 getter = kwargs.pop('getter', base_layer_utils.make_variable) 636 if (autocast and 637 self._dtype_policy.compute_dtype != self._dtype_policy.variable_dtype 638 and dtype.is_floating): 639 old_getter = getter 640 # Wrap variable constructor to return an AutoCastVariable. 641 def getter(*args, **kwargs): # pylint: disable=function-redefined 642 variable = old_getter(*args, **kwargs) 643 return autocast_variable.create_autocast_variable(variable) 644 # Also the caching_device does not work with the mixed precision API, 645 # disable it if it is specified. 646 # TODO(b/142020079): Reenable it once the bug is fixed. 647 if caching_device is not None: 648 tf_logging.warning( 649 '`caching_device` does not work with mixed precision API. Ignoring ' 650 'user specified `caching_device`.') 651 caching_device = None 652 653 variable = self._add_variable_with_custom_getter( 654 name=name, 655 shape=shape, 656 # TODO(allenl): a `make_variable` equivalent should be added as a 657 # `Trackable` method. 658 getter=getter, 659 # Manage errors in Layer rather than Trackable. 660 overwrite=True, 661 initializer=initializer, 662 dtype=dtype, 663 constraint=constraint, 664 trainable=trainable, 665 use_resource=use_resource, 666 collections=collections_arg, 667 synchronization=synchronization, 668 aggregation=aggregation, 669 caching_device=caching_device) 670 if regularizer is not None: 671 # TODO(fchollet): in the future, this should be handled at the 672 # level of variable creation, and weight regularization losses 673 # should be variable attributes. 674 name_in_scope = variable.name[:variable.name.find(':')] 675 self._handle_weight_regularization(name_in_scope, 676 variable, 677 regularizer) 678 if base_layer_utils.is_split_variable(variable): 679 for v in variable: 680 backend.track_variable(v) 681 if trainable: 682 self._trainable_weights.append(v) 683 else: 684 self._non_trainable_weights.append(v) 685 else: 686 backend.track_variable(variable) 687 if trainable: 688 self._trainable_weights.append(variable) 689 else: 690 self._non_trainable_weights.append(variable) 691 return variable 692 693 @generic_utils.default 694 def get_config(self): 695 """Returns the config of the layer. 696 697 A layer config is a Python dictionary (serializable) 698 containing the configuration of a layer. 699 The same layer can be reinstantiated later 700 (without its trained weights) from this configuration. 701 702 The config of a layer does not include connectivity 703 information, nor the layer class name. These are handled 704 by `Network` (one layer of abstraction above). 705 706 Note that `get_config()` does not guarantee to return a fresh copy of dict 707 every time it is called. The callers should make a copy of the returned dict 708 if they want to modify it. 709 710 Returns: 711 Python dictionary. 712 """ 713 all_args = tf_inspect.getfullargspec(self.__init__).args 714 config = { 715 'name': self.name, 716 'trainable': self.trainable, 717 } 718 if hasattr(self, '_batch_input_shape'): 719 config['batch_input_shape'] = self._batch_input_shape 720 config['dtype'] = policy.serialize(self._dtype_policy) 721 if hasattr(self, 'dynamic'): 722 # Only include `dynamic` in the `config` if it is `True` 723 if self.dynamic: 724 config['dynamic'] = self.dynamic 725 elif 'dynamic' in all_args: 726 all_args.remove('dynamic') 727 expected_args = config.keys() 728 # Finds all arguments in the `__init__` that are not in the config: 729 extra_args = [arg for arg in all_args if arg not in expected_args] 730 # Check that either the only argument in the `__init__` is `self`, 731 # or that `get_config` has been overridden: 732 if len(extra_args) > 1 and hasattr(self.get_config, '_is_default'): 733 raise NotImplementedError('Layer %s has arguments in `__init__` and ' 734 'therefore must override `get_config`.' % 735 self.__class__.__name__) 736 return config 737 738 @classmethod 739 def from_config(cls, config): 740 """Creates a layer from its config. 741 742 This method is the reverse of `get_config`, 743 capable of instantiating the same layer from the config 744 dictionary. It does not handle layer connectivity 745 (handled by Network), nor weights (handled by `set_weights`). 746 747 Args: 748 config: A Python dictionary, typically the 749 output of get_config. 750 751 Returns: 752 A layer instance. 753 """ 754 return cls(**config) 755 756 def compute_output_shape(self, input_shape): 757 """Computes the output shape of the layer. 758 759 If the layer has not been built, this method will call `build` on the 760 layer. This assumes that the layer will later be used with inputs that 761 match the input shape provided here. 762 763 Args: 764 input_shape: Shape tuple (tuple of integers) 765 or list of shape tuples (one per output tensor of the layer). 766 Shape tuples can include None for free dimensions, 767 instead of an integer. 768 769 Returns: 770 An input shape tuple. 771 """ 772 if context.executing_eagerly(): 773 # In this case we build the model first in order to do shape inference. 774 # This is acceptable because the framework only calls 775 # `compute_output_shape` on shape values that the layer would later be 776 # built for. It would however cause issues in case a user attempts to 777 # use `compute_output_shape` manually with shapes that are incompatible 778 # with the shape the Layer will be called on (these users will have to 779 # implement `compute_output_shape` themselves). 780 self._maybe_build(input_shape) 781 with func_graph.FuncGraph(str(self.name) + '_scratch_graph').as_default(): 782 input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False) 783 def _make_placeholder_like(shape): 784 ph = backend.placeholder(shape=shape, dtype=self.dtype) 785 ph._keras_mask = None 786 return ph 787 inputs = nest.map_structure(_make_placeholder_like, input_shape) 788 try: 789 outputs = self(inputs, training=False) 790 except TypeError as e: 791 raise NotImplementedError( 792 'We could not automatically infer the static shape of the ' 793 'layer\'s output. Please implement the ' 794 '`compute_output_shape` method on your layer (%s).' % 795 self.__class__.__name__) from e 796 return nest.map_structure(lambda t: t.shape, outputs) 797 raise NotImplementedError( 798 'Please run in eager mode or implement the `compute_output_shape` ' 799 'method on your layer (%s).' % self.__class__.__name__) 800 801 @doc_controls.for_subclass_implementers 802 def compute_output_signature(self, input_signature): 803 """Compute the output tensor signature of the layer based on the inputs. 804 805 Unlike a TensorShape object, a TensorSpec object contains both shape 806 and dtype information for a tensor. This method allows layers to provide 807 output dtype information if it is different from the input dtype. 808 For any layer that doesn't implement this function, 809 the framework will fall back to use `compute_output_shape`, and will 810 assume that the output dtype matches the input dtype. 811 812 Args: 813 input_signature: Single TensorSpec or nested structure of TensorSpec 814 objects, describing a candidate input for the layer. 815 816 Returns: 817 Single TensorSpec or nested structure of TensorSpec objects, describing 818 how the layer would transform the provided input. 819 820 Raises: 821 TypeError: If input_signature contains a non-TensorSpec object. 822 """ 823 def check_type_return_shape(s): 824 if not isinstance(s, tensor_spec.TensorSpec): 825 raise TypeError('Only TensorSpec signature types are supported, ' 826 'but saw signature entry: {}.'.format(s)) 827 return s.shape 828 input_shape = nest.map_structure(check_type_return_shape, input_signature) 829 output_shape = self.compute_output_shape(input_shape) 830 dtype = self._compute_dtype 831 if dtype is None: 832 input_dtypes = [s.dtype for s in nest.flatten(input_signature)] 833 # Default behavior when self.dtype is None, is to use the first input's 834 # dtype. 835 dtype = input_dtypes[0] 836 return nest.map_structure( 837 lambda s: tensor_spec.TensorSpec(dtype=dtype, shape=s), 838 output_shape) 839 840 def _keras_tensor_symbolic_call(self, inputs, input_masks, args, kwargs): 841 if self.dynamic: 842 # We will use static shape inference to return symbolic tensors 843 # matching the specifications of the layer outputs. 844 # Since `self.dynamic` is True, we will never attempt to 845 # run the underlying TF graph (which is disconnected). 846 # TODO(fchollet): consider py_func as an alternative, which 847 # would enable us to run the underlying graph if needed. 848 input_signature = nest.map_structure( 849 lambda x: tensor_spec.TensorSpec(shape=x.shape, dtype=x.dtype), 850 inputs) 851 output_signature = self.compute_output_signature(input_signature) 852 return nest.map_structure(keras_tensor.KerasTensor, output_signature) 853 else: 854 return self._infer_output_signature(inputs, args, kwargs, input_masks) 855 856 def _infer_output_signature(self, inputs, args, kwargs, input_masks): 857 """TODO(kaftan): Docstring.""" 858 859 call_fn = self.call 860 # Wrapping `call` function in autograph to allow for dynamic control 861 # flow and control dependencies in call. We are limiting this to 862 # subclassed layers as autograph is strictly needed only for 863 # subclassed layers and models. 864 # tf_convert will respect the value of autograph setting in the 865 # enclosing tf.function, if any. 866 if (base_layer_utils.is_subclassed(self) and 867 not base_layer_utils.from_saved_model(self)): 868 call_fn = autograph.tf_convert(self.call, ag_ctx.control_status_ctx()) 869 870 # We enter a scratch graph and build placeholder inputs inside of it that 871 # match the input args. 872 # We then call the layer inside of the scratch graph to identify the 873 # output signatures, then we build KerasTensors corresponding to those 874 # outputs. 875 scratch_graph = func_graph.FuncGraph(str(self.name) + '_scratch_graph') 876 with scratch_graph.as_default(): 877 inputs = nest.map_structure( 878 keras_tensor.keras_tensor_to_placeholder, inputs) 879 args = nest.map_structure( 880 keras_tensor.keras_tensor_to_placeholder, args) 881 kwargs = nest.map_structure( 882 keras_tensor.keras_tensor_to_placeholder, kwargs) 883 input_masks = nest.map_structure( 884 keras_tensor.keras_tensor_to_placeholder, input_masks) 885 886 with backend.name_scope(self._name_scope()): # pylint: disable=not-callable 887 with autocast_variable.enable_auto_cast_variables( 888 self._compute_dtype_object): 889 # Build layer if applicable (if the `build` method has been 890 # overridden). 891 # TODO(kaftan): do we maybe_build here, or have we already done it? 892 self._maybe_build(inputs) 893 inputs = self._maybe_cast_inputs(inputs) 894 outputs = call_fn(inputs, *args, **kwargs) 895 896 self._handle_activity_regularization(inputs, outputs) 897 self._set_mask_metadata(inputs, outputs, input_masks, 898 build_graph=False) 899 outputs = nest.map_structure( 900 keras_tensor.keras_tensor_from_tensor, outputs) 901 902 if hasattr(self, '_set_inputs') and not self.inputs: 903 # TODO(kaftan): figure out if we need to do this at all 904 # Subclassed network: explicitly set metadata normally set by 905 # a call to self._set_inputs(). 906 self._set_inputs(inputs, outputs) 907 del scratch_graph 908 return outputs 909 910 @generic_utils.default 911 def compute_mask(self, inputs, mask=None): # pylint: disable=unused-argument 912 """Computes an output mask tensor. 913 914 Args: 915 inputs: Tensor or list of tensors. 916 mask: Tensor or list of tensors. 917 918 Returns: 919 None or a tensor (or list of tensors, 920 one per output tensor of the layer). 921 """ 922 if not self._supports_masking: 923 if any(m is not None for m in nest.flatten(mask)): 924 raise TypeError('Layer ' + self.name + ' does not support masking, ' 925 'but was passed an input_mask: ' + str(mask)) 926 # masking not explicitly supported: return None as mask. 927 return None 928 # if masking is explicitly supported, by default 929 # carry over the input mask 930 return mask 931 932 def __call__(self, *args, **kwargs): 933 """Wraps `call`, applying pre- and post-processing steps. 934 935 Args: 936 *args: Positional arguments to be passed to `self.call`. 937 **kwargs: Keyword arguments to be passed to `self.call`. 938 939 Returns: 940 Output tensor(s). 941 942 Note: 943 - The following optional keyword arguments are reserved for specific uses: 944 * `training`: Boolean scalar tensor of Python boolean indicating 945 whether the `call` is meant for training or inference. 946 * `mask`: Boolean input mask. 947 - If the layer's `call` method takes a `mask` argument (as some Keras 948 layers do), its default value will be set to the mask generated 949 for `inputs` by the previous layer (if `input` did come from 950 a layer that generated a corresponding mask, i.e. if it came from 951 a Keras layer with masking support. 952 - If the layer is not built, the method will call `build`. 953 954 Raises: 955 ValueError: if the layer's `call` method returns None (an invalid value). 956 RuntimeError: if `super().__init__()` was not called in the constructor. 957 """ 958 if not hasattr(self, '_thread_local'): 959 raise RuntimeError( 960 'You must call `super().__init__()` in the layer constructor.') 961 962 # `inputs` (the first arg in the method spec) is special cased in 963 # layer call due to historical reasons. 964 # This special casing currently takes the form of: 965 # - 'inputs' must be explicitly passed. A layer cannot have zero arguments, 966 # and inputs cannot have been provided via the default value of a kwarg. 967 # - numpy/scalar values in `inputs` get converted to tensors 968 # - implicit masks / mask metadata are only collected from 'inputs` 969 # - Layers are built using shape info from 'inputs' only 970 # - input_spec compatibility is only checked against `inputs` 971 # - mixed precision casting (autocast) is only applied to `inputs`, 972 # not to any other argument. 973 # - setting the SavedModel saving spec. 974 inputs, args, kwargs = self._split_out_first_arg(args, kwargs) 975 input_list = nest.flatten(inputs) 976 977 # Functional Model construction mode is invoked when `Layer`s are called on 978 # symbolic `KerasTensor`s, i.e.: 979 # >> inputs = tf.keras.Input(10) 980 # >> outputs = MyLayer()(inputs) # Functional construction mode. 981 # >> model = tf.keras.Model(inputs, outputs) 982 if _in_functional_construction_mode(self, inputs, args, kwargs, input_list): 983 return self._functional_construction_call(inputs, args, kwargs, 984 input_list) 985 986 # Maintains info about the `Layer.call` stack. 987 call_context = base_layer_utils.call_context() 988 989 # Accept NumPy and scalar inputs by converting to Tensors. 990 if any(isinstance(x, ( 991 np_arrays.ndarray, np.ndarray, float, int)) for x in input_list): 992 inputs = nest.map_structure(_convert_numpy_or_python_types, inputs) 993 input_list = nest.flatten(inputs) 994 995 # Handle `mask` propagation from previous layer to current layer. Masks can 996 # be propagated explicitly via the `mask` argument, or implicitly via 997 # setting the `_keras_mask` attribute on the inputs to a Layer. Masks passed 998 # explicitly take priority. 999 input_masks, mask_is_implicit = self._get_input_masks( 1000 inputs, input_list, args, kwargs) 1001 if self._expects_mask_arg and mask_is_implicit: 1002 kwargs['mask'] = input_masks 1003 1004 # Training mode for `Layer.call` is set via (in order of priority): 1005 # (1) The `training` argument passed to this `Layer.call`, if it is not None 1006 # (2) The training mode of an outer `Layer.call`. 1007 # (3) The default mode set by `tf.keras.backend.set_learning_phase` (if set) 1008 # (4) Any non-None default value for `training` specified in the call 1009 # signature 1010 # (5) False (treating the layer as if it's in inference) 1011 args, kwargs, training_mode = self._set_training_mode( 1012 args, kwargs, call_context) 1013 1014 # Losses are cleared for all sublayers on the outermost `Layer.call`. 1015 # Losses are not cleared on inner `Layer.call`s, because sublayers can be 1016 # called multiple times. 1017 if not call_context.in_call: 1018 self._clear_losses() 1019 1020 eager = context.executing_eagerly() 1021 with call_context.enter( 1022 layer=self, 1023 inputs=inputs, 1024 build_graph=not eager, 1025 training=training_mode): 1026 1027 input_spec.assert_input_compatibility(self.input_spec, inputs, self.name) 1028 if eager: 1029 call_fn = self.call 1030 name_scope = self._name 1031 else: 1032 name_scope = self._name_scope() # Avoid autoincrementing. # pylint: disable=not-callable 1033 call_fn = self._autographed_call() 1034 1035 with ops.name_scope_v2(name_scope): 1036 if not self.built: 1037 self._maybe_build(inputs) 1038 1039 if self._autocast: 1040 inputs = self._maybe_cast_inputs(inputs, input_list) 1041 1042 with autocast_variable.enable_auto_cast_variables( 1043 self._compute_dtype_object): 1044 outputs = call_fn(inputs, *args, **kwargs) 1045 1046 if self._activity_regularizer: 1047 self._handle_activity_regularization(inputs, outputs) 1048 if self._supports_masking: 1049 self._set_mask_metadata(inputs, outputs, input_masks, not eager) 1050 if self._saved_model_inputs_spec is None: 1051 self._set_save_spec(inputs) 1052 1053 return outputs 1054 1055 def _functional_construction_call(self, inputs, args, kwargs, input_list): 1056 call_context = base_layer_utils.call_context() 1057 1058 # Accept NumPy and scalar inputs by converting to Tensors. 1059 if any(isinstance(x, ( 1060 np_arrays.ndarray, np.ndarray, float, int)) for x in input_list): 1061 1062 def _convert_non_tensor(x): 1063 # Don't call `ops.convert_to_tensor` on all `inputs` because 1064 # `SparseTensors` can't be converted to `Tensor`. 1065 if isinstance(x, (np_arrays.ndarray, np.ndarray, float, int)): 1066 return ops.convert_to_tensor_v2_with_dispatch(x) 1067 return x 1068 1069 inputs = nest.map_structure(_convert_non_tensor, inputs) 1070 input_list = nest.flatten(inputs) 1071 1072 # Handle `mask` propagation from previous layer to current layer. Masks can 1073 # be propagated explicitly via the `mask` argument, or implicitly via 1074 # setting the `_keras_mask` attribute on the inputs to a Layer. Masks passed 1075 # explicitly take priority. 1076 mask_arg_passed_by_framework = False 1077 input_masks, mask_is_implicit = self._get_input_masks( 1078 inputs, input_list, args, kwargs) 1079 if self._expects_mask_arg and mask_is_implicit: 1080 kwargs['mask'] = input_masks 1081 mask_arg_passed_by_framework = True 1082 1083 # If `training` argument is None or not explicitly passed, 1084 # propagate `training` value from this layer's calling layer. 1085 training_value = None 1086 training_arg_passed_by_framework = False 1087 # Priority 1: `training` was explicitly passed a non-None value. 1088 if self._call_arg_was_passed('training', args, kwargs): 1089 training_value = self._get_call_arg_value('training', args, kwargs) 1090 if not self._expects_training_arg: 1091 kwargs.pop('training') 1092 1093 if training_value is None: 1094 # Priority 2: `training` was passed to a parent layer. 1095 if call_context.training is not None: 1096 training_value = call_context.training 1097 # Priority 3: `learning_phase()` has been set. 1098 elif backend.global_learning_phase_is_set(): 1099 training_value = backend.learning_phase() 1100 # Force the training_value to be bool type which matches to the contract 1101 # for layer/model call args. 1102 if tensor_util.is_tf_type(training_value): 1103 training_value = math_ops.cast(training_value, dtypes.bool) 1104 else: 1105 training_value = bool(training_value) 1106 # Priority 4: trace layer with the default training argument specified 1107 # in the `call` signature (or in inference mode if the `call` signature 1108 # specifies no non-None default). 1109 else: 1110 training_value = self._default_training_arg 1111 # In cases (2), (3), (4) the training argument is passed automatically 1112 # by the framework, and will not be hard-coded into the model. 1113 if self._expects_training_arg: 1114 args, kwargs = self._set_call_arg_value('training', training_value, 1115 args, kwargs) 1116 training_arg_passed_by_framework = True 1117 1118 with call_context.enter( 1119 layer=self, inputs=inputs, build_graph=True, training=training_value): 1120 # Check input assumptions set after layer building, e.g. input shape. 1121 outputs = self._keras_tensor_symbolic_call( 1122 inputs, input_masks, args, kwargs) 1123 1124 if outputs is None: 1125 raise ValueError('A layer\'s `call` method should return a ' 1126 'Tensor or a list of Tensors, not None ' 1127 '(layer: ' + self.name + ').') 1128 if training_arg_passed_by_framework: 1129 args, kwargs = self._set_call_arg_value( 1130 'training', None, args, kwargs, pop_kwarg_if_none=True) 1131 if mask_arg_passed_by_framework: 1132 kwargs.pop('mask') 1133 # Node connectivity does not special-case the first argument. 1134 outputs = self._set_connectivity_metadata((inputs,) + args, kwargs, 1135 outputs) 1136 return outputs 1137 1138 def _set_training_mode(self, args, kwargs, call_context): 1139 training_mode = None 1140 if self._expects_training_arg: 1141 # (1) `training` was passed to this `Layer.call`. 1142 if self._call_arg_was_passed('training', args, kwargs): 1143 training_mode = self._get_call_arg_value('training', args, kwargs) 1144 # If no `training` arg was passed, or `None` was explicitly passed, 1145 # the framework will make a decision about the training mode is. 1146 if training_mode is None: 1147 call_ctx_training = call_context.training 1148 # (2) `training` mode is inferred from an outer `Layer.call`. 1149 if call_ctx_training is not None: 1150 training_mode = call_ctx_training 1151 # (3) User set `tf.keras.backend.set_learning_phase`. 1152 elif backend.global_learning_phase_is_set(): 1153 training_mode = backend.learning_phase() 1154 # Ensure value is a `bool` or `tf.bool`. 1155 if isinstance(training_mode, bool): 1156 pass 1157 elif tensor_util.is_tf_type(training_mode): 1158 training_mode = math_ops.cast(training_mode, dtypes.bool) 1159 else: 1160 training_mode = bool(training_mode) 1161 # (4) We default to using `call`'s default value for `training`, 1162 # or treating the layer as if it is in inference if no non-None default 1163 # is specified in the `call` signature. 1164 else: 1165 training_mode = self._default_training_arg 1166 1167 # For case (2), (3), (4) `training` arg is passed by framework. 1168 args, kwargs = self._set_call_arg_value('training', training_mode, args, 1169 kwargs) 1170 else: 1171 if 'training' in kwargs: 1172 # `training` was passed to this `Layer` but is not needed for 1173 # `Layer.call`. It will set the default mode for inner `Layer.call`s. 1174 training_mode = kwargs.pop('training') 1175 else: 1176 # Grab the current `training` mode from any outer `Layer.call`. 1177 training_mode = call_context.training 1178 1179 return args, kwargs, training_mode 1180 1181 def _autographed_call(self): 1182 # Wrapping `call` function in autograph to allow for dynamic control 1183 # flow and control dependencies in call. We are limiting this to 1184 # subclassed layers as autograph is strictly needed only for 1185 # subclassed layers and models. 1186 # tf_convert will respect the value of autograph setting in the 1187 # enclosing tf.function, if any. 1188 if (base_layer_utils.is_subclassed(self) and 1189 not base_layer_utils.from_saved_model(self)): 1190 return autograph.tf_convert(self.call, ag_ctx.control_status_ctx()) 1191 else: 1192 return self.call 1193 1194 @property 1195 def dtype(self): 1196 """The dtype of the layer weights. 1197 1198 This is equivalent to `Layer.dtype_policy.variable_dtype`. Unless 1199 mixed precision is used, this is the same as `Layer.compute_dtype`, the 1200 dtype of the layer's computations. 1201 """ 1202 return self._dtype_policy.variable_dtype 1203 1204 @property 1205 def name(self): 1206 """Name of the layer (string), set in the constructor.""" 1207 return self._name 1208 1209 @property 1210 def supports_masking(self): 1211 """Whether this layer supports computing a mask using `compute_mask`.""" 1212 return self._supports_masking 1213 1214 @supports_masking.setter 1215 def supports_masking(self, value): 1216 self._supports_masking = value 1217 1218 @property 1219 def dynamic(self): 1220 """Whether the layer is dynamic (eager-only); set in the constructor.""" 1221 return any(layer._dynamic for layer in self._flatten_layers()) 1222 1223 @property 1224 @doc_controls.do_not_doc_inheritable 1225 def stateful(self): 1226 return any(layer._stateful for layer in self._flatten_layers()) 1227 1228 @stateful.setter 1229 def stateful(self, value): 1230 self._stateful = value 1231 1232 @property 1233 def trainable(self): 1234 return self._trainable 1235 1236 @trainable.setter 1237 def trainable(self, value): 1238 for layer in self._flatten_layers(): 1239 layer._trainable = value 1240 1241 @property 1242 def activity_regularizer(self): 1243 """Optional regularizer function for the output of this layer.""" 1244 return self._activity_regularizer 1245 1246 @activity_regularizer.setter 1247 def activity_regularizer(self, regularizer): 1248 """Optional regularizer function for the output of this layer.""" 1249 self._activity_regularizer = regularizer 1250 1251 @property 1252 def input_spec(self): 1253 """`InputSpec` instance(s) describing the input format for this layer. 1254 1255 When you create a layer subclass, you can set `self.input_spec` to enable 1256 the layer to run input compatibility checks when it is called. 1257 Consider a `Conv2D` layer: it can only be called on a single input tensor 1258 of rank 4. As such, you can set, in `__init__()`: 1259 1260 ```python 1261 self.input_spec = tf.keras.layers.InputSpec(ndim=4) 1262 ``` 1263 1264 Now, if you try to call the layer on an input that isn't rank 4 1265 (for instance, an input of shape `(2,)`, it will raise a nicely-formatted 1266 error: 1267 1268 ``` 1269 ValueError: Input 0 of layer conv2d is incompatible with the layer: 1270 expected ndim=4, found ndim=1. Full shape received: [2] 1271 ``` 1272 1273 Input checks that can be specified via `input_spec` include: 1274 - Structure (e.g. a single input, a list of 2 inputs, etc) 1275 - Shape 1276 - Rank (ndim) 1277 - Dtype 1278 1279 For more information, see `tf.keras.layers.InputSpec`. 1280 1281 Returns: 1282 A `tf.keras.layers.InputSpec` instance, or nested structure thereof. 1283 """ 1284 return self._input_spec 1285 1286 @input_spec.setter 1287 # Must be decorated to prevent tracking, since the input_spec can be nested 1288 # InputSpec objects. 1289 @trackable.no_automatic_dependency_tracking 1290 def input_spec(self, value): 1291 for v in nest.flatten(value): 1292 if v is not None and not isinstance(v, InputSpec): 1293 raise TypeError('Layer input_spec must be an instance of InputSpec. ' 1294 'Got: {}'.format(v)) 1295 self._input_spec = value 1296 1297 @property 1298 def trainable_weights(self): 1299 """List of all trainable weights tracked by this layer. 1300 1301 Trainable weights are updated via gradient descent during training. 1302 1303 Returns: 1304 A list of trainable variables. 1305 """ 1306 if self.trainable: 1307 children_weights = self._gather_children_attribute('trainable_variables') 1308 return self._dedup_weights(self._trainable_weights + children_weights) 1309 else: 1310 return [] 1311 1312 @property 1313 def non_trainable_weights(self): 1314 """List of all non-trainable weights tracked by this layer. 1315 1316 Non-trainable weights are *not* updated during training. They are expected 1317 to be updated manually in `call()`. 1318 1319 Returns: 1320 A list of non-trainable variables. 1321 """ 1322 if self.trainable: 1323 children_weights = self._gather_children_attribute( 1324 'non_trainable_variables') 1325 non_trainable_weights = self._non_trainable_weights + children_weights 1326 else: 1327 children_weights = self._gather_children_attribute('variables') 1328 non_trainable_weights = ( 1329 self._trainable_weights + self._non_trainable_weights + 1330 children_weights) 1331 return self._dedup_weights(non_trainable_weights) 1332 1333 @property 1334 def weights(self): 1335 """Returns the list of all layer variables/weights. 1336 1337 Returns: 1338 A list of variables. 1339 """ 1340 return self.trainable_weights + self.non_trainable_weights 1341 1342 @property 1343 @doc_controls.do_not_generate_docs 1344 def updates(self): 1345 warnings.warn('`layer.updates` will be removed in a future version. ' 1346 'This property should not be used in TensorFlow 2.0, ' 1347 'as `updates` are applied automatically.') 1348 return [] 1349 1350 @property 1351 def losses(self): 1352 """List of losses added using the `add_loss()` API. 1353 1354 Variable regularization tensors are created when this property is accessed, 1355 so it is eager safe: accessing `losses` under a `tf.GradientTape` will 1356 propagate gradients back to the corresponding variables. 1357 1358 Examples: 1359 1360 >>> class MyLayer(tf.keras.layers.Layer): 1361 ... def call(self, inputs): 1362 ... self.add_loss(tf.abs(tf.reduce_mean(inputs))) 1363 ... return inputs 1364 >>> l = MyLayer() 1365 >>> l(np.ones((10, 1))) 1366 >>> l.losses 1367 [1.0] 1368 1369 >>> inputs = tf.keras.Input(shape=(10,)) 1370 >>> x = tf.keras.layers.Dense(10)(inputs) 1371 >>> outputs = tf.keras.layers.Dense(1)(x) 1372 >>> model = tf.keras.Model(inputs, outputs) 1373 >>> # Activity regularization. 1374 >>> len(model.losses) 1375 0 1376 >>> model.add_loss(tf.abs(tf.reduce_mean(x))) 1377 >>> len(model.losses) 1378 1 1379 1380 >>> inputs = tf.keras.Input(shape=(10,)) 1381 >>> d = tf.keras.layers.Dense(10, kernel_initializer='ones') 1382 >>> x = d(inputs) 1383 >>> outputs = tf.keras.layers.Dense(1)(x) 1384 >>> model = tf.keras.Model(inputs, outputs) 1385 >>> # Weight regularization. 1386 >>> model.add_loss(lambda: tf.reduce_mean(d.kernel)) 1387 >>> model.losses 1388 [<tf.Tensor: shape=(), dtype=float32, numpy=1.0>] 1389 1390 Returns: 1391 A list of tensors. 1392 """ 1393 collected_losses = [] 1394 for layer in self._flatten_layers(): 1395 # If any eager losses are present, we assume the model to be part of an 1396 # eager training loop (either a custom one or the one used when 1397 # `run_eagerly=True`) and so we always return just the eager losses. 1398 if layer._eager_losses: 1399 # Filter placeholder losses that may have been added by revived layers. 1400 # (see base_layer_utils for details). 1401 if (layer._eager_losses[0] is 1402 not base_layer_utils.REVIVED_LOSS_PLACEHOLDER): 1403 collected_losses.extend(layer._eager_losses) 1404 else: 1405 collected_losses.extend(layer._losses) 1406 for regularizer in layer._callable_losses: 1407 loss_tensor = regularizer() 1408 if loss_tensor is not None: 1409 collected_losses.append(loss_tensor) 1410 return collected_losses 1411 1412 def add_loss(self, losses, **kwargs): 1413 """Add loss tensor(s), potentially dependent on layer inputs. 1414 1415 Some losses (for instance, activity regularization losses) may be dependent 1416 on the inputs passed when calling a layer. Hence, when reusing the same 1417 layer on different inputs `a` and `b`, some entries in `layer.losses` may 1418 be dependent on `a` and some on `b`. This method automatically keeps track 1419 of dependencies. 1420 1421 This method can be used inside a subclassed layer or model's `call` 1422 function, in which case `losses` should be a Tensor or list of Tensors. 1423 1424 Example: 1425 1426 ```python 1427 class MyLayer(tf.keras.layers.Layer): 1428 def call(self, inputs): 1429 self.add_loss(tf.abs(tf.reduce_mean(inputs))) 1430 return inputs 1431 ``` 1432 1433 This method can also be called directly on a Functional Model during 1434 construction. In this case, any loss Tensors passed to this Model must 1435 be symbolic and be able to be traced back to the model's `Input`s. These 1436 losses become part of the model's topology and are tracked in `get_config`. 1437 1438 Example: 1439 1440 ```python 1441 inputs = tf.keras.Input(shape=(10,)) 1442 x = tf.keras.layers.Dense(10)(inputs) 1443 outputs = tf.keras.layers.Dense(1)(x) 1444 model = tf.keras.Model(inputs, outputs) 1445 # Activity regularization. 1446 model.add_loss(tf.abs(tf.reduce_mean(x))) 1447 ``` 1448 1449 If this is not the case for your loss (if, for example, your loss references 1450 a `Variable` of one of the model's layers), you can wrap your loss in a 1451 zero-argument lambda. These losses are not tracked as part of the model's 1452 topology since they can't be serialized. 1453 1454 Example: 1455 1456 ```python 1457 inputs = tf.keras.Input(shape=(10,)) 1458 d = tf.keras.layers.Dense(10) 1459 x = d(inputs) 1460 outputs = tf.keras.layers.Dense(1)(x) 1461 model = tf.keras.Model(inputs, outputs) 1462 # Weight regularization. 1463 model.add_loss(lambda: tf.reduce_mean(d.kernel)) 1464 ``` 1465 1466 Args: 1467 losses: Loss tensor, or list/tuple of tensors. Rather than tensors, losses 1468 may also be zero-argument callables which create a loss tensor. 1469 **kwargs: Additional keyword arguments for backward compatibility. 1470 Accepted values: 1471 inputs - Deprecated, will be automatically inferred. 1472 """ 1473 kwargs.pop('inputs', None) 1474 if kwargs: 1475 raise TypeError('Unknown keyword arguments: %s' % (kwargs.keys(),)) 1476 1477 def _tag_callable(loss): 1478 """Tags callable loss tensor as `_unconditional_loss`.""" 1479 if callable(loss): 1480 # We run the loss without autocasting, as regularizers are often 1481 # numerically unstable in float16. 1482 with autocast_variable.enable_auto_cast_variables(None): 1483 loss = loss() 1484 if loss is None: 1485 return None # Will be filtered out when computing the .losses property 1486 if not tensor_util.is_tf_type(loss): 1487 loss = ops.convert_to_tensor_v2_with_dispatch( 1488 loss, dtype=backend.floatx()) 1489 loss._unconditional_loss = True # pylint: disable=protected-access 1490 return loss 1491 1492 losses = nest.flatten(losses) 1493 1494 callable_losses = [] 1495 eager_losses = [] 1496 symbolic_losses = [] 1497 for loss in losses: 1498 if callable(loss): 1499 callable_losses.append(functools.partial(_tag_callable, loss)) 1500 continue 1501 if loss is None: 1502 continue 1503 if not tensor_util.is_tf_type(loss) and not isinstance( 1504 loss, keras_tensor.KerasTensor): 1505 loss = ops.convert_to_tensor_v2_with_dispatch( 1506 loss, dtype=backend.floatx()) 1507 # TF Functions should take the eager path. 1508 if ((tf_utils.is_symbolic_tensor(loss) or 1509 isinstance(loss, keras_tensor.KerasTensor)) and 1510 not base_layer_utils.is_in_tf_function()): 1511 symbolic_losses.append(loss) 1512 elif tensor_util.is_tf_type(loss): 1513 eager_losses.append(loss) 1514 1515 self._callable_losses.extend(callable_losses) 1516 1517 in_call_context = base_layer_utils.call_context().in_call 1518 if eager_losses and not in_call_context: 1519 raise ValueError( 1520 'Expected a symbolic Tensors or a callable for the loss value. ' 1521 'Please wrap your loss computation in a zero argument `lambda`.') 1522 1523 self._eager_losses.extend(eager_losses) 1524 1525 for symbolic_loss in symbolic_losses: 1526 if getattr(self, '_is_graph_network', False): 1527 self._graph_network_add_loss(symbolic_loss) 1528 else: 1529 # Possible a loss was added in a Layer's `build`. 1530 self._losses.append(symbolic_loss) 1531 1532 def _clear_losses(self): 1533 """Used every step in eager to reset losses.""" 1534 # Set to thread local directly to avoid Layer.__setattr__ overhead. 1535 if not getattr(self, '_self_tracked_trackables', 1536 None): # Fast path for single Layer. 1537 self._thread_local._eager_losses = [] 1538 else: 1539 for layer in self._flatten_layers(): 1540 layer._thread_local._eager_losses = [] 1541 1542 @property 1543 def metrics(self): 1544 """List of metrics added using the `add_metric()` API. 1545 1546 Example: 1547 1548 >>> input = tf.keras.layers.Input(shape=(3,)) 1549 >>> d = tf.keras.layers.Dense(2) 1550 >>> output = d(input) 1551 >>> d.add_metric(tf.reduce_max(output), name='max') 1552 >>> d.add_metric(tf.reduce_min(output), name='min') 1553 >>> [m.name for m in d.metrics] 1554 ['max', 'min'] 1555 1556 Returns: 1557 A list of `Metric` objects. 1558 """ 1559 collected_metrics = [] 1560 for layer in self._flatten_layers(): 1561 with layer._metrics_lock: 1562 collected_metrics.extend(layer._metrics) 1563 return collected_metrics 1564 1565 def add_metric(self, value, name=None, **kwargs): 1566 """Adds metric tensor to the layer. 1567 1568 This method can be used inside the `call()` method of a subclassed layer 1569 or model. 1570 1571 ```python 1572 class MyMetricLayer(tf.keras.layers.Layer): 1573 def __init__(self): 1574 super(MyMetricLayer, self).__init__(name='my_metric_layer') 1575 self.mean = tf.keras.metrics.Mean(name='metric_1') 1576 1577 def call(self, inputs): 1578 self.add_metric(self.mean(inputs)) 1579 self.add_metric(tf.reduce_sum(inputs), name='metric_2') 1580 return inputs 1581 ``` 1582 1583 This method can also be called directly on a Functional Model during 1584 construction. In this case, any tensor passed to this Model must 1585 be symbolic and be able to be traced back to the model's `Input`s. These 1586 metrics become part of the model's topology and are tracked when you 1587 save the model via `save()`. 1588 1589 ```python 1590 inputs = tf.keras.Input(shape=(10,)) 1591 x = tf.keras.layers.Dense(10)(inputs) 1592 outputs = tf.keras.layers.Dense(1)(x) 1593 model = tf.keras.Model(inputs, outputs) 1594 model.add_metric(math_ops.reduce_sum(x), name='metric_1') 1595 ``` 1596 1597 Note: Calling `add_metric()` with the result of a metric object on a 1598 Functional Model, as shown in the example below, is not supported. This is 1599 because we cannot trace the metric result tensor back to the model's inputs. 1600 1601 ```python 1602 inputs = tf.keras.Input(shape=(10,)) 1603 x = tf.keras.layers.Dense(10)(inputs) 1604 outputs = tf.keras.layers.Dense(1)(x) 1605 model = tf.keras.Model(inputs, outputs) 1606 model.add_metric(tf.keras.metrics.Mean()(x), name='metric_1') 1607 ``` 1608 1609 Args: 1610 value: Metric tensor. 1611 name: String metric name. 1612 **kwargs: Additional keyword arguments for backward compatibility. 1613 Accepted values: 1614 `aggregation` - When the `value` tensor provided is not the result of 1615 calling a `keras.Metric` instance, it will be aggregated by default 1616 using a `keras.Metric.Mean`. 1617 """ 1618 kwargs_keys = list(kwargs.keys()) 1619 if (len(kwargs_keys) > 1 or 1620 (len(kwargs_keys) == 1 and kwargs_keys[0] != 'aggregation')): 1621 raise TypeError('Unknown keyword arguments: ', str(kwargs.keys())) 1622 1623 from_metric_obj = hasattr(value, '_metric_obj') 1624 is_symbolic = isinstance(value, keras_tensor.KerasTensor) 1625 in_call_context = base_layer_utils.call_context().in_call 1626 1627 if name is None and not from_metric_obj: 1628 # Eg. `self.add_metric(math_ops.reduce_sum(x))` 1629 # In eager mode, we use metric name to lookup a metric. Without a name, 1630 # a new Mean metric wrapper will be created on every model/layer call. 1631 # So, we raise an error when no name is provided. 1632 # We will do the same for symbolic mode for consistency although a name 1633 # will be generated if no name is provided. 1634 1635 # We will not raise this error in the foll use case for the sake of 1636 # consistency as name in provided in the metric constructor. 1637 # mean = metrics.Mean(name='my_metric') 1638 # model.add_metric(mean(outputs)) 1639 raise ValueError('Please provide a name for your metric like ' 1640 '`self.add_metric(tf.reduce_sum(inputs), ' 1641 'name=\'mean_activation\')`') 1642 elif from_metric_obj: 1643 name = value._metric_obj.name 1644 1645 if not in_call_context and not is_symbolic: 1646 raise ValueError('Expected a symbolic Tensor for the metric value, ' 1647 'received: ' + str(value)) 1648 1649 # If a metric was added in a Layer's `call` or `build`. 1650 if in_call_context or not getattr(self, '_is_graph_network', False): 1651 # TF Function path should take the eager path. 1652 1653 # If the given metric is available in `metrics` list we just update state 1654 # on it, otherwise we create a new metric instance and 1655 # add it to the `metrics` list. 1656 metric_obj = getattr(value, '_metric_obj', None) 1657 # Tensors that come from a Metric object already updated the Metric state. 1658 should_update_state = not metric_obj 1659 name = metric_obj.name if metric_obj else name 1660 1661 with self._metrics_lock: 1662 match = self._get_existing_metric(name) 1663 if match: 1664 metric_obj = match 1665 elif metric_obj: 1666 self._metrics.append(metric_obj) 1667 else: 1668 # Build the metric object with the value's dtype if it defines one 1669 metric_obj = metrics_mod.Mean( 1670 name=name, dtype=getattr(value, 'dtype', None)) 1671 self._metrics.append(metric_obj) 1672 1673 if should_update_state: 1674 metric_obj(value) 1675 else: 1676 if from_metric_obj: 1677 raise ValueError('Using the result of calling a `Metric` object ' 1678 'when calling `add_metric` on a Functional ' 1679 'Model is not supported. Please pass the ' 1680 'Tensor to monitor directly.') 1681 1682 # Insert layers into the Keras Graph Network. 1683 aggregation = None if from_metric_obj else 'mean' 1684 self._graph_network_add_metric(value, aggregation, name) 1685 1686 @doc_controls.do_not_doc_inheritable 1687 def add_update(self, updates, inputs=None): 1688 """Add update op(s), potentially dependent on layer inputs. 1689 1690 Weight updates (for instance, the updates of the moving mean and variance 1691 in a BatchNormalization layer) may be dependent on the inputs passed 1692 when calling a layer. Hence, when reusing the same layer on 1693 different inputs `a` and `b`, some entries in `layer.updates` may be 1694 dependent on `a` and some on `b`. This method automatically keeps track 1695 of dependencies. 1696 1697 This call is ignored when eager execution is enabled (in that case, variable 1698 updates are run on the fly and thus do not need to be tracked for later 1699 execution). 1700 1701 Args: 1702 updates: Update op, or list/tuple of update ops, or zero-arg callable 1703 that returns an update op. A zero-arg callable should be passed in 1704 order to disable running the updates by setting `trainable=False` 1705 on this Layer, when executing in Eager mode. 1706 inputs: Deprecated, will be automatically inferred. 1707 """ 1708 if inputs is not None: 1709 tf_logging.warning( 1710 '`add_update` `inputs` kwarg has been deprecated. You no longer need ' 1711 'to pass a value to `inputs` as it is being automatically inferred.') 1712 call_context = base_layer_utils.call_context() 1713 # No need to run updates during Functional API construction. 1714 if call_context.in_keras_graph: 1715 return 1716 1717 # Callable updates are disabled by setting `trainable=False`. 1718 if not call_context.frozen: 1719 for update in nest.flatten(updates): 1720 if callable(update): 1721 update() # pylint: disable=not-callable 1722 1723 def set_weights(self, weights): 1724 """Sets the weights of the layer, from NumPy arrays. 1725 1726 The weights of a layer represent the state of the layer. This function 1727 sets the weight values from numpy arrays. The weight values should be 1728 passed in the order they are created by the layer. Note that the layer's 1729 weights must be instantiated before calling this function, by calling 1730 the layer. 1731 1732 For example, a `Dense` layer returns a list of two values: the kernel matrix 1733 and the bias vector. These can be used to set the weights of another 1734 `Dense` layer: 1735 1736 >>> layer_a = tf.keras.layers.Dense(1, 1737 ... kernel_initializer=tf.constant_initializer(1.)) 1738 >>> a_out = layer_a(tf.convert_to_tensor([[1., 2., 3.]])) 1739 >>> layer_a.get_weights() 1740 [array([[1.], 1741 [1.], 1742 [1.]], dtype=float32), array([0.], dtype=float32)] 1743 >>> layer_b = tf.keras.layers.Dense(1, 1744 ... kernel_initializer=tf.constant_initializer(2.)) 1745 >>> b_out = layer_b(tf.convert_to_tensor([[10., 20., 30.]])) 1746 >>> layer_b.get_weights() 1747 [array([[2.], 1748 [2.], 1749 [2.]], dtype=float32), array([0.], dtype=float32)] 1750 >>> layer_b.set_weights(layer_a.get_weights()) 1751 >>> layer_b.get_weights() 1752 [array([[1.], 1753 [1.], 1754 [1.]], dtype=float32), array([0.], dtype=float32)] 1755 1756 Args: 1757 weights: a list of NumPy arrays. The number 1758 of arrays and their shape must match 1759 number of the dimensions of the weights 1760 of the layer (i.e. it should match the 1761 output of `get_weights`). 1762 1763 Raises: 1764 ValueError: If the provided weights list does not match the 1765 layer's specifications. 1766 """ 1767 params = self.weights 1768 1769 expected_num_weights = 0 1770 for param in params: 1771 if isinstance(param, base_layer_utils.TrackableWeightHandler): 1772 expected_num_weights += param.num_tensors 1773 else: 1774 expected_num_weights += 1 1775 1776 if expected_num_weights != len(weights): 1777 raise ValueError( 1778 'You called `set_weights(weights)` on layer "%s" ' 1779 'with a weight list of length %s, but the layer was ' 1780 'expecting %s weights. Provided weights: %s...' % 1781 (self.name, len(weights), expected_num_weights, str(weights)[:50])) 1782 1783 weight_index = 0 1784 weight_value_tuples = [] 1785 for param in params: 1786 if isinstance(param, base_layer_utils.TrackableWeightHandler): 1787 num_tensors = param.num_tensors 1788 tensors = weights[weight_index:weight_index + num_tensors] 1789 param.set_weights(tensors) 1790 weight_index += num_tensors 1791 else: 1792 weight = weights[weight_index] 1793 weight_shape = weight.shape if hasattr(weight, 'shape') else () 1794 ref_shape = param.shape 1795 if not ref_shape.is_compatible_with(weight_shape): 1796 raise ValueError( 1797 'Layer weight shape %s not compatible with provided weight ' 1798 'shape %s' % (ref_shape, weight_shape)) 1799 weight_value_tuples.append((param, weight)) 1800 weight_index += 1 1801 1802 backend.batch_set_value(weight_value_tuples) 1803 1804 # Perform any layer defined finalization of the layer state. 1805 for layer in self._flatten_layers(): 1806 layer.finalize_state() 1807 1808 def get_weights(self): 1809 """Returns the current weights of the layer, as NumPy arrays. 1810 1811 The weights of a layer represent the state of the layer. This function 1812 returns both trainable and non-trainable weight values associated with this 1813 layer as a list of NumPy arrays, which can in turn be used to load state 1814 into similarly parameterized layers. 1815 1816 For example, a `Dense` layer returns a list of two values: the kernel matrix 1817 and the bias vector. These can be used to set the weights of another 1818 `Dense` layer: 1819 1820 >>> layer_a = tf.keras.layers.Dense(1, 1821 ... kernel_initializer=tf.constant_initializer(1.)) 1822 >>> a_out = layer_a(tf.convert_to_tensor([[1., 2., 3.]])) 1823 >>> layer_a.get_weights() 1824 [array([[1.], 1825 [1.], 1826 [1.]], dtype=float32), array([0.], dtype=float32)] 1827 >>> layer_b = tf.keras.layers.Dense(1, 1828 ... kernel_initializer=tf.constant_initializer(2.)) 1829 >>> b_out = layer_b(tf.convert_to_tensor([[10., 20., 30.]])) 1830 >>> layer_b.get_weights() 1831 [array([[2.], 1832 [2.], 1833 [2.]], dtype=float32), array([0.], dtype=float32)] 1834 >>> layer_b.set_weights(layer_a.get_weights()) 1835 >>> layer_b.get_weights() 1836 [array([[1.], 1837 [1.], 1838 [1.]], dtype=float32), array([0.], dtype=float32)] 1839 1840 Returns: 1841 Weights values as a list of NumPy arrays. 1842 """ 1843 weights = self.weights 1844 output_weights = [] 1845 for weight in weights: 1846 if isinstance(weight, base_layer_utils.TrackableWeightHandler): 1847 output_weights.extend(weight.get_tensors()) 1848 else: 1849 output_weights.append(weight) 1850 return backend.batch_get_value(output_weights) 1851 1852 @doc_controls.do_not_generate_docs 1853 def finalize_state(self): 1854 """Finalizes the layers state after updating layer weights. 1855 1856 This function can be subclassed in a layer and will be called after updating 1857 a layer weights. It can be overridden to finalize any additional layer state 1858 after a weight update. 1859 """ 1860 pass 1861 1862 @doc_controls.do_not_generate_docs 1863 def get_updates_for(self, inputs): 1864 """Deprecated, do NOT use! 1865 1866 Retrieves updates relevant to a specific set of inputs. 1867 1868 Args: 1869 inputs: Input tensor or list/tuple of input tensors. 1870 1871 Returns: 1872 List of update ops of the layer that depend on `inputs`. 1873 """ 1874 warnings.warn('`layer.get_updates_for` is deprecated and ' 1875 'will be removed in a future version. ' 1876 'Please use `layer.updates` method instead.') 1877 return self.updates 1878 1879 @doc_controls.do_not_generate_docs 1880 def get_losses_for(self, inputs): 1881 """Deprecated, do NOT use! 1882 1883 Retrieves losses relevant to a specific set of inputs. 1884 1885 Args: 1886 inputs: Input tensor or list/tuple of input tensors. 1887 1888 Returns: 1889 List of loss tensors of the layer that depend on `inputs`. 1890 """ 1891 warnings.warn('`layer.get_losses_for` is deprecated and ' 1892 'will be removed in a future version. ' 1893 'Please use `layer.losses` instead.') 1894 return self.losses 1895 1896 @doc_controls.do_not_doc_inheritable 1897 def get_input_mask_at(self, node_index): 1898 """Retrieves the input mask tensor(s) of a layer at a given node. 1899 1900 Args: 1901 node_index: Integer, index of the node 1902 from which to retrieve the attribute. 1903 E.g. `node_index=0` will correspond to the 1904 first time the layer was called. 1905 1906 Returns: 1907 A mask tensor 1908 (or list of tensors if the layer has multiple inputs). 1909 """ 1910 inputs = self.get_input_at(node_index) 1911 if isinstance(inputs, list): 1912 return [getattr(x, '_keras_mask', None) for x in inputs] 1913 else: 1914 return getattr(inputs, '_keras_mask', None) 1915 1916 @doc_controls.do_not_doc_inheritable 1917 def get_output_mask_at(self, node_index): 1918 """Retrieves the output mask tensor(s) of a layer at a given node. 1919 1920 Args: 1921 node_index: Integer, index of the node 1922 from which to retrieve the attribute. 1923 E.g. `node_index=0` will correspond to the 1924 first time the layer was called. 1925 1926 Returns: 1927 A mask tensor 1928 (or list of tensors if the layer has multiple outputs). 1929 """ 1930 output = self.get_output_at(node_index) 1931 if isinstance(output, list): 1932 return [getattr(x, '_keras_mask', None) for x in output] 1933 else: 1934 return getattr(output, '_keras_mask', None) 1935 1936 @property 1937 @doc_controls.do_not_doc_inheritable 1938 def input_mask(self): 1939 """Retrieves the input mask tensor(s) of a layer. 1940 1941 Only applicable if the layer has exactly one inbound node, 1942 i.e. if it is connected to one incoming layer. 1943 1944 Returns: 1945 Input mask tensor (potentially None) or list of input 1946 mask tensors. 1947 1948 Raises: 1949 AttributeError: if the layer is connected to 1950 more than one incoming layers. 1951 """ 1952 inputs = self.input 1953 if isinstance(inputs, list): 1954 return [getattr(x, '_keras_mask', None) for x in inputs] 1955 else: 1956 return getattr(inputs, '_keras_mask', None) 1957 1958 @property 1959 @doc_controls.do_not_doc_inheritable 1960 def output_mask(self): 1961 """Retrieves the output mask tensor(s) of a layer. 1962 1963 Only applicable if the layer has exactly one inbound node, 1964 i.e. if it is connected to one incoming layer. 1965 1966 Returns: 1967 Output mask tensor (potentially None) or list of output 1968 mask tensors. 1969 1970 Raises: 1971 AttributeError: if the layer is connected to 1972 more than one incoming layers. 1973 """ 1974 output = self.output 1975 if isinstance(output, list): 1976 return [getattr(x, '_keras_mask', None) for x in output] 1977 else: 1978 return getattr(output, '_keras_mask', None) 1979 1980 @doc_controls.do_not_doc_inheritable 1981 def get_input_shape_at(self, node_index): 1982 """Retrieves the input shape(s) of a layer at a given node. 1983 1984 Args: 1985 node_index: Integer, index of the node 1986 from which to retrieve the attribute. 1987 E.g. `node_index=0` will correspond to the 1988 first time the layer was called. 1989 1990 Returns: 1991 A shape tuple 1992 (or list of shape tuples if the layer has multiple inputs). 1993 1994 Raises: 1995 RuntimeError: If called in Eager mode. 1996 """ 1997 return self._get_node_attribute_at_index(node_index, 'input_shapes', 1998 'input shape') 1999 2000 @doc_controls.do_not_doc_inheritable 2001 def get_output_shape_at(self, node_index): 2002 """Retrieves the output shape(s) of a layer at a given node. 2003 2004 Args: 2005 node_index: Integer, index of the node 2006 from which to retrieve the attribute. 2007 E.g. `node_index=0` will correspond to the 2008 first time the layer was called. 2009 2010 Returns: 2011 A shape tuple 2012 (or list of shape tuples if the layer has multiple outputs). 2013 2014 Raises: 2015 RuntimeError: If called in Eager mode. 2016 """ 2017 return self._get_node_attribute_at_index(node_index, 'output_shapes', 2018 'output shape') 2019 2020 @doc_controls.do_not_doc_inheritable 2021 def get_input_at(self, node_index): 2022 """Retrieves the input tensor(s) of a layer at a given node. 2023 2024 Args: 2025 node_index: Integer, index of the node 2026 from which to retrieve the attribute. 2027 E.g. `node_index=0` will correspond to the 2028 first input node of the layer. 2029 2030 Returns: 2031 A tensor (or list of tensors if the layer has multiple inputs). 2032 2033 Raises: 2034 RuntimeError: If called in Eager mode. 2035 """ 2036 return self._get_node_attribute_at_index(node_index, 'input_tensors', 2037 'input') 2038 2039 @doc_controls.do_not_doc_inheritable 2040 def get_output_at(self, node_index): 2041 """Retrieves the output tensor(s) of a layer at a given node. 2042 2043 Args: 2044 node_index: Integer, index of the node 2045 from which to retrieve the attribute. 2046 E.g. `node_index=0` will correspond to the 2047 first output node of the layer. 2048 2049 Returns: 2050 A tensor (or list of tensors if the layer has multiple outputs). 2051 2052 Raises: 2053 RuntimeError: If called in Eager mode. 2054 """ 2055 return self._get_node_attribute_at_index(node_index, 'output_tensors', 2056 'output') 2057 2058 @property 2059 def input(self): 2060 """Retrieves the input tensor(s) of a layer. 2061 2062 Only applicable if the layer has exactly one input, 2063 i.e. if it is connected to one incoming layer. 2064 2065 Returns: 2066 Input tensor or list of input tensors. 2067 2068 Raises: 2069 RuntimeError: If called in Eager mode. 2070 AttributeError: If no inbound nodes are found. 2071 """ 2072 if not self._inbound_nodes: 2073 raise AttributeError('Layer ' + self.name + 2074 ' is not connected, no input to return.') 2075 return self._get_node_attribute_at_index(0, 'input_tensors', 'input') 2076 2077 @property 2078 def output(self): 2079 """Retrieves the output tensor(s) of a layer. 2080 2081 Only applicable if the layer has exactly one output, 2082 i.e. if it is connected to one incoming layer. 2083 2084 Returns: 2085 Output tensor or list of output tensors. 2086 2087 Raises: 2088 AttributeError: if the layer is connected to more than one incoming 2089 layers. 2090 RuntimeError: if called in Eager mode. 2091 """ 2092 if not self._inbound_nodes: 2093 raise AttributeError('Layer ' + self.name + ' has no inbound nodes.') 2094 return self._get_node_attribute_at_index(0, 'output_tensors', 'output') 2095 2096 @property 2097 @doc_controls.do_not_doc_inheritable 2098 def input_shape(self): 2099 """Retrieves the input shape(s) of a layer. 2100 2101 Only applicable if the layer has exactly one input, 2102 i.e. if it is connected to one incoming layer, or if all inputs 2103 have the same shape. 2104 2105 Returns: 2106 Input shape, as an integer shape tuple 2107 (or list of shape tuples, one tuple per input tensor). 2108 2109 Raises: 2110 AttributeError: if the layer has no defined input_shape. 2111 RuntimeError: if called in Eager mode. 2112 """ 2113 if not self._inbound_nodes: 2114 raise AttributeError('The layer has never been called ' 2115 'and thus has no defined input shape.') 2116 all_input_shapes = set( 2117 [str(node.input_shapes) for node in self._inbound_nodes]) 2118 if len(all_input_shapes) == 1: 2119 return self._inbound_nodes[0].input_shapes 2120 else: 2121 raise AttributeError('The layer "' + str(self.name) + 2122 ' has multiple inbound nodes, ' 2123 'with different input shapes. Hence ' 2124 'the notion of "input shape" is ' 2125 'ill-defined for the layer. ' 2126 'Use `get_input_shape_at(node_index)` ' 2127 'instead.') 2128 2129 def count_params(self): 2130 """Count the total number of scalars composing the weights. 2131 2132 Returns: 2133 An integer count. 2134 2135 Raises: 2136 ValueError: if the layer isn't yet built 2137 (in which case its weights aren't yet defined). 2138 """ 2139 if not self.built: 2140 if getattr(self, '_is_graph_network', False): 2141 with tf_utils.maybe_init_scope(self): 2142 self._maybe_build(self.inputs) 2143 else: 2144 raise ValueError('You tried to call `count_params` on ' + self.name + 2145 ', but the layer isn\'t built. ' 2146 'You can build it manually via: `' + self.name + 2147 '.build(batch_input_shape)`.') 2148 return layer_utils.count_params(self.weights) 2149 2150 @property 2151 @doc_controls.do_not_doc_inheritable 2152 def output_shape(self): 2153 """Retrieves the output shape(s) of a layer. 2154 2155 Only applicable if the layer has one output, 2156 or if all outputs have the same shape. 2157 2158 Returns: 2159 Output shape, as an integer shape tuple 2160 (or list of shape tuples, one tuple per output tensor). 2161 2162 Raises: 2163 AttributeError: if the layer has no defined output shape. 2164 RuntimeError: if called in Eager mode. 2165 """ 2166 if not self._inbound_nodes: 2167 raise AttributeError('The layer has never been called ' 2168 'and thus has no defined output shape.') 2169 all_output_shapes = set( 2170 [str(node.output_shapes) for node in self._inbound_nodes]) 2171 if len(all_output_shapes) == 1: 2172 return self._inbound_nodes[0].output_shapes 2173 else: 2174 raise AttributeError('The layer "%s"' 2175 ' has multiple inbound nodes, ' 2176 'with different output shapes. Hence ' 2177 'the notion of "output shape" is ' 2178 'ill-defined for the layer. ' 2179 'Use `get_output_shape_at(node_index)` ' 2180 'instead.' % self.name) 2181 2182 @property 2183 @doc_controls.do_not_doc_inheritable 2184 def inbound_nodes(self): 2185 """Deprecated, do NOT use! Only for compatibility with external Keras.""" 2186 return self._inbound_nodes 2187 2188 @property 2189 @doc_controls.do_not_doc_inheritable 2190 def outbound_nodes(self): 2191 """Deprecated, do NOT use! Only for compatibility with external Keras.""" 2192 return self._outbound_nodes 2193 2194 ############################################################################## 2195 # Methods & attributes below are public aliases of other methods. # 2196 ############################################################################## 2197 2198 @doc_controls.do_not_doc_inheritable 2199 def apply(self, inputs, *args, **kwargs): 2200 """Deprecated, do NOT use! 2201 2202 This is an alias of `self.__call__`. 2203 2204 Args: 2205 inputs: Input tensor(s). 2206 *args: additional positional arguments to be passed to `self.call`. 2207 **kwargs: additional keyword arguments to be passed to `self.call`. 2208 2209 Returns: 2210 Output tensor(s). 2211 """ 2212 warnings.warn('`layer.apply` is deprecated and ' 2213 'will be removed in a future version. ' 2214 'Please use `layer.__call__` method instead.') 2215 return self.__call__(inputs, *args, **kwargs) 2216 2217 @doc_controls.do_not_doc_inheritable 2218 def add_variable(self, *args, **kwargs): 2219 """Deprecated, do NOT use! Alias for `add_weight`.""" 2220 warnings.warn('`layer.add_variable` is deprecated and ' 2221 'will be removed in a future version. ' 2222 'Please use `layer.add_weight` method instead.') 2223 return self.add_weight(*args, **kwargs) 2224 2225 @property 2226 @doc_controls.do_not_generate_docs 2227 def variables(self): 2228 """Returns the list of all layer variables/weights. 2229 2230 Alias of `self.weights`. 2231 2232 Note: This will not track the weights of nested `tf.Modules` that are not 2233 themselves Keras layers. 2234 2235 Returns: 2236 A list of variables. 2237 """ 2238 return self.weights 2239 2240 @property 2241 @doc_controls.do_not_generate_docs 2242 def trainable_variables(self): 2243 return self.trainable_weights 2244 2245 @property 2246 @doc_controls.do_not_generate_docs 2247 def non_trainable_variables(self): 2248 return self.non_trainable_weights 2249 2250 ############################################################################## 2251 # Methods & attributes below are all private and only used by the framework. # 2252 ############################################################################## 2253 2254 @property 2255 def _inbound_nodes(self): 2256 return self._inbound_nodes_value 2257 2258 @_inbound_nodes.setter 2259 @trackable.no_automatic_dependency_tracking 2260 def _inbound_nodes(self, value): 2261 self._inbound_nodes_value = value 2262 2263 @property 2264 def _outbound_nodes(self): 2265 return self._outbound_nodes_value 2266 2267 @_outbound_nodes.setter 2268 @trackable.no_automatic_dependency_tracking 2269 def _outbound_nodes(self, value): 2270 self._outbound_nodes_value = value 2271 2272 def _set_dtype_policy(self, dtype): 2273 """Sets self._dtype_policy.""" 2274 if isinstance(dtype, policy.Policy): 2275 self._dtype_policy = dtype 2276 elif isinstance(dtype, dict): 2277 self._dtype_policy = policy.deserialize(dtype) 2278 elif isinstance(dtype, str) and dtype in ('mixed_float16', 2279 'mixed_bfloat16'): 2280 # The isinstance check is required since np.dtype raises an error if 2281 # compared to a non-dtype string. 2282 self._dtype_policy = policy.Policy(dtype) 2283 elif dtype: 2284 self._dtype_policy = policy.Policy(dtypes.as_dtype(dtype).name) 2285 else: 2286 self._dtype_policy = policy.global_policy() 2287 if (self._dtype_policy.name == 'mixed_float16' and 2288 not loss_scale_optimizer.strategy_supports_loss_scaling()): 2289 # Although only loss scaling doesn't support certain strategies, to avoid 2290 # confusion, we disallow the 'mixed_float16' policy with unsupported 2291 # strategies. This is because 'mixed_float16' requires loss scaling for 2292 # numeric stability. 2293 strategy = ds_context.get_strategy() 2294 raise ValueError('Mixed precision is not supported with the ' 2295 'tf.distribute.Strategy: %s. Either stop using mixed ' 2296 'precision by removing the use of the "%s" policy or ' 2297 'use a different Strategy, e.g. a MirroredStrategy.' % 2298 (strategy.__class__.__name__, self._dtype_policy.name)) 2299 2300 # Performance optimization: cache the compute dtype as a Dtype object or 2301 # None, so that str to Dtype conversion doesn't happen in Layer.__call__. 2302 # TODO(b/157486353): Investigate returning DTypes in Policy. 2303 if self._dtype_policy.compute_dtype: 2304 self._compute_dtype_object = dtypes.as_dtype( 2305 self._dtype_policy.compute_dtype) 2306 else: 2307 self._compute_dtype_object = None 2308 2309 @property 2310 def dtype_policy(self): 2311 """The dtype policy associated with this layer. 2312 2313 This is an instance of a `tf.keras.mixed_precision.Policy`. 2314 """ 2315 return self._dtype_policy 2316 2317 @property 2318 def compute_dtype(self): 2319 """The dtype of the layer's computations. 2320 2321 This is equivalent to `Layer.dtype_policy.compute_dtype`. Unless 2322 mixed precision is used, this is the same as `Layer.dtype`, the dtype of 2323 the weights. 2324 2325 Layers automatically cast their inputs to the compute dtype, which causes 2326 computations and the output to be in the compute dtype as well. This is done 2327 by the base Layer class in `Layer.__call__`, so you do not have to insert 2328 these casts if implementing your own layer. 2329 2330 Layers often perform certain internal computations in higher precision when 2331 `compute_dtype` is float16 or bfloat16 for numeric stability. The output 2332 will still typically be float16 or bfloat16 in such cases. 2333 2334 Returns: 2335 The layer's compute dtype. 2336 """ 2337 return self._dtype_policy.compute_dtype 2338 2339 @property 2340 def _compute_dtype(self): 2341 """Deprecated alias of `compute_dtype`.""" 2342 return self._dtype_policy.compute_dtype 2343 2344 @property 2345 def variable_dtype(self): 2346 """Alias of `Layer.dtype`, the dtype of the weights.""" 2347 return self.dtype 2348 2349 def _maybe_cast_inputs(self, inputs, input_list=None): 2350 """Maybe casts the inputs to the compute dtype. 2351 2352 If self._compute_dtype is floating-point, and self_autocast is True, 2353 floating-point inputs are casted to self._compute_dtype. 2354 2355 Args: 2356 inputs: Input tensor, or structure of input tensors. 2357 input_list: Flat list of input tensors. 2358 2359 Returns: 2360 `inputs`, but tensors may have been casted to self._compute_dtype 2361 """ 2362 if not input_list: 2363 input_list = nest.flatten(inputs) 2364 2365 compute_dtype_object = self._compute_dtype_object 2366 should_autocast = ( 2367 self._autocast and compute_dtype_object and 2368 compute_dtype_object.is_floating) 2369 2370 if (should_autocast and 2371 any(map(self._should_cast_single_input, input_list))): 2372 # Only perform expensive `nest` operation when needed. 2373 return nest.map_structure(self._cast_single_input, inputs) 2374 else: 2375 return inputs 2376 2377 def _should_cast_single_input(self, x): 2378 if isinstance(x, _AUTOCAST_TYPES): 2379 return (self._compute_dtype_object and 2380 x.dtype != self._compute_dtype_object and x.dtype.is_floating) 2381 return False 2382 2383 def _cast_single_input(self, x): 2384 """Cast a single Tensor or TensorSpec to the compute dtype.""" 2385 if self._should_cast_single_input(x): 2386 return math_ops.cast(x, self._compute_dtype_object) 2387 else: 2388 return x 2389 2390 # _dtype used to be an attribute set in the constructor. We still expose it 2391 # because some clients still use it. 2392 # TODO(reedwm): Deprecate, then remove the _dtype property. 2393 @property 2394 def _dtype(self): 2395 # This is equivalent to returning self.dtype . We do not return self.dtype 2396 # as it would cause infinite recursion in a few subclasses, which override 2397 # "dtype" to return self._dtype. 2398 return self._dtype_policy.variable_dtype 2399 2400 @_dtype.setter 2401 def _dtype(self, value): 2402 value = dtypes.as_dtype(value).name 2403 self._set_dtype_policy(policy.Policy(value)) 2404 2405 def _name_scope(self): # pylint: disable=method-hidden 2406 if not tf2.enabled(): 2407 return self.name 2408 name_scope = self.name 2409 current_name_scope = ops.get_name_scope() 2410 if current_name_scope: 2411 name_scope = current_name_scope + '/' + name_scope 2412 if name_scope: 2413 # Note that the trailing `/` prevents autogenerated 2414 # numerical suffixes to get appended. It will also fully reset 2415 # nested name scope (i.e. the outer name scope has no effect). 2416 name_scope += '/' 2417 return name_scope 2418 2419 def _init_set_name(self, name, zero_based=True): 2420 if not name: 2421 self._name = backend.unique_object_name( 2422 generic_utils.to_snake_case(self.__class__.__name__), 2423 zero_based=zero_based) 2424 else: 2425 backend.observe_object_name(name) 2426 self._name = name 2427 2428 def _get_existing_metric(self, name=None): 2429 match = [m for m in self._metrics if m.name == name] 2430 if not match: 2431 return 2432 if len(match) > 1: 2433 raise ValueError( 2434 'Please provide different names for the metrics you have added. ' 2435 'We found {} metrics with the name: "{}"'.format(len(match), name)) 2436 return match[0] 2437 2438 def _handle_weight_regularization(self, name, variable, regularizer): 2439 """Create lambdas which compute regularization losses.""" 2440 2441 def _loss_for_variable(v): 2442 """Creates a regularization loss `Tensor` for variable `v`.""" 2443 with backend.name_scope(name + '/Regularizer'): 2444 regularization = regularizer(v) 2445 return regularization 2446 2447 if base_layer_utils.is_split_variable(variable): 2448 for v in variable: 2449 self.add_loss(functools.partial(_loss_for_variable, v)) 2450 else: 2451 self.add_loss(functools.partial(_loss_for_variable, variable)) 2452 2453 def _handle_activity_regularization(self, inputs, outputs): 2454 # Apply activity regularization. 2455 # Note that it should be applied every time the layer creates a new 2456 # output, since it is output-specific. 2457 if self._activity_regularizer: 2458 output_list = nest.flatten(outputs) 2459 with backend.name_scope('ActivityRegularizer'): 2460 for output in output_list: 2461 activity_loss = self._activity_regularizer(output) 2462 batch_size = math_ops.cast( 2463 array_ops.shape(output)[0], activity_loss.dtype) 2464 # Make activity regularization strength batch-agnostic. 2465 mean_activity_loss = activity_loss / batch_size 2466 self.add_loss(mean_activity_loss) 2467 2468 def _set_mask_metadata(self, inputs, outputs, previous_mask, build_graph): 2469 # Many `Layer`s don't need to call `compute_mask`. 2470 # This method is optimized to do as little work as needed for the common 2471 # case. 2472 if not self._supports_masking: 2473 return 2474 2475 flat_outputs = nest.flatten(outputs) 2476 2477 mask_already_computed = ( 2478 getattr(self, '_compute_output_and_mask_jointly', False) or 2479 all(getattr(x, '_keras_mask', None) is not None for x in flat_outputs)) 2480 if mask_already_computed: 2481 if build_graph: 2482 self._set_mask_keras_history_checked(flat_outputs) 2483 return 2484 2485 output_masks = self.compute_mask(inputs, previous_mask) 2486 if output_masks is None: 2487 return 2488 2489 flat_masks = nest.flatten(output_masks) 2490 for tensor, mask in zip(flat_outputs, flat_masks): 2491 try: 2492 tensor._keras_mask = mask 2493 except AttributeError: 2494 # C Type such as np.ndarray. 2495 pass 2496 2497 if build_graph: 2498 self._set_mask_keras_history_checked(flat_outputs) 2499 2500 def _set_mask_keras_history_checked(self, flat_outputs): 2501 for output in flat_outputs: 2502 if getattr(output, '_keras_mask', None) is not None: 2503 # Do not track masks for `TensorFlowOpLayer` construction. 2504 output._keras_mask._keras_history_checked = True 2505 2506 def _get_input_masks(self, inputs, input_list, args, kwargs): 2507 if not self._supports_masking and not self._expects_mask_arg: 2508 # Input masks only need to be retrieved if they are needed for `call` 2509 # or `compute_mask`. 2510 input_masks = None 2511 implicit_mask = False 2512 elif self._call_arg_was_passed('mask', args, kwargs): 2513 input_masks = self._get_call_arg_value('mask', args, kwargs) 2514 implicit_mask = False 2515 else: 2516 input_masks = [getattr(t, '_keras_mask', None) for t in input_list] 2517 if all(mask is None for mask in input_masks): 2518 input_masks = None 2519 implicit_mask = False 2520 else: 2521 # Only do expensive `nest` op when masking is actually being used. 2522 input_masks = nest.pack_sequence_as(inputs, input_masks) 2523 implicit_mask = True 2524 return input_masks, implicit_mask 2525 2526 def _call_arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False): 2527 # Performance optimization: do no work in most common case. 2528 if not args and not kwargs: 2529 return False 2530 2531 if arg_name in kwargs: 2532 return True 2533 call_fn_args = self._call_fn_args 2534 if not inputs_in_args: 2535 # Ignore `inputs` arg. 2536 call_fn_args = call_fn_args[1:] 2537 return arg_name in dict(zip(call_fn_args, args)) 2538 2539 def _get_call_arg_value(self, arg_name, args, kwargs, inputs_in_args=False): 2540 if arg_name in kwargs: 2541 return kwargs[arg_name] 2542 call_fn_args = self._call_fn_args 2543 if not inputs_in_args: 2544 # Ignore `inputs` arg. 2545 call_fn_args = call_fn_args[1:] 2546 args_dict = dict(zip(call_fn_args, args)) 2547 return args_dict[arg_name] 2548 2549 def _set_call_arg_value( 2550 self, arg_name, new_value, args, 2551 kwargs, inputs_in_args=False, pop_kwarg_if_none=False): 2552 arg_pos = self._call_fn_arg_positions.get(arg_name, None) 2553 if arg_pos is not None: 2554 if not inputs_in_args: 2555 # Ignore `inputs` arg. 2556 arg_pos = arg_pos - 1 2557 if len(args) > arg_pos: 2558 args = list(args) 2559 args[arg_pos] = new_value 2560 return tuple(args), kwargs 2561 if new_value is None and pop_kwarg_if_none: 2562 kwargs.pop(arg_name, None) 2563 else: 2564 kwargs[arg_name] = new_value 2565 return args, kwargs 2566 2567 def _set_connectivity_metadata(self, args, kwargs, outputs): 2568 # If the layer returns tensors from its inputs unmodified, 2569 # we copy them to avoid loss of KerasHistory metadata. 2570 flat_outputs = nest.flatten(outputs) 2571 flat_inputs = nest.flatten((args, kwargs)) 2572 input_ids_set = {id(i) for i in flat_inputs} 2573 outputs_copy = [] 2574 for x in flat_outputs: 2575 if id(x) in input_ids_set: 2576 with backend.name_scope(self.name): 2577 x = array_ops.identity(x) 2578 outputs_copy.append(x) 2579 outputs = nest.pack_sequence_as(outputs, outputs_copy) 2580 2581 # Create node, Node wires itself to inbound and outbound layers. 2582 # The Node constructor actually updates this layer's self._inbound_nodes, 2583 # sets _keras_history on the outputs, and adds itself to the 2584 # `_outbound_nodes` of the layers that produced the inputs to this 2585 # layer call. 2586 node_module.Node(self, call_args=args, call_kwargs=kwargs, outputs=outputs) 2587 return outputs 2588 2589 def _get_node_attribute_at_index(self, node_index, attr, attr_name): 2590 """Private utility to retrieves an attribute (e.g. inputs) from a node. 2591 2592 This is used to implement the methods: 2593 - get_input_shape_at 2594 - get_output_shape_at 2595 - get_input_at 2596 etc... 2597 2598 Args: 2599 node_index: Integer index of the node from which 2600 to retrieve the attribute. 2601 attr: Exact node attribute name. 2602 attr_name: Human-readable attribute name, for error messages. 2603 2604 Returns: 2605 The layer's attribute `attr` at the node of index `node_index`. 2606 2607 Raises: 2608 RuntimeError: If the layer has no inbound nodes, or if called in Eager 2609 mode. 2610 ValueError: If the index provided does not match any node. 2611 """ 2612 if not self._inbound_nodes: 2613 raise RuntimeError('The layer has never been called ' 2614 'and thus has no defined ' + attr_name + '.') 2615 if not len(self._inbound_nodes) > node_index: 2616 raise ValueError('Asked to get ' + attr_name + ' at node ' + 2617 str(node_index) + ', but the layer has only ' + 2618 str(len(self._inbound_nodes)) + ' inbound nodes.') 2619 values = getattr(self._inbound_nodes[node_index], attr) 2620 if isinstance(values, list) and len(values) == 1: 2621 return values[0] 2622 else: 2623 return values 2624 2625 def _maybe_build(self, inputs): 2626 # Check input assumptions set before layer building, e.g. input rank. 2627 if not self.built: 2628 input_spec.assert_input_compatibility( 2629 self.input_spec, inputs, self.name) 2630 input_list = nest.flatten(inputs) 2631 if input_list and self._dtype_policy.compute_dtype is None: 2632 try: 2633 dtype = input_list[0].dtype.base_dtype.name 2634 except AttributeError: 2635 pass 2636 else: 2637 self._set_dtype_policy(policy.Policy(dtype)) 2638 input_shapes = None 2639 # Converts Tensors / CompositeTensors to TensorShapes. 2640 if all(hasattr(x, 'shape') for x in input_list): 2641 input_shapes = tf_utils.get_shapes(inputs) 2642 else: 2643 # Converts input shape to TensorShapes. 2644 try: 2645 input_shapes = tf_utils.convert_shapes(inputs, to_tuples=False) 2646 except ValueError: 2647 pass 2648 # Only call `build` if the user has manually overridden the build method. 2649 if not hasattr(self.build, '_is_default'): 2650 # Any setup work performed only once should happen in an `init_scope` 2651 # to avoid creating symbolic Tensors that will later pollute any eager 2652 # operations. 2653 with tf_utils.maybe_init_scope(self): 2654 self.build(input_shapes) # pylint:disable=not-callable 2655 # We must set also ensure that the layer is marked as built, and the build 2656 # shape is stored since user defined build functions may not be calling 2657 # `super.build()` 2658 Layer.build(self, input_shapes) 2659 2660 # Optionally load weight values specified at layer instantiation. 2661 if self._initial_weights is not None: 2662 with ops.init_scope(): 2663 # Using `init_scope` since we want variable assignment in 2664 # `set_weights` to be treated like variable initialization. 2665 self.set_weights(self._initial_weights) 2666 self._initial_weights = None 2667 2668 def _symbolic_call(self, inputs): 2669 input_shapes = nest.map_structure(lambda x: x.shape, inputs) 2670 output_shapes = self.compute_output_shape(input_shapes) 2671 # Convert to TensorShape so that nest.map_structure will not map into 2672 # individual dim of the shape. 2673 output_shapes = tf_utils.convert_shapes(output_shapes, to_tuples=False) 2674 2675 def _make_placeholder_like(shape): 2676 ph = backend.placeholder(shape=shape, dtype=self.dtype) 2677 ph._keras_mask = None 2678 return ph 2679 return nest.map_structure(_make_placeholder_like, output_shapes) 2680 2681 def _get_trainable_state(self): 2682 """Get the `trainable` state of each sublayer. 2683 2684 Returns: 2685 A dict mapping all sublayers to their `trainable` value. 2686 """ 2687 trainable_state = weakref.WeakKeyDictionary() 2688 for layer in self._flatten_layers(): 2689 trainable_state[layer] = layer.trainable 2690 return trainable_state 2691 2692 def _set_trainable_state(self, trainable_state): 2693 """Set `trainable` state for each sublayer.""" 2694 for layer in self._flatten_layers(): 2695 if layer in trainable_state: 2696 layer.trainable = trainable_state[layer] 2697 2698 @property 2699 def _obj_reference_counts(self): 2700 """A dictionary counting the number of attributes referencing an object.""" 2701 self._maybe_create_attribute('_obj_reference_counts_dict', 2702 object_identity.ObjectIdentityDictionary()) 2703 return self._obj_reference_counts_dict 2704 2705 @trackable.no_automatic_dependency_tracking 2706 def _maybe_create_attribute(self, name, default_value): 2707 """Create the attribute with the default value if it hasn't been created. 2708 2709 This is useful for fields that is used for tracking purpose, 2710 _trainable_weights, or _layers. Note that user could create a layer subclass 2711 and assign an internal field before invoking the Layer.__init__(), the 2712 __setattr__() need to create the tracking fields and __init__() need to not 2713 override them. 2714 2715 Args: 2716 name: String, the name of the attribute. 2717 default_value: Object, the default value of the attribute. 2718 """ 2719 if not hasattr(self, name): 2720 self.__setattr__(name, default_value) 2721 2722 def __delattr__(self, name): 2723 # For any super.__delattr__() call, we will directly use the implementation 2724 # in Trackable and skip the behavior in AutoTrackable. The Layer was 2725 # originally use Trackable as base class, the change of using Module as base 2726 # class forced us to have AutoTrackable in the class hierarchy. 2727 # 2728 # TODO(b/180760306) Keeping the status quo of skipping _delattr__ and 2729 # __setattr__ in AutoTrackable may be unsustainable. 2730 existing_value = getattr(self, name, None) 2731 2732 # If this value is replacing an existing object assigned to an attribute, we 2733 # should clean it out to avoid leaking memory. First we check if there are 2734 # other attributes referencing it. 2735 reference_counts = self._obj_reference_counts 2736 if existing_value not in reference_counts: 2737 super(autotrackable.AutoTrackable, self).__delattr__(name) # pylint: disable=bad-super-call 2738 return 2739 2740 reference_count = reference_counts[existing_value] 2741 if reference_count > 1: 2742 # There are other remaining references. We can't remove this object from 2743 # _layers etc. 2744 reference_counts[existing_value] = reference_count - 1 2745 super(autotrackable.AutoTrackable, self).__delattr__(name) # pylint: disable=bad-super-call 2746 return 2747 else: 2748 # This is the last remaining reference. 2749 del reference_counts[existing_value] 2750 2751 super(autotrackable.AutoTrackable, self).__delattr__(name) # pylint: disable=bad-super-call 2752 2753 if (isinstance(existing_value, Layer) 2754 or base_layer_utils.has_weights(existing_value)): 2755 super(autotrackable.AutoTrackable, self).__setattr__( # pylint: disable=bad-super-call 2756 '_self_tracked_trackables', 2757 [l for l in self._self_tracked_trackables if l is not existing_value]) 2758 if isinstance(existing_value, tf_variables.Variable): 2759 super(autotrackable.AutoTrackable, self).__setattr__( # pylint: disable=bad-super-call 2760 '_trainable_weights', 2761 [w for w in self._trainable_weights if w is not existing_value]) 2762 super(autotrackable.AutoTrackable, self).__setattr__( # pylint: disable=bad-super-call 2763 '_non_trainable_weights', 2764 [w for w in self._non_trainable_weights if w is not existing_value]) 2765 2766 def __setattr__(self, name, value): 2767 if (name == '_self_setattr_tracking' or 2768 not getattr(self, '_self_setattr_tracking', True) or 2769 # Exclude @property.setters from tracking 2770 hasattr(self.__class__, name)): 2771 try: 2772 super(autotrackable.AutoTrackable, self).__setattr__(name, value) # pylint: disable=bad-super-call 2773 except AttributeError: 2774 raise AttributeError( 2775 ('Can\'t set the attribute "{}", likely because it conflicts with ' 2776 'an existing read-only @property of the object. Please choose a ' 2777 'different name.').format(name)) 2778 return 2779 2780 # Wraps data structures in `Trackable`, unwraps `NoDependency` objects. 2781 value = data_structures.sticky_attribute_assignment( 2782 trackable=self, value=value, name=name) 2783 2784 reference_counts = self._obj_reference_counts 2785 reference_counts[value] = reference_counts.get(value, 0) + 1 2786 2787 # Clean out the old attribute, which clears _layers and _trainable_weights 2788 # if necessary. 2789 try: 2790 self.__delattr__(name) 2791 except AttributeError: 2792 pass 2793 2794 # Keep track of metric instance created in subclassed layer. 2795 for val in nest.flatten(value): 2796 if isinstance(val, metrics_mod.Metric) and hasattr(self, '_metrics'): 2797 self._metrics.append(val) 2798 2799 # Append value to self._self_tracked_trackables if relevant 2800 if (getattr(self, '_auto_track_sub_layers', True) and 2801 (isinstance(value, module.Module) or 2802 base_layer_utils.has_weights(value))): 2803 self._maybe_create_attribute('_self_tracked_trackables', []) 2804 # We need to check object identity to avoid de-duplicating empty 2805 # container types which compare equal. 2806 if not any((layer is value for layer in self._self_tracked_trackables)): 2807 self._self_tracked_trackables.append(value) 2808 if hasattr(value, '_use_resource_variables'): 2809 # Legacy layers (V1 tf.layers) must always use 2810 # resource variables. 2811 value._use_resource_variables = True 2812 2813 # Append value to list of trainable / non-trainable weights if relevant 2814 # TODO(b/125122625): This won't pick up on any variables added to a 2815 # list/dict after creation. 2816 for val in nest.flatten(value, expand_composites=True): 2817 if not isinstance(val, tf_variables.Variable): 2818 continue 2819 2820 # Users may add extra weights/variables 2821 # simply by assigning them to attributes (invalid for graph networks) 2822 self._maybe_create_attribute('_trainable_weights', []) 2823 self._maybe_create_attribute('_non_trainable_weights', []) 2824 if val.trainable: 2825 if any(val is w for w in self._trainable_weights): 2826 continue 2827 self._trainable_weights.append(val) 2828 else: 2829 if any(val is w for w in self._non_trainable_weights): 2830 continue 2831 self._non_trainable_weights.append(val) 2832 2833 backend.track_variable(val) 2834 2835 # TODO(b/180760306) Skip the auto trackable from tf.Module to keep status 2836 # quo. See the comment at __delattr__. 2837 super(autotrackable.AutoTrackable, self).__setattr__(name, value) # pylint: disable=bad-super-call 2838 2839 def _gather_children_attribute(self, attribute): 2840 assert attribute in { 2841 'variables', 'trainable_variables', 'non_trainable_variables' 2842 } 2843 if hasattr(self, '_self_tracked_trackables'): 2844 nested_layers = self._flatten_modules(include_self=False, recursive=False) 2845 return list( 2846 itertools.chain.from_iterable( 2847 getattr(layer, attribute) for layer in nested_layers)) 2848 return [] 2849 2850 def _flatten_layers(self, recursive=True, include_self=True): 2851 for m in self._flatten_modules( 2852 recursive=recursive, include_self=include_self): 2853 if isinstance(m, Layer): 2854 yield m 2855 2856 def _flatten_modules(self, recursive=True, include_self=True): 2857 """Flattens `tf.Module` instances (excluding `Metrics`). 2858 2859 Args: 2860 recursive: Whether to recursively flatten through submodules. 2861 include_self: Whether to include this `Layer` instance. 2862 2863 Yields: 2864 `tf.Module` instance tracked by this `Layer`. 2865 """ 2866 if include_self: 2867 yield self 2868 2869 # Only instantiate set and deque if needed. 2870 trackables = getattr(self, '_self_tracked_trackables', None) 2871 if trackables: 2872 seen_object_ids = set() 2873 deque = collections.deque(trackables) 2874 while deque: 2875 trackable_obj = deque.popleft() 2876 trackable_id = id(trackable_obj) 2877 if trackable_id in seen_object_ids: 2878 continue 2879 seen_object_ids.add(trackable_id) 2880 2881 # Metrics are not considered part of the Layer's topology. 2882 if (isinstance(trackable_obj, module.Module) and 2883 not isinstance(trackable_obj, metrics_mod.Metric)): 2884 yield trackable_obj 2885 # Introspect recursively through sublayers. 2886 if recursive: 2887 subtrackables = getattr(trackable_obj, '_self_tracked_trackables', 2888 None) 2889 if subtrackables: 2890 deque.extendleft(reversed(subtrackables)) 2891 elif isinstance(trackable_obj, data_structures.TrackableDataStructure): 2892 # Data structures are introspected even with `recursive=False`. 2893 tracked_values = trackable_obj._values 2894 if tracked_values: 2895 deque.extendleft(reversed(tracked_values)) 2896 2897 # This is a hack so that the is_layer (within 2898 # training/trackable/layer_utils.py) check doesn't get the weights attr. 2899 # TODO(b/110718070): Remove when fixed. 2900 def _is_layer(self): 2901 return True 2902 2903 def _init_call_fn_args(self, expects_training_arg=None): 2904 # Clear cached call function arguments. 2905 self.__class__._call_full_argspec.fget.cache.pop(self, None) 2906 self.__class__._call_fn_args.fget.cache.pop(self, None) 2907 self.__class__._call_accepts_kwargs.fget.cache.pop(self, None) 2908 2909 call_fn_args = self._call_fn_args 2910 call_fn_args += self._call_full_argspec.kwonlyargs or [] 2911 if expects_training_arg is None: 2912 self._expects_training_arg = ('training' in call_fn_args or 2913 self._call_accepts_kwargs) 2914 else: 2915 # Use value encoded into the metadata when loading from the SavedModel. 2916 self._expects_training_arg = expects_training_arg 2917 # The default training arg will be any (non-None) default specified in the 2918 # method signature, or None if no value is specified. 2919 call_fn_arg_defaults = self._call_fn_arg_defaults.copy() 2920 call_fn_arg_defaults.update(self._call_full_argspec.kwonlydefaults or {}) 2921 self._default_training_arg = call_fn_arg_defaults.get('training') 2922 2923 self._expects_mask_arg = ('mask' in call_fn_args or 2924 self._call_accepts_kwargs) 2925 2926 @property 2927 @layer_utils.cached_per_instance 2928 def _call_full_argspec(self): 2929 # Argspec inspection is expensive and the call spec is used often, so it 2930 # makes sense to cache the result. 2931 return tf_inspect.getfullargspec(self.call) 2932 2933 @property 2934 @layer_utils.cached_per_instance 2935 def _call_fn_args(self): 2936 all_args = self._call_full_argspec.args 2937 # Scrub `self` that appears if a decorator was applied. 2938 if all_args and all_args[0] == 'self': 2939 return all_args[1:] 2940 return all_args 2941 2942 @property 2943 @layer_utils.cached_per_instance 2944 def _call_fn_arg_defaults(self): 2945 call_fn_args = self._call_fn_args 2946 call_fn_defaults = self._call_full_argspec.defaults or [] 2947 defaults = dict() 2948 2949 # The call arg defaults are an n-tuple of the last n elements of the args 2950 # list. (n = # of elements that have a default argument) 2951 for i in range(-1 * len(call_fn_defaults), 0): 2952 defaults[call_fn_args[i]] = call_fn_defaults[i] 2953 return defaults 2954 2955 @property 2956 @layer_utils.cached_per_instance 2957 def _call_fn_arg_positions(self): 2958 call_fn_arg_positions = dict() 2959 for pos, arg in enumerate(self._call_fn_args): 2960 call_fn_arg_positions[arg] = pos 2961 return call_fn_arg_positions 2962 2963 @property 2964 @layer_utils.cached_per_instance 2965 def _call_accepts_kwargs(self): 2966 return self._call_full_argspec.varkw is not None 2967 2968 @property 2969 def _eager_losses(self): 2970 # A list of loss values containing activity regularizers and losses 2971 # manually added through `add_loss` during eager execution. It is cleared 2972 # after every batch. 2973 # Because we plan on eventually allowing a same model instance to be trained 2974 # in eager mode or graph mode alternatively, we need to keep track of 2975 # eager losses and symbolic losses via separate attributes. 2976 if not hasattr(self._thread_local, '_eager_losses'): 2977 self._thread_local._eager_losses = [] 2978 return self._thread_local._eager_losses 2979 2980 @_eager_losses.setter 2981 def _eager_losses(self, losses): 2982 self._thread_local._eager_losses = losses 2983 2984 def _dedup_weights(self, weights): 2985 """Dedupe weights while maintaining order as much as possible.""" 2986 output, seen_ids = [], set() 2987 for w in weights: 2988 if id(w) not in seen_ids: 2989 output.append(w) 2990 # Track the Variable's identity to avoid __eq__ issues. 2991 seen_ids.add(id(w)) 2992 2993 return output 2994 2995 def _split_out_first_arg(self, args, kwargs): 2996 # Grab the argument corresponding to the first argument in the 2997 # layer's `call` method spec. This will either be the first positional 2998 # argument, or it will be provided as a keyword argument. 2999 if args: 3000 inputs = args[0] 3001 args = args[1:] 3002 elif self._call_fn_args[0] in kwargs: 3003 kwargs = copy.copy(kwargs) 3004 inputs = kwargs.pop(self._call_fn_args[0]) 3005 else: 3006 raise ValueError( 3007 'The first argument to `Layer.call` must always be passed.') 3008 return inputs, args, kwargs 3009 3010 # SavedModel properties. Please see keras/saving/saved_model for details. 3011 3012 @trackable.no_automatic_dependency_tracking 3013 def _set_save_spec(self, inputs): 3014 if self._saved_model_inputs_spec is not None: 3015 return # Already set. 3016 3017 self._saved_model_inputs_spec = nest.map_structure(tf_utils.get_tensor_spec, 3018 inputs) 3019 3020 def _get_save_spec(self, dynamic_batch=True): 3021 if self._saved_model_inputs_spec is None: 3022 return None 3023 3024 return nest.map_structure( 3025 lambda t: tf_utils.get_tensor_spec(t, dynamic_batch=dynamic_batch), 3026 self._saved_model_inputs_spec) 3027 3028 @property 3029 def _trackable_saved_model_saver(self): 3030 return layer_serialization.LayerSavedModelSaver(self) 3031 3032 @property 3033 def _object_identifier(self): 3034 return self._trackable_saved_model_saver.object_identifier 3035 3036 @property 3037 def _tracking_metadata(self): 3038 """Info about this layer to be saved into the SavedModel.""" 3039 return self._trackable_saved_model_saver.tracking_metadata 3040 3041 def _trackable_children(self, save_type='checkpoint', **kwargs): 3042 if save_type == 'savedmodel': 3043 cache = kwargs['cache'] 3044 # TODO(b/213628533): This must be called before super() to ensure 3045 # that any input shape changes are applied before getting the config of 3046 # the model. 3047 children = self._trackable_saved_model_saver.trackable_children(cache) 3048 else: 3049 children = {} 3050 children.update(super()._trackable_children(save_type, **kwargs)) 3051 return children 3052 3053 @property 3054 def _use_input_spec_as_call_signature(self): 3055 # Whether input spec can be used as the call signature when tracing the 3056 # Layer for SavedModel. By default, this is set to `True` for layers 3057 # exported from the Keras library, because the layers more rigidly define 3058 # the `input_specs` property (many custom layers only set the `ndims`) 3059 return get_canonical_name_for_symbol(type(self), 3060 api_name='keras') is not None 3061 3062 def __getstate__(self): 3063 # Override to support `copy.deepcopy` and pickling. 3064 # Thread-local objects cannot be copied in Python 3, so pop these. 3065 # Thread-local objects are used to cache losses in MirroredStrategy, and 3066 # so shouldn't be copied. 3067 state = self.__dict__.copy() 3068 state.pop('_thread_local', None) 3069 state.pop('_metrics_lock', None) 3070 return state 3071 3072 def __setstate__(self, state): 3073 state['_thread_local'] = threading.local() 3074 state['_metrics_lock'] = threading.Lock() 3075 # Bypass Trackable logic as `__dict__` already contains this info. 3076 object.__setattr__(self, '__dict__', state) 3077 3078 3079class TensorFlowOpLayer(Layer): 3080 """Wraps a TensorFlow Operation in a Layer. 3081 3082 This class is used internally by the Functional API. When a user 3083 uses a raw TensorFlow Operation on symbolic tensors originating 3084 from an `Input` Layer, the resultant operation will be wrapped 3085 with this Layer object in order to make the operation compatible 3086 with the Keras API. 3087 3088 This Layer will create a new, identical operation (except for inputs 3089 and outputs) every time it is called. If `run_eagerly` is `True`, 3090 the op creation and calculation will happen inside an Eager function. 3091 3092 Instances of this Layer are created when `autolambda` is called, which 3093 is whenever a Layer's `__call__` encounters symbolic inputs that do 3094 not have Keras metadata, or when a Network's `__init__` encounters 3095 outputs that do not have Keras metadata. 3096 3097 Attributes: 3098 node_def: String, the serialized NodeDef of the Op this layer will wrap. 3099 name: String, the name of the Layer. 3100 constants: Dict of NumPy arrays, the values of any Tensors needed for this 3101 Operation that do not originate from a Keras `Input` Layer. Since all 3102 placeholders must come from Keras `Input` Layers, these Tensors must be 3103 treated as constant in the Functional API. 3104 trainable: Bool, whether this Layer is trainable. Currently Variables are 3105 not supported, and so this parameter has no effect. 3106 dtype: The default dtype of this Layer. Inherited from `Layer` and has no 3107 effect on this class, however is used in `get_config`. 3108 """ 3109 3110 @trackable.no_automatic_dependency_tracking 3111 def __init__(self, 3112 node_def, 3113 name, 3114 constants=None, 3115 trainable=True, 3116 dtype=None): 3117 # Pass autocast=False, as if inputs are cast, input types might not match 3118 # Operation type. 3119 super(TensorFlowOpLayer, self).__init__( 3120 name=_TF_OP_LAYER_NAME_PREFIX + name, trainable=trainable, dtype=dtype, 3121 autocast=False) 3122 if isinstance(node_def, dict): 3123 self.node_def = json_format.ParseDict(node_def, node_def_pb2.NodeDef()) 3124 else: 3125 if not isinstance(node_def, bytes): 3126 node_def = node_def.encode('utf-8') 3127 self.node_def = node_def_pb2.NodeDef.FromString(node_def) 3128 # JSON serialization stringifies keys which are integer input indices. 3129 self.constants = ({ 3130 int(index): constant for index, constant in constants.items() 3131 } if constants is not None else {}) 3132 # Layer uses original op unless it is called on new inputs. 3133 # This means `built` is not set in `__call__`. 3134 self.built = True 3135 3136 # Do not individually trace TensorflowOpLayers in the SavedModel. 3137 self._must_restore_from_config = True 3138 3139 def call(self, inputs): 3140 if context.executing_eagerly(): 3141 return self._defun_call(inputs) 3142 return self._make_op(inputs) 3143 3144 def _make_node_def(self, graph): 3145 node_def = node_def_pb2.NodeDef() 3146 node_def.CopyFrom(self.node_def) 3147 # Used in TPUReplicateContext to indicate whether this node has been cloned 3148 # and to not add TPU attributes. 3149 node_def.attr['_cloned'].b = True 3150 node_def.name = graph.unique_name(node_def.name) 3151 return node_def 3152 3153 def _make_op(self, inputs): 3154 inputs = nest.flatten(inputs) 3155 graph = inputs[0].graph 3156 node_def = self._make_node_def(graph) 3157 with graph.as_default(): 3158 for index, constant in self.constants.items(): 3159 # Recreate constant in graph to add distribution context. 3160 value = tensor_util.constant_value(constant) 3161 if value is not None: 3162 constant = constant_op.constant(value, name=node_def.input[index]) 3163 inputs.insert(index, constant) 3164 # TODO(b/183990973): We should drop or consolidate these private api calls 3165 # for adding an op to the graph and recording its gradient. 3166 c_op = ops._create_c_op(graph, node_def, inputs, control_inputs=[]) 3167 op = graph._create_op_from_tf_operation(c_op) 3168 op._control_flow_post_processing() 3169 3170 # Record the gradient because custom-made ops don't go through the 3171 # code-gen'd eager call path 3172 op_type = compat.as_str(op.op_def.name) 3173 attr_names = [compat.as_str(attr.name) for attr in op.op_def.attr] 3174 attrs = [] 3175 for attr_name in attr_names: 3176 attrs.append(attr_name) 3177 attrs.append(op.get_attr(attr_name)) 3178 attrs = tuple(attrs) 3179 backprop.record_gradient(op_type, op.inputs, attrs, op.outputs) 3180 3181 if len(op.outputs) == 1: 3182 return op.outputs[0] 3183 return op.outputs 3184 3185 @def_function.function 3186 def _defun_call(self, inputs): 3187 """Wraps the op creation method in an Eager function for `run_eagerly`.""" 3188 return self._make_op(inputs) 3189 3190 def get_config(self): 3191 config = super(TensorFlowOpLayer, self).get_config() 3192 config.update({ 3193 # `__init__` prefixes the name. Revert to the constructor argument. 3194 'name': config['name'][len(_TF_OP_LAYER_NAME_PREFIX):], 3195 'node_def': json_format.MessageToDict(self.node_def), 3196 'constants': { 3197 i: backend.get_value(c) for i, c in self.constants.items() 3198 } 3199 }) 3200 return config 3201 3202 3203class AddLoss(Layer): 3204 """Adds its inputs as a loss. 3205 3206 Attributes: 3207 unconditional: Whether or not the loss should be conditioned on the inputs. 3208 """ 3209 3210 def __init__(self, unconditional, **kwargs): 3211 # Pass autocast=False, as there is no reason to cast loss to a different 3212 # dtype. 3213 kwargs['autocast'] = False 3214 super(AddLoss, self).__init__(**kwargs) 3215 self.unconditional = unconditional 3216 3217 def call(self, inputs): 3218 self.add_loss(inputs, inputs=(not self.unconditional)) 3219 return inputs 3220 3221 def get_config(self): 3222 config = super(AddLoss, self).get_config() 3223 config.update({'unconditional': self.unconditional}) 3224 return config 3225 3226 3227class AddMetric(Layer): 3228 """Adds its inputs as a metric. 3229 3230 Attributes: 3231 aggregation: 'mean' or None. How the inputs should be aggregated. 3232 metric_name: The name to use for this metric. 3233 """ 3234 3235 def __init__(self, aggregation=None, metric_name=None, **kwargs): 3236 super(AddMetric, self).__init__(**kwargs) 3237 self.aggregation = aggregation 3238 self.metric_name = metric_name 3239 3240 def call(self, inputs): 3241 self.add_metric(inputs, aggregation=self.aggregation, name=self.metric_name) 3242 return inputs 3243 3244 def get_config(self): 3245 config = super(AddMetric, self).get_config() 3246 config.update({ 3247 'aggregation': self.aggregation, 3248 'metric_name': self.metric_name 3249 }) 3250 return config 3251 3252 3253def _in_functional_construction_mode(layer, inputs, args, kwargs, input_list): # pylint: disable=unused-argument 3254 """Check the arguments to see if we are constructing a functional model.""" 3255 # We are constructing a functional model if any of the inputs 3256 # are KerasTensors 3257 return any( 3258 isinstance(tensor, keras_tensor.KerasTensor) 3259 for tensor in nest.flatten([inputs, args, kwargs])) 3260 3261 3262def _convert_numpy_or_python_types(x): 3263 if isinstance(x, (np_arrays.ndarray, np.ndarray, float, int)): 3264 return ops.convert_to_tensor_v2_with_dispatch(x) 3265 return x 3266 3267 3268# Avoid breaking users who directly import this symbol from this file. 3269# TODO(fchollet): remove this. 3270InputSpec = input_spec.InputSpec # pylint:disable=invalid-name 3271