1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Contains the base Layer class, from which all layers inherit.""" 17 18import collections 19import functools 20import itertools 21import threading 22import warnings 23 24import numpy as np 25 26from tensorflow.python.autograph.core import ag_ctx 27from tensorflow.python.autograph.impl import api as autograph 28from tensorflow.python.distribute import distribution_strategy_context as ds_context 29from tensorflow.python.eager import context 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import errors 32from tensorflow.python.framework import func_graph 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import sparse_tensor 35from tensorflow.python.framework import tensor_spec 36from tensorflow.python.framework import tensor_util 37from tensorflow.python.keras import backend 38from tensorflow.python.keras import constraints 39from tensorflow.python.keras import initializers 40from tensorflow.python.keras import regularizers 41from tensorflow.python.keras.engine import base_layer 42from tensorflow.python.keras.engine import base_layer_utils 43from tensorflow.python.keras.engine import input_spec 44from tensorflow.python.keras.mixed_precision import autocast_variable 45from tensorflow.python.keras.mixed_precision import loss_scale_optimizer 46from tensorflow.python.keras.mixed_precision import policy 47from tensorflow.python.keras.saving.saved_model import layer_serialization 48from tensorflow.python.keras.utils import generic_utils 49from tensorflow.python.keras.utils import layer_utils 50from tensorflow.python.keras.utils import object_identity 51from tensorflow.python.keras.utils import tf_inspect 52from tensorflow.python.keras.utils import tf_utils 53# A module that only depends on `keras.layers` import these from here. 54from tensorflow.python.keras.utils.generic_utils import to_snake_case # pylint: disable=unused-import 55from tensorflow.python.keras.utils.tf_utils import is_tensor_or_tensor_list # pylint: disable=unused-import 56from tensorflow.python.module import module 57from tensorflow.python.ops import array_ops 58from tensorflow.python.ops import math_ops 59from tensorflow.python.ops import variables as tf_variables 60from tensorflow.python.ops.ragged import ragged_tensor 61from tensorflow.python.platform import tf_logging 62from tensorflow.python.trackable import autotrackable 63from tensorflow.python.trackable import base as trackable 64from tensorflow.python.trackable import data_structures 65from tensorflow.python.util import nest 66from tensorflow.tools.docs import doc_controls 67 68 69# pylint: disable=g-classes-have-attributes 70class Layer(base_layer.Layer): 71 """Base layer class. 72 73 This is the class from which all layers inherit. 74 75 A layer is a class implementing common neural networks operations, such 76 as convolution, batch norm, etc. These operations require managing weights, 77 losses, updates, and inter-layer connectivity. 78 79 Users will just instantiate a layer and then treat it as a callable. 80 81 We recommend that descendants of `Layer` implement the following methods: 82 83 * `__init__()`: Save configuration in member variables 84 * `build()`: Called once from `__call__`, when we know the shapes of inputs 85 and `dtype`. Should have the calls to `add_weight()`, and then 86 call the super's `build()` (which sets `self.built = True`, which is 87 nice in case the user wants to call `build()` manually before the 88 first `__call__`). 89 * `call()`: Called in `__call__` after making sure `build()` has been called 90 once. Should actually perform the logic of applying the layer to the 91 input tensors (which should be passed in as the first argument). 92 93 Args: 94 trainable: Boolean, whether the layer's variables should be trainable. 95 name: String name of the layer. 96 dtype: The dtype of the layer's computations and weights (default of 97 `None` means use `tf.keras.backend.floatx` in TensorFlow 2, or the type 98 of the first input in TensorFlow 1). 99 dynamic: Set this to `True` if your layer should only be run eagerly, and 100 should not be used to generate a static computation graph. 101 This would be the case for a Tree-RNN or a recursive network, 102 for example, or generally for any layer that manipulates tensors 103 using Python control flow. If `False`, we assume that the layer can 104 safely be used to generate a static computation graph. 105 106 Attributes: 107 name: The name of the layer (string). 108 dtype: The dtype of the layer's computations and weights. If mixed 109 precision is used with a `tf.keras.mixed_precision.Policy`, this is 110 instead just the dtype of the layer's weights, as the computations are 111 done in a different dtype. 112 updates: List of update ops of this layer. 113 losses: List of losses added by this layer. 114 trainable_weights: List of variables to be included in backprop. 115 non_trainable_weights: List of variables that should not be 116 included in backprop. 117 weights: The concatenation of the lists trainable_weights and 118 non_trainable_weights (in this order). 119 trainable: Whether the layer should be trained (boolean). 120 input_spec: Optional (list of) `InputSpec` object(s) specifying the 121 constraints on inputs that can be accepted by the layer. 122 123 Each layer has a dtype, which is typically the dtype of the layer's 124 computations and variables. A layer's dtype can be queried via the 125 `Layer.dtype` property. The dtype is specified with the `dtype` constructor 126 argument. In TensorFlow 2, the dtype defaults to `tf.keras.backend.floatx()` 127 if no dtype is passed. `floatx()` itself defaults to "float32". Additionally, 128 layers will cast their inputs to the layer's dtype in TensorFlow 2. When mixed 129 precision is used, layers may have different computation and variable dtypes. 130 See `tf.keras.mixed_precision.Policy` for details on layer dtypes. 131 """ 132 133 # See tf.Module for the usage of this property. 134 # The key for _obj_reference_counts_dict is a Trackable, which could be a 135 # variable or layer etc. tf.Module._flatten will fail to flatten the key 136 # since it is trying to convert Trackable to a string. This attribute can be 137 # ignored even after the fix of nest lib, since the trackable object should 138 # already been available as individual attributes. _obj_reference_counts_dict 139 # just contains a copy of them. 140 _TF_MODULE_IGNORED_PROPERTIES = frozenset(itertools.chain( 141 ('_obj_reference_counts_dict',), 142 module.Module._TF_MODULE_IGNORED_PROPERTIES 143 )) 144 145 @trackable.no_automatic_dependency_tracking 146 def __init__(self, trainable=True, name=None, dtype=None, dynamic=False, 147 **kwargs): 148 self._instrument_layer_creation() 149 150 # These properties should be set by the user via keyword arguments. 151 # note that 'dtype', 'input_shape' and 'batch_input_shape' 152 # are only applicable to input layers: do not pass these keywords 153 # to non-input layers. 154 allowed_kwargs = { 155 'input_dim', 'input_shape', 'batch_input_shape', 'batch_size', 156 'weights', 'activity_regularizer', 'autocast', 'implementation' 157 } 158 # Validate optional keyword arguments. 159 generic_utils.validate_kwargs(kwargs, allowed_kwargs) 160 161 # Mutable properties 162 # Indicates whether the layer's weights are updated during training 163 # and whether the layer's updates are run during training. 164 self._trainable = trainable 165 # A stateful layer is a layer whose updates are run during inference too, 166 # for instance stateful RNNs. 167 self._stateful = False 168 # Indicates whether `build` needs to be called upon layer call, to create 169 # the layer's weights. 170 self.built = False 171 self._build_input_shape = None 172 # Provides information about which inputs are compatible with the layer. 173 self._input_spec = None 174 self.supports_masking = False 175 176 self._init_set_name(name) 177 self._activity_regularizer = regularizers.get( 178 kwargs.pop('activity_regularizer', None)) 179 self._maybe_create_attribute('_trainable_weights', []) 180 self._maybe_create_attribute('_non_trainable_weights', []) 181 self._updates = [] 182 # Object to store all thread local layer properties. 183 self._thread_local = threading.local() 184 # A list of zero-argument lambdas which return Tensors, used for variable 185 # regularizers. 186 self._callable_losses = [] 187 # A list of symbolic Tensors containing activity regularizers and losses 188 # manually added through `add_loss` in graph-building mode. 189 self._losses = [] 190 # A list of metric instances corresponding to the symbolic metric tensors 191 # added using the `add_metric` API. 192 self._metrics = [] 193 194 # Both graph and subclassed networks have a dtype policy. For graph 195 # networks, the policy's compute and variable dtypes are ignored. Such 196 # networks only use the policy if it is a PolicyV1, in which case it uses 197 # the PolicyV1's loss_scale (Policy does not have a loss_scale). For 198 # subclassed networks, the compute and variable dtypes are used as like any 199 # ordinary layer. 200 self._set_dtype_policy(dtype) 201 # Boolean indicating whether the layer automatically casts its inputs to the 202 # layer's compute_dtype. 203 self._autocast = kwargs.get('autocast', 204 base_layer_utils.v2_dtype_behavior_enabled()) 205 206 # Dependencies tracked via attribute assignment. 207 # All layers in order of horizontal graph traversal. 208 # Entries are unique. For models includes input and output layers. 209 self._maybe_create_attribute('_self_tracked_trackables', []) 210 211 # These lists will be filled via successive calls 212 # to self._add_inbound_node(). 213 # Used in symbolic mode only, only in conjunction with graph-networks 214 self._inbound_nodes_value = [] 215 self._outbound_nodes_value = [] 216 217 self._init_call_fn_args() 218 219 # Whether the `call` method can be used to build a TF graph without issues. 220 # This attribute has no effect if the model is created using the Functional 221 # API. Instead, `model.dynamic` is determined based on the internal layers. 222 self._dynamic = dynamic 223 224 # Manage input shape information if passed. 225 if 'input_dim' in kwargs and 'input_shape' not in kwargs: 226 # Backwards compatibility: alias 'input_dim' to 'input_shape'. 227 kwargs['input_shape'] = (kwargs['input_dim'],) 228 if 'input_shape' in kwargs or 'batch_input_shape' in kwargs: 229 # In this case we will later create an input layer 230 # to insert before the current layer 231 if 'batch_input_shape' in kwargs: 232 batch_input_shape = tuple(kwargs['batch_input_shape']) 233 elif 'input_shape' in kwargs: 234 if 'batch_size' in kwargs: 235 batch_size = kwargs['batch_size'] 236 else: 237 batch_size = None 238 batch_input_shape = (batch_size,) + tuple(kwargs['input_shape']) 239 self._batch_input_shape = batch_input_shape 240 241 # Manage initial weight values if passed. 242 self._initial_weights = kwargs.get('weights', None) 243 244 # Whether the layer will track any layers that is set as attribute on itself 245 # as sub-layers, the weights from the sub-layers will be included in the 246 # parent layer's variables() as well. 247 # Default to True, which means auto tracking is turned on. Certain subclass 248 # might want to turn it off, like Sequential model. 249 self._auto_track_sub_layers = True 250 251 # Mark this layer as having been originally built as a tf1 layer/model 252 self._originally_built_as_v1 = True 253 254 # For backwards compat reasons, most built-in layers do not guarantee 255 # That they will 100% preserve the structure of input args when saving 256 # / loading configs. E.g. they may un-nest an arg that is 257 # a list with one element. 258 self._preserve_input_structure_in_config = False 259 260 @trackable.no_automatic_dependency_tracking 261 @generic_utils.default 262 def build(self, input_shape): 263 """Creates the variables of the layer (optional, for subclass implementers). 264 265 This is a method that implementers of subclasses of `Layer` or `Model` 266 can override if they need a state-creation step in-between 267 layer instantiation and layer call. 268 269 This is typically used to create the weights of `Layer` subclasses. 270 271 Args: 272 input_shape: Instance of `TensorShape`, or list of instances of 273 `TensorShape` if the layer expects a list of inputs 274 (one instance per input). 275 """ 276 if not hasattr(self.build, '_is_default'): 277 self._build_input_shape = input_shape 278 self.built = True 279 280 @doc_controls.for_subclass_implementers 281 def call(self, inputs, **kwargs): # pylint: disable=unused-argument 282 """This is where the layer's logic lives. 283 284 Args: 285 inputs: Input tensor, or list/tuple of input tensors. 286 **kwargs: Additional keyword arguments. 287 288 Returns: 289 A tensor or list/tuple of tensors. 290 """ 291 return inputs 292 293 @doc_controls.for_subclass_implementers 294 def _add_trackable(self, trackable_object, trainable): 295 """Adds a Trackable object to this layer's state. 296 297 Args: 298 trackable_object: The tf.tracking.Trackable object to add. 299 trainable: Boolean, whether the variable should be part of the layer's 300 "trainable_variables" (e.g. variables, biases) or 301 "non_trainable_variables" (e.g. BatchNorm mean and variance). 302 303 Returns: 304 The TrackableWeightHandler used to track this object. 305 """ 306 if isinstance(trackable_object, base_layer_utils.TrackableWeightHandler): 307 handler = trackable_object 308 else: 309 handler = base_layer_utils.TrackableWeightHandler(trackable_object) 310 if trainable: 311 self._trainable_weights.append(handler) 312 else: 313 self._non_trainable_weights.append(handler) 314 return handler 315 316 @doc_controls.for_subclass_implementers 317 def add_weight(self, 318 name=None, 319 shape=None, 320 dtype=None, 321 initializer=None, 322 regularizer=None, 323 trainable=None, 324 constraint=None, 325 partitioner=None, 326 use_resource=None, 327 synchronization=tf_variables.VariableSynchronization.AUTO, 328 aggregation=tf_variables.VariableAggregation.NONE, 329 **kwargs): 330 """Adds a new variable to the layer. 331 332 Args: 333 name: Variable name. 334 shape: Variable shape. Defaults to scalar if unspecified. 335 dtype: The type of the variable. Defaults to `self.dtype` or `float32`. 336 initializer: Initializer instance (callable). 337 regularizer: Regularizer instance (callable). 338 trainable: Boolean, whether the variable should be part of the layer's 339 "trainable_variables" (e.g. variables, biases) 340 or "non_trainable_variables" (e.g. BatchNorm mean and variance). 341 Note that `trainable` cannot be `True` if `synchronization` 342 is set to `ON_READ`. 343 constraint: Constraint instance (callable). 344 partitioner: Partitioner to be passed to the `Trackable` API. 345 use_resource: Whether to use `ResourceVariable`. 346 synchronization: Indicates when a distributed a variable will be 347 aggregated. Accepted values are constants defined in the class 348 `tf.VariableSynchronization`. By default the synchronization is set to 349 `AUTO` and the current `DistributionStrategy` chooses 350 when to synchronize. If `synchronization` is set to `ON_READ`, 351 `trainable` must not be set to `True`. 352 aggregation: Indicates how a distributed variable will be aggregated. 353 Accepted values are constants defined in the class 354 `tf.VariableAggregation`. 355 **kwargs: Additional keyword arguments. Accepted values are `getter`, 356 `collections`, `experimental_autocast` and `caching_device`. 357 358 Returns: 359 The created variable. Usually either a `Variable` or `ResourceVariable` 360 instance. If `partitioner` is not `None`, a `PartitionedVariable` 361 instance is returned. 362 363 Raises: 364 RuntimeError: If called with partitioned variable regularization and 365 eager execution is enabled. 366 ValueError: When giving unsupported dtype and no initializer or when 367 trainable has been set to True with synchronization set as `ON_READ`. 368 """ 369 if shape is None: 370 shape = () 371 # Validate optional keyword arguments. 372 for kwarg in kwargs: 373 if kwarg not in ['getter', 'collections', 'experimental_autocast', 374 'caching_device']: 375 raise TypeError('Unknown keyword argument:', kwarg) 376 has_custom_getter = 'getter' in kwargs 377 getter = kwargs.pop('getter', base_layer_utils.make_variable) 378 collections_arg = kwargs.pop('collections', None) 379 # 'experimental_autocast' can be set to False by the caller to indicate an 380 # AutoCastVariable should never be created. 381 autocast = kwargs.pop('experimental_autocast', True) 382 # See the docstring for tf.Variable about the details for caching_device. 383 caching_device = kwargs.pop('caching_device', None) 384 385 if dtype is None: 386 dtype = self.dtype or backend.floatx() 387 dtype = dtypes.as_dtype(dtype) 388 if self._dtype_policy.variable_dtype is None: 389 # The policy is "_infer", so we infer the policy from the variable dtype. 390 self._set_dtype_policy(policy.Policy(dtype.base_dtype.name)) 391 initializer = initializers.get(initializer) 392 regularizer = regularizers.get(regularizer) 393 constraint = constraints.get(constraint) 394 395 if synchronization == tf_variables.VariableSynchronization.ON_READ: 396 if trainable: 397 raise ValueError( 398 'Synchronization value can be set to ' 399 'VariableSynchronization.ON_READ only for non-trainable variables. ' 400 'You have specified trainable=True and ' 401 'synchronization=VariableSynchronization.ON_READ.') 402 else: 403 # Set trainable to be false when variable is to be synced on read. 404 trainable = False 405 elif trainable is None: 406 trainable = True 407 408 # Initialize variable when no initializer provided 409 if initializer is None: 410 # If dtype is DT_FLOAT, provide a uniform unit scaling initializer 411 if dtype.is_floating: 412 initializer = initializers.get('glorot_uniform') 413 # If dtype is DT_INT/DT_UINT, provide a default value `zero` 414 # If dtype is DT_BOOL, provide a default value `FALSE` 415 elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool: 416 initializer = initializers.zeros() 417 # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here? 418 elif not has_custom_getter: 419 # When `getter` is specified, it's possibly fine for `initializer` to be 420 # None since it's up to the custom `getter` to raise error in case it 421 # indeed needs `initializer`. 422 raise ValueError('An initializer for variable %s of type %s is required' 423 ' for layer %s' % (name, dtype.base_dtype, self.name)) 424 425 if (autocast and 426 self._dtype_policy.compute_dtype != self._dtype_policy.variable_dtype 427 and dtype.is_floating): 428 # Wrap 'getter' with a version that returns an AutoCastVariable. 429 old_getter = getter 430 def getter(*args, **kwargs): # pylint: disable=function-redefined 431 variable = old_getter(*args, **kwargs) 432 return autocast_variable.create_autocast_variable(variable) 433 # Also the caching_device does not work with the mixed precision API, 434 # disable it if it is specified. 435 # TODO(b/142020079): Reenable it once the bug is fixed. 436 if caching_device is not None: 437 tf_logging.warning( 438 '`caching_device` does not work with mixed precision API. Ignoring ' 439 'user specified `caching_device`.') 440 caching_device = None 441 442 variable = self._add_variable_with_custom_getter( 443 name=name, 444 shape=shape, 445 # TODO(allenl): a `make_variable` equivalent should be added as a 446 # `Trackable` method. 447 getter=getter, 448 # Manage errors in Layer rather than Trackable. 449 overwrite=True, 450 initializer=initializer, 451 dtype=dtype, 452 constraint=constraint, 453 trainable=trainable, 454 partitioner=partitioner, 455 use_resource=use_resource, 456 collections=collections_arg, 457 synchronization=synchronization, 458 aggregation=aggregation, 459 caching_device=caching_device) 460 if regularizer is not None: 461 # TODO(fchollet): in the future, this should be handled at the 462 # level of variable creation, and weight regularization losses 463 # should be variable attributes. 464 name_in_scope = variable.name[:variable.name.find(':')] 465 self._handle_weight_regularization(name_in_scope, 466 variable, 467 regularizer) 468 if base_layer_utils.is_split_variable(variable): 469 for v in variable: 470 backend.track_variable(v) 471 if trainable: 472 self._trainable_weights.append(v) 473 else: 474 self._non_trainable_weights.append(v) 475 else: 476 backend.track_variable(variable) 477 if trainable: 478 self._trainable_weights.append(variable) 479 else: 480 self._non_trainable_weights.append(variable) 481 return variable 482 483 @generic_utils.default 484 def get_config(self): 485 """Returns the config of the layer. 486 487 A layer config is a Python dictionary (serializable) 488 containing the configuration of a layer. 489 The same layer can be reinstantiated later 490 (without its trained weights) from this configuration. 491 492 The config of a layer does not include connectivity 493 information, nor the layer class name. These are handled 494 by `Network` (one layer of abstraction above). 495 496 Returns: 497 Python dictionary. 498 """ 499 all_args = tf_inspect.getfullargspec(self.__init__).args 500 config = {'name': self.name, 'trainable': self.trainable} 501 if hasattr(self, '_batch_input_shape'): 502 config['batch_input_shape'] = self._batch_input_shape 503 config['dtype'] = policy.serialize(self._dtype_policy) 504 if hasattr(self, 'dynamic'): 505 # Only include `dynamic` in the `config` if it is `True` 506 if self.dynamic: 507 config['dynamic'] = self.dynamic 508 elif 'dynamic' in all_args: 509 all_args.remove('dynamic') 510 expected_args = config.keys() 511 # Finds all arguments in the `__init__` that are not in the config: 512 extra_args = [arg for arg in all_args if arg not in expected_args] 513 # Check that either the only argument in the `__init__` is `self`, 514 # or that `get_config` has been overridden: 515 if len(extra_args) > 1 and hasattr(self.get_config, '_is_default'): 516 raise NotImplementedError('Layers with arguments in `__init__` must ' 517 'override `get_config`.') 518 return config 519 520 @classmethod 521 def from_config(cls, config): 522 """Creates a layer from its config. 523 524 This method is the reverse of `get_config`, 525 capable of instantiating the same layer from the config 526 dictionary. It does not handle layer connectivity 527 (handled by Network), nor weights (handled by `set_weights`). 528 529 Args: 530 config: A Python dictionary, typically the 531 output of get_config. 532 533 Returns: 534 A layer instance. 535 """ 536 return cls(**config) 537 538 def compute_output_shape(self, input_shape): 539 """Computes the output shape of the layer. 540 541 If the layer has not been built, this method will call `build` on the 542 layer. This assumes that the layer will later be used with inputs that 543 match the input shape provided here. 544 545 Args: 546 input_shape: Shape tuple (tuple of integers) 547 or list of shape tuples (one per output tensor of the layer). 548 Shape tuples can include None for free dimensions, 549 instead of an integer. 550 551 Returns: 552 An input shape tuple. 553 """ 554 if context.executing_eagerly(): 555 # In this case we build the model first in order to do shape inference. 556 # This is acceptable because the framework only calls 557 # `compute_output_shape` on shape values that the layer would later be 558 # built for. It would however cause issues in case a user attempts to 559 # use `compute_output_shape` manually with shapes that are incompatible 560 # with the shape the Layer will be called on (these users will have to 561 # implement `compute_output_shape` themselves). 562 self._maybe_build(input_shape) 563 with ops.get_default_graph().as_default(): 564 graph = func_graph.FuncGraph('graph') 565 with graph.as_default(): 566 input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False) 567 inputs = nest.map_structure( 568 base_layer_utils.generate_placeholders_from_shape, input_shape) 569 try: 570 outputs = self(inputs, training=False) 571 except TypeError as e: 572 raise NotImplementedError( 573 'We could not automatically infer the static shape of the ' 574 'layer\'s output. Please implement the ' 575 '`compute_output_shape` method on your layer (%s).' % 576 self.__class__.__name__) from e 577 return nest.map_structure(lambda t: t.shape, outputs) 578 raise NotImplementedError 579 580 @doc_controls.for_subclass_implementers 581 def compute_output_signature(self, input_signature): 582 """Compute the output tensor signature of the layer based on the inputs. 583 584 Unlike a TensorShape object, a TensorSpec object contains both shape 585 and dtype information for a tensor. This method allows layers to provide 586 output dtype information if it is different from the input dtype. 587 For any layer that doesn't implement this function, 588 the framework will fall back to use `compute_output_shape`, and will 589 assume that the output dtype matches the input dtype. 590 591 Args: 592 input_signature: Single TensorSpec or nested structure of TensorSpec 593 objects, describing a candidate input for the layer. 594 595 Returns: 596 Single TensorSpec or nested structure of TensorSpec objects, describing 597 how the layer would transform the provided input. 598 599 Raises: 600 TypeError: If input_signature contains a non-TensorSpec object. 601 """ 602 def check_type_return_shape(s): 603 if not isinstance(s, tensor_spec.TensorSpec): 604 raise TypeError('Only TensorSpec signature types are supported, ' 605 'but saw signature entry: {}.'.format(s)) 606 return s.shape 607 input_shape = nest.map_structure(check_type_return_shape, input_signature) 608 output_shape = self.compute_output_shape(input_shape) 609 dtype = self._compute_dtype 610 if dtype is None: 611 input_dtypes = [s.dtype for s in nest.flatten(input_signature)] 612 # Default behavior when self.dtype is None, is to use the first input's 613 # dtype. 614 dtype = input_dtypes[0] 615 return nest.map_structure( 616 lambda s: tensor_spec.TensorSpec(dtype=dtype, shape=s), 617 output_shape) 618 619 @generic_utils.default 620 def compute_mask(self, inputs, mask=None): # pylint: disable=unused-argument 621 """Computes an output mask tensor. 622 623 Args: 624 inputs: Tensor or list of tensors. 625 mask: Tensor or list of tensors. 626 627 Returns: 628 None or a tensor (or list of tensors, 629 one per output tensor of the layer). 630 """ 631 if not self.supports_masking: 632 if any(m is not None for m in nest.flatten(mask)): 633 raise TypeError('Layer ' + self.name + ' does not support masking, ' 634 'but was passed an input_mask: ' + str(mask)) 635 # masking not explicitly supported: return None as mask. 636 return None 637 # if masking is explicitly supported, by default 638 # carry over the input mask 639 return mask 640 641 def __call__(self, *args, **kwargs): 642 """Wraps `call`, applying pre- and post-processing steps. 643 644 Args: 645 *args: Positional arguments to be passed to `self.call`. 646 **kwargs: Keyword arguments to be passed to `self.call`. 647 648 Returns: 649 Output tensor(s). 650 651 Note: 652 - The following optional keyword arguments are reserved for specific uses: 653 * `training`: Boolean scalar tensor of Python boolean indicating 654 whether the `call` is meant for training or inference. 655 * `mask`: Boolean input mask. 656 - If the layer's `call` method takes a `mask` argument (as some Keras 657 layers do), its default value will be set to the mask generated 658 for `inputs` by the previous layer (if `input` did come from 659 a layer that generated a corresponding mask, i.e. if it came from 660 a Keras layer with masking support. 661 662 Raises: 663 ValueError: if the layer's `call` method returns None (an invalid value). 664 RuntimeError: if `super().__init__()` was not called in the constructor. 665 """ 666 self._assert_built_as_v1() 667 668 if not hasattr(self, '_thread_local'): 669 raise RuntimeError( 670 'You must call `super().__init__()` in the layer constructor.') 671 672 # Grab the first positional or keyword argument. 673 if args: 674 inputs = args[0] 675 args = args[1:] 676 elif self._call_fn_args[0] in kwargs: 677 inputs = kwargs.pop(self._call_fn_args[0]) 678 else: 679 raise ValueError( 680 'The first argument to `Layer.call` must always be passed.') 681 682 call_context = base_layer_utils.call_context() 683 input_list = nest.flatten(inputs) 684 685 # We will attempt to build a TF graph if & only if all inputs are symbolic. 686 # This is always the case in graph mode. It can also be the case in eager 687 # mode when all inputs can be traced back to `keras.Input()` (when building 688 # models using the functional API). 689 build_graph = tf_utils.are_all_symbolic_tensors(input_list) 690 691 # Accept NumPy and scalar inputs by converting to Tensors. 692 if any(isinstance(x, (np.ndarray, float, int)) for x in input_list): 693 def _convert_non_tensor(x): 694 # Don't call `ops.convert_to_tensor` on all `inputs` because 695 # `SparseTensors` can't be converted to `Tensor`. 696 if isinstance(x, (np.ndarray, float, int)): 697 return ops.convert_to_tensor_v2_with_dispatch(x) 698 return x 699 inputs = nest.map_structure(_convert_non_tensor, inputs) 700 input_list = nest.flatten(inputs) 701 702 # Handle `mask` propagation from previous layer to current layer. Masks can 703 # be propagated explicitly via the `mask` argument, or implicitly via 704 # setting the `_keras_mask` attribute on the inputs to a Layer. Masks passed 705 # explicitly take priority. 706 mask_arg_passed_by_framework = False 707 input_masks = self._collect_input_masks(inputs, args, kwargs) 708 if (self._expects_mask_arg and input_masks is not None and 709 not self._call_arg_was_passed('mask', args, kwargs)): 710 mask_arg_passed_by_framework = True 711 kwargs['mask'] = input_masks 712 713 # If `training` argument is None or not explicitly passed, 714 # propagate `training` value from this layer's calling layer. 715 training_value = None 716 training_arg_passed_by_framework = False 717 # Priority 1: `training` was explicitly passed. 718 if self._call_arg_was_passed('training', args, kwargs): 719 training_value = self._get_call_arg_value('training', args, kwargs) 720 if not self._expects_training_arg: 721 kwargs.pop('training') 722 723 if training_value is None: 724 # Priority 2: `training` was passed to a parent layer. 725 if call_context.training is not None: 726 training_value = call_context.training 727 # Priority 3a: `learning_phase()` has been set. 728 elif backend.global_learning_phase_is_set(): 729 training_value = backend.learning_phase() 730 # Priority 3b: Pass the `learning_phase()` if in the Keras FuncGraph. 731 elif build_graph: 732 with backend.get_graph().as_default(): 733 if base_layer_utils.is_in_keras_graph(): 734 training_value = backend.learning_phase() 735 736 if self._expects_training_arg and training_value is not None: 737 # Force the training_value to be bool type which matches to the contract 738 # for layer/model call args. 739 if tensor_util.is_tf_type(training_value): 740 training_value = math_ops.cast(training_value, dtypes.bool) 741 else: 742 training_value = bool(training_value) 743 args, kwargs = self._set_call_arg_value( 744 'training', training_value, args, kwargs) 745 training_arg_passed_by_framework = True 746 747 # Only create Keras history if at least one tensor originates from a 748 # `keras.Input`. Otherwise this Layer may be being used outside the Keras 749 # framework. 750 if build_graph and base_layer_utils.needs_keras_history(inputs): 751 base_layer_utils.create_keras_history(inputs) 752 753 with call_context.enter(self, inputs, build_graph, training_value): 754 # Check input assumptions set after layer building, e.g. input shape. 755 if build_graph: 756 # Symbolic execution on symbolic tensors. We will attempt to build 757 # the corresponding TF subgraph inside `backend.get_graph()` 758 input_spec.assert_input_compatibility(self.input_spec, inputs, 759 self.name) 760 graph = backend.get_graph() 761 with graph.as_default(), backend.name_scope(self._name_scope()): # pylint: disable=not-callable 762 # Build layer if applicable (if the `build` method has been 763 # overridden). 764 self._maybe_build(inputs) 765 cast_inputs = self._maybe_cast_inputs(inputs) 766 767 # Wrapping `call` function in autograph to allow for dynamic control 768 # flow and control dependencies in call. We are limiting this to 769 # subclassed layers as autograph is strictly needed only for 770 # subclassed layers and models. 771 # tf_convert will respect the value of autograph setting in the 772 # enclosing tf.function, if any. 773 if (base_layer_utils.is_subclassed(self) and 774 not base_layer_utils.from_saved_model(self)): 775 call_fn = autograph.tf_convert( 776 self.call, ag_ctx.control_status_ctx()) 777 else: 778 call_fn = self.call 779 780 if not self.dynamic: 781 try: 782 with autocast_variable.enable_auto_cast_variables( 783 self._compute_dtype_object): 784 outputs = call_fn(cast_inputs, *args, **kwargs) 785 786 except errors.OperatorNotAllowedInGraphError as e: 787 raise TypeError('You are attempting to use Python control ' 788 'flow in a layer that was not declared to be ' 789 'dynamic. Pass `dynamic=True` to the class ' 790 'constructor.\nEncountered error:\n"""\n' + 791 str(e) + '\n"""') 792 else: 793 # We will use static shape inference to return symbolic tensors 794 # matching the specifications of the layer outputs. 795 # Since `self.dynamic` is True, we will never attempt to 796 # run the underlying TF graph (which is disconnected). 797 # TODO(fchollet): consider py_func as an alternative, which 798 # would enable us to run the underlying graph if needed. 799 outputs = self._symbolic_call(inputs) 800 801 if outputs is None: 802 raise ValueError('A layer\'s `call` method should return a ' 803 'Tensor or a list of Tensors, not None ' 804 '(layer: ' + self.name + ').') 805 if base_layer_utils.have_all_keras_metadata(inputs): 806 if training_arg_passed_by_framework: 807 args, kwargs = self._set_call_arg_value( 808 'training', None, args, kwargs, pop_kwarg_if_none=True) 809 if mask_arg_passed_by_framework: 810 kwargs.pop('mask') 811 outputs = self._set_connectivity_metadata((inputs,) + args, kwargs, 812 outputs) 813 self._handle_activity_regularization(inputs, outputs) 814 self._set_mask_metadata(inputs, outputs, input_masks) 815 if hasattr(self, '_set_inputs') and not self.inputs: 816 # Subclassed network: explicitly set metadata normally set by 817 # a call to self._set_inputs(). 818 # TODO(b/120997007): This should be done in Eager as well, but 819 # causes garbage collection issues because of the placeholders 820 # created on the default Keras graph. 821 self._set_inputs(inputs, outputs) 822 else: 823 # Eager execution on data tensors. 824 with backend.name_scope(self._name_scope()): # pylint: disable=not-callable 825 self._maybe_build(inputs) 826 cast_inputs = self._maybe_cast_inputs(inputs) 827 with autocast_variable.enable_auto_cast_variables( 828 self._compute_dtype_object): 829 outputs = self.call(cast_inputs, *args, **kwargs) 830 self._handle_activity_regularization(inputs, outputs) 831 self._set_mask_metadata(inputs, outputs, input_masks) 832 833 return outputs 834 835 def _assert_built_as_v1(self): 836 if not hasattr(self, '_originally_built_as_v1'): 837 raise ValueError( 838 'Your Layer or Model is in an invalid state. ' 839 'This can happen for the following cases:\n ' 840 '1. You might be interleaving estimator/non-estimator models or ' 841 'interleaving models/layers made in tf.compat.v1.Graph.as_default() ' 842 'with models/layers created outside of it. ' 843 'Converting a model to an estimator (via model_to_estimator) ' 844 'invalidates all models/layers made before the conversion (even ' 845 'if they were not the model converted to an estimator). ' 846 'Similarly, making a layer or a model inside a ' 847 'a tf.compat.v1.Graph invalidates all layers/models you previously ' 848 'made outside of the graph.\n' 849 '2. You might be using a custom keras layer implementation with ' 850 ' custom __init__ which didn\'t call super().__init__. ' 851 ' Please check the implementation of %s and its bases.' % 852 (type(self),)) 853 854 @property 855 def dtype(self): 856 return self._dtype_policy.variable_dtype 857 858 @property 859 def name(self): 860 return self._name 861 862 @property 863 def dynamic(self): 864 return any(layer._dynamic for layer in self._flatten_layers()) 865 866 @property 867 @doc_controls.do_not_generate_docs 868 def stateful(self): 869 return any(layer._stateful for layer in self._flatten_layers()) 870 871 @stateful.setter 872 def stateful(self, value): 873 self._stateful = value 874 875 @property 876 def trainable(self): 877 return self._trainable 878 879 @trainable.setter 880 def trainable(self, value): 881 self._trainable = value 882 for layer in getattr(self, '_self_tracked_trackables', []): 883 layer.trainable = value 884 885 @property 886 def activity_regularizer(self): 887 """Optional regularizer function for the output of this layer.""" 888 return self._activity_regularizer 889 890 @activity_regularizer.setter 891 def activity_regularizer(self, regularizer): 892 """Optional regularizer function for the output of this layer.""" 893 self._activity_regularizer = regularizer 894 895 @property 896 def input_spec(self): 897 return self._input_spec 898 899 @input_spec.setter 900 # Must be decorated to prevent tracking, since the input_spec can be nested 901 # InputSpec objects. 902 @trackable.no_automatic_dependency_tracking 903 def input_spec(self, value): 904 for v in nest.flatten(value): 905 if v is not None and not isinstance(v, base_layer.InputSpec): 906 raise TypeError('Layer input_spec must be an instance of InputSpec. ' 907 'Got: {}'.format(v)) 908 self._input_spec = value 909 910 @property 911 def updates(self): 912 collected_updates = [] 913 all_layers = self._flatten_layers() 914 with backend.get_graph().as_default(): 915 for layer in all_layers: 916 if not layer.trainable and not layer.stateful: 917 continue 918 for u in layer._updates: 919 if callable(u): 920 try: 921 u = u() 922 except ValueError as e: 923 if 'InaccessibleTensorError' in type(e).__name__: 924 # For one specific case of error we try to raise 925 # a more meaningful error message about the graph if we can. 926 # This error is an internal TF symbol that is not 927 # publicly exposed, so we check the name directly rather 928 # than using a direct import. 929 base_layer_utils.check_graph_consistency( 930 method='add_update', force_raise=True) 931 raise # check_graph_consistency may not always raise. 932 base_layer_utils.check_graph_consistency(u, method='add_update') 933 collected_updates.append(u) 934 return collected_updates 935 936 @property 937 def losses(self): 938 """Losses which are associated with this `Layer`. 939 940 Variable regularization tensors are created when this property is accessed, 941 so it is eager safe: accessing `losses` under a `tf.GradientTape` will 942 propagate gradients back to the corresponding variables. 943 944 Returns: 945 A list of tensors. 946 """ 947 collected_losses = [] 948 all_layers = self._flatten_layers() 949 for layer in all_layers: 950 # If any eager losses are present, we assume the model to be part of an 951 # eager training loop (either a custom one or the one used when 952 # `run_eagerly=True`) and so we always return just the eager losses. 953 collected_losses.extend(layer._losses) 954 for regularizer in layer._callable_losses: 955 loss_tensor = regularizer() 956 if loss_tensor is not None: 957 collected_losses.append(loss_tensor) 958 return collected_losses 959 960 @doc_controls.for_subclass_implementers 961 def add_loss(self, losses, inputs=None): 962 """Add loss tensor(s), potentially dependent on layer inputs. 963 964 Some losses (for instance, activity regularization losses) may be dependent 965 on the inputs passed when calling a layer. Hence, when reusing the same 966 layer on different inputs `a` and `b`, some entries in `layer.losses` may 967 be dependent on `a` and some on `b`. This method automatically keeps track 968 of dependencies. 969 970 This method can be used inside a subclassed layer or model's `call` 971 function, in which case `losses` should be a Tensor or list of Tensors. 972 973 Example: 974 975 ```python 976 class MyLayer(tf.keras.layers.Layer): 977 def call(inputs, self): 978 self.add_loss(tf.abs(tf.reduce_mean(inputs)), inputs=True) 979 return inputs 980 ``` 981 982 This method can also be called directly on a Functional Model during 983 construction. In this case, any loss Tensors passed to this Model must 984 be symbolic and be able to be traced back to the model's `Input`s. These 985 losses become part of the model's topology and are tracked in `get_config`. 986 987 Example: 988 989 ```python 990 inputs = tf.keras.Input(shape=(10,)) 991 x = tf.keras.layers.Dense(10)(inputs) 992 outputs = tf.keras.layers.Dense(1)(x) 993 model = tf.keras.Model(inputs, outputs) 994 # Activity regularization. 995 model.add_loss(tf.abs(tf.reduce_mean(x))) 996 ``` 997 998 If this is not the case for your loss (if, for example, your loss references 999 a `Variable` of one of the model's layers), you can wrap your loss in a 1000 zero-argument lambda. These losses are not tracked as part of the model's 1001 topology since they can't be serialized. 1002 1003 Example: 1004 1005 ```python 1006 inputs = tf.keras.Input(shape=(10,)) 1007 x = tf.keras.layers.Dense(10)(inputs) 1008 outputs = tf.keras.layers.Dense(1)(x) 1009 model = tf.keras.Model(inputs, outputs) 1010 # Weight regularization. 1011 model.add_loss(lambda: tf.reduce_mean(x.kernel)) 1012 ``` 1013 1014 The `get_losses_for` method allows to retrieve the losses relevant to a 1015 specific set of inputs. 1016 1017 Args: 1018 losses: Loss tensor, or list/tuple of tensors. Rather than tensors, losses 1019 may also be zero-argument callables which create a loss tensor. 1020 inputs: Ignored when executing eagerly. If anything other than None is 1021 passed, it signals the losses are conditional on some of the layer's 1022 inputs, and thus they should only be run where these inputs are 1023 available. This is the case for activity regularization losses, for 1024 instance. If `None` is passed, the losses are assumed 1025 to be unconditional, and will apply across all dataflows of the layer 1026 (e.g. weight regularization losses). 1027 """ 1028 def _tag_unconditional(loss): 1029 """Process the loss and tag it by setting loss._unconditional_loss.""" 1030 if callable(loss): 1031 # We run the loss without autocasting, as regularizers are often 1032 # numerically unstable in float16. 1033 with autocast_variable.enable_auto_cast_variables(None): 1034 loss = loss() 1035 if loss is None: 1036 return None # Will be filtered out when computing the .losses property 1037 if not tensor_util.is_tf_type(loss): 1038 loss = ops.convert_to_tensor_v2_with_dispatch( 1039 loss, dtype=backend.floatx()) 1040 loss._unconditional_loss = (inputs is None) # pylint: disable=protected-access 1041 return loss 1042 1043 losses = nest.flatten(losses) 1044 1045 callable_losses = [] 1046 symbolic_losses = [] 1047 for loss in losses: 1048 if callable(loss): 1049 callable_losses.append(functools.partial(_tag_unconditional, loss)) 1050 continue 1051 if loss is None: 1052 continue 1053 if not tensor_util.is_tf_type(loss): 1054 loss = ops.convert_to_tensor_v2_with_dispatch( 1055 loss, dtype=backend.floatx()) 1056 # TF Functions should take the eager path. 1057 if (tf_utils.is_symbolic_tensor(loss) and 1058 not base_layer_utils.is_in_tf_function()): 1059 symbolic_losses.append(_tag_unconditional(loss)) 1060 base_layer_utils.check_graph_consistency(loss, method='add_loss') 1061 1062 self._callable_losses.extend(callable_losses) 1063 1064 in_call_context = base_layer_utils.call_context().in_call 1065 1066 if in_call_context: 1067 for symbolic_loss in symbolic_losses: 1068 self._losses.append(symbolic_loss) 1069 else: 1070 for symbolic_loss in symbolic_losses: 1071 if getattr(self, '_is_graph_network', False): 1072 self._graph_network_add_loss(symbolic_loss) 1073 else: 1074 # Possible a loss was added in a Layer's `build`. 1075 self._losses.append(symbolic_loss) 1076 1077 @property 1078 def metrics(self): 1079 collected_metrics = [] 1080 for layer in self._flatten_layers(): 1081 collected_metrics.extend(layer._metrics) 1082 return collected_metrics 1083 1084 @doc_controls.for_subclass_implementers 1085 def add_metric(self, value, aggregation=None, name=None): 1086 """Adds metric tensor to the layer. 1087 1088 Args: 1089 value: Metric tensor. 1090 aggregation: Sample-wise metric reduction function. If `aggregation=None`, 1091 it indicates that the metric tensor provided has been aggregated 1092 already. eg, `bin_acc = BinaryAccuracy(name='acc')` followed by 1093 `model.add_metric(bin_acc(y_true, y_pred))`. If aggregation='mean', the 1094 given metric tensor will be sample-wise reduced using `mean` function. 1095 eg, `model.add_metric(tf.reduce_sum(outputs), name='output_mean', 1096 aggregation='mean')`. 1097 name: String metric name. 1098 1099 Raises: 1100 ValueError: If `aggregation` is anything other than None or `mean`. 1101 """ 1102 if aggregation is not None and aggregation != 'mean': 1103 raise ValueError( 1104 'We currently support only `mean` sample-wise metric aggregation. ' 1105 'You provided aggregation=`%s`' % aggregation) 1106 1107 from_metric_obj = hasattr(value, '_metric_obj') 1108 is_symbolic = tf_utils.is_symbolic_tensor(value) 1109 in_call_context = base_layer_utils.call_context().in_call 1110 1111 if name is None and not from_metric_obj: 1112 # Eg. `self.add_metric(math_ops.reduce_sum(x), aggregation='mean')` 1113 # In eager mode, we use metric name to lookup a metric. Without a name, 1114 # a new Mean metric wrapper will be created on every model/layer call. 1115 # So, we raise an error when no name is provided. 1116 # We will do the same for symbolic mode for consistency although a name 1117 # will be generated if no name is provided. 1118 1119 # We will not raise this error in the foll use case for the sake of 1120 # consistency as name in provided in the metric constructor. 1121 # mean = metrics.Mean(name='my_metric') 1122 # model.add_metric(mean(outputs)) 1123 raise ValueError('Please provide a name for your metric like ' 1124 '`self.add_metric(tf.reduce_sum(inputs), ' 1125 'name=\'mean_activation\', aggregation=\'mean\')`') 1126 elif from_metric_obj: 1127 name = value._metric_obj.name 1128 1129 if in_call_context: 1130 # TF Function path should take the eager path. 1131 self._symbolic_add_metric(value, aggregation, name) 1132 else: 1133 if not is_symbolic: 1134 raise ValueError('Expected a symbolic Tensor for the metric value, ' 1135 'received: ' + str(value)) 1136 1137 # Possible a metric was added in a Layer's `build`. 1138 if not getattr(self, '_is_graph_network', False): 1139 with backend.get_graph().as_default(): 1140 self._symbolic_add_metric(value, aggregation, name) 1141 return 1142 1143 if from_metric_obj: 1144 raise ValueError('Using the result of calling a `Metric` object ' 1145 'when calling `add_metric` on a Functional ' 1146 'Model is not supported. Please pass the ' 1147 'Tensor to monitor directly.') 1148 1149 # Insert layers into the Keras Graph Network. 1150 self._graph_network_add_metric(value, aggregation, name) 1151 1152 @doc_controls.for_subclass_implementers 1153 def add_update(self, updates, inputs=None): 1154 """Add update op(s), potentially dependent on layer inputs. 1155 1156 Weight updates (for instance, the updates of the moving mean and variance 1157 in a BatchNormalization layer) may be dependent on the inputs passed 1158 when calling a layer. Hence, when reusing the same layer on 1159 different inputs `a` and `b`, some entries in `layer.updates` may be 1160 dependent on `a` and some on `b`. This method automatically keeps track 1161 of dependencies. 1162 1163 The `get_updates_for` method allows to retrieve the updates relevant to a 1164 specific set of inputs. 1165 1166 This call is ignored when eager execution is enabled (in that case, variable 1167 updates are run on the fly and thus do not need to be tracked for later 1168 execution). 1169 1170 Args: 1171 updates: Update op, or list/tuple of update ops, or zero-arg callable 1172 that returns an update op. A zero-arg callable should be passed in 1173 order to disable running the updates by setting `trainable=False` 1174 on this Layer, when executing in Eager mode. 1175 inputs: Deprecated, will be automatically inferred. 1176 """ 1177 if inputs is not None: 1178 tf_logging.warning( 1179 '`add_update` `inputs` kwarg has been deprecated. You no longer need ' 1180 'to pass a value to `inputs` as it is being automatically inferred.') 1181 call_context = base_layer_utils.call_context() 1182 1183 if (ds_context.has_strategy() and 1184 ds_context.in_cross_replica_context() and 1185 # When saving the model, the distribution strategy context should be 1186 # ignored, following the default path for adding updates. 1187 not call_context.saving): 1188 # Updates don't need to be run in a cross-replica context. 1189 return 1190 1191 updates = generic_utils.to_list(updates) 1192 1193 if call_context.in_call: 1194 relevant_inputs = call_context.inputs 1195 else: 1196 inbound_nodes = getattr(self, '_inbound_nodes', []) 1197 relevant_inputs = [node.input_tensors for node in inbound_nodes] 1198 1199 def process_update(x): 1200 """Standardize update ops. 1201 1202 Args: 1203 x: Tensor, op, or callable. 1204 1205 Returns: 1206 An update op. 1207 """ 1208 if callable(x): 1209 update = lambda: process_update(x()) 1210 return update() 1211 elif isinstance(x, ops.Operation): 1212 update = x 1213 elif hasattr(x, 'op'): 1214 update = x.op 1215 else: 1216 update = ops.convert_to_tensor_v2_with_dispatch(x) 1217 1218 reachable = tf_utils.get_reachable_from_inputs(relevant_inputs, [update]) 1219 update._unconditional_update = update not in reachable 1220 return update 1221 1222 updates = [process_update(x) for x in updates] 1223 self._updates.extend(updates) 1224 1225 def set_weights(self, weights): 1226 """Sets the weights of the layer, from Numpy arrays. 1227 1228 The weights of a layer represent the state of the layer. This function 1229 sets the weight values from numpy arrays. The weight values should be 1230 passed in the order they are created by the layer. Note that the layer's 1231 weights must be instantiated before calling this function by calling 1232 the layer. 1233 1234 For example, a Dense layer returns a list of two values-- per-output 1235 weights and the bias value. These can be used to set the weights of another 1236 Dense layer: 1237 1238 >>> a = tf.keras.layers.Dense(1, 1239 ... kernel_initializer=tf.constant_initializer(1.)) 1240 >>> a_out = a(tf.convert_to_tensor([[1., 2., 3.]])) 1241 >>> a.get_weights() 1242 [array([[1.], 1243 [1.], 1244 [1.]], dtype=float32), array([0.], dtype=float32)] 1245 >>> b = tf.keras.layers.Dense(1, 1246 ... kernel_initializer=tf.constant_initializer(2.)) 1247 >>> b_out = b(tf.convert_to_tensor([[10., 20., 30.]])) 1248 >>> b.get_weights() 1249 [array([[2.], 1250 [2.], 1251 [2.]], dtype=float32), array([0.], dtype=float32)] 1252 >>> b.set_weights(a.get_weights()) 1253 >>> b.get_weights() 1254 [array([[1.], 1255 [1.], 1256 [1.]], dtype=float32), array([0.], dtype=float32)] 1257 1258 Args: 1259 weights: a list of Numpy arrays. The number 1260 of arrays and their shape must match 1261 number of the dimensions of the weights 1262 of the layer (i.e. it should match the 1263 output of `get_weights`). 1264 1265 Raises: 1266 ValueError: If the provided weights list does not match the 1267 layer's specifications. 1268 """ 1269 params = self.weights 1270 1271 expected_num_weights = 0 1272 for param in params: 1273 if isinstance(param, base_layer_utils.TrackableWeightHandler): 1274 expected_num_weights += param.num_tensors 1275 else: 1276 expected_num_weights += 1 1277 1278 if expected_num_weights != len(weights): 1279 raise ValueError( 1280 'You called `set_weights(weights)` on layer "%s" ' 1281 'with a weight list of length %s, but the layer was ' 1282 'expecting %s weights. Provided weights: %s...' % 1283 (self.name, len(weights), expected_num_weights, str(weights)[:50])) 1284 1285 weight_index = 0 1286 weight_value_tuples = [] 1287 for param in params: 1288 if isinstance(param, base_layer_utils.TrackableWeightHandler): 1289 num_tensors = param.num_tensors 1290 tensors = weights[weight_index:weight_index + num_tensors] 1291 param.set_weights(tensors) 1292 weight_index += num_tensors 1293 else: 1294 weight = weights[weight_index] 1295 weight_shape = weight.shape if hasattr(weight, 'shape') else () 1296 ref_shape = param.shape 1297 if not ref_shape.is_compatible_with(weight_shape): 1298 raise ValueError( 1299 'Layer weight shape %s not compatible with provided weight ' 1300 'shape %s' % (ref_shape, weight_shape)) 1301 weight_value_tuples.append((param, weight)) 1302 weight_index += 1 1303 1304 backend.batch_set_value(weight_value_tuples) 1305 1306 def get_weights(self): 1307 """Returns the current weights of the layer. 1308 1309 The weights of a layer represent the state of the layer. This function 1310 returns both trainable and non-trainable weight values associated with this 1311 layer as a list of Numpy arrays, which can in turn be used to load state 1312 into similarly parameterized layers. 1313 1314 For example, a Dense layer returns a list of two values-- per-output 1315 weights and the bias value. These can be used to set the weights of another 1316 Dense layer: 1317 1318 >>> a = tf.keras.layers.Dense(1, 1319 ... kernel_initializer=tf.constant_initializer(1.)) 1320 >>> a_out = a(tf.convert_to_tensor([[1., 2., 3.]])) 1321 >>> a.get_weights() 1322 [array([[1.], 1323 [1.], 1324 [1.]], dtype=float32), array([0.], dtype=float32)] 1325 >>> b = tf.keras.layers.Dense(1, 1326 ... kernel_initializer=tf.constant_initializer(2.)) 1327 >>> b_out = b(tf.convert_to_tensor([[10., 20., 30.]])) 1328 >>> b.get_weights() 1329 [array([[2.], 1330 [2.], 1331 [2.]], dtype=float32), array([0.], dtype=float32)] 1332 >>> b.set_weights(a.get_weights()) 1333 >>> b.get_weights() 1334 [array([[1.], 1335 [1.], 1336 [1.]], dtype=float32), array([0.], dtype=float32)] 1337 1338 Returns: 1339 Weights values as a list of numpy arrays. 1340 """ 1341 weights = self.weights 1342 output_weights = [] 1343 for weight in weights: 1344 if isinstance(weight, base_layer_utils.TrackableWeightHandler): 1345 output_weights.extend(weight.get_tensors()) 1346 else: 1347 output_weights.append(weight) 1348 return backend.batch_get_value(output_weights) 1349 1350 def get_updates_for(self, inputs): 1351 """Retrieves updates relevant to a specific set of inputs. 1352 1353 Args: 1354 inputs: Input tensor or list/tuple of input tensors. 1355 1356 Returns: 1357 List of update ops of the layer that depend on `inputs`. 1358 """ 1359 if inputs is None: 1360 # Requesting unconditional updates. 1361 return [u for u in self.updates if u._unconditional_update] 1362 1363 # Requesting input-conditional updates. 1364 updates = [u for u in self.updates if not u._unconditional_update] 1365 inputs = nest.flatten(inputs) 1366 reachable = tf_utils.get_reachable_from_inputs(inputs, updates) 1367 return [u for u in updates if u in reachable] 1368 1369 def get_losses_for(self, inputs): 1370 """Retrieves losses relevant to a specific set of inputs. 1371 1372 Args: 1373 inputs: Input tensor or list/tuple of input tensors. 1374 1375 Returns: 1376 List of loss tensors of the layer that depend on `inputs`. 1377 """ 1378 if inputs is None: 1379 # Requesting unconditional losses. 1380 return [l for l in self.losses if l._unconditional_loss] 1381 1382 # Requesting input-conditional losses. 1383 losses = [l for l in self.losses if not l._unconditional_loss] 1384 inputs = nest.flatten(inputs) 1385 reachable = tf_utils.get_reachable_from_inputs(inputs, losses) 1386 return [l for l in losses if l in reachable] 1387 1388 def get_input_mask_at(self, node_index): 1389 """Retrieves the input mask tensor(s) of a layer at a given node. 1390 1391 Args: 1392 node_index: Integer, index of the node 1393 from which to retrieve the attribute. 1394 E.g. `node_index=0` will correspond to the 1395 first time the layer was called. 1396 1397 Returns: 1398 A mask tensor 1399 (or list of tensors if the layer has multiple inputs). 1400 """ 1401 inputs = self.get_input_at(node_index) 1402 if isinstance(inputs, list): 1403 return [getattr(x, '_keras_mask', None) for x in inputs] 1404 else: 1405 return getattr(inputs, '_keras_mask', None) 1406 1407 def get_output_mask_at(self, node_index): 1408 """Retrieves the output mask tensor(s) of a layer at a given node. 1409 1410 Args: 1411 node_index: Integer, index of the node 1412 from which to retrieve the attribute. 1413 E.g. `node_index=0` will correspond to the 1414 first time the layer was called. 1415 1416 Returns: 1417 A mask tensor 1418 (or list of tensors if the layer has multiple outputs). 1419 """ 1420 output = self.get_output_at(node_index) 1421 if isinstance(output, list): 1422 return [getattr(x, '_keras_mask', None) for x in output] 1423 else: 1424 return getattr(output, '_keras_mask', None) 1425 1426 @property 1427 def input_mask(self): 1428 """Retrieves the input mask tensor(s) of a layer. 1429 1430 Only applicable if the layer has exactly one inbound node, 1431 i.e. if it is connected to one incoming layer. 1432 1433 Returns: 1434 Input mask tensor (potentially None) or list of input 1435 mask tensors. 1436 1437 Raises: 1438 AttributeError: if the layer is connected to 1439 more than one incoming layers. 1440 """ 1441 inputs = self.input 1442 if isinstance(inputs, list): 1443 return [getattr(x, '_keras_mask', None) for x in inputs] 1444 else: 1445 return getattr(inputs, '_keras_mask', None) 1446 1447 @property 1448 def output_mask(self): 1449 """Retrieves the output mask tensor(s) of a layer. 1450 1451 Only applicable if the layer has exactly one inbound node, 1452 i.e. if it is connected to one incoming layer. 1453 1454 Returns: 1455 Output mask tensor (potentially None) or list of output 1456 mask tensors. 1457 1458 Raises: 1459 AttributeError: if the layer is connected to 1460 more than one incoming layers. 1461 """ 1462 output = self.output 1463 if isinstance(output, list): 1464 return [getattr(x, '_keras_mask', None) for x in output] 1465 else: 1466 return getattr(output, '_keras_mask', None) 1467 1468 def get_input_shape_at(self, node_index): 1469 """Retrieves the input shape(s) of a layer at a given node. 1470 1471 Args: 1472 node_index: Integer, index of the node 1473 from which to retrieve the attribute. 1474 E.g. `node_index=0` will correspond to the 1475 first time the layer was called. 1476 1477 Returns: 1478 A shape tuple 1479 (or list of shape tuples if the layer has multiple inputs). 1480 1481 Raises: 1482 RuntimeError: If called in Eager mode. 1483 """ 1484 return self._get_node_attribute_at_index(node_index, 'input_shapes', 1485 'input shape') 1486 1487 def get_output_shape_at(self, node_index): 1488 """Retrieves the output shape(s) of a layer at a given node. 1489 1490 Args: 1491 node_index: Integer, index of the node 1492 from which to retrieve the attribute. 1493 E.g. `node_index=0` will correspond to the 1494 first time the layer was called. 1495 1496 Returns: 1497 A shape tuple 1498 (or list of shape tuples if the layer has multiple outputs). 1499 1500 Raises: 1501 RuntimeError: If called in Eager mode. 1502 """ 1503 return self._get_node_attribute_at_index(node_index, 'output_shapes', 1504 'output shape') 1505 1506 def get_input_at(self, node_index): 1507 """Retrieves the input tensor(s) of a layer at a given node. 1508 1509 Args: 1510 node_index: Integer, index of the node 1511 from which to retrieve the attribute. 1512 E.g. `node_index=0` will correspond to the 1513 first input node of the layer. 1514 1515 Returns: 1516 A tensor (or list of tensors if the layer has multiple inputs). 1517 1518 Raises: 1519 RuntimeError: If called in Eager mode. 1520 """ 1521 return self._get_node_attribute_at_index(node_index, 'input_tensors', 1522 'input') 1523 1524 def get_output_at(self, node_index): 1525 """Retrieves the output tensor(s) of a layer at a given node. 1526 1527 Args: 1528 node_index: Integer, index of the node 1529 from which to retrieve the attribute. 1530 E.g. `node_index=0` will correspond to the 1531 first output node of the layer. 1532 1533 Returns: 1534 A tensor (or list of tensors if the layer has multiple outputs). 1535 1536 Raises: 1537 RuntimeError: If called in Eager mode. 1538 """ 1539 return self._get_node_attribute_at_index(node_index, 'output_tensors', 1540 'output') 1541 1542 @property 1543 def input(self): 1544 """Retrieves the input tensor(s) of a layer. 1545 1546 Only applicable if the layer has exactly one input, 1547 i.e. if it is connected to one incoming layer. 1548 1549 Returns: 1550 Input tensor or list of input tensors. 1551 1552 Raises: 1553 RuntimeError: If called in Eager mode. 1554 AttributeError: If no inbound nodes are found. 1555 """ 1556 if not self._inbound_nodes: 1557 raise AttributeError('Layer ' + self.name + 1558 ' is not connected, no input to return.') 1559 return self._get_node_attribute_at_index(0, 'input_tensors', 'input') 1560 1561 @property 1562 def output(self): 1563 """Retrieves the output tensor(s) of a layer. 1564 1565 Only applicable if the layer has exactly one output, 1566 i.e. if it is connected to one incoming layer. 1567 1568 Returns: 1569 Output tensor or list of output tensors. 1570 1571 Raises: 1572 AttributeError: if the layer is connected to more than one incoming 1573 layers. 1574 RuntimeError: if called in Eager mode. 1575 """ 1576 if not self._inbound_nodes: 1577 raise AttributeError('Layer ' + self.name + ' has no inbound nodes.') 1578 return self._get_node_attribute_at_index(0, 'output_tensors', 'output') 1579 1580 @property 1581 def input_shape(self): 1582 """Retrieves the input shape(s) of a layer. 1583 1584 Only applicable if the layer has exactly one input, 1585 i.e. if it is connected to one incoming layer, or if all inputs 1586 have the same shape. 1587 1588 Returns: 1589 Input shape, as an integer shape tuple 1590 (or list of shape tuples, one tuple per input tensor). 1591 1592 Raises: 1593 AttributeError: if the layer has no defined input_shape. 1594 RuntimeError: if called in Eager mode. 1595 """ 1596 if not self._inbound_nodes: 1597 raise AttributeError('The layer has never been called ' 1598 'and thus has no defined input shape.') 1599 all_input_shapes = set( 1600 [str(node.input_shapes) for node in self._inbound_nodes]) 1601 if len(all_input_shapes) == 1: 1602 return self._inbound_nodes[0].input_shapes 1603 else: 1604 raise AttributeError('The layer "' + str(self.name) + 1605 ' has multiple inbound nodes, ' 1606 'with different input shapes. Hence ' 1607 'the notion of "input shape" is ' 1608 'ill-defined for the layer. ' 1609 'Use `get_input_shape_at(node_index)` ' 1610 'instead.') 1611 1612 def count_params(self): 1613 """Count the total number of scalars composing the weights. 1614 1615 Returns: 1616 An integer count. 1617 1618 Raises: 1619 ValueError: if the layer isn't yet built 1620 (in which case its weights aren't yet defined). 1621 """ 1622 if not self.built: 1623 if getattr(self, '_is_graph_network', False): 1624 with tf_utils.maybe_init_scope(self): 1625 self._maybe_build(self.inputs) 1626 else: 1627 raise ValueError('You tried to call `count_params` on ' + self.name + 1628 ', but the layer isn\'t built. ' 1629 'You can build it manually via: `' + self.name + 1630 '.build(batch_input_shape)`.') 1631 return layer_utils.count_params(self.weights) 1632 1633 @property 1634 def output_shape(self): 1635 """Retrieves the output shape(s) of a layer. 1636 1637 Only applicable if the layer has one output, 1638 or if all outputs have the same shape. 1639 1640 Returns: 1641 Output shape, as an integer shape tuple 1642 (or list of shape tuples, one tuple per output tensor). 1643 1644 Raises: 1645 AttributeError: if the layer has no defined output shape. 1646 RuntimeError: if called in Eager mode. 1647 """ 1648 if not self._inbound_nodes: 1649 raise AttributeError('The layer has never been called ' 1650 'and thus has no defined output shape.') 1651 all_output_shapes = set( 1652 [str(node.output_shapes) for node in self._inbound_nodes]) 1653 if len(all_output_shapes) == 1: 1654 return self._inbound_nodes[0].output_shapes 1655 else: 1656 raise AttributeError('The layer "%s"' 1657 ' has multiple inbound nodes, ' 1658 'with different output shapes. Hence ' 1659 'the notion of "output shape" is ' 1660 'ill-defined for the layer. ' 1661 'Use `get_output_shape_at(node_index)` ' 1662 'instead.' % self.name) 1663 1664 @property 1665 @doc_controls.do_not_doc_inheritable 1666 def inbound_nodes(self): 1667 """Deprecated, do NOT use! Only for compatibility with external Keras.""" 1668 return self._inbound_nodes 1669 1670 @property 1671 @doc_controls.do_not_doc_inheritable 1672 def outbound_nodes(self): 1673 """Deprecated, do NOT use! Only for compatibility with external Keras.""" 1674 return self._outbound_nodes 1675 1676 ############################################################################## 1677 # Methods & attributes below are public aliases of other methods. # 1678 ############################################################################## 1679 1680 @doc_controls.do_not_doc_inheritable 1681 def apply(self, inputs, *args, **kwargs): 1682 """Deprecated, do NOT use! 1683 1684 This is an alias of `self.__call__`. 1685 1686 Args: 1687 inputs: Input tensor(s). 1688 *args: additional positional arguments to be passed to `self.call`. 1689 **kwargs: additional keyword arguments to be passed to `self.call`. 1690 1691 Returns: 1692 Output tensor(s). 1693 """ 1694 warnings.warn('`layer.apply` is deprecated and ' 1695 'will be removed in a future version. ' 1696 'Please use `layer.__call__` method instead.') 1697 return self.__call__(inputs, *args, **kwargs) 1698 1699 @doc_controls.do_not_doc_inheritable 1700 def add_variable(self, *args, **kwargs): 1701 """Deprecated, do NOT use! Alias for `add_weight`.""" 1702 warnings.warn('`layer.add_variable` is deprecated and ' 1703 'will be removed in a future version. ' 1704 'Please use `layer.add_weight` method instead.') 1705 return self.add_weight(*args, **kwargs) 1706 1707 @property 1708 def variables(self): 1709 """Returns the list of all layer variables/weights. 1710 1711 Alias of `self.weights`. 1712 1713 Returns: 1714 A list of variables. 1715 """ 1716 return self.weights 1717 1718 @property 1719 def trainable_variables(self): 1720 return self.trainable_weights 1721 1722 @property 1723 def non_trainable_variables(self): 1724 return self.non_trainable_weights 1725 1726 ############################################################################## 1727 # Methods & attributes below are all private and only used by the framework. # 1728 ############################################################################## 1729 1730 @property 1731 def _inbound_nodes(self): 1732 return self._inbound_nodes_value 1733 1734 @_inbound_nodes.setter 1735 @trackable.no_automatic_dependency_tracking 1736 def _inbound_nodes(self, value): 1737 self._inbound_nodes_value = value 1738 1739 @property 1740 def _outbound_nodes(self): 1741 return self._outbound_nodes_value 1742 1743 @_outbound_nodes.setter 1744 @trackable.no_automatic_dependency_tracking 1745 def _outbound_nodes(self, value): 1746 self._outbound_nodes_value = value 1747 1748 def _set_dtype_policy(self, dtype): 1749 """Sets self._dtype_policy.""" 1750 if isinstance(dtype, policy.Policy): 1751 self._dtype_policy = dtype 1752 elif isinstance(dtype, dict): 1753 self._dtype_policy = policy.deserialize(dtype) 1754 elif isinstance(dtype, str) and dtype in ('mixed_float16', 1755 'mixed_bfloat16'): 1756 # The isinstance check is required since np.dtype raises an error if 1757 # compared to a non-dtype string. 1758 self._dtype_policy = policy.Policy(dtype) 1759 elif dtype: 1760 self._dtype_policy = policy.Policy(dtypes.as_dtype(dtype).name) 1761 else: 1762 self._dtype_policy = policy.global_policy() 1763 if (self._dtype_policy.name == 'mixed_float16' and 1764 not loss_scale_optimizer.strategy_supports_loss_scaling()): 1765 # Although only loss scaling doesn't support certain strategies, to avoid 1766 # confusion, we disallow the 'mixed_float16' policy with unsupported 1767 # strategies. This is because 'mixed_float16' requires loss scaling for 1768 # numeric stability. 1769 strategy = ds_context.get_strategy() 1770 raise ValueError('Mixed precision is not supported with the ' 1771 'tf.distribute.Strategy: %s. Either stop using mixed ' 1772 'precision by removing the use of the "%s" policy or ' 1773 'use a different Strategy, e.g. a MirroredStrategy.' % 1774 (strategy.__class__.__name__, self._dtype_policy.name)) 1775 1776 # Performance optimization: cache the compute dtype as a Dtype object or 1777 # None, so that str to Dtype conversion doesn't happen in Layer.__call__. 1778 if self._dtype_policy.compute_dtype: 1779 self._compute_dtype_object = dtypes.as_dtype( 1780 self._dtype_policy.compute_dtype) 1781 else: 1782 self._compute_dtype_object = None 1783 1784 # TODO(reedwm): Expose this property? 1785 @property 1786 def _compute_dtype(self): 1787 """The layer's compute dtype. 1788 1789 Unless mixed-precision is used, this is the same as `Layer.dtype`. 1790 1791 If self._autocast is True, layer's will cast floating-point inputs to this. 1792 1793 Returns: 1794 The layer's compute dtype. 1795 """ 1796 return self._dtype_policy.compute_dtype 1797 1798 def _maybe_cast_inputs(self, inputs): 1799 """Maybe casts the inputs to the compute dtype. 1800 1801 If self._compute_dtype is floating-point, and self_autocast is True, 1802 floating-point inputs are casted to self._compute_dtype. 1803 1804 Args: 1805 inputs: Input tensor, or structure of input tensors. 1806 1807 Returns: 1808 `inputs`, but tensors may have been casted to self._compute_dtype 1809 """ 1810 compute_dtype = self._compute_dtype 1811 if (self._autocast and compute_dtype and 1812 dtypes.as_dtype(compute_dtype).is_floating): 1813 def f(x): 1814 """Cast a single Tensor or TensorSpec to the compute dtype.""" 1815 cast_types = (ops.Tensor, sparse_tensor.SparseTensor, 1816 ragged_tensor.RaggedTensor) 1817 if (isinstance(x, cast_types) and x.dtype.is_floating and 1818 x.dtype.base_dtype.name != compute_dtype): 1819 return math_ops.cast(x, compute_dtype) 1820 elif isinstance(x, tensor_spec.TensorSpec) and x.dtype.is_floating: 1821 # Inputs may be TensorSpecs when this function is called from 1822 # model._set_inputs. 1823 return tensor_spec.TensorSpec(x.shape, compute_dtype, x.name) 1824 else: 1825 return x 1826 return nest.map_structure(f, inputs) 1827 else: 1828 return inputs 1829 1830 # _dtype used to be an attribute set in the constructor. We still expose it 1831 # because some clients still use it. 1832 # TODO(reedwm): Deprecate, then remove the _dtype property. 1833 @property 1834 def _dtype(self): 1835 # This is equivalent to returning self.dtype . We do not return self.dtype 1836 # as it would cause infinite recursion in a few subclasses, which override 1837 # "dtype" to return self._dtype. 1838 return self._dtype_policy.variable_dtype 1839 1840 @_dtype.setter 1841 def _dtype(self, value): 1842 value = dtypes.as_dtype(value).name 1843 self._set_dtype_policy(policy.Policy(value)) 1844 1845 def _name_scope(self): # pylint: disable=method-hidden 1846 return self.name 1847 1848 def _init_set_name(self, name, zero_based=True): 1849 if not name: 1850 self._name = backend.unique_object_name( 1851 generic_utils.to_snake_case(self.__class__.__name__), 1852 zero_based=zero_based) 1853 else: 1854 self._name = name 1855 1856 def _get_existing_metric(self, name=None): 1857 match = [m for m in self._metrics if m.name == name] 1858 if not match: 1859 return 1860 if len(match) > 1: 1861 raise ValueError( 1862 'Please provide different names for the metrics you have added. ' 1863 'We found {} metrics with the name: "{}"'.format(len(match), name)) 1864 return match[0] 1865 1866 def _symbolic_add_metric(self, value, aggregation=None, name=None): 1867 base_layer_utils.check_graph_consistency(value, method='add_metric') 1868 match = self._get_existing_metric(name) 1869 if aggregation is None: 1870 # Iterate over the metrics and check if the given metric exists already. 1871 # This can happen when a metric instance is created in subclassed model 1872 # layer `__init__` and we have tracked that instance already in 1873 # model.__setattr__. 1874 if match: 1875 result_tensor = value 1876 metric_obj = match 1877 elif hasattr(value, '_metric_obj'): 1878 # We track the instance using the metadata on the result tensor. 1879 result_tensor = value 1880 metric_obj = result_tensor._metric_obj 1881 self._metrics.append(metric_obj) 1882 else: 1883 raise ValueError( 1884 'We do not support adding an aggregated metric result tensor that ' 1885 'is not the output of a `tf.keras.metrics.Metric` metric instance. ' 1886 'Without having access to the metric instance we cannot reset the ' 1887 'state of a metric after every epoch during training. You can ' 1888 'create a `tf.keras.metrics.Metric` instance and pass the result ' 1889 'here or pass an un-aggregated result with `aggregation` parameter ' 1890 'set as `mean`. For example: `self.add_metric(tf.reduce_sum(inputs)' 1891 ', name=\'mean_activation\', aggregation=\'mean\')`') 1892 else: 1893 # If a non-aggregated tensor is given as input (ie. `aggregation` is 1894 # explicitly set to `mean`), we wrap the tensor in `Mean` metric. 1895 if match: 1896 result_tensor = match(value) 1897 metric_obj = match 1898 else: 1899 metric_obj, result_tensor = base_layer_utils.create_mean_metric( 1900 value, name) 1901 self._metrics.append(metric_obj) 1902 1903 def _handle_weight_regularization(self, name, variable, regularizer): 1904 """Create lambdas which compute regularization losses.""" 1905 1906 def _loss_for_variable(v): 1907 """Creates a regularization loss `Tensor` for variable `v`.""" 1908 with backend.name_scope(name + '/Regularizer'): 1909 regularization = regularizer(v) 1910 return regularization 1911 1912 if base_layer_utils.is_split_variable(variable): 1913 for v in variable: 1914 self.add_loss(functools.partial(_loss_for_variable, v)) 1915 else: 1916 self.add_loss(functools.partial(_loss_for_variable, variable)) 1917 1918 def _handle_activity_regularization(self, inputs, outputs): 1919 # Apply activity regularization. 1920 # Note that it should be applied every time the layer creates a new 1921 # output, since it is output-specific. 1922 if self._activity_regularizer: 1923 output_list = nest.flatten(outputs) 1924 with backend.name_scope('ActivityRegularizer'): 1925 for output in output_list: 1926 activity_loss = self._activity_regularizer(output) 1927 batch_size = math_ops.cast( 1928 array_ops.shape(output)[0], activity_loss.dtype) 1929 # Make activity regularization strength batch-agnostic. 1930 mean_activity_loss = activity_loss / batch_size 1931 base_layer_utils.check_graph_consistency( 1932 mean_activity_loss, method='activity_regularizer') 1933 self.add_loss(mean_activity_loss, inputs=inputs) 1934 1935 def _set_mask_metadata(self, inputs, outputs, previous_mask): 1936 flat_outputs = nest.flatten(outputs) 1937 1938 mask_already_computed = ( 1939 getattr(self, '_compute_output_and_mask_jointly', False) or 1940 all(getattr(x, '_keras_mask', None) is not None for x in flat_outputs)) 1941 1942 # Only compute the mask if the Layer explicitly supports masking or has 1943 # overridden `compute_mask`. 1944 should_compute_mask = ( 1945 hasattr(self, 'compute_mask') and 1946 (self.supports_masking or 1947 not getattr(self.compute_mask, '_is_default', False))) 1948 1949 if mask_already_computed: 1950 flat_masks = [getattr(x, '_keras_mask', None) for x in flat_outputs] 1951 elif not should_compute_mask: 1952 flat_masks = [None for _ in flat_outputs] 1953 else: 1954 output_masks = self.compute_mask(inputs, previous_mask) 1955 # `compute_mask` can return a single `None` even when a Layer 1956 # has multiple outputs. 1957 if output_masks is None: 1958 flat_masks = [None for _ in flat_outputs] 1959 else: 1960 flat_masks = nest.flatten(output_masks) 1961 1962 for output, mask in zip(flat_outputs, flat_masks): 1963 try: 1964 output._keras_mask = mask 1965 except AttributeError: 1966 # C Type such as np.ndarray. 1967 pass 1968 1969 if tf_utils.are_all_symbolic_tensors(flat_outputs): 1970 for output in flat_outputs: 1971 if getattr(output, '_keras_mask', None) is not None: 1972 # Do not track masks for `TensorFlowOpLayer` construction. 1973 output._keras_mask._keras_history_checked = True 1974 1975 def _collect_input_masks(self, inputs, args, kwargs): 1976 """Checks if `mask` argument was passed, else gathers mask from inputs.""" 1977 if self._call_arg_was_passed('mask', args, kwargs): 1978 return self._get_call_arg_value('mask', args, kwargs) 1979 1980 if not self._should_compute_mask: 1981 return None 1982 1983 input_masks = nest.map_structure(lambda t: getattr(t, '_keras_mask', None), 1984 inputs) 1985 if generic_utils.is_all_none(input_masks): 1986 return None 1987 return input_masks 1988 1989 def _call_arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False): 1990 if arg_name in kwargs: 1991 return True 1992 call_fn_args = self._call_fn_args 1993 if not inputs_in_args: 1994 # Ignore `inputs` arg. 1995 call_fn_args = call_fn_args[1:] 1996 if arg_name in dict(zip(call_fn_args, args)): 1997 return True 1998 return False 1999 2000 def _get_call_arg_value(self, arg_name, args, kwargs, inputs_in_args=False): 2001 if arg_name in kwargs: 2002 return kwargs[arg_name] 2003 call_fn_args = self._call_fn_args 2004 if not inputs_in_args: 2005 # Ignore `inputs` arg. 2006 call_fn_args = call_fn_args[1:] 2007 args_dict = dict(zip(call_fn_args, args)) 2008 return args_dict[arg_name] 2009 2010 def _set_call_arg_value( 2011 self, arg_name, new_value, args, 2012 kwargs, inputs_in_args=False, pop_kwarg_if_none=False): 2013 arg_pos = self._call_fn_arg_positions.get(arg_name, None) 2014 if arg_pos is not None: 2015 if not inputs_in_args: 2016 # Ignore `inputs` arg. 2017 arg_pos = arg_pos - 1 2018 if len(args) > arg_pos: 2019 args = list(args) 2020 args[arg_pos] = new_value 2021 return args, kwargs 2022 if new_value is None and pop_kwarg_if_none: 2023 kwargs.pop(arg_name, None) 2024 else: 2025 kwargs[arg_name] = new_value 2026 return args, kwargs 2027 2028 def _get_node_attribute_at_index(self, node_index, attr, attr_name): 2029 """Private utility to retrieves an attribute (e.g. inputs) from a node. 2030 2031 This is used to implement the methods: 2032 - get_input_shape_at 2033 - get_output_shape_at 2034 - get_input_at 2035 etc... 2036 2037 Args: 2038 node_index: Integer index of the node from which 2039 to retrieve the attribute. 2040 attr: Exact node attribute name. 2041 attr_name: Human-readable attribute name, for error messages. 2042 2043 Returns: 2044 The layer's attribute `attr` at the node of index `node_index`. 2045 2046 Raises: 2047 RuntimeError: If the layer has no inbound nodes, or if called in Eager 2048 mode. 2049 ValueError: If the index provided does not match any node. 2050 """ 2051 if not self._inbound_nodes: 2052 raise RuntimeError('The layer has never been called ' 2053 'and thus has no defined ' + attr_name + '.') 2054 if not len(self._inbound_nodes) > node_index: 2055 raise ValueError('Asked to get ' + attr_name + ' at node ' + 2056 str(node_index) + ', but the layer has only ' + 2057 str(len(self._inbound_nodes)) + ' inbound nodes.') 2058 values = getattr(self._inbound_nodes[node_index], attr) 2059 if isinstance(values, list) and len(values) == 1: 2060 return values[0] 2061 else: 2062 return values 2063 2064 def _maybe_build(self, inputs): 2065 # Check input assumptions set before layer building, e.g. input rank. 2066 if not self.built: 2067 input_spec.assert_input_compatibility( 2068 self.input_spec, inputs, self.name) 2069 input_list = nest.flatten(inputs) 2070 if input_list and self._dtype_policy.compute_dtype is None: 2071 try: 2072 dtype = input_list[0].dtype.base_dtype.name 2073 except AttributeError: 2074 pass 2075 else: 2076 self._set_dtype_policy(policy.Policy(dtype)) 2077 input_shapes = None 2078 if all(hasattr(x, 'shape') for x in input_list): 2079 input_shapes = nest.map_structure(lambda x: x.shape, inputs) 2080 # Only call `build` if the user has manually overridden the build method. 2081 if not hasattr(self.build, '_is_default'): 2082 # Any setup work performed only once should happen in an `init_scope` 2083 # to avoid creating symbolic Tensors that will later pollute any eager 2084 # operations. 2085 with tf_utils.maybe_init_scope(self): 2086 self.build(input_shapes) 2087 # We must set also ensure that the layer is marked as built, and the build 2088 # shape is stored since user defined build functions may not be calling 2089 # `super.build()` 2090 Layer.build(self, input_shapes) 2091 2092 # Optionally load weight values specified at layer instantiation. 2093 if self._initial_weights is not None: 2094 self.set_weights(self._initial_weights) 2095 self._initial_weights = None 2096 2097 def _symbolic_call(self, inputs): 2098 input_shapes = nest.map_structure(lambda x: x.shape, inputs) 2099 output_shapes = self.compute_output_shape(input_shapes) 2100 2101 def _make_placeholder_like(shape): 2102 ph = backend.placeholder(shape=shape, dtype=self.dtype) 2103 ph._keras_mask = None 2104 return ph 2105 2106 return nest.map_structure(_make_placeholder_like, output_shapes) 2107 2108 def _get_trainable_state(self): 2109 """Get the `trainable` state of each sublayer. 2110 2111 Returns: 2112 A dict mapping all sublayers to their `trainable` value. 2113 """ 2114 layers = self._flatten_layers(include_self=False, recursive=False) 2115 trainable_state = {self: self.trainable} 2116 for l in layers: 2117 trainable_state.update(l._get_trainable_state()) 2118 return trainable_state 2119 2120 def _set_trainable_state(self, trainable_state): 2121 """Set `trainable` state for each sublayer.""" 2122 if self in trainable_state: 2123 self.trainable = trainable_state[self] 2124 layers = self._flatten_layers(include_self=False, recursive=False) 2125 for l in layers: 2126 if l in trainable_state: 2127 l._set_trainable_state(trainable_state) 2128 2129 @property 2130 def _obj_reference_counts(self): 2131 """A dictionary counting the number of attributes referencing an object.""" 2132 self._maybe_create_attribute('_obj_reference_counts_dict', 2133 object_identity.ObjectIdentityDictionary()) 2134 return self._obj_reference_counts_dict 2135 2136 @trackable.no_automatic_dependency_tracking 2137 def _maybe_create_attribute(self, name, default_value): 2138 """Create the attribute with the default value if it hasn't been created. 2139 2140 This is useful for fields that is used for tracking purpose, 2141 _trainable_weights, or _layers. Note that user could create a layer subclass 2142 and assign an internal field before invoking the Layer.__init__(), the 2143 __setattr__() need to create the tracking fields and __init__() need to not 2144 override them. 2145 2146 Args: 2147 name: String, the name of the attribute. 2148 default_value: Object, the default value of the attribute. 2149 """ 2150 if not hasattr(self, name): 2151 self.__setattr__(name, default_value) 2152 2153 def __delattr__(self, name): 2154 # For any super.__delattr__() call, we will directly use the implementation 2155 # in Trackable and skip the behavior in AutoTrackable. The Layer was 2156 # originally use Trackable as base class, the change of using Module as base 2157 # class forced us to have AutoTrackable in the class hierarchy. 2158 # 2159 # TODO(b/180760306) Keeping the status quo of skipping _delattr__ and 2160 # __setattr__ in AutoTrackable may be unsustainable. 2161 existing_value = getattr(self, name, None) 2162 2163 # If this value is replacing an existing object assigned to an attribute, we 2164 # should clean it out to avoid leaking memory. First we check if there are 2165 # other attributes referencing it. 2166 reference_counts = self._obj_reference_counts 2167 if existing_value not in reference_counts: 2168 super(autotrackable.AutoTrackable, self).__delattr__(name) # pylint: disable=bad-super-call 2169 return 2170 2171 reference_count = reference_counts[existing_value] 2172 if reference_count > 1: 2173 # There are other remaining references. We can't remove this object from 2174 # _layers etc. 2175 reference_counts[existing_value] = reference_count - 1 2176 super(autotrackable.AutoTrackable, self).__delattr__(name) # pylint: disable=bad-super-call 2177 return 2178 else: 2179 # This is the last remaining reference. 2180 del reference_counts[existing_value] 2181 2182 super(autotrackable.AutoTrackable, self).__delattr__(name) # pylint: disable=bad-super-call 2183 2184 if (isinstance(existing_value, Layer) 2185 or base_layer_utils.has_weights(existing_value)): 2186 super(autotrackable.AutoTrackable, self).__setattr__( # pylint: disable=bad-super-call 2187 '_self_tracked_trackables', 2188 [l for l in self._self_tracked_trackables if l is not existing_value]) 2189 if isinstance(existing_value, tf_variables.Variable): 2190 super(autotrackable.AutoTrackable, self).__setattr__( # pylint: disable=bad-super-call 2191 '_trainable_weights', 2192 [w for w in self._trainable_weights if w is not existing_value]) 2193 super(autotrackable.AutoTrackable, self).__setattr__( # pylint: disable=bad-super-call 2194 '_non_trainable_weights', 2195 [w for w in self._non_trainable_weights if w is not existing_value]) 2196 2197 def __setattr__(self, name, value): 2198 if (name == '_self_setattr_tracking' or 2199 not getattr(self, '_self_setattr_tracking', True) or 2200 # Exclude @property.setters from tracking 2201 hasattr(self.__class__, name)): 2202 try: 2203 super(autotrackable.AutoTrackable, self).__setattr__(name, value) # pylint: disable=bad-super-call 2204 except AttributeError: 2205 raise AttributeError( 2206 ('Can\'t set the attribute "{}", likely because it conflicts with ' 2207 'an existing read-only @property of the object. Please choose a ' 2208 'different name.').format(name)) 2209 return 2210 2211 # Keep track of trackable objects, for the needs of `Network.save_weights`. 2212 value = data_structures.sticky_attribute_assignment( 2213 trackable=self, value=value, name=name) 2214 2215 reference_counts = self._obj_reference_counts 2216 reference_counts[value] = reference_counts.get(value, 0) + 1 2217 2218 # Clean out the old attribute, which clears _layers and _trainable_weights 2219 # if necessary. 2220 try: 2221 self.__delattr__(name) 2222 except AttributeError: 2223 pass 2224 2225 # Keep track of metric instance created in subclassed layer. 2226 from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top 2227 for val in nest.flatten(value): 2228 if isinstance(val, metrics_module.Metric) and hasattr(self, '_metrics'): 2229 self._metrics.append(val) 2230 2231 # TODO(scottzhu): Need to track Module object as well for weight tracking. 2232 # Be careful about metric if it becomes a Module in future. 2233 # Append value to self._layers if relevant 2234 if (getattr(self, '_auto_track_sub_layers', True) and 2235 (isinstance(value, Layer) or base_layer_utils.has_weights(value))): 2236 self._maybe_create_attribute('_self_tracked_trackables', []) 2237 # We need to check object identity to avoid de-duplicating empty 2238 # container types which compare equal. 2239 if not any((layer is value for layer in self._self_tracked_trackables)): 2240 self._self_tracked_trackables.append(value) 2241 if hasattr(value, '_use_resource_variables'): 2242 # Legacy layers (V1 tf.layers) must always use 2243 # resource variables. 2244 value._use_resource_variables = True 2245 2246 # Append value to list of trainable / non-trainable weights if relevant 2247 # TODO(b/125122625): This won't pick up on any variables added to a 2248 # list/dict after creation. 2249 for val in nest.flatten(value): 2250 if not isinstance(val, tf_variables.Variable): 2251 continue 2252 2253 # Users may add extra weights/variables 2254 # simply by assigning them to attributes (invalid for graph networks) 2255 self._maybe_create_attribute('_trainable_weights', []) 2256 self._maybe_create_attribute('_non_trainable_weights', []) 2257 if val.trainable: 2258 if any(val is w for w in self._trainable_weights): 2259 continue 2260 self._trainable_weights.append(val) 2261 else: 2262 if any(val is w for w in self._non_trainable_weights): 2263 continue 2264 self._non_trainable_weights.append(val) 2265 2266 backend.track_variable(val) 2267 2268 # TODO(b/180760306) Skip the auto trackable from tf.Module to keep status 2269 # quo. See the comment at __delattr__. 2270 super(autotrackable.AutoTrackable, self).__setattr__(name, value) # pylint: disable=bad-super-call 2271 2272 # This is a hack so that the is_layer (within 2273 # training/trackable/layer_utils.py) check doesn't get the weights attr. 2274 # TODO(b/110718070): Remove when fixed. 2275 def _is_layer(self): 2276 return True 2277 2278 def _init_call_fn_args(self, expects_training_arg=None): 2279 # Clear cached call function arguments. 2280 self.__class__._call_full_argspec.fget.cache.pop(self, None) 2281 self.__class__._call_fn_args.fget.cache.pop(self, None) 2282 self.__class__._call_accepts_kwargs.fget.cache.pop(self, None) 2283 2284 call_fn_args = self._call_fn_args 2285 if expects_training_arg is None: 2286 self._expects_training_arg = ('training' in call_fn_args or 2287 self._call_accepts_kwargs) 2288 else: 2289 # Use value encoded into the metadata when loading from the SavedModel. 2290 self._expects_training_arg = expects_training_arg 2291 self._expects_mask_arg = ('mask' in call_fn_args or 2292 self._call_accepts_kwargs) 2293 2294 @property 2295 @layer_utils.cached_per_instance 2296 def _call_full_argspec(self): 2297 # Argspec inspection is expensive and the call spec is used often, so it 2298 # makes sense to cache the result. 2299 return tf_inspect.getfullargspec(self.call) 2300 2301 @property 2302 @layer_utils.cached_per_instance 2303 def _call_fn_args(self): 2304 all_args = self._call_full_argspec.args 2305 # Scrub `self` that appears if a decorator was applied. 2306 if all_args and all_args[0] == 'self': 2307 return all_args[1:] 2308 return all_args 2309 2310 @property 2311 @layer_utils.cached_per_instance 2312 def _call_fn_arg_positions(self): 2313 call_fn_arg_positions = dict() 2314 for pos, arg in enumerate(self._call_fn_args): 2315 call_fn_arg_positions[arg] = pos 2316 return call_fn_arg_positions 2317 2318 @property 2319 @layer_utils.cached_per_instance 2320 def _call_accepts_kwargs(self): 2321 return self._call_full_argspec.varkw is not None 2322 2323 @property 2324 @layer_utils.cached_per_instance 2325 def _should_compute_mask(self): 2326 return ('mask' in self._call_fn_args or 2327 getattr(self, 'compute_mask', None) is not None) 2328 2329 def _dedup_weights(self, weights): 2330 """Dedupe weights while maintaining order as much as possible.""" 2331 output, seen_ids = [], set() 2332 for w in weights: 2333 if id(w) not in seen_ids: 2334 output.append(w) 2335 # Track the Variable's identity to avoid __eq__ issues. 2336 seen_ids.add(id(w)) 2337 2338 return output 2339 2340 # SavedModel properties. Please see keras/saving/saved_model for details. 2341 2342 @property 2343 def _trackable_saved_model_saver(self): 2344 return layer_serialization.LayerSavedModelSaver(self) 2345 2346 @property 2347 def _object_identifier(self): 2348 return self._trackable_saved_model_saver.object_identifier 2349 2350 @property 2351 def _tracking_metadata(self): 2352 return self._trackable_saved_model_saver.tracking_metadata 2353 2354 def _trackable_children(self, save_type='checkpoint', **kwargs): 2355 if save_type == 'savedmodel': 2356 cache = kwargs['cache'] 2357 # TODO(b/213628533): This must be called before super() to ensure 2358 # that any input shape changes are applied before getting the config of 2359 # the model. 2360 children = self._trackable_saved_model_saver.trackable_children(cache) 2361 else: 2362 children = {} 2363 children.update(super()._trackable_children(save_type, **kwargs)) 2364 return children 2365 2366 def __getstate__(self): 2367 # Override to support `copy.deepcopy` and pickling. 2368 # Thread-local objects cannot be copied in Python 3, so pop these. 2369 # Thread-local objects are used to cache losses in MirroredStrategy, and 2370 # so shouldn't be copied. 2371 state = self.__dict__.copy() 2372 state.pop('_thread_local', None) 2373 return state 2374 2375 def __setstate__(self, state): 2376 state['_thread_local'] = threading.local() 2377 # Bypass Trackable logic as `__dict__` already contains this info. 2378 object.__setattr__(self, '__dict__', state) 2379 2380 2381class KerasHistory( 2382 collections.namedtuple('KerasHistory', 2383 ['layer', 'node_index', 'tensor_index'])): 2384 """Tracks the Layer call that created a Tensor, for Keras Graph Networks. 2385 2386 During construction of Keras Graph Networks, this metadata is added to 2387 each Tensor produced as the output of a Layer, starting with an 2388 `InputLayer`. This allows Keras to track how each Tensor was produced, and 2389 this information is later retraced by the `keras.engine.Network` class to 2390 reconstruct the Keras Graph Network. 2391 2392 Attributes: 2393 layer: The Layer that produced the Tensor. 2394 node_index: The specific call to the Layer that produced this Tensor. Layers 2395 can be called multiple times in order to share weights. A new node is 2396 created every time a Tensor is called. 2397 tensor_index: The output index for this Tensor. Always zero if the Layer 2398 that produced this Tensor only has one output. Nested structures of 2399 Tensors are deterministically assigned an index via `nest.flatten`. 2400 """ 2401 # Added to maintain memory and performance characteristics of `namedtuple` 2402 # while subclassing. 2403 __slots__ = () 2404 2405 2406# Avoid breaking users who directly import this symbol from this file. 2407# TODO(fchollet): remove this. 2408InputSpec = input_spec.InputSpec # pylint:disable=invalid-name 2409