1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16# pylint: disable=g-classes-have-attributes 17"""Recurrent layers and their base classes.""" 18 19import collections 20import warnings 21 22import numpy as np 23 24from tensorflow.python.distribute import distribution_strategy_context as ds_context 25from tensorflow.python.eager import context 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import tensor_shape 28from tensorflow.python.keras import activations 29from tensorflow.python.keras import backend 30from tensorflow.python.keras import constraints 31from tensorflow.python.keras import initializers 32from tensorflow.python.keras import regularizers 33from tensorflow.python.keras.engine.base_layer import Layer 34from tensorflow.python.keras.engine.input_spec import InputSpec 35from tensorflow.python.keras.saving.saved_model import layer_serialization 36from tensorflow.python.keras.utils import control_flow_util 37from tensorflow.python.keras.utils import generic_utils 38from tensorflow.python.keras.utils import tf_utils 39from tensorflow.python.ops import array_ops 40from tensorflow.python.ops import control_flow_ops 41from tensorflow.python.ops import math_ops 42from tensorflow.python.ops import state_ops 43from tensorflow.python.platform import tf_logging as logging 44from tensorflow.python.trackable import base as trackable 45from tensorflow.python.util import nest 46from tensorflow.python.util.tf_export import keras_export 47from tensorflow.tools.docs import doc_controls 48 49 50RECURRENT_DROPOUT_WARNING_MSG = ( 51 'RNN `implementation=2` is not supported when `recurrent_dropout` is set. ' 52 'Using `implementation=1`.') 53 54 55@keras_export('keras.layers.StackedRNNCells') 56class StackedRNNCells(Layer): 57 """Wrapper allowing a stack of RNN cells to behave as a single cell. 58 59 Used to implement efficient stacked RNNs. 60 61 Args: 62 cells: List of RNN cell instances. 63 64 Examples: 65 66 ```python 67 batch_size = 3 68 sentence_max_length = 5 69 n_features = 2 70 new_shape = (batch_size, sentence_max_length, n_features) 71 x = tf.constant(np.reshape(np.arange(30), new_shape), dtype = tf.float32) 72 73 rnn_cells = [tf.keras.layers.LSTMCell(128) for _ in range(2)] 74 stacked_lstm = tf.keras.layers.StackedRNNCells(rnn_cells) 75 lstm_layer = tf.keras.layers.RNN(stacked_lstm) 76 77 result = lstm_layer(x) 78 ``` 79 """ 80 81 def __init__(self, cells, **kwargs): 82 for cell in cells: 83 if not 'call' in dir(cell): 84 raise ValueError('All cells must have a `call` method. ' 85 'received cells:', cells) 86 if not 'state_size' in dir(cell): 87 raise ValueError('All cells must have a ' 88 '`state_size` attribute. ' 89 'received cells:', cells) 90 self.cells = cells 91 # reverse_state_order determines whether the state size will be in a reverse 92 # order of the cells' state. User might want to set this to True to keep the 93 # existing behavior. This is only useful when use RNN(return_state=True) 94 # since the state will be returned as the same order of state_size. 95 self.reverse_state_order = kwargs.pop('reverse_state_order', False) 96 if self.reverse_state_order: 97 logging.warning('reverse_state_order=True in StackedRNNCells will soon ' 98 'be deprecated. Please update the code to work with the ' 99 'natural order of states if you rely on the RNN states, ' 100 'eg RNN(return_state=True).') 101 super(StackedRNNCells, self).__init__(**kwargs) 102 103 @property 104 def state_size(self): 105 return tuple(c.state_size for c in 106 (self.cells[::-1] if self.reverse_state_order else self.cells)) 107 108 @property 109 def output_size(self): 110 if getattr(self.cells[-1], 'output_size', None) is not None: 111 return self.cells[-1].output_size 112 elif _is_multiple_state(self.cells[-1].state_size): 113 return self.cells[-1].state_size[0] 114 else: 115 return self.cells[-1].state_size 116 117 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 118 initial_states = [] 119 for cell in self.cells[::-1] if self.reverse_state_order else self.cells: 120 get_initial_state_fn = getattr(cell, 'get_initial_state', None) 121 if get_initial_state_fn: 122 initial_states.append(get_initial_state_fn( 123 inputs=inputs, batch_size=batch_size, dtype=dtype)) 124 else: 125 initial_states.append(_generate_zero_filled_state_for_cell( 126 cell, inputs, batch_size, dtype)) 127 128 return tuple(initial_states) 129 130 def call(self, inputs, states, constants=None, training=None, **kwargs): 131 # Recover per-cell states. 132 state_size = (self.state_size[::-1] 133 if self.reverse_state_order else self.state_size) 134 nested_states = nest.pack_sequence_as(state_size, nest.flatten(states)) 135 136 # Call the cells in order and store the returned states. 137 new_nested_states = [] 138 for cell, states in zip(self.cells, nested_states): 139 states = states if nest.is_nested(states) else [states] 140 # TF cell does not wrap the state into list when there is only one state. 141 is_tf_rnn_cell = getattr(cell, '_is_tf_rnn_cell', None) is not None 142 states = states[0] if len(states) == 1 and is_tf_rnn_cell else states 143 if generic_utils.has_arg(cell.call, 'training'): 144 kwargs['training'] = training 145 else: 146 kwargs.pop('training', None) 147 # Use the __call__ function for callable objects, eg layers, so that it 148 # will have the proper name scopes for the ops, etc. 149 cell_call_fn = cell.__call__ if callable(cell) else cell.call 150 if generic_utils.has_arg(cell.call, 'constants'): 151 inputs, states = cell_call_fn(inputs, states, 152 constants=constants, **kwargs) 153 else: 154 inputs, states = cell_call_fn(inputs, states, **kwargs) 155 new_nested_states.append(states) 156 157 return inputs, nest.pack_sequence_as(state_size, 158 nest.flatten(new_nested_states)) 159 160 @tf_utils.shape_type_conversion 161 def build(self, input_shape): 162 if isinstance(input_shape, list): 163 input_shape = input_shape[0] 164 for cell in self.cells: 165 if isinstance(cell, Layer) and not cell.built: 166 with backend.name_scope(cell.name): 167 cell.build(input_shape) 168 cell.built = True 169 if getattr(cell, 'output_size', None) is not None: 170 output_dim = cell.output_size 171 elif _is_multiple_state(cell.state_size): 172 output_dim = cell.state_size[0] 173 else: 174 output_dim = cell.state_size 175 input_shape = tuple([input_shape[0]] + 176 tensor_shape.TensorShape(output_dim).as_list()) 177 self.built = True 178 179 def get_config(self): 180 cells = [] 181 for cell in self.cells: 182 cells.append(generic_utils.serialize_keras_object(cell)) 183 config = {'cells': cells} 184 base_config = super(StackedRNNCells, self).get_config() 185 return dict(list(base_config.items()) + list(config.items())) 186 187 @classmethod 188 def from_config(cls, config, custom_objects=None): 189 from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top 190 cells = [] 191 for cell_config in config.pop('cells'): 192 cells.append( 193 deserialize_layer(cell_config, custom_objects=custom_objects)) 194 return cls(cells, **config) 195 196 197@keras_export('keras.layers.RNN') 198class RNN(Layer): 199 """Base class for recurrent layers. 200 201 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) 202 for details about the usage of RNN API. 203 204 Args: 205 cell: A RNN cell instance or a list of RNN cell instances. 206 A RNN cell is a class that has: 207 - A `call(input_at_t, states_at_t)` method, returning 208 `(output_at_t, states_at_t_plus_1)`. The call method of the 209 cell can also take the optional argument `constants`, see 210 section "Note on passing external constants" below. 211 - A `state_size` attribute. This can be a single integer 212 (single state) in which case it is the size of the recurrent 213 state. This can also be a list/tuple of integers (one size per state). 214 The `state_size` can also be TensorShape or tuple/list of 215 TensorShape, to represent high dimension state. 216 - A `output_size` attribute. This can be a single integer or a 217 TensorShape, which represent the shape of the output. For backward 218 compatible reason, if this attribute is not available for the 219 cell, the value will be inferred by the first element of the 220 `state_size`. 221 - A `get_initial_state(inputs=None, batch_size=None, dtype=None)` 222 method that creates a tensor meant to be fed to `call()` as the 223 initial state, if the user didn't specify any initial state via other 224 means. The returned initial state should have a shape of 225 [batch_size, cell.state_size]. The cell might choose to create a 226 tensor full of zeros, or full of other values based on the cell's 227 implementation. 228 `inputs` is the input tensor to the RNN layer, which should 229 contain the batch size as its shape[0], and also dtype. Note that 230 the shape[0] might be `None` during the graph construction. Either 231 the `inputs` or the pair of `batch_size` and `dtype` are provided. 232 `batch_size` is a scalar tensor that represents the batch size 233 of the inputs. `dtype` is `tf.DType` that represents the dtype of 234 the inputs. 235 For backward compatibility, if this method is not implemented 236 by the cell, the RNN layer will create a zero filled tensor with the 237 size of [batch_size, cell.state_size]. 238 In the case that `cell` is a list of RNN cell instances, the cells 239 will be stacked on top of each other in the RNN, resulting in an 240 efficient stacked RNN. 241 return_sequences: Boolean (default `False`). Whether to return the last 242 output in the output sequence, or the full sequence. 243 return_state: Boolean (default `False`). Whether to return the last state 244 in addition to the output. 245 go_backwards: Boolean (default `False`). 246 If True, process the input sequence backwards and return the 247 reversed sequence. 248 stateful: Boolean (default `False`). If True, the last state 249 for each sample at index i in a batch will be used as initial 250 state for the sample of index i in the following batch. 251 unroll: Boolean (default `False`). 252 If True, the network will be unrolled, else a symbolic loop will be used. 253 Unrolling can speed-up a RNN, although it tends to be more 254 memory-intensive. Unrolling is only suitable for short sequences. 255 time_major: The shape format of the `inputs` and `outputs` tensors. 256 If True, the inputs and outputs will be in shape 257 `(timesteps, batch, ...)`, whereas in the False case, it will be 258 `(batch, timesteps, ...)`. Using `time_major = True` is a bit more 259 efficient because it avoids transposes at the beginning and end of the 260 RNN calculation. However, most TensorFlow data is batch-major, so by 261 default this function accepts input and emits output in batch-major 262 form. 263 zero_output_for_mask: Boolean (default `False`). 264 Whether the output should use zeros for the masked timesteps. Note that 265 this field is only used when `return_sequences` is True and mask is 266 provided. It can useful if you want to reuse the raw output sequence of 267 the RNN without interference from the masked timesteps, eg, merging 268 bidirectional RNNs. 269 270 Call arguments: 271 inputs: Input tensor. 272 mask: Binary tensor of shape `[batch_size, timesteps]` indicating whether 273 a given timestep should be masked. An individual `True` entry indicates 274 that the corresponding timestep should be utilized, while a `False` 275 entry indicates that the corresponding timestep should be ignored. 276 training: Python boolean indicating whether the layer should behave in 277 training mode or in inference mode. This argument is passed to the cell 278 when calling it. This is for use with cells that use dropout. 279 initial_state: List of initial state tensors to be passed to the first 280 call of the cell. 281 constants: List of constant tensors to be passed to the cell at each 282 timestep. 283 284 Input shape: 285 N-D tensor with shape `[batch_size, timesteps, ...]` or 286 `[timesteps, batch_size, ...]` when time_major is True. 287 288 Output shape: 289 - If `return_state`: a list of tensors. The first tensor is 290 the output. The remaining tensors are the last states, 291 each with shape `[batch_size, state_size]`, where `state_size` could 292 be a high dimension tensor shape. 293 - If `return_sequences`: N-D tensor with shape 294 `[batch_size, timesteps, output_size]`, where `output_size` could 295 be a high dimension tensor shape, or 296 `[timesteps, batch_size, output_size]` when `time_major` is True. 297 - Else, N-D tensor with shape `[batch_size, output_size]`, where 298 `output_size` could be a high dimension tensor shape. 299 300 Masking: 301 This layer supports masking for input data with a variable number 302 of timesteps. To introduce masks to your data, 303 use an [tf.keras.layers.Embedding] layer with the `mask_zero` parameter 304 set to `True`. 305 306 Note on using statefulness in RNNs: 307 You can set RNN layers to be 'stateful', which means that the states 308 computed for the samples in one batch will be reused as initial states 309 for the samples in the next batch. This assumes a one-to-one mapping 310 between samples in different successive batches. 311 312 To enable statefulness: 313 - Specify `stateful=True` in the layer constructor. 314 - Specify a fixed batch size for your model, by passing 315 If sequential model: 316 `batch_input_shape=(...)` to the first layer in your model. 317 Else for functional model with 1 or more Input layers: 318 `batch_shape=(...)` to all the first layers in your model. 319 This is the expected shape of your inputs 320 *including the batch size*. 321 It should be a tuple of integers, e.g. `(32, 10, 100)`. 322 - Specify `shuffle=False` when calling `fit()`. 323 324 To reset the states of your model, call `.reset_states()` on either 325 a specific layer, or on your entire model. 326 327 Note on specifying the initial state of RNNs: 328 You can specify the initial state of RNN layers symbolically by 329 calling them with the keyword argument `initial_state`. The value of 330 `initial_state` should be a tensor or list of tensors representing 331 the initial state of the RNN layer. 332 333 You can specify the initial state of RNN layers numerically by 334 calling `reset_states` with the keyword argument `states`. The value of 335 `states` should be a numpy array or list of numpy arrays representing 336 the initial state of the RNN layer. 337 338 Note on passing external constants to RNNs: 339 You can pass "external" constants to the cell using the `constants` 340 keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This 341 requires that the `cell.call` method accepts the same keyword argument 342 `constants`. Such constants can be used to condition the cell 343 transformation on additional static inputs (not changing over time), 344 a.k.a. an attention mechanism. 345 346 Examples: 347 348 ```python 349 # First, let's define a RNN Cell, as a layer subclass. 350 351 class MinimalRNNCell(keras.layers.Layer): 352 353 def __init__(self, units, **kwargs): 354 self.units = units 355 self.state_size = units 356 super(MinimalRNNCell, self).__init__(**kwargs) 357 358 def build(self, input_shape): 359 self.kernel = self.add_weight(shape=(input_shape[-1], self.units), 360 initializer='uniform', 361 name='kernel') 362 self.recurrent_kernel = self.add_weight( 363 shape=(self.units, self.units), 364 initializer='uniform', 365 name='recurrent_kernel') 366 self.built = True 367 368 def call(self, inputs, states): 369 prev_output = states[0] 370 h = backend.dot(inputs, self.kernel) 371 output = h + backend.dot(prev_output, self.recurrent_kernel) 372 return output, [output] 373 374 # Let's use this cell in a RNN layer: 375 376 cell = MinimalRNNCell(32) 377 x = keras.Input((None, 5)) 378 layer = RNN(cell) 379 y = layer(x) 380 381 # Here's how to use the cell to build a stacked RNN: 382 383 cells = [MinimalRNNCell(32), MinimalRNNCell(64)] 384 x = keras.Input((None, 5)) 385 layer = RNN(cells) 386 y = layer(x) 387 ``` 388 """ 389 390 def __init__(self, 391 cell, 392 return_sequences=False, 393 return_state=False, 394 go_backwards=False, 395 stateful=False, 396 unroll=False, 397 time_major=False, 398 **kwargs): 399 if isinstance(cell, (list, tuple)): 400 cell = StackedRNNCells(cell) 401 if not 'call' in dir(cell): 402 raise ValueError('`cell` should have a `call` method. ' 403 'The RNN was passed:', cell) 404 if not 'state_size' in dir(cell): 405 raise ValueError('The RNN cell should have ' 406 'an attribute `state_size` ' 407 '(tuple of integers, ' 408 'one integer per RNN state).') 409 # If True, the output for masked timestep will be zeros, whereas in the 410 # False case, output from previous timestep is returned for masked timestep. 411 self.zero_output_for_mask = kwargs.pop('zero_output_for_mask', False) 412 413 if 'input_shape' not in kwargs and ( 414 'input_dim' in kwargs or 'input_length' in kwargs): 415 input_shape = (kwargs.pop('input_length', None), 416 kwargs.pop('input_dim', None)) 417 kwargs['input_shape'] = input_shape 418 419 super(RNN, self).__init__(**kwargs) 420 self.cell = cell 421 self.return_sequences = return_sequences 422 self.return_state = return_state 423 self.go_backwards = go_backwards 424 self.stateful = stateful 425 self.unroll = unroll 426 self.time_major = time_major 427 428 self.supports_masking = True 429 # The input shape is unknown yet, it could have nested tensor inputs, and 430 # the input spec will be the list of specs for nested inputs, the structure 431 # of the input_spec will be the same as the input. 432 self.input_spec = None 433 self.state_spec = None 434 self._states = None 435 self.constants_spec = None 436 self._num_constants = 0 437 438 if stateful: 439 if ds_context.has_strategy(): 440 raise ValueError('RNNs with stateful=True not yet supported with ' 441 'tf.distribute.Strategy.') 442 443 @property 444 def _use_input_spec_as_call_signature(self): 445 if self.unroll: 446 # When the RNN layer is unrolled, the time step shape cannot be unknown. 447 # The input spec does not define the time step (because this layer can be 448 # called with any time step value, as long as it is not None), so it 449 # cannot be used as the call function signature when saving to SavedModel. 450 return False 451 return super(RNN, self)._use_input_spec_as_call_signature 452 453 @property 454 def states(self): 455 if self._states is None: 456 state = nest.map_structure(lambda _: None, self.cell.state_size) 457 return state if nest.is_nested(self.cell.state_size) else [state] 458 return self._states 459 460 @states.setter 461 # Automatic tracking catches "self._states" which adds an extra weight and 462 # breaks HDF5 checkpoints. 463 @trackable.no_automatic_dependency_tracking 464 def states(self, states): 465 self._states = states 466 467 def compute_output_shape(self, input_shape): 468 if isinstance(input_shape, list): 469 input_shape = input_shape[0] 470 # Check whether the input shape contains any nested shapes. It could be 471 # (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from numpy 472 # inputs. 473 try: 474 input_shape = tensor_shape.TensorShape(input_shape) 475 except (ValueError, TypeError): 476 # A nested tensor input 477 input_shape = nest.flatten(input_shape)[0] 478 479 batch = input_shape[0] 480 time_step = input_shape[1] 481 if self.time_major: 482 batch, time_step = time_step, batch 483 484 if _is_multiple_state(self.cell.state_size): 485 state_size = self.cell.state_size 486 else: 487 state_size = [self.cell.state_size] 488 489 def _get_output_shape(flat_output_size): 490 output_dim = tensor_shape.TensorShape(flat_output_size).as_list() 491 if self.return_sequences: 492 if self.time_major: 493 output_shape = tensor_shape.TensorShape( 494 [time_step, batch] + output_dim) 495 else: 496 output_shape = tensor_shape.TensorShape( 497 [batch, time_step] + output_dim) 498 else: 499 output_shape = tensor_shape.TensorShape([batch] + output_dim) 500 return output_shape 501 502 if getattr(self.cell, 'output_size', None) is not None: 503 # cell.output_size could be nested structure. 504 output_shape = nest.flatten(nest.map_structure( 505 _get_output_shape, self.cell.output_size)) 506 output_shape = output_shape[0] if len(output_shape) == 1 else output_shape 507 else: 508 # Note that state_size[0] could be a tensor_shape or int. 509 output_shape = _get_output_shape(state_size[0]) 510 511 if self.return_state: 512 def _get_state_shape(flat_state): 513 state_shape = [batch] + tensor_shape.TensorShape(flat_state).as_list() 514 return tensor_shape.TensorShape(state_shape) 515 state_shape = nest.map_structure(_get_state_shape, state_size) 516 return generic_utils.to_list(output_shape) + nest.flatten(state_shape) 517 else: 518 return output_shape 519 520 def compute_mask(self, inputs, mask): 521 # Time step masks must be the same for each input. 522 # This is because the mask for an RNN is of size [batch, time_steps, 1], 523 # and specifies which time steps should be skipped, and a time step 524 # must be skipped for all inputs. 525 # TODO(scottzhu): Should we accept multiple different masks? 526 mask = nest.flatten(mask)[0] 527 output_mask = mask if self.return_sequences else None 528 if self.return_state: 529 state_mask = [None for _ in self.states] 530 return [output_mask] + state_mask 531 else: 532 return output_mask 533 534 def build(self, input_shape): 535 if isinstance(input_shape, list): 536 input_shape = input_shape[0] 537 # The input_shape here could be a nest structure. 538 539 # do the tensor_shape to shapes here. The input could be single tensor, or a 540 # nested structure of tensors. 541 def get_input_spec(shape): 542 """Convert input shape to InputSpec.""" 543 if isinstance(shape, tensor_shape.TensorShape): 544 input_spec_shape = shape.as_list() 545 else: 546 input_spec_shape = list(shape) 547 batch_index, time_step_index = (1, 0) if self.time_major else (0, 1) 548 if not self.stateful: 549 input_spec_shape[batch_index] = None 550 input_spec_shape[time_step_index] = None 551 return InputSpec(shape=tuple(input_spec_shape)) 552 553 def get_step_input_shape(shape): 554 if isinstance(shape, tensor_shape.TensorShape): 555 shape = tuple(shape.as_list()) 556 # remove the timestep from the input_shape 557 return shape[1:] if self.time_major else (shape[0],) + shape[2:] 558 559 # Check whether the input shape contains any nested shapes. It could be 560 # (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from numpy 561 # inputs. 562 try: 563 input_shape = tensor_shape.TensorShape(input_shape) 564 except (ValueError, TypeError): 565 # A nested tensor input 566 pass 567 568 if not nest.is_nested(input_shape): 569 # This indicates the there is only one input. 570 if self.input_spec is not None: 571 self.input_spec[0] = get_input_spec(input_shape) 572 else: 573 self.input_spec = [get_input_spec(input_shape)] 574 step_input_shape = get_step_input_shape(input_shape) 575 else: 576 if self.input_spec is not None: 577 self.input_spec[0] = nest.map_structure(get_input_spec, input_shape) 578 else: 579 self.input_spec = generic_utils.to_list( 580 nest.map_structure(get_input_spec, input_shape)) 581 step_input_shape = nest.map_structure(get_step_input_shape, input_shape) 582 583 # allow cell (if layer) to build before we set or validate state_spec. 584 if isinstance(self.cell, Layer) and not self.cell.built: 585 with backend.name_scope(self.cell.name): 586 self.cell.build(step_input_shape) 587 self.cell.built = True 588 589 # set or validate state_spec 590 if _is_multiple_state(self.cell.state_size): 591 state_size = list(self.cell.state_size) 592 else: 593 state_size = [self.cell.state_size] 594 595 if self.state_spec is not None: 596 # initial_state was passed in call, check compatibility 597 self._validate_state_spec(state_size, self.state_spec) 598 else: 599 self.state_spec = [ 600 InputSpec(shape=[None] + tensor_shape.TensorShape(dim).as_list()) 601 for dim in state_size 602 ] 603 if self.stateful: 604 self.reset_states() 605 self.built = True 606 607 @staticmethod 608 def _validate_state_spec(cell_state_sizes, init_state_specs): 609 """Validate the state spec between the initial_state and the state_size. 610 611 Args: 612 cell_state_sizes: list, the `state_size` attribute from the cell. 613 init_state_specs: list, the `state_spec` from the initial_state that is 614 passed in `call()`. 615 616 Raises: 617 ValueError: When initial state spec is not compatible with the state size. 618 """ 619 validation_error = ValueError( 620 'An `initial_state` was passed that is not compatible with ' 621 '`cell.state_size`. Received `state_spec`={}; ' 622 'however `cell.state_size` is ' 623 '{}'.format(init_state_specs, cell_state_sizes)) 624 flat_cell_state_sizes = nest.flatten(cell_state_sizes) 625 flat_state_specs = nest.flatten(init_state_specs) 626 627 if len(flat_cell_state_sizes) != len(flat_state_specs): 628 raise validation_error 629 for cell_state_spec, cell_state_size in zip(flat_state_specs, 630 flat_cell_state_sizes): 631 if not tensor_shape.TensorShape( 632 # Ignore the first axis for init_state which is for batch 633 cell_state_spec.shape[1:]).is_compatible_with( 634 tensor_shape.TensorShape(cell_state_size)): 635 raise validation_error 636 637 @doc_controls.do_not_doc_inheritable 638 def get_initial_state(self, inputs): 639 get_initial_state_fn = getattr(self.cell, 'get_initial_state', None) 640 641 if nest.is_nested(inputs): 642 # The input are nested sequences. Use the first element in the seq to get 643 # batch size and dtype. 644 inputs = nest.flatten(inputs)[0] 645 646 input_shape = array_ops.shape(inputs) 647 batch_size = input_shape[1] if self.time_major else input_shape[0] 648 dtype = inputs.dtype 649 if get_initial_state_fn: 650 init_state = get_initial_state_fn( 651 inputs=None, batch_size=batch_size, dtype=dtype) 652 else: 653 init_state = _generate_zero_filled_state(batch_size, self.cell.state_size, 654 dtype) 655 # Keras RNN expect the states in a list, even if it's a single state tensor. 656 if not nest.is_nested(init_state): 657 init_state = [init_state] 658 # Force the state to be a list in case it is a namedtuple eg LSTMStateTuple. 659 return list(init_state) 660 661 def __call__(self, inputs, initial_state=None, constants=None, **kwargs): 662 inputs, initial_state, constants = _standardize_args(inputs, 663 initial_state, 664 constants, 665 self._num_constants) 666 667 if initial_state is None and constants is None: 668 return super(RNN, self).__call__(inputs, **kwargs) 669 670 # If any of `initial_state` or `constants` are specified and are Keras 671 # tensors, then add them to the inputs and temporarily modify the 672 # input_spec to include them. 673 674 additional_inputs = [] 675 additional_specs = [] 676 if initial_state is not None: 677 additional_inputs += initial_state 678 self.state_spec = nest.map_structure( 679 lambda s: InputSpec(shape=backend.int_shape(s)), initial_state) 680 additional_specs += self.state_spec 681 if constants is not None: 682 additional_inputs += constants 683 self.constants_spec = [ 684 InputSpec(shape=backend.int_shape(constant)) for constant in constants 685 ] 686 self._num_constants = len(constants) 687 additional_specs += self.constants_spec 688 # additional_inputs can be empty if initial_state or constants are provided 689 # but empty (e.g. the cell is stateless). 690 flat_additional_inputs = nest.flatten(additional_inputs) 691 is_keras_tensor = backend.is_keras_tensor( 692 flat_additional_inputs[0]) if flat_additional_inputs else True 693 for tensor in flat_additional_inputs: 694 if backend.is_keras_tensor(tensor) != is_keras_tensor: 695 raise ValueError('The initial state or constants of an RNN' 696 ' layer cannot be specified with a mix of' 697 ' Keras tensors and non-Keras tensors' 698 ' (a "Keras tensor" is a tensor that was' 699 ' returned by a Keras layer, or by `Input`)') 700 701 if is_keras_tensor: 702 # Compute the full input spec, including state and constants 703 full_input = [inputs] + additional_inputs 704 if self.built: 705 # Keep the input_spec since it has been populated in build() method. 706 full_input_spec = self.input_spec + additional_specs 707 else: 708 # The original input_spec is None since there could be a nested tensor 709 # input. Update the input_spec to match the inputs. 710 full_input_spec = generic_utils.to_list( 711 nest.map_structure(lambda _: None, inputs)) + additional_specs 712 # Perform the call with temporarily replaced input_spec 713 self.input_spec = full_input_spec 714 output = super(RNN, self).__call__(full_input, **kwargs) 715 # Remove the additional_specs from input spec and keep the rest. It is 716 # important to keep since the input spec was populated by build(), and 717 # will be reused in the stateful=True. 718 self.input_spec = self.input_spec[:-len(additional_specs)] 719 return output 720 else: 721 if initial_state is not None: 722 kwargs['initial_state'] = initial_state 723 if constants is not None: 724 kwargs['constants'] = constants 725 return super(RNN, self).__call__(inputs, **kwargs) 726 727 def call(self, 728 inputs, 729 mask=None, 730 training=None, 731 initial_state=None, 732 constants=None): 733 # The input should be dense, padded with zeros. If a ragged input is fed 734 # into the layer, it is padded and the row lengths are used for masking. 735 inputs, row_lengths = backend.convert_inputs_if_ragged(inputs) 736 is_ragged_input = (row_lengths is not None) 737 self._validate_args_if_ragged(is_ragged_input, mask) 738 739 inputs, initial_state, constants = self._process_inputs( 740 inputs, initial_state, constants) 741 742 self._maybe_reset_cell_dropout_mask(self.cell) 743 if isinstance(self.cell, StackedRNNCells): 744 for cell in self.cell.cells: 745 self._maybe_reset_cell_dropout_mask(cell) 746 747 if mask is not None: 748 # Time step masks must be the same for each input. 749 # TODO(scottzhu): Should we accept multiple different masks? 750 mask = nest.flatten(mask)[0] 751 752 if nest.is_nested(inputs): 753 # In the case of nested input, use the first element for shape check. 754 input_shape = backend.int_shape(nest.flatten(inputs)[0]) 755 else: 756 input_shape = backend.int_shape(inputs) 757 timesteps = input_shape[0] if self.time_major else input_shape[1] 758 if self.unroll and timesteps is None: 759 raise ValueError('Cannot unroll a RNN if the ' 760 'time dimension is undefined. \n' 761 '- If using a Sequential model, ' 762 'specify the time dimension by passing ' 763 'an `input_shape` or `batch_input_shape` ' 764 'argument to your first layer. If your ' 765 'first layer is an Embedding, you can ' 766 'also use the `input_length` argument.\n' 767 '- If using the functional API, specify ' 768 'the time dimension by passing a `shape` ' 769 'or `batch_shape` argument to your Input layer.') 770 771 kwargs = {} 772 if generic_utils.has_arg(self.cell.call, 'training'): 773 kwargs['training'] = training 774 775 # TF RNN cells expect single tensor as state instead of list wrapped tensor. 776 is_tf_rnn_cell = getattr(self.cell, '_is_tf_rnn_cell', None) is not None 777 # Use the __call__ function for callable objects, eg layers, so that it 778 # will have the proper name scopes for the ops, etc. 779 cell_call_fn = self.cell.__call__ if callable(self.cell) else self.cell.call 780 if constants: 781 if not generic_utils.has_arg(self.cell.call, 'constants'): 782 raise ValueError('RNN cell does not support constants') 783 784 def step(inputs, states): 785 constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type 786 states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type 787 788 states = states[0] if len(states) == 1 and is_tf_rnn_cell else states 789 output, new_states = cell_call_fn( 790 inputs, states, constants=constants, **kwargs) 791 if not nest.is_nested(new_states): 792 new_states = [new_states] 793 return output, new_states 794 else: 795 796 def step(inputs, states): 797 states = states[0] if len(states) == 1 and is_tf_rnn_cell else states 798 output, new_states = cell_call_fn(inputs, states, **kwargs) 799 if not nest.is_nested(new_states): 800 new_states = [new_states] 801 return output, new_states 802 last_output, outputs, states = backend.rnn( 803 step, 804 inputs, 805 initial_state, 806 constants=constants, 807 go_backwards=self.go_backwards, 808 mask=mask, 809 unroll=self.unroll, 810 input_length=row_lengths if row_lengths is not None else timesteps, 811 time_major=self.time_major, 812 zero_output_for_mask=self.zero_output_for_mask) 813 814 if self.stateful: 815 updates = [ 816 state_ops.assign(self_state, state) for self_state, state in zip( 817 nest.flatten(self.states), nest.flatten(states)) 818 ] 819 self.add_update(updates) 820 821 if self.return_sequences: 822 output = backend.maybe_convert_to_ragged( 823 is_ragged_input, outputs, row_lengths, go_backwards=self.go_backwards) 824 else: 825 output = last_output 826 827 if self.return_state: 828 if not isinstance(states, (list, tuple)): 829 states = [states] 830 else: 831 states = list(states) 832 return generic_utils.to_list(output) + states 833 else: 834 return output 835 836 def _process_inputs(self, inputs, initial_state, constants): 837 # input shape: `(samples, time (padded with zeros), input_dim)` 838 # note that the .build() method of subclasses MUST define 839 # self.input_spec and self.state_spec with complete input shapes. 840 if (isinstance(inputs, collections.abc.Sequence) 841 and not isinstance(inputs, tuple)): 842 # get initial_state from full input spec 843 # as they could be copied to multiple GPU. 844 if not self._num_constants: 845 initial_state = inputs[1:] 846 else: 847 initial_state = inputs[1:-self._num_constants] 848 constants = inputs[-self._num_constants:] 849 if len(initial_state) == 0: 850 initial_state = None 851 inputs = inputs[0] 852 853 if self.stateful: 854 if initial_state is not None: 855 # When layer is stateful and initial_state is provided, check if the 856 # recorded state is same as the default value (zeros). Use the recorded 857 # state if it is not same as the default. 858 non_zero_count = math_ops.add_n([math_ops.count_nonzero_v2(s) 859 for s in nest.flatten(self.states)]) 860 # Set strict = True to keep the original structure of the state. 861 initial_state = control_flow_ops.cond(non_zero_count > 0, 862 true_fn=lambda: self.states, 863 false_fn=lambda: initial_state, 864 strict=True) 865 else: 866 initial_state = self.states 867 elif initial_state is None: 868 initial_state = self.get_initial_state(inputs) 869 870 if len(initial_state) != len(self.states): 871 raise ValueError('Layer has ' + str(len(self.states)) + 872 ' states but was passed ' + str(len(initial_state)) + 873 ' initial states.') 874 return inputs, initial_state, constants 875 876 def _validate_args_if_ragged(self, is_ragged_input, mask): 877 if not is_ragged_input: 878 return 879 880 if mask is not None: 881 raise ValueError('The mask that was passed in was ' + str(mask) + 882 ' and cannot be applied to RaggedTensor inputs. Please ' 883 'make sure that there is no mask passed in by upstream ' 884 'layers.') 885 if self.unroll: 886 raise ValueError('The input received contains RaggedTensors and does ' 887 'not support unrolling. Disable unrolling by passing ' 888 '`unroll=False` in the RNN Layer constructor.') 889 890 def _maybe_reset_cell_dropout_mask(self, cell): 891 if isinstance(cell, DropoutRNNCellMixin): 892 cell.reset_dropout_mask() 893 cell.reset_recurrent_dropout_mask() 894 895 def reset_states(self, states=None): 896 """Reset the recorded states for the stateful RNN layer. 897 898 Can only be used when RNN layer is constructed with `stateful` = `True`. 899 Args: 900 states: Numpy arrays that contains the value for the initial state, which 901 will be feed to cell at the first time step. When the value is None, 902 zero filled numpy array will be created based on the cell state size. 903 904 Raises: 905 AttributeError: When the RNN layer is not stateful. 906 ValueError: When the batch size of the RNN layer is unknown. 907 ValueError: When the input numpy array is not compatible with the RNN 908 layer state, either size wise or dtype wise. 909 """ 910 if not self.stateful: 911 raise AttributeError('Layer must be stateful.') 912 spec_shape = None 913 if self.input_spec is not None: 914 spec_shape = nest.flatten(self.input_spec[0])[0].shape 915 if spec_shape is None: 916 # It is possible to have spec shape to be None, eg when construct a RNN 917 # with a custom cell, or standard RNN layers (LSTM/GRU) which we only know 918 # it has 3 dim input, but not its full shape spec before build(). 919 batch_size = None 920 else: 921 batch_size = spec_shape[1] if self.time_major else spec_shape[0] 922 if not batch_size: 923 raise ValueError('If a RNN is stateful, it needs to know ' 924 'its batch size. Specify the batch size ' 925 'of your input tensors: \n' 926 '- If using a Sequential model, ' 927 'specify the batch size by passing ' 928 'a `batch_input_shape` ' 929 'argument to your first layer.\n' 930 '- If using the functional API, specify ' 931 'the batch size by passing a ' 932 '`batch_shape` argument to your Input layer.') 933 # initialize state if None 934 if nest.flatten(self.states)[0] is None: 935 if getattr(self.cell, 'get_initial_state', None): 936 flat_init_state_values = nest.flatten(self.cell.get_initial_state( 937 inputs=None, batch_size=batch_size, 938 dtype=self.dtype or backend.floatx())) 939 else: 940 flat_init_state_values = nest.flatten(_generate_zero_filled_state( 941 batch_size, self.cell.state_size, self.dtype or backend.floatx())) 942 flat_states_variables = nest.map_structure( 943 backend.variable, flat_init_state_values) 944 self.states = nest.pack_sequence_as(self.cell.state_size, 945 flat_states_variables) 946 if not nest.is_nested(self.states): 947 self.states = [self.states] 948 elif states is None: 949 for state, size in zip(nest.flatten(self.states), 950 nest.flatten(self.cell.state_size)): 951 backend.set_value( 952 state, 953 np.zeros([batch_size] + tensor_shape.TensorShape(size).as_list())) 954 else: 955 flat_states = nest.flatten(self.states) 956 flat_input_states = nest.flatten(states) 957 if len(flat_input_states) != len(flat_states): 958 raise ValueError('Layer ' + self.name + ' expects ' + 959 str(len(flat_states)) + ' states, ' 960 'but it received ' + str(len(flat_input_states)) + 961 ' state values. Input received: ' + str(states)) 962 set_value_tuples = [] 963 for i, (value, state) in enumerate(zip(flat_input_states, 964 flat_states)): 965 if value.shape != state.shape: 966 raise ValueError( 967 'State ' + str(i) + ' is incompatible with layer ' + 968 self.name + ': expected shape=' + str( 969 (batch_size, state)) + ', found shape=' + str(value.shape)) 970 set_value_tuples.append((state, value)) 971 backend.batch_set_value(set_value_tuples) 972 973 def get_config(self): 974 config = { 975 'return_sequences': self.return_sequences, 976 'return_state': self.return_state, 977 'go_backwards': self.go_backwards, 978 'stateful': self.stateful, 979 'unroll': self.unroll, 980 'time_major': self.time_major 981 } 982 if self._num_constants: 983 config['num_constants'] = self._num_constants 984 if self.zero_output_for_mask: 985 config['zero_output_for_mask'] = self.zero_output_for_mask 986 987 config['cell'] = generic_utils.serialize_keras_object(self.cell) 988 base_config = super(RNN, self).get_config() 989 return dict(list(base_config.items()) + list(config.items())) 990 991 @classmethod 992 def from_config(cls, config, custom_objects=None): 993 from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top 994 cell = deserialize_layer(config.pop('cell'), custom_objects=custom_objects) 995 num_constants = config.pop('num_constants', 0) 996 layer = cls(cell, **config) 997 layer._num_constants = num_constants 998 return layer 999 1000 @property 1001 def _trackable_saved_model_saver(self): 1002 return layer_serialization.RNNSavedModelSaver(self) 1003 1004 1005@keras_export('keras.layers.AbstractRNNCell') 1006class AbstractRNNCell(Layer): 1007 """Abstract object representing an RNN cell. 1008 1009 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) 1010 for details about the usage of RNN API. 1011 1012 This is the base class for implementing RNN cells with custom behavior. 1013 1014 Every `RNNCell` must have the properties below and implement `call` with 1015 the signature `(output, next_state) = call(input, state)`. 1016 1017 Examples: 1018 1019 ```python 1020 class MinimalRNNCell(AbstractRNNCell): 1021 1022 def __init__(self, units, **kwargs): 1023 self.units = units 1024 super(MinimalRNNCell, self).__init__(**kwargs) 1025 1026 @property 1027 def state_size(self): 1028 return self.units 1029 1030 def build(self, input_shape): 1031 self.kernel = self.add_weight(shape=(input_shape[-1], self.units), 1032 initializer='uniform', 1033 name='kernel') 1034 self.recurrent_kernel = self.add_weight( 1035 shape=(self.units, self.units), 1036 initializer='uniform', 1037 name='recurrent_kernel') 1038 self.built = True 1039 1040 def call(self, inputs, states): 1041 prev_output = states[0] 1042 h = backend.dot(inputs, self.kernel) 1043 output = h + backend.dot(prev_output, self.recurrent_kernel) 1044 return output, output 1045 ``` 1046 1047 This definition of cell differs from the definition used in the literature. 1048 In the literature, 'cell' refers to an object with a single scalar output. 1049 This definition refers to a horizontal array of such units. 1050 1051 An RNN cell, in the most abstract setting, is anything that has 1052 a state and performs some operation that takes a matrix of inputs. 1053 This operation results in an output matrix with `self.output_size` columns. 1054 If `self.state_size` is an integer, this operation also results in a new 1055 state matrix with `self.state_size` columns. If `self.state_size` is a 1056 (possibly nested tuple of) TensorShape object(s), then it should return a 1057 matching structure of Tensors having shape `[batch_size].concatenate(s)` 1058 for each `s` in `self.batch_size`. 1059 """ 1060 1061 def call(self, inputs, states): 1062 """The function that contains the logic for one RNN step calculation. 1063 1064 Args: 1065 inputs: the input tensor, which is a slide from the overall RNN input by 1066 the time dimension (usually the second dimension). 1067 states: the state tensor from previous step, which has the same shape 1068 as `(batch, state_size)`. In the case of timestep 0, it will be the 1069 initial state user specified, or zero filled tensor otherwise. 1070 1071 Returns: 1072 A tuple of two tensors: 1073 1. output tensor for the current timestep, with size `output_size`. 1074 2. state tensor for next step, which has the shape of `state_size`. 1075 """ 1076 raise NotImplementedError('Abstract method') 1077 1078 @property 1079 def state_size(self): 1080 """size(s) of state(s) used by this cell. 1081 1082 It can be represented by an Integer, a TensorShape or a tuple of Integers 1083 or TensorShapes. 1084 """ 1085 raise NotImplementedError('Abstract method') 1086 1087 @property 1088 def output_size(self): 1089 """Integer or TensorShape: size of outputs produced by this cell.""" 1090 raise NotImplementedError('Abstract method') 1091 1092 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 1093 return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype) 1094 1095 1096@doc_controls.do_not_generate_docs 1097class DropoutRNNCellMixin(object): 1098 """Object that hold dropout related fields for RNN Cell. 1099 1100 This class is not a standalone RNN cell. It suppose to be used with a RNN cell 1101 by multiple inheritance. Any cell that mix with class should have following 1102 fields: 1103 dropout: a float number within range [0, 1). The ratio that the input 1104 tensor need to dropout. 1105 recurrent_dropout: a float number within range [0, 1). The ratio that the 1106 recurrent state weights need to dropout. 1107 This object will create and cache created dropout masks, and reuse them for 1108 the incoming data, so that the same mask is used for every batch input. 1109 """ 1110 1111 def __init__(self, *args, **kwargs): 1112 self._create_non_trackable_mask_cache() 1113 super(DropoutRNNCellMixin, self).__init__(*args, **kwargs) 1114 1115 @trackable.no_automatic_dependency_tracking 1116 def _create_non_trackable_mask_cache(self): 1117 """Create the cache for dropout and recurrent dropout mask. 1118 1119 Note that the following two masks will be used in "graph function" mode, 1120 e.g. these masks are symbolic tensors. In eager mode, the `eager_*_mask` 1121 tensors will be generated differently than in the "graph function" case, 1122 and they will be cached. 1123 1124 Also note that in graph mode, we still cache those masks only because the 1125 RNN could be created with `unroll=True`. In that case, the `cell.call()` 1126 function will be invoked multiple times, and we want to ensure same mask 1127 is used every time. 1128 1129 Also the caches are created without tracking. Since they are not picklable 1130 by python when deepcopy, we don't want `layer._obj_reference_counts_dict` 1131 to track it by default. 1132 """ 1133 self._dropout_mask_cache = backend.ContextValueCache( 1134 self._create_dropout_mask) 1135 self._recurrent_dropout_mask_cache = backend.ContextValueCache( 1136 self._create_recurrent_dropout_mask) 1137 1138 def reset_dropout_mask(self): 1139 """Reset the cached dropout masks if any. 1140 1141 This is important for the RNN layer to invoke this in it `call()` method so 1142 that the cached mask is cleared before calling the `cell.call()`. The mask 1143 should be cached across the timestep within the same batch, but shouldn't 1144 be cached between batches. Otherwise it will introduce unreasonable bias 1145 against certain index of data within the batch. 1146 """ 1147 self._dropout_mask_cache.clear() 1148 1149 def reset_recurrent_dropout_mask(self): 1150 """Reset the cached recurrent dropout masks if any. 1151 1152 This is important for the RNN layer to invoke this in it call() method so 1153 that the cached mask is cleared before calling the cell.call(). The mask 1154 should be cached across the timestep within the same batch, but shouldn't 1155 be cached between batches. Otherwise it will introduce unreasonable bias 1156 against certain index of data within the batch. 1157 """ 1158 self._recurrent_dropout_mask_cache.clear() 1159 1160 def _create_dropout_mask(self, inputs, training, count=1): 1161 return _generate_dropout_mask( 1162 array_ops.ones_like(inputs), 1163 self.dropout, 1164 training=training, 1165 count=count) 1166 1167 def _create_recurrent_dropout_mask(self, inputs, training, count=1): 1168 return _generate_dropout_mask( 1169 array_ops.ones_like(inputs), 1170 self.recurrent_dropout, 1171 training=training, 1172 count=count) 1173 1174 def get_dropout_mask_for_cell(self, inputs, training, count=1): 1175 """Get the dropout mask for RNN cell's input. 1176 1177 It will create mask based on context if there isn't any existing cached 1178 mask. If a new mask is generated, it will update the cache in the cell. 1179 1180 Args: 1181 inputs: The input tensor whose shape will be used to generate dropout 1182 mask. 1183 training: Boolean tensor, whether its in training mode, dropout will be 1184 ignored in non-training mode. 1185 count: Int, how many dropout mask will be generated. It is useful for cell 1186 that has internal weights fused together. 1187 Returns: 1188 List of mask tensor, generated or cached mask based on context. 1189 """ 1190 if self.dropout == 0: 1191 return None 1192 init_kwargs = dict(inputs=inputs, training=training, count=count) 1193 return self._dropout_mask_cache.setdefault(kwargs=init_kwargs) 1194 1195 def get_recurrent_dropout_mask_for_cell(self, inputs, training, count=1): 1196 """Get the recurrent dropout mask for RNN cell. 1197 1198 It will create mask based on context if there isn't any existing cached 1199 mask. If a new mask is generated, it will update the cache in the cell. 1200 1201 Args: 1202 inputs: The input tensor whose shape will be used to generate dropout 1203 mask. 1204 training: Boolean tensor, whether its in training mode, dropout will be 1205 ignored in non-training mode. 1206 count: Int, how many dropout mask will be generated. It is useful for cell 1207 that has internal weights fused together. 1208 Returns: 1209 List of mask tensor, generated or cached mask based on context. 1210 """ 1211 if self.recurrent_dropout == 0: 1212 return None 1213 init_kwargs = dict(inputs=inputs, training=training, count=count) 1214 return self._recurrent_dropout_mask_cache.setdefault(kwargs=init_kwargs) 1215 1216 def __getstate__(self): 1217 # Used for deepcopy. The caching can't be pickled by python, since it will 1218 # contain tensor and graph. 1219 state = super(DropoutRNNCellMixin, self).__getstate__() 1220 state.pop('_dropout_mask_cache', None) 1221 state.pop('_recurrent_dropout_mask_cache', None) 1222 return state 1223 1224 def __setstate__(self, state): 1225 state['_dropout_mask_cache'] = backend.ContextValueCache( 1226 self._create_dropout_mask) 1227 state['_recurrent_dropout_mask_cache'] = backend.ContextValueCache( 1228 self._create_recurrent_dropout_mask) 1229 super(DropoutRNNCellMixin, self).__setstate__(state) 1230 1231 1232@keras_export('keras.layers.SimpleRNNCell') 1233class SimpleRNNCell(DropoutRNNCellMixin, Layer): 1234 """Cell class for SimpleRNN. 1235 1236 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) 1237 for details about the usage of RNN API. 1238 1239 This class processes one step within the whole time sequence input, whereas 1240 `tf.keras.layer.SimpleRNN` processes the whole sequence. 1241 1242 Args: 1243 units: Positive integer, dimensionality of the output space. 1244 activation: Activation function to use. 1245 Default: hyperbolic tangent (`tanh`). 1246 If you pass `None`, no activation is applied 1247 (ie. "linear" activation: `a(x) = x`). 1248 use_bias: Boolean, (default `True`), whether the layer uses a bias vector. 1249 kernel_initializer: Initializer for the `kernel` weights matrix, 1250 used for the linear transformation of the inputs. Default: 1251 `glorot_uniform`. 1252 recurrent_initializer: Initializer for the `recurrent_kernel` 1253 weights matrix, used for the linear transformation of the recurrent state. 1254 Default: `orthogonal`. 1255 bias_initializer: Initializer for the bias vector. Default: `zeros`. 1256 kernel_regularizer: Regularizer function applied to the `kernel` weights 1257 matrix. Default: `None`. 1258 recurrent_regularizer: Regularizer function applied to the 1259 `recurrent_kernel` weights matrix. Default: `None`. 1260 bias_regularizer: Regularizer function applied to the bias vector. Default: 1261 `None`. 1262 kernel_constraint: Constraint function applied to the `kernel` weights 1263 matrix. Default: `None`. 1264 recurrent_constraint: Constraint function applied to the `recurrent_kernel` 1265 weights matrix. Default: `None`. 1266 bias_constraint: Constraint function applied to the bias vector. Default: 1267 `None`. 1268 dropout: Float between 0 and 1. Fraction of the units to drop for the linear 1269 transformation of the inputs. Default: 0. 1270 recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for 1271 the linear transformation of the recurrent state. Default: 0. 1272 1273 Call arguments: 1274 inputs: A 2D tensor, with shape of `[batch, feature]`. 1275 states: A 2D tensor with shape of `[batch, units]`, which is the state from 1276 the previous time step. For timestep 0, the initial state provided by user 1277 will be feed to cell. 1278 training: Python boolean indicating whether the layer should behave in 1279 training mode or in inference mode. Only relevant when `dropout` or 1280 `recurrent_dropout` is used. 1281 1282 Examples: 1283 1284 ```python 1285 inputs = np.random.random([32, 10, 8]).astype(np.float32) 1286 rnn = tf.keras.layers.RNN(tf.keras.layers.SimpleRNNCell(4)) 1287 1288 output = rnn(inputs) # The output has shape `[32, 4]`. 1289 1290 rnn = tf.keras.layers.RNN( 1291 tf.keras.layers.SimpleRNNCell(4), 1292 return_sequences=True, 1293 return_state=True) 1294 1295 # whole_sequence_output has shape `[32, 10, 4]`. 1296 # final_state has shape `[32, 4]`. 1297 whole_sequence_output, final_state = rnn(inputs) 1298 ``` 1299 """ 1300 1301 def __init__(self, 1302 units, 1303 activation='tanh', 1304 use_bias=True, 1305 kernel_initializer='glorot_uniform', 1306 recurrent_initializer='orthogonal', 1307 bias_initializer='zeros', 1308 kernel_regularizer=None, 1309 recurrent_regularizer=None, 1310 bias_regularizer=None, 1311 kernel_constraint=None, 1312 recurrent_constraint=None, 1313 bias_constraint=None, 1314 dropout=0., 1315 recurrent_dropout=0., 1316 **kwargs): 1317 if units < 0: 1318 raise ValueError(f'Received an invalid value for units, expected ' 1319 f'a positive integer, got {units}.') 1320 # By default use cached variable under v2 mode, see b/143699808. 1321 if ops.executing_eagerly_outside_functions(): 1322 self._enable_caching_device = kwargs.pop('enable_caching_device', True) 1323 else: 1324 self._enable_caching_device = kwargs.pop('enable_caching_device', False) 1325 super(SimpleRNNCell, self).__init__(**kwargs) 1326 self.units = units 1327 self.activation = activations.get(activation) 1328 self.use_bias = use_bias 1329 1330 self.kernel_initializer = initializers.get(kernel_initializer) 1331 self.recurrent_initializer = initializers.get(recurrent_initializer) 1332 self.bias_initializer = initializers.get(bias_initializer) 1333 1334 self.kernel_regularizer = regularizers.get(kernel_regularizer) 1335 self.recurrent_regularizer = regularizers.get(recurrent_regularizer) 1336 self.bias_regularizer = regularizers.get(bias_regularizer) 1337 1338 self.kernel_constraint = constraints.get(kernel_constraint) 1339 self.recurrent_constraint = constraints.get(recurrent_constraint) 1340 self.bias_constraint = constraints.get(bias_constraint) 1341 1342 self.dropout = min(1., max(0., dropout)) 1343 self.recurrent_dropout = min(1., max(0., recurrent_dropout)) 1344 self.state_size = self.units 1345 self.output_size = self.units 1346 1347 @tf_utils.shape_type_conversion 1348 def build(self, input_shape): 1349 default_caching_device = _caching_device(self) 1350 self.kernel = self.add_weight( 1351 shape=(input_shape[-1], self.units), 1352 name='kernel', 1353 initializer=self.kernel_initializer, 1354 regularizer=self.kernel_regularizer, 1355 constraint=self.kernel_constraint, 1356 caching_device=default_caching_device) 1357 self.recurrent_kernel = self.add_weight( 1358 shape=(self.units, self.units), 1359 name='recurrent_kernel', 1360 initializer=self.recurrent_initializer, 1361 regularizer=self.recurrent_regularizer, 1362 constraint=self.recurrent_constraint, 1363 caching_device=default_caching_device) 1364 if self.use_bias: 1365 self.bias = self.add_weight( 1366 shape=(self.units,), 1367 name='bias', 1368 initializer=self.bias_initializer, 1369 regularizer=self.bias_regularizer, 1370 constraint=self.bias_constraint, 1371 caching_device=default_caching_device) 1372 else: 1373 self.bias = None 1374 self.built = True 1375 1376 def call(self, inputs, states, training=None): 1377 prev_output = states[0] if nest.is_nested(states) else states 1378 dp_mask = self.get_dropout_mask_for_cell(inputs, training) 1379 rec_dp_mask = self.get_recurrent_dropout_mask_for_cell( 1380 prev_output, training) 1381 1382 if dp_mask is not None: 1383 h = backend.dot(inputs * dp_mask, self.kernel) 1384 else: 1385 h = backend.dot(inputs, self.kernel) 1386 if self.bias is not None: 1387 h = backend.bias_add(h, self.bias) 1388 1389 if rec_dp_mask is not None: 1390 prev_output = prev_output * rec_dp_mask 1391 output = h + backend.dot(prev_output, self.recurrent_kernel) 1392 if self.activation is not None: 1393 output = self.activation(output) 1394 1395 new_state = [output] if nest.is_nested(states) else output 1396 return output, new_state 1397 1398 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 1399 return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype) 1400 1401 def get_config(self): 1402 config = { 1403 'units': 1404 self.units, 1405 'activation': 1406 activations.serialize(self.activation), 1407 'use_bias': 1408 self.use_bias, 1409 'kernel_initializer': 1410 initializers.serialize(self.kernel_initializer), 1411 'recurrent_initializer': 1412 initializers.serialize(self.recurrent_initializer), 1413 'bias_initializer': 1414 initializers.serialize(self.bias_initializer), 1415 'kernel_regularizer': 1416 regularizers.serialize(self.kernel_regularizer), 1417 'recurrent_regularizer': 1418 regularizers.serialize(self.recurrent_regularizer), 1419 'bias_regularizer': 1420 regularizers.serialize(self.bias_regularizer), 1421 'kernel_constraint': 1422 constraints.serialize(self.kernel_constraint), 1423 'recurrent_constraint': 1424 constraints.serialize(self.recurrent_constraint), 1425 'bias_constraint': 1426 constraints.serialize(self.bias_constraint), 1427 'dropout': 1428 self.dropout, 1429 'recurrent_dropout': 1430 self.recurrent_dropout 1431 } 1432 config.update(_config_for_enable_caching_device(self)) 1433 base_config = super(SimpleRNNCell, self).get_config() 1434 return dict(list(base_config.items()) + list(config.items())) 1435 1436 1437@keras_export('keras.layers.SimpleRNN') 1438class SimpleRNN(RNN): 1439 """Fully-connected RNN where the output is to be fed back to input. 1440 1441 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) 1442 for details about the usage of RNN API. 1443 1444 Args: 1445 units: Positive integer, dimensionality of the output space. 1446 activation: Activation function to use. 1447 Default: hyperbolic tangent (`tanh`). 1448 If you pass None, no activation is applied 1449 (ie. "linear" activation: `a(x) = x`). 1450 use_bias: Boolean, (default `True`), whether the layer uses a bias vector. 1451 kernel_initializer: Initializer for the `kernel` weights matrix, 1452 used for the linear transformation of the inputs. Default: 1453 `glorot_uniform`. 1454 recurrent_initializer: Initializer for the `recurrent_kernel` 1455 weights matrix, used for the linear transformation of the recurrent state. 1456 Default: `orthogonal`. 1457 bias_initializer: Initializer for the bias vector. Default: `zeros`. 1458 kernel_regularizer: Regularizer function applied to the `kernel` weights 1459 matrix. Default: `None`. 1460 recurrent_regularizer: Regularizer function applied to the 1461 `recurrent_kernel` weights matrix. Default: `None`. 1462 bias_regularizer: Regularizer function applied to the bias vector. Default: 1463 `None`. 1464 activity_regularizer: Regularizer function applied to the output of the 1465 layer (its "activation"). Default: `None`. 1466 kernel_constraint: Constraint function applied to the `kernel` weights 1467 matrix. Default: `None`. 1468 recurrent_constraint: Constraint function applied to the `recurrent_kernel` 1469 weights matrix. Default: `None`. 1470 bias_constraint: Constraint function applied to the bias vector. Default: 1471 `None`. 1472 dropout: Float between 0 and 1. 1473 Fraction of the units to drop for the linear transformation of the inputs. 1474 Default: 0. 1475 recurrent_dropout: Float between 0 and 1. 1476 Fraction of the units to drop for the linear transformation of the 1477 recurrent state. Default: 0. 1478 return_sequences: Boolean. Whether to return the last output 1479 in the output sequence, or the full sequence. Default: `False`. 1480 return_state: Boolean. Whether to return the last state 1481 in addition to the output. Default: `False` 1482 go_backwards: Boolean (default False). 1483 If True, process the input sequence backwards and return the 1484 reversed sequence. 1485 stateful: Boolean (default False). If True, the last state 1486 for each sample at index i in a batch will be used as initial 1487 state for the sample of index i in the following batch. 1488 unroll: Boolean (default False). 1489 If True, the network will be unrolled, 1490 else a symbolic loop will be used. 1491 Unrolling can speed-up a RNN, 1492 although it tends to be more memory-intensive. 1493 Unrolling is only suitable for short sequences. 1494 1495 Call arguments: 1496 inputs: A 3D tensor, with shape `[batch, timesteps, feature]`. 1497 mask: Binary tensor of shape `[batch, timesteps]` indicating whether 1498 a given timestep should be masked. An individual `True` entry indicates 1499 that the corresponding timestep should be utilized, while a `False` entry 1500 indicates that the corresponding timestep should be ignored. 1501 training: Python boolean indicating whether the layer should behave in 1502 training mode or in inference mode. This argument is passed to the cell 1503 when calling it. This is only relevant if `dropout` or 1504 `recurrent_dropout` is used. 1505 initial_state: List of initial state tensors to be passed to the first 1506 call of the cell. 1507 1508 Examples: 1509 1510 ```python 1511 inputs = np.random.random([32, 10, 8]).astype(np.float32) 1512 simple_rnn = tf.keras.layers.SimpleRNN(4) 1513 1514 output = simple_rnn(inputs) # The output has shape `[32, 4]`. 1515 1516 simple_rnn = tf.keras.layers.SimpleRNN( 1517 4, return_sequences=True, return_state=True) 1518 1519 # whole_sequence_output has shape `[32, 10, 4]`. 1520 # final_state has shape `[32, 4]`. 1521 whole_sequence_output, final_state = simple_rnn(inputs) 1522 ``` 1523 """ 1524 1525 def __init__(self, 1526 units, 1527 activation='tanh', 1528 use_bias=True, 1529 kernel_initializer='glorot_uniform', 1530 recurrent_initializer='orthogonal', 1531 bias_initializer='zeros', 1532 kernel_regularizer=None, 1533 recurrent_regularizer=None, 1534 bias_regularizer=None, 1535 activity_regularizer=None, 1536 kernel_constraint=None, 1537 recurrent_constraint=None, 1538 bias_constraint=None, 1539 dropout=0., 1540 recurrent_dropout=0., 1541 return_sequences=False, 1542 return_state=False, 1543 go_backwards=False, 1544 stateful=False, 1545 unroll=False, 1546 **kwargs): 1547 if 'implementation' in kwargs: 1548 kwargs.pop('implementation') 1549 logging.warning('The `implementation` argument ' 1550 'in `SimpleRNN` has been deprecated. ' 1551 'Please remove it from your layer call.') 1552 if 'enable_caching_device' in kwargs: 1553 cell_kwargs = {'enable_caching_device': 1554 kwargs.pop('enable_caching_device')} 1555 else: 1556 cell_kwargs = {} 1557 cell = SimpleRNNCell( 1558 units, 1559 activation=activation, 1560 use_bias=use_bias, 1561 kernel_initializer=kernel_initializer, 1562 recurrent_initializer=recurrent_initializer, 1563 bias_initializer=bias_initializer, 1564 kernel_regularizer=kernel_regularizer, 1565 recurrent_regularizer=recurrent_regularizer, 1566 bias_regularizer=bias_regularizer, 1567 kernel_constraint=kernel_constraint, 1568 recurrent_constraint=recurrent_constraint, 1569 bias_constraint=bias_constraint, 1570 dropout=dropout, 1571 recurrent_dropout=recurrent_dropout, 1572 dtype=kwargs.get('dtype'), 1573 trainable=kwargs.get('trainable', True), 1574 **cell_kwargs) 1575 super(SimpleRNN, self).__init__( 1576 cell, 1577 return_sequences=return_sequences, 1578 return_state=return_state, 1579 go_backwards=go_backwards, 1580 stateful=stateful, 1581 unroll=unroll, 1582 **kwargs) 1583 self.activity_regularizer = regularizers.get(activity_regularizer) 1584 self.input_spec = [InputSpec(ndim=3)] 1585 1586 def call(self, inputs, mask=None, training=None, initial_state=None): 1587 return super(SimpleRNN, self).call( 1588 inputs, mask=mask, training=training, initial_state=initial_state) 1589 1590 @property 1591 def units(self): 1592 return self.cell.units 1593 1594 @property 1595 def activation(self): 1596 return self.cell.activation 1597 1598 @property 1599 def use_bias(self): 1600 return self.cell.use_bias 1601 1602 @property 1603 def kernel_initializer(self): 1604 return self.cell.kernel_initializer 1605 1606 @property 1607 def recurrent_initializer(self): 1608 return self.cell.recurrent_initializer 1609 1610 @property 1611 def bias_initializer(self): 1612 return self.cell.bias_initializer 1613 1614 @property 1615 def kernel_regularizer(self): 1616 return self.cell.kernel_regularizer 1617 1618 @property 1619 def recurrent_regularizer(self): 1620 return self.cell.recurrent_regularizer 1621 1622 @property 1623 def bias_regularizer(self): 1624 return self.cell.bias_regularizer 1625 1626 @property 1627 def kernel_constraint(self): 1628 return self.cell.kernel_constraint 1629 1630 @property 1631 def recurrent_constraint(self): 1632 return self.cell.recurrent_constraint 1633 1634 @property 1635 def bias_constraint(self): 1636 return self.cell.bias_constraint 1637 1638 @property 1639 def dropout(self): 1640 return self.cell.dropout 1641 1642 @property 1643 def recurrent_dropout(self): 1644 return self.cell.recurrent_dropout 1645 1646 def get_config(self): 1647 config = { 1648 'units': 1649 self.units, 1650 'activation': 1651 activations.serialize(self.activation), 1652 'use_bias': 1653 self.use_bias, 1654 'kernel_initializer': 1655 initializers.serialize(self.kernel_initializer), 1656 'recurrent_initializer': 1657 initializers.serialize(self.recurrent_initializer), 1658 'bias_initializer': 1659 initializers.serialize(self.bias_initializer), 1660 'kernel_regularizer': 1661 regularizers.serialize(self.kernel_regularizer), 1662 'recurrent_regularizer': 1663 regularizers.serialize(self.recurrent_regularizer), 1664 'bias_regularizer': 1665 regularizers.serialize(self.bias_regularizer), 1666 'activity_regularizer': 1667 regularizers.serialize(self.activity_regularizer), 1668 'kernel_constraint': 1669 constraints.serialize(self.kernel_constraint), 1670 'recurrent_constraint': 1671 constraints.serialize(self.recurrent_constraint), 1672 'bias_constraint': 1673 constraints.serialize(self.bias_constraint), 1674 'dropout': 1675 self.dropout, 1676 'recurrent_dropout': 1677 self.recurrent_dropout 1678 } 1679 base_config = super(SimpleRNN, self).get_config() 1680 config.update(_config_for_enable_caching_device(self.cell)) 1681 del base_config['cell'] 1682 return dict(list(base_config.items()) + list(config.items())) 1683 1684 @classmethod 1685 def from_config(cls, config): 1686 if 'implementation' in config: 1687 config.pop('implementation') 1688 return cls(**config) 1689 1690 1691@keras_export(v1=['keras.layers.GRUCell']) 1692class GRUCell(DropoutRNNCellMixin, Layer): 1693 """Cell class for the GRU layer. 1694 1695 Args: 1696 units: Positive integer, dimensionality of the output space. 1697 activation: Activation function to use. 1698 Default: hyperbolic tangent (`tanh`). 1699 If you pass None, no activation is applied 1700 (ie. "linear" activation: `a(x) = x`). 1701 recurrent_activation: Activation function to use 1702 for the recurrent step. 1703 Default: hard sigmoid (`hard_sigmoid`). 1704 If you pass `None`, no activation is applied 1705 (ie. "linear" activation: `a(x) = x`). 1706 use_bias: Boolean, whether the layer uses a bias vector. 1707 kernel_initializer: Initializer for the `kernel` weights matrix, 1708 used for the linear transformation of the inputs. 1709 recurrent_initializer: Initializer for the `recurrent_kernel` 1710 weights matrix, 1711 used for the linear transformation of the recurrent state. 1712 bias_initializer: Initializer for the bias vector. 1713 kernel_regularizer: Regularizer function applied to 1714 the `kernel` weights matrix. 1715 recurrent_regularizer: Regularizer function applied to 1716 the `recurrent_kernel` weights matrix. 1717 bias_regularizer: Regularizer function applied to the bias vector. 1718 kernel_constraint: Constraint function applied to 1719 the `kernel` weights matrix. 1720 recurrent_constraint: Constraint function applied to 1721 the `recurrent_kernel` weights matrix. 1722 bias_constraint: Constraint function applied to the bias vector. 1723 dropout: Float between 0 and 1. 1724 Fraction of the units to drop for the linear transformation of the inputs. 1725 recurrent_dropout: Float between 0 and 1. 1726 Fraction of the units to drop for 1727 the linear transformation of the recurrent state. 1728 reset_after: GRU convention (whether to apply reset gate after or 1729 before matrix multiplication). False = "before" (default), 1730 True = "after" (CuDNN compatible). 1731 1732 Call arguments: 1733 inputs: A 2D tensor. 1734 states: List of state tensors corresponding to the previous timestep. 1735 training: Python boolean indicating whether the layer should behave in 1736 training mode or in inference mode. Only relevant when `dropout` or 1737 `recurrent_dropout` is used. 1738 """ 1739 1740 def __init__(self, 1741 units, 1742 activation='tanh', 1743 recurrent_activation='hard_sigmoid', 1744 use_bias=True, 1745 kernel_initializer='glorot_uniform', 1746 recurrent_initializer='orthogonal', 1747 bias_initializer='zeros', 1748 kernel_regularizer=None, 1749 recurrent_regularizer=None, 1750 bias_regularizer=None, 1751 kernel_constraint=None, 1752 recurrent_constraint=None, 1753 bias_constraint=None, 1754 dropout=0., 1755 recurrent_dropout=0., 1756 reset_after=False, 1757 **kwargs): 1758 if units < 0: 1759 raise ValueError(f'Received an invalid value for units, expected ' 1760 f'a positive integer, got {units}.') 1761 # By default use cached variable under v2 mode, see b/143699808. 1762 if ops.executing_eagerly_outside_functions(): 1763 self._enable_caching_device = kwargs.pop('enable_caching_device', True) 1764 else: 1765 self._enable_caching_device = kwargs.pop('enable_caching_device', False) 1766 super(GRUCell, self).__init__(**kwargs) 1767 self.units = units 1768 self.activation = activations.get(activation) 1769 self.recurrent_activation = activations.get(recurrent_activation) 1770 self.use_bias = use_bias 1771 1772 self.kernel_initializer = initializers.get(kernel_initializer) 1773 self.recurrent_initializer = initializers.get(recurrent_initializer) 1774 self.bias_initializer = initializers.get(bias_initializer) 1775 1776 self.kernel_regularizer = regularizers.get(kernel_regularizer) 1777 self.recurrent_regularizer = regularizers.get(recurrent_regularizer) 1778 self.bias_regularizer = regularizers.get(bias_regularizer) 1779 1780 self.kernel_constraint = constraints.get(kernel_constraint) 1781 self.recurrent_constraint = constraints.get(recurrent_constraint) 1782 self.bias_constraint = constraints.get(bias_constraint) 1783 1784 self.dropout = min(1., max(0., dropout)) 1785 self.recurrent_dropout = min(1., max(0., recurrent_dropout)) 1786 1787 implementation = kwargs.pop('implementation', 1) 1788 if self.recurrent_dropout != 0 and implementation != 1: 1789 logging.debug(RECURRENT_DROPOUT_WARNING_MSG) 1790 self.implementation = 1 1791 else: 1792 self.implementation = implementation 1793 self.reset_after = reset_after 1794 self.state_size = self.units 1795 self.output_size = self.units 1796 1797 @tf_utils.shape_type_conversion 1798 def build(self, input_shape): 1799 input_dim = input_shape[-1] 1800 default_caching_device = _caching_device(self) 1801 self.kernel = self.add_weight( 1802 shape=(input_dim, self.units * 3), 1803 name='kernel', 1804 initializer=self.kernel_initializer, 1805 regularizer=self.kernel_regularizer, 1806 constraint=self.kernel_constraint, 1807 caching_device=default_caching_device) 1808 self.recurrent_kernel = self.add_weight( 1809 shape=(self.units, self.units * 3), 1810 name='recurrent_kernel', 1811 initializer=self.recurrent_initializer, 1812 regularizer=self.recurrent_regularizer, 1813 constraint=self.recurrent_constraint, 1814 caching_device=default_caching_device) 1815 1816 if self.use_bias: 1817 if not self.reset_after: 1818 bias_shape = (3 * self.units,) 1819 else: 1820 # separate biases for input and recurrent kernels 1821 # Note: the shape is intentionally different from CuDNNGRU biases 1822 # `(2 * 3 * self.units,)`, so that we can distinguish the classes 1823 # when loading and converting saved weights. 1824 bias_shape = (2, 3 * self.units) 1825 self.bias = self.add_weight(shape=bias_shape, 1826 name='bias', 1827 initializer=self.bias_initializer, 1828 regularizer=self.bias_regularizer, 1829 constraint=self.bias_constraint, 1830 caching_device=default_caching_device) 1831 else: 1832 self.bias = None 1833 self.built = True 1834 1835 def call(self, inputs, states, training=None): 1836 h_tm1 = states[0] if nest.is_nested(states) else states # previous memory 1837 1838 dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=3) 1839 rec_dp_mask = self.get_recurrent_dropout_mask_for_cell( 1840 h_tm1, training, count=3) 1841 1842 if self.use_bias: 1843 if not self.reset_after: 1844 input_bias, recurrent_bias = self.bias, None 1845 else: 1846 input_bias, recurrent_bias = array_ops.unstack(self.bias) 1847 1848 if self.implementation == 1: 1849 if 0. < self.dropout < 1.: 1850 inputs_z = inputs * dp_mask[0] 1851 inputs_r = inputs * dp_mask[1] 1852 inputs_h = inputs * dp_mask[2] 1853 else: 1854 inputs_z = inputs 1855 inputs_r = inputs 1856 inputs_h = inputs 1857 1858 x_z = backend.dot(inputs_z, self.kernel[:, :self.units]) 1859 x_r = backend.dot(inputs_r, self.kernel[:, self.units:self.units * 2]) 1860 x_h = backend.dot(inputs_h, self.kernel[:, self.units * 2:]) 1861 1862 if self.use_bias: 1863 x_z = backend.bias_add(x_z, input_bias[:self.units]) 1864 x_r = backend.bias_add(x_r, input_bias[self.units: self.units * 2]) 1865 x_h = backend.bias_add(x_h, input_bias[self.units * 2:]) 1866 1867 if 0. < self.recurrent_dropout < 1.: 1868 h_tm1_z = h_tm1 * rec_dp_mask[0] 1869 h_tm1_r = h_tm1 * rec_dp_mask[1] 1870 h_tm1_h = h_tm1 * rec_dp_mask[2] 1871 else: 1872 h_tm1_z = h_tm1 1873 h_tm1_r = h_tm1 1874 h_tm1_h = h_tm1 1875 1876 recurrent_z = backend.dot(h_tm1_z, self.recurrent_kernel[:, :self.units]) 1877 recurrent_r = backend.dot( 1878 h_tm1_r, self.recurrent_kernel[:, self.units:self.units * 2]) 1879 if self.reset_after and self.use_bias: 1880 recurrent_z = backend.bias_add(recurrent_z, recurrent_bias[:self.units]) 1881 recurrent_r = backend.bias_add( 1882 recurrent_r, recurrent_bias[self.units:self.units * 2]) 1883 1884 z = self.recurrent_activation(x_z + recurrent_z) 1885 r = self.recurrent_activation(x_r + recurrent_r) 1886 1887 # reset gate applied after/before matrix multiplication 1888 if self.reset_after: 1889 recurrent_h = backend.dot( 1890 h_tm1_h, self.recurrent_kernel[:, self.units * 2:]) 1891 if self.use_bias: 1892 recurrent_h = backend.bias_add( 1893 recurrent_h, recurrent_bias[self.units * 2:]) 1894 recurrent_h = r * recurrent_h 1895 else: 1896 recurrent_h = backend.dot( 1897 r * h_tm1_h, self.recurrent_kernel[:, self.units * 2:]) 1898 1899 hh = self.activation(x_h + recurrent_h) 1900 else: 1901 if 0. < self.dropout < 1.: 1902 inputs = inputs * dp_mask[0] 1903 1904 # inputs projected by all gate matrices at once 1905 matrix_x = backend.dot(inputs, self.kernel) 1906 if self.use_bias: 1907 # biases: bias_z_i, bias_r_i, bias_h_i 1908 matrix_x = backend.bias_add(matrix_x, input_bias) 1909 1910 x_z, x_r, x_h = array_ops.split(matrix_x, 3, axis=-1) 1911 1912 if self.reset_after: 1913 # hidden state projected by all gate matrices at once 1914 matrix_inner = backend.dot(h_tm1, self.recurrent_kernel) 1915 if self.use_bias: 1916 matrix_inner = backend.bias_add(matrix_inner, recurrent_bias) 1917 else: 1918 # hidden state projected separately for update/reset and new 1919 matrix_inner = backend.dot( 1920 h_tm1, self.recurrent_kernel[:, :2 * self.units]) 1921 1922 recurrent_z, recurrent_r, recurrent_h = array_ops.split( 1923 matrix_inner, [self.units, self.units, -1], axis=-1) 1924 1925 z = self.recurrent_activation(x_z + recurrent_z) 1926 r = self.recurrent_activation(x_r + recurrent_r) 1927 1928 if self.reset_after: 1929 recurrent_h = r * recurrent_h 1930 else: 1931 recurrent_h = backend.dot( 1932 r * h_tm1, self.recurrent_kernel[:, 2 * self.units:]) 1933 1934 hh = self.activation(x_h + recurrent_h) 1935 # previous and candidate state mixed by update gate 1936 h = z * h_tm1 + (1 - z) * hh 1937 new_state = [h] if nest.is_nested(states) else h 1938 return h, new_state 1939 1940 def get_config(self): 1941 config = { 1942 'units': self.units, 1943 'activation': activations.serialize(self.activation), 1944 'recurrent_activation': 1945 activations.serialize(self.recurrent_activation), 1946 'use_bias': self.use_bias, 1947 'kernel_initializer': initializers.serialize(self.kernel_initializer), 1948 'recurrent_initializer': 1949 initializers.serialize(self.recurrent_initializer), 1950 'bias_initializer': initializers.serialize(self.bias_initializer), 1951 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 1952 'recurrent_regularizer': 1953 regularizers.serialize(self.recurrent_regularizer), 1954 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 1955 'kernel_constraint': constraints.serialize(self.kernel_constraint), 1956 'recurrent_constraint': 1957 constraints.serialize(self.recurrent_constraint), 1958 'bias_constraint': constraints.serialize(self.bias_constraint), 1959 'dropout': self.dropout, 1960 'recurrent_dropout': self.recurrent_dropout, 1961 'implementation': self.implementation, 1962 'reset_after': self.reset_after 1963 } 1964 config.update(_config_for_enable_caching_device(self)) 1965 base_config = super(GRUCell, self).get_config() 1966 return dict(list(base_config.items()) + list(config.items())) 1967 1968 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 1969 return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype) 1970 1971 1972@keras_export(v1=['keras.layers.GRU']) 1973class GRU(RNN): 1974 """Gated Recurrent Unit - Cho et al. 2014. 1975 1976 There are two variants. The default one is based on 1406.1078v3 and 1977 has reset gate applied to hidden state before matrix multiplication. The 1978 other one is based on original 1406.1078v1 and has the order reversed. 1979 1980 The second variant is compatible with CuDNNGRU (GPU-only) and allows 1981 inference on CPU. Thus it has separate biases for `kernel` and 1982 `recurrent_kernel`. Use `'reset_after'=True` and 1983 `recurrent_activation='sigmoid'`. 1984 1985 Args: 1986 units: Positive integer, dimensionality of the output space. 1987 activation: Activation function to use. 1988 Default: hyperbolic tangent (`tanh`). 1989 If you pass `None`, no activation is applied 1990 (ie. "linear" activation: `a(x) = x`). 1991 recurrent_activation: Activation function to use 1992 for the recurrent step. 1993 Default: hard sigmoid (`hard_sigmoid`). 1994 If you pass `None`, no activation is applied 1995 (ie. "linear" activation: `a(x) = x`). 1996 use_bias: Boolean, whether the layer uses a bias vector. 1997 kernel_initializer: Initializer for the `kernel` weights matrix, 1998 used for the linear transformation of the inputs. 1999 recurrent_initializer: Initializer for the `recurrent_kernel` 2000 weights matrix, used for the linear transformation of the recurrent state. 2001 bias_initializer: Initializer for the bias vector. 2002 kernel_regularizer: Regularizer function applied to 2003 the `kernel` weights matrix. 2004 recurrent_regularizer: Regularizer function applied to 2005 the `recurrent_kernel` weights matrix. 2006 bias_regularizer: Regularizer function applied to the bias vector. 2007 activity_regularizer: Regularizer function applied to 2008 the output of the layer (its "activation").. 2009 kernel_constraint: Constraint function applied to 2010 the `kernel` weights matrix. 2011 recurrent_constraint: Constraint function applied to 2012 the `recurrent_kernel` weights matrix. 2013 bias_constraint: Constraint function applied to the bias vector. 2014 dropout: Float between 0 and 1. 2015 Fraction of the units to drop for 2016 the linear transformation of the inputs. 2017 recurrent_dropout: Float between 0 and 1. 2018 Fraction of the units to drop for 2019 the linear transformation of the recurrent state. 2020 return_sequences: Boolean. Whether to return the last output 2021 in the output sequence, or the full sequence. 2022 return_state: Boolean. Whether to return the last state 2023 in addition to the output. 2024 go_backwards: Boolean (default False). 2025 If True, process the input sequence backwards and return the 2026 reversed sequence. 2027 stateful: Boolean (default False). If True, the last state 2028 for each sample at index i in a batch will be used as initial 2029 state for the sample of index i in the following batch. 2030 unroll: Boolean (default False). 2031 If True, the network will be unrolled, 2032 else a symbolic loop will be used. 2033 Unrolling can speed-up a RNN, 2034 although it tends to be more memory-intensive. 2035 Unrolling is only suitable for short sequences. 2036 time_major: The shape format of the `inputs` and `outputs` tensors. 2037 If True, the inputs and outputs will be in shape 2038 `(timesteps, batch, ...)`, whereas in the False case, it will be 2039 `(batch, timesteps, ...)`. Using `time_major = True` is a bit more 2040 efficient because it avoids transposes at the beginning and end of the 2041 RNN calculation. However, most TensorFlow data is batch-major, so by 2042 default this function accepts input and emits output in batch-major 2043 form. 2044 reset_after: GRU convention (whether to apply reset gate after or 2045 before matrix multiplication). False = "before" (default), 2046 True = "after" (CuDNN compatible). 2047 2048 Call arguments: 2049 inputs: A 3D tensor. 2050 mask: Binary tensor of shape `(samples, timesteps)` indicating whether 2051 a given timestep should be masked. An individual `True` entry indicates 2052 that the corresponding timestep should be utilized, while a `False` 2053 entry indicates that the corresponding timestep should be ignored. 2054 training: Python boolean indicating whether the layer should behave in 2055 training mode or in inference mode. This argument is passed to the cell 2056 when calling it. This is only relevant if `dropout` or 2057 `recurrent_dropout` is used. 2058 initial_state: List of initial state tensors to be passed to the first 2059 call of the cell. 2060 """ 2061 2062 def __init__(self, 2063 units, 2064 activation='tanh', 2065 recurrent_activation='hard_sigmoid', 2066 use_bias=True, 2067 kernel_initializer='glorot_uniform', 2068 recurrent_initializer='orthogonal', 2069 bias_initializer='zeros', 2070 kernel_regularizer=None, 2071 recurrent_regularizer=None, 2072 bias_regularizer=None, 2073 activity_regularizer=None, 2074 kernel_constraint=None, 2075 recurrent_constraint=None, 2076 bias_constraint=None, 2077 dropout=0., 2078 recurrent_dropout=0., 2079 return_sequences=False, 2080 return_state=False, 2081 go_backwards=False, 2082 stateful=False, 2083 unroll=False, 2084 reset_after=False, 2085 **kwargs): 2086 implementation = kwargs.pop('implementation', 1) 2087 if implementation == 0: 2088 logging.warning('`implementation=0` has been deprecated, ' 2089 'and now defaults to `implementation=1`.' 2090 'Please update your layer call.') 2091 if 'enable_caching_device' in kwargs: 2092 cell_kwargs = {'enable_caching_device': 2093 kwargs.pop('enable_caching_device')} 2094 else: 2095 cell_kwargs = {} 2096 cell = GRUCell( 2097 units, 2098 activation=activation, 2099 recurrent_activation=recurrent_activation, 2100 use_bias=use_bias, 2101 kernel_initializer=kernel_initializer, 2102 recurrent_initializer=recurrent_initializer, 2103 bias_initializer=bias_initializer, 2104 kernel_regularizer=kernel_regularizer, 2105 recurrent_regularizer=recurrent_regularizer, 2106 bias_regularizer=bias_regularizer, 2107 kernel_constraint=kernel_constraint, 2108 recurrent_constraint=recurrent_constraint, 2109 bias_constraint=bias_constraint, 2110 dropout=dropout, 2111 recurrent_dropout=recurrent_dropout, 2112 implementation=implementation, 2113 reset_after=reset_after, 2114 dtype=kwargs.get('dtype'), 2115 trainable=kwargs.get('trainable', True), 2116 **cell_kwargs) 2117 super(GRU, self).__init__( 2118 cell, 2119 return_sequences=return_sequences, 2120 return_state=return_state, 2121 go_backwards=go_backwards, 2122 stateful=stateful, 2123 unroll=unroll, 2124 **kwargs) 2125 self.activity_regularizer = regularizers.get(activity_regularizer) 2126 self.input_spec = [InputSpec(ndim=3)] 2127 2128 def call(self, inputs, mask=None, training=None, initial_state=None): 2129 return super(GRU, self).call( 2130 inputs, mask=mask, training=training, initial_state=initial_state) 2131 2132 @property 2133 def units(self): 2134 return self.cell.units 2135 2136 @property 2137 def activation(self): 2138 return self.cell.activation 2139 2140 @property 2141 def recurrent_activation(self): 2142 return self.cell.recurrent_activation 2143 2144 @property 2145 def use_bias(self): 2146 return self.cell.use_bias 2147 2148 @property 2149 def kernel_initializer(self): 2150 return self.cell.kernel_initializer 2151 2152 @property 2153 def recurrent_initializer(self): 2154 return self.cell.recurrent_initializer 2155 2156 @property 2157 def bias_initializer(self): 2158 return self.cell.bias_initializer 2159 2160 @property 2161 def kernel_regularizer(self): 2162 return self.cell.kernel_regularizer 2163 2164 @property 2165 def recurrent_regularizer(self): 2166 return self.cell.recurrent_regularizer 2167 2168 @property 2169 def bias_regularizer(self): 2170 return self.cell.bias_regularizer 2171 2172 @property 2173 def kernel_constraint(self): 2174 return self.cell.kernel_constraint 2175 2176 @property 2177 def recurrent_constraint(self): 2178 return self.cell.recurrent_constraint 2179 2180 @property 2181 def bias_constraint(self): 2182 return self.cell.bias_constraint 2183 2184 @property 2185 def dropout(self): 2186 return self.cell.dropout 2187 2188 @property 2189 def recurrent_dropout(self): 2190 return self.cell.recurrent_dropout 2191 2192 @property 2193 def implementation(self): 2194 return self.cell.implementation 2195 2196 @property 2197 def reset_after(self): 2198 return self.cell.reset_after 2199 2200 def get_config(self): 2201 config = { 2202 'units': 2203 self.units, 2204 'activation': 2205 activations.serialize(self.activation), 2206 'recurrent_activation': 2207 activations.serialize(self.recurrent_activation), 2208 'use_bias': 2209 self.use_bias, 2210 'kernel_initializer': 2211 initializers.serialize(self.kernel_initializer), 2212 'recurrent_initializer': 2213 initializers.serialize(self.recurrent_initializer), 2214 'bias_initializer': 2215 initializers.serialize(self.bias_initializer), 2216 'kernel_regularizer': 2217 regularizers.serialize(self.kernel_regularizer), 2218 'recurrent_regularizer': 2219 regularizers.serialize(self.recurrent_regularizer), 2220 'bias_regularizer': 2221 regularizers.serialize(self.bias_regularizer), 2222 'activity_regularizer': 2223 regularizers.serialize(self.activity_regularizer), 2224 'kernel_constraint': 2225 constraints.serialize(self.kernel_constraint), 2226 'recurrent_constraint': 2227 constraints.serialize(self.recurrent_constraint), 2228 'bias_constraint': 2229 constraints.serialize(self.bias_constraint), 2230 'dropout': 2231 self.dropout, 2232 'recurrent_dropout': 2233 self.recurrent_dropout, 2234 'implementation': 2235 self.implementation, 2236 'reset_after': 2237 self.reset_after 2238 } 2239 config.update(_config_for_enable_caching_device(self.cell)) 2240 base_config = super(GRU, self).get_config() 2241 del base_config['cell'] 2242 return dict(list(base_config.items()) + list(config.items())) 2243 2244 @classmethod 2245 def from_config(cls, config): 2246 if 'implementation' in config and config['implementation'] == 0: 2247 config['implementation'] = 1 2248 return cls(**config) 2249 2250 2251@keras_export(v1=['keras.layers.LSTMCell']) 2252class LSTMCell(DropoutRNNCellMixin, Layer): 2253 """Cell class for the LSTM layer. 2254 2255 Args: 2256 units: Positive integer, dimensionality of the output space. 2257 activation: Activation function to use. 2258 Default: hyperbolic tangent (`tanh`). 2259 If you pass `None`, no activation is applied 2260 (ie. "linear" activation: `a(x) = x`). 2261 recurrent_activation: Activation function to use 2262 for the recurrent step. 2263 Default: hard sigmoid (`hard_sigmoid`). 2264 If you pass `None`, no activation is applied 2265 (ie. "linear" activation: `a(x) = x`). 2266 use_bias: Boolean, whether the layer uses a bias vector. 2267 kernel_initializer: Initializer for the `kernel` weights matrix, 2268 used for the linear transformation of the inputs. 2269 recurrent_initializer: Initializer for the `recurrent_kernel` 2270 weights matrix, 2271 used for the linear transformation of the recurrent state. 2272 bias_initializer: Initializer for the bias vector. 2273 unit_forget_bias: Boolean. 2274 If True, add 1 to the bias of the forget gate at initialization. 2275 Setting it to true will also force `bias_initializer="zeros"`. 2276 This is recommended in [Jozefowicz et al., 2015]( 2277 http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) 2278 kernel_regularizer: Regularizer function applied to 2279 the `kernel` weights matrix. 2280 recurrent_regularizer: Regularizer function applied to 2281 the `recurrent_kernel` weights matrix. 2282 bias_regularizer: Regularizer function applied to the bias vector. 2283 kernel_constraint: Constraint function applied to 2284 the `kernel` weights matrix. 2285 recurrent_constraint: Constraint function applied to 2286 the `recurrent_kernel` weights matrix. 2287 bias_constraint: Constraint function applied to the bias vector. 2288 dropout: Float between 0 and 1. 2289 Fraction of the units to drop for 2290 the linear transformation of the inputs. 2291 recurrent_dropout: Float between 0 and 1. 2292 Fraction of the units to drop for 2293 the linear transformation of the recurrent state. 2294 2295 Call arguments: 2296 inputs: A 2D tensor. 2297 states: List of state tensors corresponding to the previous timestep. 2298 training: Python boolean indicating whether the layer should behave in 2299 training mode or in inference mode. Only relevant when `dropout` or 2300 `recurrent_dropout` is used. 2301 """ 2302 2303 def __init__(self, 2304 units, 2305 activation='tanh', 2306 recurrent_activation='hard_sigmoid', 2307 use_bias=True, 2308 kernel_initializer='glorot_uniform', 2309 recurrent_initializer='orthogonal', 2310 bias_initializer='zeros', 2311 unit_forget_bias=True, 2312 kernel_regularizer=None, 2313 recurrent_regularizer=None, 2314 bias_regularizer=None, 2315 kernel_constraint=None, 2316 recurrent_constraint=None, 2317 bias_constraint=None, 2318 dropout=0., 2319 recurrent_dropout=0., 2320 **kwargs): 2321 if units < 0: 2322 raise ValueError(f'Received an invalid value for units, expected ' 2323 f'a positive integer, got {units}.') 2324 # By default use cached variable under v2 mode, see b/143699808. 2325 if ops.executing_eagerly_outside_functions(): 2326 self._enable_caching_device = kwargs.pop('enable_caching_device', True) 2327 else: 2328 self._enable_caching_device = kwargs.pop('enable_caching_device', False) 2329 super(LSTMCell, self).__init__(**kwargs) 2330 self.units = units 2331 self.activation = activations.get(activation) 2332 self.recurrent_activation = activations.get(recurrent_activation) 2333 self.use_bias = use_bias 2334 2335 self.kernel_initializer = initializers.get(kernel_initializer) 2336 self.recurrent_initializer = initializers.get(recurrent_initializer) 2337 self.bias_initializer = initializers.get(bias_initializer) 2338 self.unit_forget_bias = unit_forget_bias 2339 2340 self.kernel_regularizer = regularizers.get(kernel_regularizer) 2341 self.recurrent_regularizer = regularizers.get(recurrent_regularizer) 2342 self.bias_regularizer = regularizers.get(bias_regularizer) 2343 2344 self.kernel_constraint = constraints.get(kernel_constraint) 2345 self.recurrent_constraint = constraints.get(recurrent_constraint) 2346 self.bias_constraint = constraints.get(bias_constraint) 2347 2348 self.dropout = min(1., max(0., dropout)) 2349 self.recurrent_dropout = min(1., max(0., recurrent_dropout)) 2350 implementation = kwargs.pop('implementation', 1) 2351 if self.recurrent_dropout != 0 and implementation != 1: 2352 logging.debug(RECURRENT_DROPOUT_WARNING_MSG) 2353 self.implementation = 1 2354 else: 2355 self.implementation = implementation 2356 self.state_size = [self.units, self.units] 2357 self.output_size = self.units 2358 2359 @tf_utils.shape_type_conversion 2360 def build(self, input_shape): 2361 default_caching_device = _caching_device(self) 2362 input_dim = input_shape[-1] 2363 self.kernel = self.add_weight( 2364 shape=(input_dim, self.units * 4), 2365 name='kernel', 2366 initializer=self.kernel_initializer, 2367 regularizer=self.kernel_regularizer, 2368 constraint=self.kernel_constraint, 2369 caching_device=default_caching_device) 2370 self.recurrent_kernel = self.add_weight( 2371 shape=(self.units, self.units * 4), 2372 name='recurrent_kernel', 2373 initializer=self.recurrent_initializer, 2374 regularizer=self.recurrent_regularizer, 2375 constraint=self.recurrent_constraint, 2376 caching_device=default_caching_device) 2377 2378 if self.use_bias: 2379 if self.unit_forget_bias: 2380 2381 def bias_initializer(_, *args, **kwargs): 2382 return backend.concatenate([ 2383 self.bias_initializer((self.units,), *args, **kwargs), 2384 initializers.get('ones')((self.units,), *args, **kwargs), 2385 self.bias_initializer((self.units * 2,), *args, **kwargs), 2386 ]) 2387 else: 2388 bias_initializer = self.bias_initializer 2389 self.bias = self.add_weight( 2390 shape=(self.units * 4,), 2391 name='bias', 2392 initializer=bias_initializer, 2393 regularizer=self.bias_regularizer, 2394 constraint=self.bias_constraint, 2395 caching_device=default_caching_device) 2396 else: 2397 self.bias = None 2398 self.built = True 2399 2400 def _compute_carry_and_output(self, x, h_tm1, c_tm1): 2401 """Computes carry and output using split kernels.""" 2402 x_i, x_f, x_c, x_o = x 2403 h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1 2404 i = self.recurrent_activation( 2405 x_i + backend.dot(h_tm1_i, self.recurrent_kernel[:, :self.units])) 2406 f = self.recurrent_activation(x_f + backend.dot( 2407 h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2])) 2408 c = f * c_tm1 + i * self.activation(x_c + backend.dot( 2409 h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3])) 2410 o = self.recurrent_activation( 2411 x_o + backend.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:])) 2412 return c, o 2413 2414 def _compute_carry_and_output_fused(self, z, c_tm1): 2415 """Computes carry and output using fused kernels.""" 2416 z0, z1, z2, z3 = z 2417 i = self.recurrent_activation(z0) 2418 f = self.recurrent_activation(z1) 2419 c = f * c_tm1 + i * self.activation(z2) 2420 o = self.recurrent_activation(z3) 2421 return c, o 2422 2423 def call(self, inputs, states, training=None): 2424 h_tm1 = states[0] # previous memory state 2425 c_tm1 = states[1] # previous carry state 2426 2427 dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4) 2428 rec_dp_mask = self.get_recurrent_dropout_mask_for_cell( 2429 h_tm1, training, count=4) 2430 2431 if self.implementation == 1: 2432 if 0 < self.dropout < 1.: 2433 inputs_i = inputs * dp_mask[0] 2434 inputs_f = inputs * dp_mask[1] 2435 inputs_c = inputs * dp_mask[2] 2436 inputs_o = inputs * dp_mask[3] 2437 else: 2438 inputs_i = inputs 2439 inputs_f = inputs 2440 inputs_c = inputs 2441 inputs_o = inputs 2442 k_i, k_f, k_c, k_o = array_ops.split( 2443 self.kernel, num_or_size_splits=4, axis=1) 2444 x_i = backend.dot(inputs_i, k_i) 2445 x_f = backend.dot(inputs_f, k_f) 2446 x_c = backend.dot(inputs_c, k_c) 2447 x_o = backend.dot(inputs_o, k_o) 2448 if self.use_bias: 2449 b_i, b_f, b_c, b_o = array_ops.split( 2450 self.bias, num_or_size_splits=4, axis=0) 2451 x_i = backend.bias_add(x_i, b_i) 2452 x_f = backend.bias_add(x_f, b_f) 2453 x_c = backend.bias_add(x_c, b_c) 2454 x_o = backend.bias_add(x_o, b_o) 2455 2456 if 0 < self.recurrent_dropout < 1.: 2457 h_tm1_i = h_tm1 * rec_dp_mask[0] 2458 h_tm1_f = h_tm1 * rec_dp_mask[1] 2459 h_tm1_c = h_tm1 * rec_dp_mask[2] 2460 h_tm1_o = h_tm1 * rec_dp_mask[3] 2461 else: 2462 h_tm1_i = h_tm1 2463 h_tm1_f = h_tm1 2464 h_tm1_c = h_tm1 2465 h_tm1_o = h_tm1 2466 x = (x_i, x_f, x_c, x_o) 2467 h_tm1 = (h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o) 2468 c, o = self._compute_carry_and_output(x, h_tm1, c_tm1) 2469 else: 2470 if 0. < self.dropout < 1.: 2471 inputs = inputs * dp_mask[0] 2472 z = backend.dot(inputs, self.kernel) 2473 z += backend.dot(h_tm1, self.recurrent_kernel) 2474 if self.use_bias: 2475 z = backend.bias_add(z, self.bias) 2476 2477 z = array_ops.split(z, num_or_size_splits=4, axis=1) 2478 c, o = self._compute_carry_and_output_fused(z, c_tm1) 2479 2480 h = o * self.activation(c) 2481 return h, [h, c] 2482 2483 def get_config(self): 2484 config = { 2485 'units': 2486 self.units, 2487 'activation': 2488 activations.serialize(self.activation), 2489 'recurrent_activation': 2490 activations.serialize(self.recurrent_activation), 2491 'use_bias': 2492 self.use_bias, 2493 'kernel_initializer': 2494 initializers.serialize(self.kernel_initializer), 2495 'recurrent_initializer': 2496 initializers.serialize(self.recurrent_initializer), 2497 'bias_initializer': 2498 initializers.serialize(self.bias_initializer), 2499 'unit_forget_bias': 2500 self.unit_forget_bias, 2501 'kernel_regularizer': 2502 regularizers.serialize(self.kernel_regularizer), 2503 'recurrent_regularizer': 2504 regularizers.serialize(self.recurrent_regularizer), 2505 'bias_regularizer': 2506 regularizers.serialize(self.bias_regularizer), 2507 'kernel_constraint': 2508 constraints.serialize(self.kernel_constraint), 2509 'recurrent_constraint': 2510 constraints.serialize(self.recurrent_constraint), 2511 'bias_constraint': 2512 constraints.serialize(self.bias_constraint), 2513 'dropout': 2514 self.dropout, 2515 'recurrent_dropout': 2516 self.recurrent_dropout, 2517 'implementation': 2518 self.implementation 2519 } 2520 config.update(_config_for_enable_caching_device(self)) 2521 base_config = super(LSTMCell, self).get_config() 2522 return dict(list(base_config.items()) + list(config.items())) 2523 2524 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 2525 return list(_generate_zero_filled_state_for_cell( 2526 self, inputs, batch_size, dtype)) 2527 2528 2529@keras_export('keras.experimental.PeepholeLSTMCell') 2530class PeepholeLSTMCell(LSTMCell): 2531 """Equivalent to LSTMCell class but adds peephole connections. 2532 2533 Peephole connections allow the gates to utilize the previous internal state as 2534 well as the previous hidden state (which is what LSTMCell is limited to). 2535 This allows PeepholeLSTMCell to better learn precise timings over LSTMCell. 2536 2537 From [Gers et al., 2002]( 2538 http://www.jmlr.org/papers/volume3/gers02a/gers02a.pdf): 2539 2540 "We find that LSTM augmented by 'peephole connections' from its internal 2541 cells to its multiplicative gates can learn the fine distinction between 2542 sequences of spikes spaced either 50 or 49 time steps apart without the help 2543 of any short training exemplars." 2544 2545 The peephole implementation is based on: 2546 2547 [Sak et al., 2014](https://research.google.com/pubs/archive/43905.pdf) 2548 2549 Example: 2550 2551 ```python 2552 # Create 2 PeepholeLSTMCells 2553 peephole_lstm_cells = [PeepholeLSTMCell(size) for size in [128, 256]] 2554 # Create a layer composed sequentially of the peephole LSTM cells. 2555 layer = RNN(peephole_lstm_cells) 2556 input = keras.Input((timesteps, input_dim)) 2557 output = layer(input) 2558 ``` 2559 """ 2560 2561 def __init__(self, 2562 units, 2563 activation='tanh', 2564 recurrent_activation='hard_sigmoid', 2565 use_bias=True, 2566 kernel_initializer='glorot_uniform', 2567 recurrent_initializer='orthogonal', 2568 bias_initializer='zeros', 2569 unit_forget_bias=True, 2570 kernel_regularizer=None, 2571 recurrent_regularizer=None, 2572 bias_regularizer=None, 2573 kernel_constraint=None, 2574 recurrent_constraint=None, 2575 bias_constraint=None, 2576 dropout=0., 2577 recurrent_dropout=0., 2578 **kwargs): 2579 warnings.warn('`tf.keras.experimental.PeepholeLSTMCell` is deprecated ' 2580 'and will be removed in a future version. ' 2581 'Please use tensorflow_addons.rnn.PeepholeLSTMCell ' 2582 'instead.') 2583 super(PeepholeLSTMCell, self).__init__( 2584 units=units, 2585 activation=activation, 2586 recurrent_activation=recurrent_activation, 2587 use_bias=use_bias, 2588 kernel_initializer=kernel_initializer, 2589 recurrent_initializer=recurrent_initializer, 2590 bias_initializer=bias_initializer, 2591 unit_forget_bias=unit_forget_bias, 2592 kernel_regularizer=kernel_regularizer, 2593 recurrent_regularizer=recurrent_regularizer, 2594 bias_regularizer=bias_regularizer, 2595 kernel_constraint=kernel_constraint, 2596 recurrent_constraint=recurrent_constraint, 2597 bias_constraint=bias_constraint, 2598 dropout=dropout, 2599 recurrent_dropout=recurrent_dropout, 2600 implementation=kwargs.pop('implementation', 1), 2601 **kwargs) 2602 2603 def build(self, input_shape): 2604 super(PeepholeLSTMCell, self).build(input_shape) 2605 # The following are the weight matrices for the peephole connections. These 2606 # are multiplied with the previous internal state during the computation of 2607 # carry and output. 2608 self.input_gate_peephole_weights = self.add_weight( 2609 shape=(self.units,), 2610 name='input_gate_peephole_weights', 2611 initializer=self.kernel_initializer) 2612 self.forget_gate_peephole_weights = self.add_weight( 2613 shape=(self.units,), 2614 name='forget_gate_peephole_weights', 2615 initializer=self.kernel_initializer) 2616 self.output_gate_peephole_weights = self.add_weight( 2617 shape=(self.units,), 2618 name='output_gate_peephole_weights', 2619 initializer=self.kernel_initializer) 2620 2621 def _compute_carry_and_output(self, x, h_tm1, c_tm1): 2622 x_i, x_f, x_c, x_o = x 2623 h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1 2624 i = self.recurrent_activation( 2625 x_i + backend.dot(h_tm1_i, self.recurrent_kernel[:, :self.units]) + 2626 self.input_gate_peephole_weights * c_tm1) 2627 f = self.recurrent_activation(x_f + backend.dot( 2628 h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2]) + 2629 self.forget_gate_peephole_weights * c_tm1) 2630 c = f * c_tm1 + i * self.activation(x_c + backend.dot( 2631 h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3])) 2632 o = self.recurrent_activation( 2633 x_o + backend.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:]) + 2634 self.output_gate_peephole_weights * c) 2635 return c, o 2636 2637 def _compute_carry_and_output_fused(self, z, c_tm1): 2638 z0, z1, z2, z3 = z 2639 i = self.recurrent_activation(z0 + 2640 self.input_gate_peephole_weights * c_tm1) 2641 f = self.recurrent_activation(z1 + 2642 self.forget_gate_peephole_weights * c_tm1) 2643 c = f * c_tm1 + i * self.activation(z2) 2644 o = self.recurrent_activation(z3 + self.output_gate_peephole_weights * c) 2645 return c, o 2646 2647 2648@keras_export(v1=['keras.layers.LSTM']) 2649class LSTM(RNN): 2650 """Long Short-Term Memory layer - Hochreiter 1997. 2651 2652 Note that this cell is not optimized for performance on GPU. Please use 2653 `tf.compat.v1.keras.layers.CuDNNLSTM` for better performance on GPU. 2654 2655 Args: 2656 units: Positive integer, dimensionality of the output space. 2657 activation: Activation function to use. 2658 Default: hyperbolic tangent (`tanh`). 2659 If you pass `None`, no activation is applied 2660 (ie. "linear" activation: `a(x) = x`). 2661 recurrent_activation: Activation function to use 2662 for the recurrent step. 2663 Default: hard sigmoid (`hard_sigmoid`). 2664 If you pass `None`, no activation is applied 2665 (ie. "linear" activation: `a(x) = x`). 2666 use_bias: Boolean, whether the layer uses a bias vector. 2667 kernel_initializer: Initializer for the `kernel` weights matrix, 2668 used for the linear transformation of the inputs.. 2669 recurrent_initializer: Initializer for the `recurrent_kernel` 2670 weights matrix, 2671 used for the linear transformation of the recurrent state. 2672 bias_initializer: Initializer for the bias vector. 2673 unit_forget_bias: Boolean. 2674 If True, add 1 to the bias of the forget gate at initialization. 2675 Setting it to true will also force `bias_initializer="zeros"`. 2676 This is recommended in [Jozefowicz et al., 2015]( 2677 http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf). 2678 kernel_regularizer: Regularizer function applied to 2679 the `kernel` weights matrix. 2680 recurrent_regularizer: Regularizer function applied to 2681 the `recurrent_kernel` weights matrix. 2682 bias_regularizer: Regularizer function applied to the bias vector. 2683 activity_regularizer: Regularizer function applied to 2684 the output of the layer (its "activation"). 2685 kernel_constraint: Constraint function applied to 2686 the `kernel` weights matrix. 2687 recurrent_constraint: Constraint function applied to 2688 the `recurrent_kernel` weights matrix. 2689 bias_constraint: Constraint function applied to the bias vector. 2690 dropout: Float between 0 and 1. 2691 Fraction of the units to drop for 2692 the linear transformation of the inputs. 2693 recurrent_dropout: Float between 0 and 1. 2694 Fraction of the units to drop for 2695 the linear transformation of the recurrent state. 2696 return_sequences: Boolean. Whether to return the last output. 2697 in the output sequence, or the full sequence. 2698 return_state: Boolean. Whether to return the last state 2699 in addition to the output. 2700 go_backwards: Boolean (default False). 2701 If True, process the input sequence backwards and return the 2702 reversed sequence. 2703 stateful: Boolean (default False). If True, the last state 2704 for each sample at index i in a batch will be used as initial 2705 state for the sample of index i in the following batch. 2706 unroll: Boolean (default False). 2707 If True, the network will be unrolled, 2708 else a symbolic loop will be used. 2709 Unrolling can speed-up a RNN, 2710 although it tends to be more memory-intensive. 2711 Unrolling is only suitable for short sequences. 2712 time_major: The shape format of the `inputs` and `outputs` tensors. 2713 If True, the inputs and outputs will be in shape 2714 `(timesteps, batch, ...)`, whereas in the False case, it will be 2715 `(batch, timesteps, ...)`. Using `time_major = True` is a bit more 2716 efficient because it avoids transposes at the beginning and end of the 2717 RNN calculation. However, most TensorFlow data is batch-major, so by 2718 default this function accepts input and emits output in batch-major 2719 form. 2720 2721 Call arguments: 2722 inputs: A 3D tensor. 2723 mask: Binary tensor of shape `(samples, timesteps)` indicating whether 2724 a given timestep should be masked. An individual `True` entry indicates 2725 that the corresponding timestep should be utilized, while a `False` 2726 entry indicates that the corresponding timestep should be ignored. 2727 training: Python boolean indicating whether the layer should behave in 2728 training mode or in inference mode. This argument is passed to the cell 2729 when calling it. This is only relevant if `dropout` or 2730 `recurrent_dropout` is used. 2731 initial_state: List of initial state tensors to be passed to the first 2732 call of the cell. 2733 """ 2734 2735 def __init__(self, 2736 units, 2737 activation='tanh', 2738 recurrent_activation='hard_sigmoid', 2739 use_bias=True, 2740 kernel_initializer='glorot_uniform', 2741 recurrent_initializer='orthogonal', 2742 bias_initializer='zeros', 2743 unit_forget_bias=True, 2744 kernel_regularizer=None, 2745 recurrent_regularizer=None, 2746 bias_regularizer=None, 2747 activity_regularizer=None, 2748 kernel_constraint=None, 2749 recurrent_constraint=None, 2750 bias_constraint=None, 2751 dropout=0., 2752 recurrent_dropout=0., 2753 return_sequences=False, 2754 return_state=False, 2755 go_backwards=False, 2756 stateful=False, 2757 unroll=False, 2758 **kwargs): 2759 implementation = kwargs.pop('implementation', 1) 2760 if implementation == 0: 2761 logging.warning('`implementation=0` has been deprecated, ' 2762 'and now defaults to `implementation=1`.' 2763 'Please update your layer call.') 2764 if 'enable_caching_device' in kwargs: 2765 cell_kwargs = {'enable_caching_device': 2766 kwargs.pop('enable_caching_device')} 2767 else: 2768 cell_kwargs = {} 2769 cell = LSTMCell( 2770 units, 2771 activation=activation, 2772 recurrent_activation=recurrent_activation, 2773 use_bias=use_bias, 2774 kernel_initializer=kernel_initializer, 2775 recurrent_initializer=recurrent_initializer, 2776 unit_forget_bias=unit_forget_bias, 2777 bias_initializer=bias_initializer, 2778 kernel_regularizer=kernel_regularizer, 2779 recurrent_regularizer=recurrent_regularizer, 2780 bias_regularizer=bias_regularizer, 2781 kernel_constraint=kernel_constraint, 2782 recurrent_constraint=recurrent_constraint, 2783 bias_constraint=bias_constraint, 2784 dropout=dropout, 2785 recurrent_dropout=recurrent_dropout, 2786 implementation=implementation, 2787 dtype=kwargs.get('dtype'), 2788 trainable=kwargs.get('trainable', True), 2789 **cell_kwargs) 2790 super(LSTM, self).__init__( 2791 cell, 2792 return_sequences=return_sequences, 2793 return_state=return_state, 2794 go_backwards=go_backwards, 2795 stateful=stateful, 2796 unroll=unroll, 2797 **kwargs) 2798 self.activity_regularizer = regularizers.get(activity_regularizer) 2799 self.input_spec = [InputSpec(ndim=3)] 2800 2801 def call(self, inputs, mask=None, training=None, initial_state=None): 2802 return super(LSTM, self).call( 2803 inputs, mask=mask, training=training, initial_state=initial_state) 2804 2805 @property 2806 def units(self): 2807 return self.cell.units 2808 2809 @property 2810 def activation(self): 2811 return self.cell.activation 2812 2813 @property 2814 def recurrent_activation(self): 2815 return self.cell.recurrent_activation 2816 2817 @property 2818 def use_bias(self): 2819 return self.cell.use_bias 2820 2821 @property 2822 def kernel_initializer(self): 2823 return self.cell.kernel_initializer 2824 2825 @property 2826 def recurrent_initializer(self): 2827 return self.cell.recurrent_initializer 2828 2829 @property 2830 def bias_initializer(self): 2831 return self.cell.bias_initializer 2832 2833 @property 2834 def unit_forget_bias(self): 2835 return self.cell.unit_forget_bias 2836 2837 @property 2838 def kernel_regularizer(self): 2839 return self.cell.kernel_regularizer 2840 2841 @property 2842 def recurrent_regularizer(self): 2843 return self.cell.recurrent_regularizer 2844 2845 @property 2846 def bias_regularizer(self): 2847 return self.cell.bias_regularizer 2848 2849 @property 2850 def kernel_constraint(self): 2851 return self.cell.kernel_constraint 2852 2853 @property 2854 def recurrent_constraint(self): 2855 return self.cell.recurrent_constraint 2856 2857 @property 2858 def bias_constraint(self): 2859 return self.cell.bias_constraint 2860 2861 @property 2862 def dropout(self): 2863 return self.cell.dropout 2864 2865 @property 2866 def recurrent_dropout(self): 2867 return self.cell.recurrent_dropout 2868 2869 @property 2870 def implementation(self): 2871 return self.cell.implementation 2872 2873 def get_config(self): 2874 config = { 2875 'units': 2876 self.units, 2877 'activation': 2878 activations.serialize(self.activation), 2879 'recurrent_activation': 2880 activations.serialize(self.recurrent_activation), 2881 'use_bias': 2882 self.use_bias, 2883 'kernel_initializer': 2884 initializers.serialize(self.kernel_initializer), 2885 'recurrent_initializer': 2886 initializers.serialize(self.recurrent_initializer), 2887 'bias_initializer': 2888 initializers.serialize(self.bias_initializer), 2889 'unit_forget_bias': 2890 self.unit_forget_bias, 2891 'kernel_regularizer': 2892 regularizers.serialize(self.kernel_regularizer), 2893 'recurrent_regularizer': 2894 regularizers.serialize(self.recurrent_regularizer), 2895 'bias_regularizer': 2896 regularizers.serialize(self.bias_regularizer), 2897 'activity_regularizer': 2898 regularizers.serialize(self.activity_regularizer), 2899 'kernel_constraint': 2900 constraints.serialize(self.kernel_constraint), 2901 'recurrent_constraint': 2902 constraints.serialize(self.recurrent_constraint), 2903 'bias_constraint': 2904 constraints.serialize(self.bias_constraint), 2905 'dropout': 2906 self.dropout, 2907 'recurrent_dropout': 2908 self.recurrent_dropout, 2909 'implementation': 2910 self.implementation 2911 } 2912 config.update(_config_for_enable_caching_device(self.cell)) 2913 base_config = super(LSTM, self).get_config() 2914 del base_config['cell'] 2915 return dict(list(base_config.items()) + list(config.items())) 2916 2917 @classmethod 2918 def from_config(cls, config): 2919 if 'implementation' in config and config['implementation'] == 0: 2920 config['implementation'] = 1 2921 return cls(**config) 2922 2923 2924def _generate_dropout_mask(ones, rate, training=None, count=1): 2925 def dropped_inputs(): 2926 return backend.dropout(ones, rate) 2927 2928 if count > 1: 2929 return [ 2930 backend.in_train_phase(dropped_inputs, ones, training=training) 2931 for _ in range(count) 2932 ] 2933 return backend.in_train_phase(dropped_inputs, ones, training=training) 2934 2935 2936def _standardize_args(inputs, initial_state, constants, num_constants): 2937 """Standardizes `__call__` to a single list of tensor inputs. 2938 2939 When running a model loaded from a file, the input tensors 2940 `initial_state` and `constants` can be passed to `RNN.__call__()` as part 2941 of `inputs` instead of by the dedicated keyword arguments. This method 2942 makes sure the arguments are separated and that `initial_state` and 2943 `constants` are lists of tensors (or None). 2944 2945 Args: 2946 inputs: Tensor or list/tuple of tensors. which may include constants 2947 and initial states. In that case `num_constant` must be specified. 2948 initial_state: Tensor or list of tensors or None, initial states. 2949 constants: Tensor or list of tensors or None, constant tensors. 2950 num_constants: Expected number of constants (if constants are passed as 2951 part of the `inputs` list. 2952 2953 Returns: 2954 inputs: Single tensor or tuple of tensors. 2955 initial_state: List of tensors or None. 2956 constants: List of tensors or None. 2957 """ 2958 if isinstance(inputs, list): 2959 # There are several situations here: 2960 # In the graph mode, __call__ will be only called once. The initial_state 2961 # and constants could be in inputs (from file loading). 2962 # In the eager mode, __call__ will be called twice, once during 2963 # rnn_layer(inputs=input_t, constants=c_t, ...), and second time will be 2964 # model.fit/train_on_batch/predict with real np data. In the second case, 2965 # the inputs will contain initial_state and constants as eager tensor. 2966 # 2967 # For either case, the real input is the first item in the list, which 2968 # could be a nested structure itself. Then followed by initial_states, which 2969 # could be a list of items, or list of list if the initial_state is complex 2970 # structure, and finally followed by constants which is a flat list. 2971 assert initial_state is None and constants is None 2972 if num_constants: 2973 constants = inputs[-num_constants:] 2974 inputs = inputs[:-num_constants] 2975 if len(inputs) > 1: 2976 initial_state = inputs[1:] 2977 inputs = inputs[:1] 2978 2979 if len(inputs) > 1: 2980 inputs = tuple(inputs) 2981 else: 2982 inputs = inputs[0] 2983 2984 def to_list_or_none(x): 2985 if x is None or isinstance(x, list): 2986 return x 2987 if isinstance(x, tuple): 2988 return list(x) 2989 return [x] 2990 2991 initial_state = to_list_or_none(initial_state) 2992 constants = to_list_or_none(constants) 2993 2994 return inputs, initial_state, constants 2995 2996 2997def _is_multiple_state(state_size): 2998 """Check whether the state_size contains multiple states.""" 2999 return (hasattr(state_size, '__len__') and 3000 not isinstance(state_size, tensor_shape.TensorShape)) 3001 3002 3003def _generate_zero_filled_state_for_cell(cell, inputs, batch_size, dtype): 3004 if inputs is not None: 3005 batch_size = array_ops.shape(inputs)[0] 3006 dtype = inputs.dtype 3007 return _generate_zero_filled_state(batch_size, cell.state_size, dtype) 3008 3009 3010def _generate_zero_filled_state(batch_size_tensor, state_size, dtype): 3011 """Generate a zero filled tensor with shape [batch_size, state_size].""" 3012 if batch_size_tensor is None or dtype is None: 3013 raise ValueError( 3014 'batch_size and dtype cannot be None while constructing initial state: ' 3015 'batch_size={}, dtype={}'.format(batch_size_tensor, dtype)) 3016 3017 def create_zeros(unnested_state_size): 3018 flat_dims = tensor_shape.TensorShape(unnested_state_size).as_list() 3019 init_state_size = [batch_size_tensor] + flat_dims 3020 return array_ops.zeros(init_state_size, dtype=dtype) 3021 3022 if nest.is_nested(state_size): 3023 return nest.map_structure(create_zeros, state_size) 3024 else: 3025 return create_zeros(state_size) 3026 3027 3028def _caching_device(rnn_cell): 3029 """Returns the caching device for the RNN variable. 3030 3031 This is useful for distributed training, when variable is not located as same 3032 device as the training worker. By enabling the device cache, this allows 3033 worker to read the variable once and cache locally, rather than read it every 3034 time step from remote when it is needed. 3035 3036 Note that this is assuming the variable that cell needs for each time step is 3037 having the same value in the forward path, and only gets updated in the 3038 backprop. It is true for all the default cells (SimpleRNN, GRU, LSTM). If the 3039 cell body relies on any variable that gets updated every time step, then 3040 caching device will cause it to read the stall value. 3041 3042 Args: 3043 rnn_cell: the rnn cell instance. 3044 """ 3045 if context.executing_eagerly(): 3046 # caching_device is not supported in eager mode. 3047 return None 3048 if not getattr(rnn_cell, '_enable_caching_device', False): 3049 return None 3050 # Don't set a caching device when running in a loop, since it is possible that 3051 # train steps could be wrapped in a tf.while_loop. In that scenario caching 3052 # prevents forward computations in loop iterations from re-reading the 3053 # updated weights. 3054 if control_flow_util.IsInWhileLoop(ops.get_default_graph()): 3055 logging.warning( 3056 'Variable read device caching has been disabled because the ' 3057 'RNN is in tf.while_loop loop context, which will cause ' 3058 'reading stalled value in forward path. This could slow down ' 3059 'the training due to duplicated variable reads. Please ' 3060 'consider updating your code to remove tf.while_loop if possible.') 3061 return None 3062 if (rnn_cell._dtype_policy.compute_dtype != 3063 rnn_cell._dtype_policy.variable_dtype): 3064 logging.warning( 3065 'Variable read device caching has been disabled since it ' 3066 'doesn\'t work with the mixed precision API. This is ' 3067 'likely to cause a slowdown for RNN training due to ' 3068 'duplicated read of variable for each timestep, which ' 3069 'will be significant in a multi remote worker setting. ' 3070 'Please consider disabling mixed precision API if ' 3071 'the performance has been affected.') 3072 return None 3073 # Cache the value on the device that access the variable. 3074 return lambda op: op.device 3075 3076 3077def _config_for_enable_caching_device(rnn_cell): 3078 """Return the dict config for RNN cell wrt to enable_caching_device field. 3079 3080 Since enable_caching_device is a internal implementation detail for speed up 3081 the RNN variable read when running on the multi remote worker setting, we 3082 don't want this config to be serialized constantly in the JSON. We will only 3083 serialize this field when a none default value is used to create the cell. 3084 Args: 3085 rnn_cell: the RNN cell for serialize. 3086 3087 Returns: 3088 A dict which contains the JSON config for enable_caching_device value or 3089 empty dict if the enable_caching_device value is same as the default value. 3090 """ 3091 default_enable_caching_device = ops.executing_eagerly_outside_functions() 3092 if rnn_cell._enable_caching_device != default_enable_caching_device: 3093 return {'enable_caching_device': rnn_cell._enable_caching_device} 3094 return {} 3095