xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/layers/recurrent.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16# pylint: disable=g-classes-have-attributes
17"""Recurrent layers and their base classes."""
18
19import collections
20import warnings
21
22import numpy as np
23
24from tensorflow.python.distribute import distribution_strategy_context as ds_context
25from tensorflow.python.eager import context
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import tensor_shape
28from tensorflow.python.keras import activations
29from tensorflow.python.keras import backend
30from tensorflow.python.keras import constraints
31from tensorflow.python.keras import initializers
32from tensorflow.python.keras import regularizers
33from tensorflow.python.keras.engine.base_layer import Layer
34from tensorflow.python.keras.engine.input_spec import InputSpec
35from tensorflow.python.keras.saving.saved_model import layer_serialization
36from tensorflow.python.keras.utils import control_flow_util
37from tensorflow.python.keras.utils import generic_utils
38from tensorflow.python.keras.utils import tf_utils
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import control_flow_ops
41from tensorflow.python.ops import math_ops
42from tensorflow.python.ops import state_ops
43from tensorflow.python.platform import tf_logging as logging
44from tensorflow.python.trackable import base as trackable
45from tensorflow.python.util import nest
46from tensorflow.python.util.tf_export import keras_export
47from tensorflow.tools.docs import doc_controls
48
49
50RECURRENT_DROPOUT_WARNING_MSG = (
51    'RNN `implementation=2` is not supported when `recurrent_dropout` is set. '
52    'Using `implementation=1`.')
53
54
55@keras_export('keras.layers.StackedRNNCells')
56class StackedRNNCells(Layer):
57  """Wrapper allowing a stack of RNN cells to behave as a single cell.
58
59  Used to implement efficient stacked RNNs.
60
61  Args:
62    cells: List of RNN cell instances.
63
64  Examples:
65
66  ```python
67  batch_size = 3
68  sentence_max_length = 5
69  n_features = 2
70  new_shape = (batch_size, sentence_max_length, n_features)
71  x = tf.constant(np.reshape(np.arange(30), new_shape), dtype = tf.float32)
72
73  rnn_cells = [tf.keras.layers.LSTMCell(128) for _ in range(2)]
74  stacked_lstm = tf.keras.layers.StackedRNNCells(rnn_cells)
75  lstm_layer = tf.keras.layers.RNN(stacked_lstm)
76
77  result = lstm_layer(x)
78  ```
79  """
80
81  def __init__(self, cells, **kwargs):
82    for cell in cells:
83      if not 'call' in dir(cell):
84        raise ValueError('All cells must have a `call` method. '
85                         'received cells:', cells)
86      if not 'state_size' in dir(cell):
87        raise ValueError('All cells must have a '
88                         '`state_size` attribute. '
89                         'received cells:', cells)
90    self.cells = cells
91    # reverse_state_order determines whether the state size will be in a reverse
92    # order of the cells' state. User might want to set this to True to keep the
93    # existing behavior. This is only useful when use RNN(return_state=True)
94    # since the state will be returned as the same order of state_size.
95    self.reverse_state_order = kwargs.pop('reverse_state_order', False)
96    if self.reverse_state_order:
97      logging.warning('reverse_state_order=True in StackedRNNCells will soon '
98                      'be deprecated. Please update the code to work with the '
99                      'natural order of states if you rely on the RNN states, '
100                      'eg RNN(return_state=True).')
101    super(StackedRNNCells, self).__init__(**kwargs)
102
103  @property
104  def state_size(self):
105    return tuple(c.state_size for c in
106                 (self.cells[::-1] if self.reverse_state_order else self.cells))
107
108  @property
109  def output_size(self):
110    if getattr(self.cells[-1], 'output_size', None) is not None:
111      return self.cells[-1].output_size
112    elif _is_multiple_state(self.cells[-1].state_size):
113      return self.cells[-1].state_size[0]
114    else:
115      return self.cells[-1].state_size
116
117  def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
118    initial_states = []
119    for cell in self.cells[::-1] if self.reverse_state_order else self.cells:
120      get_initial_state_fn = getattr(cell, 'get_initial_state', None)
121      if get_initial_state_fn:
122        initial_states.append(get_initial_state_fn(
123            inputs=inputs, batch_size=batch_size, dtype=dtype))
124      else:
125        initial_states.append(_generate_zero_filled_state_for_cell(
126            cell, inputs, batch_size, dtype))
127
128    return tuple(initial_states)
129
130  def call(self, inputs, states, constants=None, training=None, **kwargs):
131    # Recover per-cell states.
132    state_size = (self.state_size[::-1]
133                  if self.reverse_state_order else self.state_size)
134    nested_states = nest.pack_sequence_as(state_size, nest.flatten(states))
135
136    # Call the cells in order and store the returned states.
137    new_nested_states = []
138    for cell, states in zip(self.cells, nested_states):
139      states = states if nest.is_nested(states) else [states]
140      # TF cell does not wrap the state into list when there is only one state.
141      is_tf_rnn_cell = getattr(cell, '_is_tf_rnn_cell', None) is not None
142      states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
143      if generic_utils.has_arg(cell.call, 'training'):
144        kwargs['training'] = training
145      else:
146        kwargs.pop('training', None)
147      # Use the __call__ function for callable objects, eg layers, so that it
148      # will have the proper name scopes for the ops, etc.
149      cell_call_fn = cell.__call__ if callable(cell) else cell.call
150      if generic_utils.has_arg(cell.call, 'constants'):
151        inputs, states = cell_call_fn(inputs, states,
152                                      constants=constants, **kwargs)
153      else:
154        inputs, states = cell_call_fn(inputs, states, **kwargs)
155      new_nested_states.append(states)
156
157    return inputs, nest.pack_sequence_as(state_size,
158                                         nest.flatten(new_nested_states))
159
160  @tf_utils.shape_type_conversion
161  def build(self, input_shape):
162    if isinstance(input_shape, list):
163      input_shape = input_shape[0]
164    for cell in self.cells:
165      if isinstance(cell, Layer) and not cell.built:
166        with backend.name_scope(cell.name):
167          cell.build(input_shape)
168          cell.built = True
169      if getattr(cell, 'output_size', None) is not None:
170        output_dim = cell.output_size
171      elif _is_multiple_state(cell.state_size):
172        output_dim = cell.state_size[0]
173      else:
174        output_dim = cell.state_size
175      input_shape = tuple([input_shape[0]] +
176                          tensor_shape.TensorShape(output_dim).as_list())
177    self.built = True
178
179  def get_config(self):
180    cells = []
181    for cell in self.cells:
182      cells.append(generic_utils.serialize_keras_object(cell))
183    config = {'cells': cells}
184    base_config = super(StackedRNNCells, self).get_config()
185    return dict(list(base_config.items()) + list(config.items()))
186
187  @classmethod
188  def from_config(cls, config, custom_objects=None):
189    from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
190    cells = []
191    for cell_config in config.pop('cells'):
192      cells.append(
193          deserialize_layer(cell_config, custom_objects=custom_objects))
194    return cls(cells, **config)
195
196
197@keras_export('keras.layers.RNN')
198class RNN(Layer):
199  """Base class for recurrent layers.
200
201  See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
202  for details about the usage of RNN API.
203
204  Args:
205    cell: A RNN cell instance or a list of RNN cell instances.
206      A RNN cell is a class that has:
207      - A `call(input_at_t, states_at_t)` method, returning
208        `(output_at_t, states_at_t_plus_1)`. The call method of the
209        cell can also take the optional argument `constants`, see
210        section "Note on passing external constants" below.
211      - A `state_size` attribute. This can be a single integer
212        (single state) in which case it is the size of the recurrent
213        state. This can also be a list/tuple of integers (one size per state).
214        The `state_size` can also be TensorShape or tuple/list of
215        TensorShape, to represent high dimension state.
216      - A `output_size` attribute. This can be a single integer or a
217        TensorShape, which represent the shape of the output. For backward
218        compatible reason, if this attribute is not available for the
219        cell, the value will be inferred by the first element of the
220        `state_size`.
221      - A `get_initial_state(inputs=None, batch_size=None, dtype=None)`
222        method that creates a tensor meant to be fed to `call()` as the
223        initial state, if the user didn't specify any initial state via other
224        means. The returned initial state should have a shape of
225        [batch_size, cell.state_size]. The cell might choose to create a
226        tensor full of zeros, or full of other values based on the cell's
227        implementation.
228        `inputs` is the input tensor to the RNN layer, which should
229        contain the batch size as its shape[0], and also dtype. Note that
230        the shape[0] might be `None` during the graph construction. Either
231        the `inputs` or the pair of `batch_size` and `dtype` are provided.
232        `batch_size` is a scalar tensor that represents the batch size
233        of the inputs. `dtype` is `tf.DType` that represents the dtype of
234        the inputs.
235        For backward compatibility, if this method is not implemented
236        by the cell, the RNN layer will create a zero filled tensor with the
237        size of [batch_size, cell.state_size].
238      In the case that `cell` is a list of RNN cell instances, the cells
239      will be stacked on top of each other in the RNN, resulting in an
240      efficient stacked RNN.
241    return_sequences: Boolean (default `False`). Whether to return the last
242      output in the output sequence, or the full sequence.
243    return_state: Boolean (default `False`). Whether to return the last state
244      in addition to the output.
245    go_backwards: Boolean (default `False`).
246      If True, process the input sequence backwards and return the
247      reversed sequence.
248    stateful: Boolean (default `False`). If True, the last state
249      for each sample at index i in a batch will be used as initial
250      state for the sample of index i in the following batch.
251    unroll: Boolean (default `False`).
252      If True, the network will be unrolled, else a symbolic loop will be used.
253      Unrolling can speed-up a RNN, although it tends to be more
254      memory-intensive. Unrolling is only suitable for short sequences.
255    time_major: The shape format of the `inputs` and `outputs` tensors.
256      If True, the inputs and outputs will be in shape
257      `(timesteps, batch, ...)`, whereas in the False case, it will be
258      `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
259      efficient because it avoids transposes at the beginning and end of the
260      RNN calculation. However, most TensorFlow data is batch-major, so by
261      default this function accepts input and emits output in batch-major
262      form.
263    zero_output_for_mask: Boolean (default `False`).
264      Whether the output should use zeros for the masked timesteps. Note that
265      this field is only used when `return_sequences` is True and mask is
266      provided. It can useful if you want to reuse the raw output sequence of
267      the RNN without interference from the masked timesteps, eg, merging
268      bidirectional RNNs.
269
270  Call arguments:
271    inputs: Input tensor.
272    mask: Binary tensor of shape `[batch_size, timesteps]` indicating whether
273      a given timestep should be masked. An individual `True` entry indicates
274      that the corresponding timestep should be utilized, while a `False`
275      entry indicates that the corresponding timestep should be ignored.
276    training: Python boolean indicating whether the layer should behave in
277      training mode or in inference mode. This argument is passed to the cell
278      when calling it. This is for use with cells that use dropout.
279    initial_state: List of initial state tensors to be passed to the first
280      call of the cell.
281    constants: List of constant tensors to be passed to the cell at each
282      timestep.
283
284  Input shape:
285    N-D tensor with shape `[batch_size, timesteps, ...]` or
286    `[timesteps, batch_size, ...]` when time_major is True.
287
288  Output shape:
289    - If `return_state`: a list of tensors. The first tensor is
290      the output. The remaining tensors are the last states,
291      each with shape `[batch_size, state_size]`, where `state_size` could
292      be a high dimension tensor shape.
293    - If `return_sequences`: N-D tensor with shape
294      `[batch_size, timesteps, output_size]`, where `output_size` could
295      be a high dimension tensor shape, or
296      `[timesteps, batch_size, output_size]` when `time_major` is True.
297    - Else, N-D tensor with shape `[batch_size, output_size]`, where
298      `output_size` could be a high dimension tensor shape.
299
300  Masking:
301    This layer supports masking for input data with a variable number
302    of timesteps. To introduce masks to your data,
303    use an [tf.keras.layers.Embedding] layer with the `mask_zero` parameter
304    set to `True`.
305
306  Note on using statefulness in RNNs:
307    You can set RNN layers to be 'stateful', which means that the states
308    computed for the samples in one batch will be reused as initial states
309    for the samples in the next batch. This assumes a one-to-one mapping
310    between samples in different successive batches.
311
312    To enable statefulness:
313      - Specify `stateful=True` in the layer constructor.
314      - Specify a fixed batch size for your model, by passing
315        If sequential model:
316          `batch_input_shape=(...)` to the first layer in your model.
317        Else for functional model with 1 or more Input layers:
318          `batch_shape=(...)` to all the first layers in your model.
319        This is the expected shape of your inputs
320        *including the batch size*.
321        It should be a tuple of integers, e.g. `(32, 10, 100)`.
322      - Specify `shuffle=False` when calling `fit()`.
323
324    To reset the states of your model, call `.reset_states()` on either
325    a specific layer, or on your entire model.
326
327  Note on specifying the initial state of RNNs:
328    You can specify the initial state of RNN layers symbolically by
329    calling them with the keyword argument `initial_state`. The value of
330    `initial_state` should be a tensor or list of tensors representing
331    the initial state of the RNN layer.
332
333    You can specify the initial state of RNN layers numerically by
334    calling `reset_states` with the keyword argument `states`. The value of
335    `states` should be a numpy array or list of numpy arrays representing
336    the initial state of the RNN layer.
337
338  Note on passing external constants to RNNs:
339    You can pass "external" constants to the cell using the `constants`
340    keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This
341    requires that the `cell.call` method accepts the same keyword argument
342    `constants`. Such constants can be used to condition the cell
343    transformation on additional static inputs (not changing over time),
344    a.k.a. an attention mechanism.
345
346  Examples:
347
348  ```python
349  # First, let's define a RNN Cell, as a layer subclass.
350
351  class MinimalRNNCell(keras.layers.Layer):
352
353      def __init__(self, units, **kwargs):
354          self.units = units
355          self.state_size = units
356          super(MinimalRNNCell, self).__init__(**kwargs)
357
358      def build(self, input_shape):
359          self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
360                                        initializer='uniform',
361                                        name='kernel')
362          self.recurrent_kernel = self.add_weight(
363              shape=(self.units, self.units),
364              initializer='uniform',
365              name='recurrent_kernel')
366          self.built = True
367
368      def call(self, inputs, states):
369          prev_output = states[0]
370          h = backend.dot(inputs, self.kernel)
371          output = h + backend.dot(prev_output, self.recurrent_kernel)
372          return output, [output]
373
374  # Let's use this cell in a RNN layer:
375
376  cell = MinimalRNNCell(32)
377  x = keras.Input((None, 5))
378  layer = RNN(cell)
379  y = layer(x)
380
381  # Here's how to use the cell to build a stacked RNN:
382
383  cells = [MinimalRNNCell(32), MinimalRNNCell(64)]
384  x = keras.Input((None, 5))
385  layer = RNN(cells)
386  y = layer(x)
387  ```
388  """
389
390  def __init__(self,
391               cell,
392               return_sequences=False,
393               return_state=False,
394               go_backwards=False,
395               stateful=False,
396               unroll=False,
397               time_major=False,
398               **kwargs):
399    if isinstance(cell, (list, tuple)):
400      cell = StackedRNNCells(cell)
401    if not 'call' in dir(cell):
402      raise ValueError('`cell` should have a `call` method. '
403                       'The RNN was passed:', cell)
404    if not 'state_size' in dir(cell):
405      raise ValueError('The RNN cell should have '
406                       'an attribute `state_size` '
407                       '(tuple of integers, '
408                       'one integer per RNN state).')
409    # If True, the output for masked timestep will be zeros, whereas in the
410    # False case, output from previous timestep is returned for masked timestep.
411    self.zero_output_for_mask = kwargs.pop('zero_output_for_mask', False)
412
413    if 'input_shape' not in kwargs and (
414        'input_dim' in kwargs or 'input_length' in kwargs):
415      input_shape = (kwargs.pop('input_length', None),
416                     kwargs.pop('input_dim', None))
417      kwargs['input_shape'] = input_shape
418
419    super(RNN, self).__init__(**kwargs)
420    self.cell = cell
421    self.return_sequences = return_sequences
422    self.return_state = return_state
423    self.go_backwards = go_backwards
424    self.stateful = stateful
425    self.unroll = unroll
426    self.time_major = time_major
427
428    self.supports_masking = True
429    # The input shape is unknown yet, it could have nested tensor inputs, and
430    # the input spec will be the list of specs for nested inputs, the structure
431    # of the input_spec will be the same as the input.
432    self.input_spec = None
433    self.state_spec = None
434    self._states = None
435    self.constants_spec = None
436    self._num_constants = 0
437
438    if stateful:
439      if ds_context.has_strategy():
440        raise ValueError('RNNs with stateful=True not yet supported with '
441                         'tf.distribute.Strategy.')
442
443  @property
444  def _use_input_spec_as_call_signature(self):
445    if self.unroll:
446      # When the RNN layer is unrolled, the time step shape cannot be unknown.
447      # The input spec does not define the time step (because this layer can be
448      # called with any time step value, as long as it is not None), so it
449      # cannot be used as the call function signature when saving to SavedModel.
450      return False
451    return super(RNN, self)._use_input_spec_as_call_signature
452
453  @property
454  def states(self):
455    if self._states is None:
456      state = nest.map_structure(lambda _: None, self.cell.state_size)
457      return state if nest.is_nested(self.cell.state_size) else [state]
458    return self._states
459
460  @states.setter
461  # Automatic tracking catches "self._states" which adds an extra weight and
462  # breaks HDF5 checkpoints.
463  @trackable.no_automatic_dependency_tracking
464  def states(self, states):
465    self._states = states
466
467  def compute_output_shape(self, input_shape):
468    if isinstance(input_shape, list):
469      input_shape = input_shape[0]
470    # Check whether the input shape contains any nested shapes. It could be
471    # (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from numpy
472    # inputs.
473    try:
474      input_shape = tensor_shape.TensorShape(input_shape)
475    except (ValueError, TypeError):
476      # A nested tensor input
477      input_shape = nest.flatten(input_shape)[0]
478
479    batch = input_shape[0]
480    time_step = input_shape[1]
481    if self.time_major:
482      batch, time_step = time_step, batch
483
484    if _is_multiple_state(self.cell.state_size):
485      state_size = self.cell.state_size
486    else:
487      state_size = [self.cell.state_size]
488
489    def _get_output_shape(flat_output_size):
490      output_dim = tensor_shape.TensorShape(flat_output_size).as_list()
491      if self.return_sequences:
492        if self.time_major:
493          output_shape = tensor_shape.TensorShape(
494              [time_step, batch] + output_dim)
495        else:
496          output_shape = tensor_shape.TensorShape(
497              [batch, time_step] + output_dim)
498      else:
499        output_shape = tensor_shape.TensorShape([batch] + output_dim)
500      return output_shape
501
502    if getattr(self.cell, 'output_size', None) is not None:
503      # cell.output_size could be nested structure.
504      output_shape = nest.flatten(nest.map_structure(
505          _get_output_shape, self.cell.output_size))
506      output_shape = output_shape[0] if len(output_shape) == 1 else output_shape
507    else:
508      # Note that state_size[0] could be a tensor_shape or int.
509      output_shape = _get_output_shape(state_size[0])
510
511    if self.return_state:
512      def _get_state_shape(flat_state):
513        state_shape = [batch] + tensor_shape.TensorShape(flat_state).as_list()
514        return tensor_shape.TensorShape(state_shape)
515      state_shape = nest.map_structure(_get_state_shape, state_size)
516      return generic_utils.to_list(output_shape) + nest.flatten(state_shape)
517    else:
518      return output_shape
519
520  def compute_mask(self, inputs, mask):
521    # Time step masks must be the same for each input.
522    # This is because the mask for an RNN is of size [batch, time_steps, 1],
523    # and specifies which time steps should be skipped, and a time step
524    # must be skipped for all inputs.
525    # TODO(scottzhu): Should we accept multiple different masks?
526    mask = nest.flatten(mask)[0]
527    output_mask = mask if self.return_sequences else None
528    if self.return_state:
529      state_mask = [None for _ in self.states]
530      return [output_mask] + state_mask
531    else:
532      return output_mask
533
534  def build(self, input_shape):
535    if isinstance(input_shape, list):
536      input_shape = input_shape[0]
537      # The input_shape here could be a nest structure.
538
539    # do the tensor_shape to shapes here. The input could be single tensor, or a
540    # nested structure of tensors.
541    def get_input_spec(shape):
542      """Convert input shape to InputSpec."""
543      if isinstance(shape, tensor_shape.TensorShape):
544        input_spec_shape = shape.as_list()
545      else:
546        input_spec_shape = list(shape)
547      batch_index, time_step_index = (1, 0) if self.time_major else (0, 1)
548      if not self.stateful:
549        input_spec_shape[batch_index] = None
550      input_spec_shape[time_step_index] = None
551      return InputSpec(shape=tuple(input_spec_shape))
552
553    def get_step_input_shape(shape):
554      if isinstance(shape, tensor_shape.TensorShape):
555        shape = tuple(shape.as_list())
556      # remove the timestep from the input_shape
557      return shape[1:] if self.time_major else (shape[0],) + shape[2:]
558
559    # Check whether the input shape contains any nested shapes. It could be
560    # (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from numpy
561    # inputs.
562    try:
563      input_shape = tensor_shape.TensorShape(input_shape)
564    except (ValueError, TypeError):
565      # A nested tensor input
566      pass
567
568    if not nest.is_nested(input_shape):
569      # This indicates the there is only one input.
570      if self.input_spec is not None:
571        self.input_spec[0] = get_input_spec(input_shape)
572      else:
573        self.input_spec = [get_input_spec(input_shape)]
574      step_input_shape = get_step_input_shape(input_shape)
575    else:
576      if self.input_spec is not None:
577        self.input_spec[0] = nest.map_structure(get_input_spec, input_shape)
578      else:
579        self.input_spec = generic_utils.to_list(
580            nest.map_structure(get_input_spec, input_shape))
581      step_input_shape = nest.map_structure(get_step_input_shape, input_shape)
582
583    # allow cell (if layer) to build before we set or validate state_spec.
584    if isinstance(self.cell, Layer) and not self.cell.built:
585      with backend.name_scope(self.cell.name):
586        self.cell.build(step_input_shape)
587        self.cell.built = True
588
589    # set or validate state_spec
590    if _is_multiple_state(self.cell.state_size):
591      state_size = list(self.cell.state_size)
592    else:
593      state_size = [self.cell.state_size]
594
595    if self.state_spec is not None:
596      # initial_state was passed in call, check compatibility
597      self._validate_state_spec(state_size, self.state_spec)
598    else:
599      self.state_spec = [
600          InputSpec(shape=[None] + tensor_shape.TensorShape(dim).as_list())
601          for dim in state_size
602      ]
603    if self.stateful:
604      self.reset_states()
605    self.built = True
606
607  @staticmethod
608  def _validate_state_spec(cell_state_sizes, init_state_specs):
609    """Validate the state spec between the initial_state and the state_size.
610
611    Args:
612      cell_state_sizes: list, the `state_size` attribute from the cell.
613      init_state_specs: list, the `state_spec` from the initial_state that is
614        passed in `call()`.
615
616    Raises:
617      ValueError: When initial state spec is not compatible with the state size.
618    """
619    validation_error = ValueError(
620        'An `initial_state` was passed that is not compatible with '
621        '`cell.state_size`. Received `state_spec`={}; '
622        'however `cell.state_size` is '
623        '{}'.format(init_state_specs, cell_state_sizes))
624    flat_cell_state_sizes = nest.flatten(cell_state_sizes)
625    flat_state_specs = nest.flatten(init_state_specs)
626
627    if len(flat_cell_state_sizes) != len(flat_state_specs):
628      raise validation_error
629    for cell_state_spec, cell_state_size in zip(flat_state_specs,
630                                                flat_cell_state_sizes):
631      if not tensor_shape.TensorShape(
632          # Ignore the first axis for init_state which is for batch
633          cell_state_spec.shape[1:]).is_compatible_with(
634              tensor_shape.TensorShape(cell_state_size)):
635        raise validation_error
636
637  @doc_controls.do_not_doc_inheritable
638  def get_initial_state(self, inputs):
639    get_initial_state_fn = getattr(self.cell, 'get_initial_state', None)
640
641    if nest.is_nested(inputs):
642      # The input are nested sequences. Use the first element in the seq to get
643      # batch size and dtype.
644      inputs = nest.flatten(inputs)[0]
645
646    input_shape = array_ops.shape(inputs)
647    batch_size = input_shape[1] if self.time_major else input_shape[0]
648    dtype = inputs.dtype
649    if get_initial_state_fn:
650      init_state = get_initial_state_fn(
651          inputs=None, batch_size=batch_size, dtype=dtype)
652    else:
653      init_state = _generate_zero_filled_state(batch_size, self.cell.state_size,
654                                               dtype)
655    # Keras RNN expect the states in a list, even if it's a single state tensor.
656    if not nest.is_nested(init_state):
657      init_state = [init_state]
658    # Force the state to be a list in case it is a namedtuple eg LSTMStateTuple.
659    return list(init_state)
660
661  def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
662    inputs, initial_state, constants = _standardize_args(inputs,
663                                                         initial_state,
664                                                         constants,
665                                                         self._num_constants)
666
667    if initial_state is None and constants is None:
668      return super(RNN, self).__call__(inputs, **kwargs)
669
670    # If any of `initial_state` or `constants` are specified and are Keras
671    # tensors, then add them to the inputs and temporarily modify the
672    # input_spec to include them.
673
674    additional_inputs = []
675    additional_specs = []
676    if initial_state is not None:
677      additional_inputs += initial_state
678      self.state_spec = nest.map_structure(
679          lambda s: InputSpec(shape=backend.int_shape(s)), initial_state)
680      additional_specs += self.state_spec
681    if constants is not None:
682      additional_inputs += constants
683      self.constants_spec = [
684          InputSpec(shape=backend.int_shape(constant)) for constant in constants
685      ]
686      self._num_constants = len(constants)
687      additional_specs += self.constants_spec
688    # additional_inputs can be empty if initial_state or constants are provided
689    # but empty (e.g. the cell is stateless).
690    flat_additional_inputs = nest.flatten(additional_inputs)
691    is_keras_tensor = backend.is_keras_tensor(
692        flat_additional_inputs[0]) if flat_additional_inputs else True
693    for tensor in flat_additional_inputs:
694      if backend.is_keras_tensor(tensor) != is_keras_tensor:
695        raise ValueError('The initial state or constants of an RNN'
696                         ' layer cannot be specified with a mix of'
697                         ' Keras tensors and non-Keras tensors'
698                         ' (a "Keras tensor" is a tensor that was'
699                         ' returned by a Keras layer, or by `Input`)')
700
701    if is_keras_tensor:
702      # Compute the full input spec, including state and constants
703      full_input = [inputs] + additional_inputs
704      if self.built:
705        # Keep the input_spec since it has been populated in build() method.
706        full_input_spec = self.input_spec + additional_specs
707      else:
708        # The original input_spec is None since there could be a nested tensor
709        # input. Update the input_spec to match the inputs.
710        full_input_spec = generic_utils.to_list(
711            nest.map_structure(lambda _: None, inputs)) + additional_specs
712      # Perform the call with temporarily replaced input_spec
713      self.input_spec = full_input_spec
714      output = super(RNN, self).__call__(full_input, **kwargs)
715      # Remove the additional_specs from input spec and keep the rest. It is
716      # important to keep since the input spec was populated by build(), and
717      # will be reused in the stateful=True.
718      self.input_spec = self.input_spec[:-len(additional_specs)]
719      return output
720    else:
721      if initial_state is not None:
722        kwargs['initial_state'] = initial_state
723      if constants is not None:
724        kwargs['constants'] = constants
725      return super(RNN, self).__call__(inputs, **kwargs)
726
727  def call(self,
728           inputs,
729           mask=None,
730           training=None,
731           initial_state=None,
732           constants=None):
733    # The input should be dense, padded with zeros. If a ragged input is fed
734    # into the layer, it is padded and the row lengths are used for masking.
735    inputs, row_lengths = backend.convert_inputs_if_ragged(inputs)
736    is_ragged_input = (row_lengths is not None)
737    self._validate_args_if_ragged(is_ragged_input, mask)
738
739    inputs, initial_state, constants = self._process_inputs(
740        inputs, initial_state, constants)
741
742    self._maybe_reset_cell_dropout_mask(self.cell)
743    if isinstance(self.cell, StackedRNNCells):
744      for cell in self.cell.cells:
745        self._maybe_reset_cell_dropout_mask(cell)
746
747    if mask is not None:
748      # Time step masks must be the same for each input.
749      # TODO(scottzhu): Should we accept multiple different masks?
750      mask = nest.flatten(mask)[0]
751
752    if nest.is_nested(inputs):
753      # In the case of nested input, use the first element for shape check.
754      input_shape = backend.int_shape(nest.flatten(inputs)[0])
755    else:
756      input_shape = backend.int_shape(inputs)
757    timesteps = input_shape[0] if self.time_major else input_shape[1]
758    if self.unroll and timesteps is None:
759      raise ValueError('Cannot unroll a RNN if the '
760                       'time dimension is undefined. \n'
761                       '- If using a Sequential model, '
762                       'specify the time dimension by passing '
763                       'an `input_shape` or `batch_input_shape` '
764                       'argument to your first layer. If your '
765                       'first layer is an Embedding, you can '
766                       'also use the `input_length` argument.\n'
767                       '- If using the functional API, specify '
768                       'the time dimension by passing a `shape` '
769                       'or `batch_shape` argument to your Input layer.')
770
771    kwargs = {}
772    if generic_utils.has_arg(self.cell.call, 'training'):
773      kwargs['training'] = training
774
775    # TF RNN cells expect single tensor as state instead of list wrapped tensor.
776    is_tf_rnn_cell = getattr(self.cell, '_is_tf_rnn_cell', None) is not None
777    # Use the __call__ function for callable objects, eg layers, so that it
778    # will have the proper name scopes for the ops, etc.
779    cell_call_fn = self.cell.__call__ if callable(self.cell) else self.cell.call
780    if constants:
781      if not generic_utils.has_arg(self.cell.call, 'constants'):
782        raise ValueError('RNN cell does not support constants')
783
784      def step(inputs, states):
785        constants = states[-self._num_constants:]  # pylint: disable=invalid-unary-operand-type
786        states = states[:-self._num_constants]  # pylint: disable=invalid-unary-operand-type
787
788        states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
789        output, new_states = cell_call_fn(
790            inputs, states, constants=constants, **kwargs)
791        if not nest.is_nested(new_states):
792          new_states = [new_states]
793        return output, new_states
794    else:
795
796      def step(inputs, states):
797        states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
798        output, new_states = cell_call_fn(inputs, states, **kwargs)
799        if not nest.is_nested(new_states):
800          new_states = [new_states]
801        return output, new_states
802    last_output, outputs, states = backend.rnn(
803        step,
804        inputs,
805        initial_state,
806        constants=constants,
807        go_backwards=self.go_backwards,
808        mask=mask,
809        unroll=self.unroll,
810        input_length=row_lengths if row_lengths is not None else timesteps,
811        time_major=self.time_major,
812        zero_output_for_mask=self.zero_output_for_mask)
813
814    if self.stateful:
815      updates = [
816          state_ops.assign(self_state, state) for self_state, state in zip(
817              nest.flatten(self.states), nest.flatten(states))
818      ]
819      self.add_update(updates)
820
821    if self.return_sequences:
822      output = backend.maybe_convert_to_ragged(
823          is_ragged_input, outputs, row_lengths, go_backwards=self.go_backwards)
824    else:
825      output = last_output
826
827    if self.return_state:
828      if not isinstance(states, (list, tuple)):
829        states = [states]
830      else:
831        states = list(states)
832      return generic_utils.to_list(output) + states
833    else:
834      return output
835
836  def _process_inputs(self, inputs, initial_state, constants):
837    # input shape: `(samples, time (padded with zeros), input_dim)`
838    # note that the .build() method of subclasses MUST define
839    # self.input_spec and self.state_spec with complete input shapes.
840    if (isinstance(inputs, collections.abc.Sequence)
841        and not isinstance(inputs, tuple)):
842      # get initial_state from full input spec
843      # as they could be copied to multiple GPU.
844      if not self._num_constants:
845        initial_state = inputs[1:]
846      else:
847        initial_state = inputs[1:-self._num_constants]
848        constants = inputs[-self._num_constants:]
849      if len(initial_state) == 0:
850        initial_state = None
851      inputs = inputs[0]
852
853    if self.stateful:
854      if initial_state is not None:
855        # When layer is stateful and initial_state is provided, check if the
856        # recorded state is same as the default value (zeros). Use the recorded
857        # state if it is not same as the default.
858        non_zero_count = math_ops.add_n([math_ops.count_nonzero_v2(s)
859                                         for s in nest.flatten(self.states)])
860        # Set strict = True to keep the original structure of the state.
861        initial_state = control_flow_ops.cond(non_zero_count > 0,
862                                              true_fn=lambda: self.states,
863                                              false_fn=lambda: initial_state,
864                                              strict=True)
865      else:
866        initial_state = self.states
867    elif initial_state is None:
868      initial_state = self.get_initial_state(inputs)
869
870    if len(initial_state) != len(self.states):
871      raise ValueError('Layer has ' + str(len(self.states)) +
872                       ' states but was passed ' + str(len(initial_state)) +
873                       ' initial states.')
874    return inputs, initial_state, constants
875
876  def _validate_args_if_ragged(self, is_ragged_input, mask):
877    if not is_ragged_input:
878      return
879
880    if mask is not None:
881      raise ValueError('The mask that was passed in was ' + str(mask) +
882                       ' and cannot be applied to RaggedTensor inputs. Please '
883                       'make sure that there is no mask passed in by upstream '
884                       'layers.')
885    if self.unroll:
886      raise ValueError('The input received contains RaggedTensors and does '
887                       'not support unrolling. Disable unrolling by passing '
888                       '`unroll=False` in the RNN Layer constructor.')
889
890  def _maybe_reset_cell_dropout_mask(self, cell):
891    if isinstance(cell, DropoutRNNCellMixin):
892      cell.reset_dropout_mask()
893      cell.reset_recurrent_dropout_mask()
894
895  def reset_states(self, states=None):
896    """Reset the recorded states for the stateful RNN layer.
897
898    Can only be used when RNN layer is constructed with `stateful` = `True`.
899    Args:
900      states: Numpy arrays that contains the value for the initial state, which
901        will be feed to cell at the first time step. When the value is None,
902        zero filled numpy array will be created based on the cell state size.
903
904    Raises:
905      AttributeError: When the RNN layer is not stateful.
906      ValueError: When the batch size of the RNN layer is unknown.
907      ValueError: When the input numpy array is not compatible with the RNN
908        layer state, either size wise or dtype wise.
909    """
910    if not self.stateful:
911      raise AttributeError('Layer must be stateful.')
912    spec_shape = None
913    if self.input_spec is not None:
914      spec_shape = nest.flatten(self.input_spec[0])[0].shape
915    if spec_shape is None:
916      # It is possible to have spec shape to be None, eg when construct a RNN
917      # with a custom cell, or standard RNN layers (LSTM/GRU) which we only know
918      # it has 3 dim input, but not its full shape spec before build().
919      batch_size = None
920    else:
921      batch_size = spec_shape[1] if self.time_major else spec_shape[0]
922    if not batch_size:
923      raise ValueError('If a RNN is stateful, it needs to know '
924                       'its batch size. Specify the batch size '
925                       'of your input tensors: \n'
926                       '- If using a Sequential model, '
927                       'specify the batch size by passing '
928                       'a `batch_input_shape` '
929                       'argument to your first layer.\n'
930                       '- If using the functional API, specify '
931                       'the batch size by passing a '
932                       '`batch_shape` argument to your Input layer.')
933    # initialize state if None
934    if nest.flatten(self.states)[0] is None:
935      if getattr(self.cell, 'get_initial_state', None):
936        flat_init_state_values = nest.flatten(self.cell.get_initial_state(
937            inputs=None, batch_size=batch_size,
938            dtype=self.dtype or backend.floatx()))
939      else:
940        flat_init_state_values = nest.flatten(_generate_zero_filled_state(
941            batch_size, self.cell.state_size, self.dtype or backend.floatx()))
942      flat_states_variables = nest.map_structure(
943          backend.variable, flat_init_state_values)
944      self.states = nest.pack_sequence_as(self.cell.state_size,
945                                          flat_states_variables)
946      if not nest.is_nested(self.states):
947        self.states = [self.states]
948    elif states is None:
949      for state, size in zip(nest.flatten(self.states),
950                             nest.flatten(self.cell.state_size)):
951        backend.set_value(
952            state,
953            np.zeros([batch_size] + tensor_shape.TensorShape(size).as_list()))
954    else:
955      flat_states = nest.flatten(self.states)
956      flat_input_states = nest.flatten(states)
957      if len(flat_input_states) != len(flat_states):
958        raise ValueError('Layer ' + self.name + ' expects ' +
959                         str(len(flat_states)) + ' states, '
960                         'but it received ' + str(len(flat_input_states)) +
961                         ' state values. Input received: ' + str(states))
962      set_value_tuples = []
963      for i, (value, state) in enumerate(zip(flat_input_states,
964                                             flat_states)):
965        if value.shape != state.shape:
966          raise ValueError(
967              'State ' + str(i) + ' is incompatible with layer ' +
968              self.name + ': expected shape=' + str(
969                  (batch_size, state)) + ', found shape=' + str(value.shape))
970        set_value_tuples.append((state, value))
971      backend.batch_set_value(set_value_tuples)
972
973  def get_config(self):
974    config = {
975        'return_sequences': self.return_sequences,
976        'return_state': self.return_state,
977        'go_backwards': self.go_backwards,
978        'stateful': self.stateful,
979        'unroll': self.unroll,
980        'time_major': self.time_major
981    }
982    if self._num_constants:
983      config['num_constants'] = self._num_constants
984    if self.zero_output_for_mask:
985      config['zero_output_for_mask'] = self.zero_output_for_mask
986
987    config['cell'] = generic_utils.serialize_keras_object(self.cell)
988    base_config = super(RNN, self).get_config()
989    return dict(list(base_config.items()) + list(config.items()))
990
991  @classmethod
992  def from_config(cls, config, custom_objects=None):
993    from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
994    cell = deserialize_layer(config.pop('cell'), custom_objects=custom_objects)
995    num_constants = config.pop('num_constants', 0)
996    layer = cls(cell, **config)
997    layer._num_constants = num_constants
998    return layer
999
1000  @property
1001  def _trackable_saved_model_saver(self):
1002    return layer_serialization.RNNSavedModelSaver(self)
1003
1004
1005@keras_export('keras.layers.AbstractRNNCell')
1006class AbstractRNNCell(Layer):
1007  """Abstract object representing an RNN cell.
1008
1009  See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
1010  for details about the usage of RNN API.
1011
1012  This is the base class for implementing RNN cells with custom behavior.
1013
1014  Every `RNNCell` must have the properties below and implement `call` with
1015  the signature `(output, next_state) = call(input, state)`.
1016
1017  Examples:
1018
1019  ```python
1020    class MinimalRNNCell(AbstractRNNCell):
1021
1022      def __init__(self, units, **kwargs):
1023        self.units = units
1024        super(MinimalRNNCell, self).__init__(**kwargs)
1025
1026      @property
1027      def state_size(self):
1028        return self.units
1029
1030      def build(self, input_shape):
1031        self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
1032                                      initializer='uniform',
1033                                      name='kernel')
1034        self.recurrent_kernel = self.add_weight(
1035            shape=(self.units, self.units),
1036            initializer='uniform',
1037            name='recurrent_kernel')
1038        self.built = True
1039
1040      def call(self, inputs, states):
1041        prev_output = states[0]
1042        h = backend.dot(inputs, self.kernel)
1043        output = h + backend.dot(prev_output, self.recurrent_kernel)
1044        return output, output
1045  ```
1046
1047  This definition of cell differs from the definition used in the literature.
1048  In the literature, 'cell' refers to an object with a single scalar output.
1049  This definition refers to a horizontal array of such units.
1050
1051  An RNN cell, in the most abstract setting, is anything that has
1052  a state and performs some operation that takes a matrix of inputs.
1053  This operation results in an output matrix with `self.output_size` columns.
1054  If `self.state_size` is an integer, this operation also results in a new
1055  state matrix with `self.state_size` columns.  If `self.state_size` is a
1056  (possibly nested tuple of) TensorShape object(s), then it should return a
1057  matching structure of Tensors having shape `[batch_size].concatenate(s)`
1058  for each `s` in `self.batch_size`.
1059  """
1060
1061  def call(self, inputs, states):
1062    """The function that contains the logic for one RNN step calculation.
1063
1064    Args:
1065      inputs: the input tensor, which is a slide from the overall RNN input by
1066        the time dimension (usually the second dimension).
1067      states: the state tensor from previous step, which has the same shape
1068        as `(batch, state_size)`. In the case of timestep 0, it will be the
1069        initial state user specified, or zero filled tensor otherwise.
1070
1071    Returns:
1072      A tuple of two tensors:
1073        1. output tensor for the current timestep, with size `output_size`.
1074        2. state tensor for next step, which has the shape of `state_size`.
1075    """
1076    raise NotImplementedError('Abstract method')
1077
1078  @property
1079  def state_size(self):
1080    """size(s) of state(s) used by this cell.
1081
1082    It can be represented by an Integer, a TensorShape or a tuple of Integers
1083    or TensorShapes.
1084    """
1085    raise NotImplementedError('Abstract method')
1086
1087  @property
1088  def output_size(self):
1089    """Integer or TensorShape: size of outputs produced by this cell."""
1090    raise NotImplementedError('Abstract method')
1091
1092  def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
1093    return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
1094
1095
1096@doc_controls.do_not_generate_docs
1097class DropoutRNNCellMixin(object):
1098  """Object that hold dropout related fields for RNN Cell.
1099
1100  This class is not a standalone RNN cell. It suppose to be used with a RNN cell
1101  by multiple inheritance. Any cell that mix with class should have following
1102  fields:
1103    dropout: a float number within range [0, 1). The ratio that the input
1104      tensor need to dropout.
1105    recurrent_dropout: a float number within range [0, 1). The ratio that the
1106      recurrent state weights need to dropout.
1107  This object will create and cache created dropout masks, and reuse them for
1108  the incoming data, so that the same mask is used for every batch input.
1109  """
1110
1111  def __init__(self, *args, **kwargs):
1112    self._create_non_trackable_mask_cache()
1113    super(DropoutRNNCellMixin, self).__init__(*args, **kwargs)
1114
1115  @trackable.no_automatic_dependency_tracking
1116  def _create_non_trackable_mask_cache(self):
1117    """Create the cache for dropout and recurrent dropout mask.
1118
1119    Note that the following two masks will be used in "graph function" mode,
1120    e.g. these masks are symbolic tensors. In eager mode, the `eager_*_mask`
1121    tensors will be generated differently than in the "graph function" case,
1122    and they will be cached.
1123
1124    Also note that in graph mode, we still cache those masks only because the
1125    RNN could be created with `unroll=True`. In that case, the `cell.call()`
1126    function will be invoked multiple times, and we want to ensure same mask
1127    is used every time.
1128
1129    Also the caches are created without tracking. Since they are not picklable
1130    by python when deepcopy, we don't want `layer._obj_reference_counts_dict`
1131    to track it by default.
1132    """
1133    self._dropout_mask_cache = backend.ContextValueCache(
1134        self._create_dropout_mask)
1135    self._recurrent_dropout_mask_cache = backend.ContextValueCache(
1136        self._create_recurrent_dropout_mask)
1137
1138  def reset_dropout_mask(self):
1139    """Reset the cached dropout masks if any.
1140
1141    This is important for the RNN layer to invoke this in it `call()` method so
1142    that the cached mask is cleared before calling the `cell.call()`. The mask
1143    should be cached across the timestep within the same batch, but shouldn't
1144    be cached between batches. Otherwise it will introduce unreasonable bias
1145    against certain index of data within the batch.
1146    """
1147    self._dropout_mask_cache.clear()
1148
1149  def reset_recurrent_dropout_mask(self):
1150    """Reset the cached recurrent dropout masks if any.
1151
1152    This is important for the RNN layer to invoke this in it call() method so
1153    that the cached mask is cleared before calling the cell.call(). The mask
1154    should be cached across the timestep within the same batch, but shouldn't
1155    be cached between batches. Otherwise it will introduce unreasonable bias
1156    against certain index of data within the batch.
1157    """
1158    self._recurrent_dropout_mask_cache.clear()
1159
1160  def _create_dropout_mask(self, inputs, training, count=1):
1161    return _generate_dropout_mask(
1162        array_ops.ones_like(inputs),
1163        self.dropout,
1164        training=training,
1165        count=count)
1166
1167  def _create_recurrent_dropout_mask(self, inputs, training, count=1):
1168    return _generate_dropout_mask(
1169        array_ops.ones_like(inputs),
1170        self.recurrent_dropout,
1171        training=training,
1172        count=count)
1173
1174  def get_dropout_mask_for_cell(self, inputs, training, count=1):
1175    """Get the dropout mask for RNN cell's input.
1176
1177    It will create mask based on context if there isn't any existing cached
1178    mask. If a new mask is generated, it will update the cache in the cell.
1179
1180    Args:
1181      inputs: The input tensor whose shape will be used to generate dropout
1182        mask.
1183      training: Boolean tensor, whether its in training mode, dropout will be
1184        ignored in non-training mode.
1185      count: Int, how many dropout mask will be generated. It is useful for cell
1186        that has internal weights fused together.
1187    Returns:
1188      List of mask tensor, generated or cached mask based on context.
1189    """
1190    if self.dropout == 0:
1191      return None
1192    init_kwargs = dict(inputs=inputs, training=training, count=count)
1193    return self._dropout_mask_cache.setdefault(kwargs=init_kwargs)
1194
1195  def get_recurrent_dropout_mask_for_cell(self, inputs, training, count=1):
1196    """Get the recurrent dropout mask for RNN cell.
1197
1198    It will create mask based on context if there isn't any existing cached
1199    mask. If a new mask is generated, it will update the cache in the cell.
1200
1201    Args:
1202      inputs: The input tensor whose shape will be used to generate dropout
1203        mask.
1204      training: Boolean tensor, whether its in training mode, dropout will be
1205        ignored in non-training mode.
1206      count: Int, how many dropout mask will be generated. It is useful for cell
1207        that has internal weights fused together.
1208    Returns:
1209      List of mask tensor, generated or cached mask based on context.
1210    """
1211    if self.recurrent_dropout == 0:
1212      return None
1213    init_kwargs = dict(inputs=inputs, training=training, count=count)
1214    return self._recurrent_dropout_mask_cache.setdefault(kwargs=init_kwargs)
1215
1216  def __getstate__(self):
1217    # Used for deepcopy. The caching can't be pickled by python, since it will
1218    # contain tensor and graph.
1219    state = super(DropoutRNNCellMixin, self).__getstate__()
1220    state.pop('_dropout_mask_cache', None)
1221    state.pop('_recurrent_dropout_mask_cache', None)
1222    return state
1223
1224  def __setstate__(self, state):
1225    state['_dropout_mask_cache'] = backend.ContextValueCache(
1226        self._create_dropout_mask)
1227    state['_recurrent_dropout_mask_cache'] = backend.ContextValueCache(
1228        self._create_recurrent_dropout_mask)
1229    super(DropoutRNNCellMixin, self).__setstate__(state)
1230
1231
1232@keras_export('keras.layers.SimpleRNNCell')
1233class SimpleRNNCell(DropoutRNNCellMixin, Layer):
1234  """Cell class for SimpleRNN.
1235
1236  See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
1237  for details about the usage of RNN API.
1238
1239  This class processes one step within the whole time sequence input, whereas
1240  `tf.keras.layer.SimpleRNN` processes the whole sequence.
1241
1242  Args:
1243    units: Positive integer, dimensionality of the output space.
1244    activation: Activation function to use.
1245      Default: hyperbolic tangent (`tanh`).
1246      If you pass `None`, no activation is applied
1247      (ie. "linear" activation: `a(x) = x`).
1248    use_bias: Boolean, (default `True`), whether the layer uses a bias vector.
1249    kernel_initializer: Initializer for the `kernel` weights matrix,
1250      used for the linear transformation of the inputs. Default:
1251      `glorot_uniform`.
1252    recurrent_initializer: Initializer for the `recurrent_kernel`
1253      weights matrix, used for the linear transformation of the recurrent state.
1254      Default: `orthogonal`.
1255    bias_initializer: Initializer for the bias vector. Default: `zeros`.
1256    kernel_regularizer: Regularizer function applied to the `kernel` weights
1257      matrix. Default: `None`.
1258    recurrent_regularizer: Regularizer function applied to the
1259      `recurrent_kernel` weights matrix. Default: `None`.
1260    bias_regularizer: Regularizer function applied to the bias vector. Default:
1261      `None`.
1262    kernel_constraint: Constraint function applied to the `kernel` weights
1263      matrix. Default: `None`.
1264    recurrent_constraint: Constraint function applied to the `recurrent_kernel`
1265      weights matrix. Default: `None`.
1266    bias_constraint: Constraint function applied to the bias vector. Default:
1267      `None`.
1268    dropout: Float between 0 and 1. Fraction of the units to drop for the linear
1269      transformation of the inputs. Default: 0.
1270    recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for
1271      the linear transformation of the recurrent state. Default: 0.
1272
1273  Call arguments:
1274    inputs: A 2D tensor, with shape of `[batch, feature]`.
1275    states: A 2D tensor with shape of `[batch, units]`, which is the state from
1276      the previous time step. For timestep 0, the initial state provided by user
1277      will be feed to cell.
1278    training: Python boolean indicating whether the layer should behave in
1279      training mode or in inference mode. Only relevant when `dropout` or
1280      `recurrent_dropout` is used.
1281
1282  Examples:
1283
1284  ```python
1285  inputs = np.random.random([32, 10, 8]).astype(np.float32)
1286  rnn = tf.keras.layers.RNN(tf.keras.layers.SimpleRNNCell(4))
1287
1288  output = rnn(inputs)  # The output has shape `[32, 4]`.
1289
1290  rnn = tf.keras.layers.RNN(
1291      tf.keras.layers.SimpleRNNCell(4),
1292      return_sequences=True,
1293      return_state=True)
1294
1295  # whole_sequence_output has shape `[32, 10, 4]`.
1296  # final_state has shape `[32, 4]`.
1297  whole_sequence_output, final_state = rnn(inputs)
1298  ```
1299  """
1300
1301  def __init__(self,
1302               units,
1303               activation='tanh',
1304               use_bias=True,
1305               kernel_initializer='glorot_uniform',
1306               recurrent_initializer='orthogonal',
1307               bias_initializer='zeros',
1308               kernel_regularizer=None,
1309               recurrent_regularizer=None,
1310               bias_regularizer=None,
1311               kernel_constraint=None,
1312               recurrent_constraint=None,
1313               bias_constraint=None,
1314               dropout=0.,
1315               recurrent_dropout=0.,
1316               **kwargs):
1317    if units < 0:
1318      raise ValueError(f'Received an invalid value for units, expected '
1319                       f'a positive integer, got {units}.')
1320    # By default use cached variable under v2 mode, see b/143699808.
1321    if ops.executing_eagerly_outside_functions():
1322      self._enable_caching_device = kwargs.pop('enable_caching_device', True)
1323    else:
1324      self._enable_caching_device = kwargs.pop('enable_caching_device', False)
1325    super(SimpleRNNCell, self).__init__(**kwargs)
1326    self.units = units
1327    self.activation = activations.get(activation)
1328    self.use_bias = use_bias
1329
1330    self.kernel_initializer = initializers.get(kernel_initializer)
1331    self.recurrent_initializer = initializers.get(recurrent_initializer)
1332    self.bias_initializer = initializers.get(bias_initializer)
1333
1334    self.kernel_regularizer = regularizers.get(kernel_regularizer)
1335    self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
1336    self.bias_regularizer = regularizers.get(bias_regularizer)
1337
1338    self.kernel_constraint = constraints.get(kernel_constraint)
1339    self.recurrent_constraint = constraints.get(recurrent_constraint)
1340    self.bias_constraint = constraints.get(bias_constraint)
1341
1342    self.dropout = min(1., max(0., dropout))
1343    self.recurrent_dropout = min(1., max(0., recurrent_dropout))
1344    self.state_size = self.units
1345    self.output_size = self.units
1346
1347  @tf_utils.shape_type_conversion
1348  def build(self, input_shape):
1349    default_caching_device = _caching_device(self)
1350    self.kernel = self.add_weight(
1351        shape=(input_shape[-1], self.units),
1352        name='kernel',
1353        initializer=self.kernel_initializer,
1354        regularizer=self.kernel_regularizer,
1355        constraint=self.kernel_constraint,
1356        caching_device=default_caching_device)
1357    self.recurrent_kernel = self.add_weight(
1358        shape=(self.units, self.units),
1359        name='recurrent_kernel',
1360        initializer=self.recurrent_initializer,
1361        regularizer=self.recurrent_regularizer,
1362        constraint=self.recurrent_constraint,
1363        caching_device=default_caching_device)
1364    if self.use_bias:
1365      self.bias = self.add_weight(
1366          shape=(self.units,),
1367          name='bias',
1368          initializer=self.bias_initializer,
1369          regularizer=self.bias_regularizer,
1370          constraint=self.bias_constraint,
1371          caching_device=default_caching_device)
1372    else:
1373      self.bias = None
1374    self.built = True
1375
1376  def call(self, inputs, states, training=None):
1377    prev_output = states[0] if nest.is_nested(states) else states
1378    dp_mask = self.get_dropout_mask_for_cell(inputs, training)
1379    rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
1380        prev_output, training)
1381
1382    if dp_mask is not None:
1383      h = backend.dot(inputs * dp_mask, self.kernel)
1384    else:
1385      h = backend.dot(inputs, self.kernel)
1386    if self.bias is not None:
1387      h = backend.bias_add(h, self.bias)
1388
1389    if rec_dp_mask is not None:
1390      prev_output = prev_output * rec_dp_mask
1391    output = h + backend.dot(prev_output, self.recurrent_kernel)
1392    if self.activation is not None:
1393      output = self.activation(output)
1394
1395    new_state = [output] if nest.is_nested(states) else output
1396    return output, new_state
1397
1398  def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
1399    return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
1400
1401  def get_config(self):
1402    config = {
1403        'units':
1404            self.units,
1405        'activation':
1406            activations.serialize(self.activation),
1407        'use_bias':
1408            self.use_bias,
1409        'kernel_initializer':
1410            initializers.serialize(self.kernel_initializer),
1411        'recurrent_initializer':
1412            initializers.serialize(self.recurrent_initializer),
1413        'bias_initializer':
1414            initializers.serialize(self.bias_initializer),
1415        'kernel_regularizer':
1416            regularizers.serialize(self.kernel_regularizer),
1417        'recurrent_regularizer':
1418            regularizers.serialize(self.recurrent_regularizer),
1419        'bias_regularizer':
1420            regularizers.serialize(self.bias_regularizer),
1421        'kernel_constraint':
1422            constraints.serialize(self.kernel_constraint),
1423        'recurrent_constraint':
1424            constraints.serialize(self.recurrent_constraint),
1425        'bias_constraint':
1426            constraints.serialize(self.bias_constraint),
1427        'dropout':
1428            self.dropout,
1429        'recurrent_dropout':
1430            self.recurrent_dropout
1431    }
1432    config.update(_config_for_enable_caching_device(self))
1433    base_config = super(SimpleRNNCell, self).get_config()
1434    return dict(list(base_config.items()) + list(config.items()))
1435
1436
1437@keras_export('keras.layers.SimpleRNN')
1438class SimpleRNN(RNN):
1439  """Fully-connected RNN where the output is to be fed back to input.
1440
1441  See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
1442  for details about the usage of RNN API.
1443
1444  Args:
1445    units: Positive integer, dimensionality of the output space.
1446    activation: Activation function to use.
1447      Default: hyperbolic tangent (`tanh`).
1448      If you pass None, no activation is applied
1449      (ie. "linear" activation: `a(x) = x`).
1450    use_bias: Boolean, (default `True`), whether the layer uses a bias vector.
1451    kernel_initializer: Initializer for the `kernel` weights matrix,
1452      used for the linear transformation of the inputs. Default:
1453      `glorot_uniform`.
1454    recurrent_initializer: Initializer for the `recurrent_kernel`
1455      weights matrix, used for the linear transformation of the recurrent state.
1456      Default: `orthogonal`.
1457    bias_initializer: Initializer for the bias vector. Default: `zeros`.
1458    kernel_regularizer: Regularizer function applied to the `kernel` weights
1459      matrix. Default: `None`.
1460    recurrent_regularizer: Regularizer function applied to the
1461      `recurrent_kernel` weights matrix. Default: `None`.
1462    bias_regularizer: Regularizer function applied to the bias vector. Default:
1463      `None`.
1464    activity_regularizer: Regularizer function applied to the output of the
1465      layer (its "activation"). Default: `None`.
1466    kernel_constraint: Constraint function applied to the `kernel` weights
1467      matrix. Default: `None`.
1468    recurrent_constraint: Constraint function applied to the `recurrent_kernel`
1469      weights matrix.  Default: `None`.
1470    bias_constraint: Constraint function applied to the bias vector. Default:
1471      `None`.
1472    dropout: Float between 0 and 1.
1473      Fraction of the units to drop for the linear transformation of the inputs.
1474      Default: 0.
1475    recurrent_dropout: Float between 0 and 1.
1476      Fraction of the units to drop for the linear transformation of the
1477      recurrent state. Default: 0.
1478    return_sequences: Boolean. Whether to return the last output
1479      in the output sequence, or the full sequence. Default: `False`.
1480    return_state: Boolean. Whether to return the last state
1481      in addition to the output. Default: `False`
1482    go_backwards: Boolean (default False).
1483      If True, process the input sequence backwards and return the
1484      reversed sequence.
1485    stateful: Boolean (default False). If True, the last state
1486      for each sample at index i in a batch will be used as initial
1487      state for the sample of index i in the following batch.
1488    unroll: Boolean (default False).
1489      If True, the network will be unrolled,
1490      else a symbolic loop will be used.
1491      Unrolling can speed-up a RNN,
1492      although it tends to be more memory-intensive.
1493      Unrolling is only suitable for short sequences.
1494
1495  Call arguments:
1496    inputs: A 3D tensor, with shape `[batch, timesteps, feature]`.
1497    mask: Binary tensor of shape `[batch, timesteps]` indicating whether
1498      a given timestep should be masked. An individual `True` entry indicates
1499      that the corresponding timestep should be utilized, while a `False` entry
1500      indicates that the corresponding timestep should be ignored.
1501    training: Python boolean indicating whether the layer should behave in
1502      training mode or in inference mode. This argument is passed to the cell
1503      when calling it. This is only relevant if `dropout` or
1504      `recurrent_dropout` is used.
1505    initial_state: List of initial state tensors to be passed to the first
1506      call of the cell.
1507
1508  Examples:
1509
1510  ```python
1511  inputs = np.random.random([32, 10, 8]).astype(np.float32)
1512  simple_rnn = tf.keras.layers.SimpleRNN(4)
1513
1514  output = simple_rnn(inputs)  # The output has shape `[32, 4]`.
1515
1516  simple_rnn = tf.keras.layers.SimpleRNN(
1517      4, return_sequences=True, return_state=True)
1518
1519  # whole_sequence_output has shape `[32, 10, 4]`.
1520  # final_state has shape `[32, 4]`.
1521  whole_sequence_output, final_state = simple_rnn(inputs)
1522  ```
1523  """
1524
1525  def __init__(self,
1526               units,
1527               activation='tanh',
1528               use_bias=True,
1529               kernel_initializer='glorot_uniform',
1530               recurrent_initializer='orthogonal',
1531               bias_initializer='zeros',
1532               kernel_regularizer=None,
1533               recurrent_regularizer=None,
1534               bias_regularizer=None,
1535               activity_regularizer=None,
1536               kernel_constraint=None,
1537               recurrent_constraint=None,
1538               bias_constraint=None,
1539               dropout=0.,
1540               recurrent_dropout=0.,
1541               return_sequences=False,
1542               return_state=False,
1543               go_backwards=False,
1544               stateful=False,
1545               unroll=False,
1546               **kwargs):
1547    if 'implementation' in kwargs:
1548      kwargs.pop('implementation')
1549      logging.warning('The `implementation` argument '
1550                      'in `SimpleRNN` has been deprecated. '
1551                      'Please remove it from your layer call.')
1552    if 'enable_caching_device' in kwargs:
1553      cell_kwargs = {'enable_caching_device':
1554                     kwargs.pop('enable_caching_device')}
1555    else:
1556      cell_kwargs = {}
1557    cell = SimpleRNNCell(
1558        units,
1559        activation=activation,
1560        use_bias=use_bias,
1561        kernel_initializer=kernel_initializer,
1562        recurrent_initializer=recurrent_initializer,
1563        bias_initializer=bias_initializer,
1564        kernel_regularizer=kernel_regularizer,
1565        recurrent_regularizer=recurrent_regularizer,
1566        bias_regularizer=bias_regularizer,
1567        kernel_constraint=kernel_constraint,
1568        recurrent_constraint=recurrent_constraint,
1569        bias_constraint=bias_constraint,
1570        dropout=dropout,
1571        recurrent_dropout=recurrent_dropout,
1572        dtype=kwargs.get('dtype'),
1573        trainable=kwargs.get('trainable', True),
1574        **cell_kwargs)
1575    super(SimpleRNN, self).__init__(
1576        cell,
1577        return_sequences=return_sequences,
1578        return_state=return_state,
1579        go_backwards=go_backwards,
1580        stateful=stateful,
1581        unroll=unroll,
1582        **kwargs)
1583    self.activity_regularizer = regularizers.get(activity_regularizer)
1584    self.input_spec = [InputSpec(ndim=3)]
1585
1586  def call(self, inputs, mask=None, training=None, initial_state=None):
1587    return super(SimpleRNN, self).call(
1588        inputs, mask=mask, training=training, initial_state=initial_state)
1589
1590  @property
1591  def units(self):
1592    return self.cell.units
1593
1594  @property
1595  def activation(self):
1596    return self.cell.activation
1597
1598  @property
1599  def use_bias(self):
1600    return self.cell.use_bias
1601
1602  @property
1603  def kernel_initializer(self):
1604    return self.cell.kernel_initializer
1605
1606  @property
1607  def recurrent_initializer(self):
1608    return self.cell.recurrent_initializer
1609
1610  @property
1611  def bias_initializer(self):
1612    return self.cell.bias_initializer
1613
1614  @property
1615  def kernel_regularizer(self):
1616    return self.cell.kernel_regularizer
1617
1618  @property
1619  def recurrent_regularizer(self):
1620    return self.cell.recurrent_regularizer
1621
1622  @property
1623  def bias_regularizer(self):
1624    return self.cell.bias_regularizer
1625
1626  @property
1627  def kernel_constraint(self):
1628    return self.cell.kernel_constraint
1629
1630  @property
1631  def recurrent_constraint(self):
1632    return self.cell.recurrent_constraint
1633
1634  @property
1635  def bias_constraint(self):
1636    return self.cell.bias_constraint
1637
1638  @property
1639  def dropout(self):
1640    return self.cell.dropout
1641
1642  @property
1643  def recurrent_dropout(self):
1644    return self.cell.recurrent_dropout
1645
1646  def get_config(self):
1647    config = {
1648        'units':
1649            self.units,
1650        'activation':
1651            activations.serialize(self.activation),
1652        'use_bias':
1653            self.use_bias,
1654        'kernel_initializer':
1655            initializers.serialize(self.kernel_initializer),
1656        'recurrent_initializer':
1657            initializers.serialize(self.recurrent_initializer),
1658        'bias_initializer':
1659            initializers.serialize(self.bias_initializer),
1660        'kernel_regularizer':
1661            regularizers.serialize(self.kernel_regularizer),
1662        'recurrent_regularizer':
1663            regularizers.serialize(self.recurrent_regularizer),
1664        'bias_regularizer':
1665            regularizers.serialize(self.bias_regularizer),
1666        'activity_regularizer':
1667            regularizers.serialize(self.activity_regularizer),
1668        'kernel_constraint':
1669            constraints.serialize(self.kernel_constraint),
1670        'recurrent_constraint':
1671            constraints.serialize(self.recurrent_constraint),
1672        'bias_constraint':
1673            constraints.serialize(self.bias_constraint),
1674        'dropout':
1675            self.dropout,
1676        'recurrent_dropout':
1677            self.recurrent_dropout
1678    }
1679    base_config = super(SimpleRNN, self).get_config()
1680    config.update(_config_for_enable_caching_device(self.cell))
1681    del base_config['cell']
1682    return dict(list(base_config.items()) + list(config.items()))
1683
1684  @classmethod
1685  def from_config(cls, config):
1686    if 'implementation' in config:
1687      config.pop('implementation')
1688    return cls(**config)
1689
1690
1691@keras_export(v1=['keras.layers.GRUCell'])
1692class GRUCell(DropoutRNNCellMixin, Layer):
1693  """Cell class for the GRU layer.
1694
1695  Args:
1696    units: Positive integer, dimensionality of the output space.
1697    activation: Activation function to use.
1698      Default: hyperbolic tangent (`tanh`).
1699      If you pass None, no activation is applied
1700      (ie. "linear" activation: `a(x) = x`).
1701    recurrent_activation: Activation function to use
1702      for the recurrent step.
1703      Default: hard sigmoid (`hard_sigmoid`).
1704      If you pass `None`, no activation is applied
1705      (ie. "linear" activation: `a(x) = x`).
1706    use_bias: Boolean, whether the layer uses a bias vector.
1707    kernel_initializer: Initializer for the `kernel` weights matrix,
1708      used for the linear transformation of the inputs.
1709    recurrent_initializer: Initializer for the `recurrent_kernel`
1710      weights matrix,
1711      used for the linear transformation of the recurrent state.
1712    bias_initializer: Initializer for the bias vector.
1713    kernel_regularizer: Regularizer function applied to
1714      the `kernel` weights matrix.
1715    recurrent_regularizer: Regularizer function applied to
1716      the `recurrent_kernel` weights matrix.
1717    bias_regularizer: Regularizer function applied to the bias vector.
1718    kernel_constraint: Constraint function applied to
1719      the `kernel` weights matrix.
1720    recurrent_constraint: Constraint function applied to
1721      the `recurrent_kernel` weights matrix.
1722    bias_constraint: Constraint function applied to the bias vector.
1723    dropout: Float between 0 and 1.
1724      Fraction of the units to drop for the linear transformation of the inputs.
1725    recurrent_dropout: Float between 0 and 1.
1726      Fraction of the units to drop for
1727      the linear transformation of the recurrent state.
1728    reset_after: GRU convention (whether to apply reset gate after or
1729      before matrix multiplication). False = "before" (default),
1730      True = "after" (CuDNN compatible).
1731
1732  Call arguments:
1733    inputs: A 2D tensor.
1734    states: List of state tensors corresponding to the previous timestep.
1735    training: Python boolean indicating whether the layer should behave in
1736      training mode or in inference mode. Only relevant when `dropout` or
1737      `recurrent_dropout` is used.
1738  """
1739
1740  def __init__(self,
1741               units,
1742               activation='tanh',
1743               recurrent_activation='hard_sigmoid',
1744               use_bias=True,
1745               kernel_initializer='glorot_uniform',
1746               recurrent_initializer='orthogonal',
1747               bias_initializer='zeros',
1748               kernel_regularizer=None,
1749               recurrent_regularizer=None,
1750               bias_regularizer=None,
1751               kernel_constraint=None,
1752               recurrent_constraint=None,
1753               bias_constraint=None,
1754               dropout=0.,
1755               recurrent_dropout=0.,
1756               reset_after=False,
1757               **kwargs):
1758    if units < 0:
1759      raise ValueError(f'Received an invalid value for units, expected '
1760                       f'a positive integer, got {units}.')
1761    # By default use cached variable under v2 mode, see b/143699808.
1762    if ops.executing_eagerly_outside_functions():
1763      self._enable_caching_device = kwargs.pop('enable_caching_device', True)
1764    else:
1765      self._enable_caching_device = kwargs.pop('enable_caching_device', False)
1766    super(GRUCell, self).__init__(**kwargs)
1767    self.units = units
1768    self.activation = activations.get(activation)
1769    self.recurrent_activation = activations.get(recurrent_activation)
1770    self.use_bias = use_bias
1771
1772    self.kernel_initializer = initializers.get(kernel_initializer)
1773    self.recurrent_initializer = initializers.get(recurrent_initializer)
1774    self.bias_initializer = initializers.get(bias_initializer)
1775
1776    self.kernel_regularizer = regularizers.get(kernel_regularizer)
1777    self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
1778    self.bias_regularizer = regularizers.get(bias_regularizer)
1779
1780    self.kernel_constraint = constraints.get(kernel_constraint)
1781    self.recurrent_constraint = constraints.get(recurrent_constraint)
1782    self.bias_constraint = constraints.get(bias_constraint)
1783
1784    self.dropout = min(1., max(0., dropout))
1785    self.recurrent_dropout = min(1., max(0., recurrent_dropout))
1786
1787    implementation = kwargs.pop('implementation', 1)
1788    if self.recurrent_dropout != 0 and implementation != 1:
1789      logging.debug(RECURRENT_DROPOUT_WARNING_MSG)
1790      self.implementation = 1
1791    else:
1792      self.implementation = implementation
1793    self.reset_after = reset_after
1794    self.state_size = self.units
1795    self.output_size = self.units
1796
1797  @tf_utils.shape_type_conversion
1798  def build(self, input_shape):
1799    input_dim = input_shape[-1]
1800    default_caching_device = _caching_device(self)
1801    self.kernel = self.add_weight(
1802        shape=(input_dim, self.units * 3),
1803        name='kernel',
1804        initializer=self.kernel_initializer,
1805        regularizer=self.kernel_regularizer,
1806        constraint=self.kernel_constraint,
1807        caching_device=default_caching_device)
1808    self.recurrent_kernel = self.add_weight(
1809        shape=(self.units, self.units * 3),
1810        name='recurrent_kernel',
1811        initializer=self.recurrent_initializer,
1812        regularizer=self.recurrent_regularizer,
1813        constraint=self.recurrent_constraint,
1814        caching_device=default_caching_device)
1815
1816    if self.use_bias:
1817      if not self.reset_after:
1818        bias_shape = (3 * self.units,)
1819      else:
1820        # separate biases for input and recurrent kernels
1821        # Note: the shape is intentionally different from CuDNNGRU biases
1822        # `(2 * 3 * self.units,)`, so that we can distinguish the classes
1823        # when loading and converting saved weights.
1824        bias_shape = (2, 3 * self.units)
1825      self.bias = self.add_weight(shape=bias_shape,
1826                                  name='bias',
1827                                  initializer=self.bias_initializer,
1828                                  regularizer=self.bias_regularizer,
1829                                  constraint=self.bias_constraint,
1830                                  caching_device=default_caching_device)
1831    else:
1832      self.bias = None
1833    self.built = True
1834
1835  def call(self, inputs, states, training=None):
1836    h_tm1 = states[0] if nest.is_nested(states) else states  # previous memory
1837
1838    dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=3)
1839    rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
1840        h_tm1, training, count=3)
1841
1842    if self.use_bias:
1843      if not self.reset_after:
1844        input_bias, recurrent_bias = self.bias, None
1845      else:
1846        input_bias, recurrent_bias = array_ops.unstack(self.bias)
1847
1848    if self.implementation == 1:
1849      if 0. < self.dropout < 1.:
1850        inputs_z = inputs * dp_mask[0]
1851        inputs_r = inputs * dp_mask[1]
1852        inputs_h = inputs * dp_mask[2]
1853      else:
1854        inputs_z = inputs
1855        inputs_r = inputs
1856        inputs_h = inputs
1857
1858      x_z = backend.dot(inputs_z, self.kernel[:, :self.units])
1859      x_r = backend.dot(inputs_r, self.kernel[:, self.units:self.units * 2])
1860      x_h = backend.dot(inputs_h, self.kernel[:, self.units * 2:])
1861
1862      if self.use_bias:
1863        x_z = backend.bias_add(x_z, input_bias[:self.units])
1864        x_r = backend.bias_add(x_r, input_bias[self.units: self.units * 2])
1865        x_h = backend.bias_add(x_h, input_bias[self.units * 2:])
1866
1867      if 0. < self.recurrent_dropout < 1.:
1868        h_tm1_z = h_tm1 * rec_dp_mask[0]
1869        h_tm1_r = h_tm1 * rec_dp_mask[1]
1870        h_tm1_h = h_tm1 * rec_dp_mask[2]
1871      else:
1872        h_tm1_z = h_tm1
1873        h_tm1_r = h_tm1
1874        h_tm1_h = h_tm1
1875
1876      recurrent_z = backend.dot(h_tm1_z, self.recurrent_kernel[:, :self.units])
1877      recurrent_r = backend.dot(
1878          h_tm1_r, self.recurrent_kernel[:, self.units:self.units * 2])
1879      if self.reset_after and self.use_bias:
1880        recurrent_z = backend.bias_add(recurrent_z, recurrent_bias[:self.units])
1881        recurrent_r = backend.bias_add(
1882            recurrent_r, recurrent_bias[self.units:self.units * 2])
1883
1884      z = self.recurrent_activation(x_z + recurrent_z)
1885      r = self.recurrent_activation(x_r + recurrent_r)
1886
1887      # reset gate applied after/before matrix multiplication
1888      if self.reset_after:
1889        recurrent_h = backend.dot(
1890            h_tm1_h, self.recurrent_kernel[:, self.units * 2:])
1891        if self.use_bias:
1892          recurrent_h = backend.bias_add(
1893              recurrent_h, recurrent_bias[self.units * 2:])
1894        recurrent_h = r * recurrent_h
1895      else:
1896        recurrent_h = backend.dot(
1897            r * h_tm1_h, self.recurrent_kernel[:, self.units * 2:])
1898
1899      hh = self.activation(x_h + recurrent_h)
1900    else:
1901      if 0. < self.dropout < 1.:
1902        inputs = inputs * dp_mask[0]
1903
1904      # inputs projected by all gate matrices at once
1905      matrix_x = backend.dot(inputs, self.kernel)
1906      if self.use_bias:
1907        # biases: bias_z_i, bias_r_i, bias_h_i
1908        matrix_x = backend.bias_add(matrix_x, input_bias)
1909
1910      x_z, x_r, x_h = array_ops.split(matrix_x, 3, axis=-1)
1911
1912      if self.reset_after:
1913        # hidden state projected by all gate matrices at once
1914        matrix_inner = backend.dot(h_tm1, self.recurrent_kernel)
1915        if self.use_bias:
1916          matrix_inner = backend.bias_add(matrix_inner, recurrent_bias)
1917      else:
1918        # hidden state projected separately for update/reset and new
1919        matrix_inner = backend.dot(
1920            h_tm1, self.recurrent_kernel[:, :2 * self.units])
1921
1922      recurrent_z, recurrent_r, recurrent_h = array_ops.split(
1923          matrix_inner, [self.units, self.units, -1], axis=-1)
1924
1925      z = self.recurrent_activation(x_z + recurrent_z)
1926      r = self.recurrent_activation(x_r + recurrent_r)
1927
1928      if self.reset_after:
1929        recurrent_h = r * recurrent_h
1930      else:
1931        recurrent_h = backend.dot(
1932            r * h_tm1, self.recurrent_kernel[:, 2 * self.units:])
1933
1934      hh = self.activation(x_h + recurrent_h)
1935    # previous and candidate state mixed by update gate
1936    h = z * h_tm1 + (1 - z) * hh
1937    new_state = [h] if nest.is_nested(states) else h
1938    return h, new_state
1939
1940  def get_config(self):
1941    config = {
1942        'units': self.units,
1943        'activation': activations.serialize(self.activation),
1944        'recurrent_activation':
1945            activations.serialize(self.recurrent_activation),
1946        'use_bias': self.use_bias,
1947        'kernel_initializer': initializers.serialize(self.kernel_initializer),
1948        'recurrent_initializer':
1949            initializers.serialize(self.recurrent_initializer),
1950        'bias_initializer': initializers.serialize(self.bias_initializer),
1951        'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
1952        'recurrent_regularizer':
1953            regularizers.serialize(self.recurrent_regularizer),
1954        'bias_regularizer': regularizers.serialize(self.bias_regularizer),
1955        'kernel_constraint': constraints.serialize(self.kernel_constraint),
1956        'recurrent_constraint':
1957            constraints.serialize(self.recurrent_constraint),
1958        'bias_constraint': constraints.serialize(self.bias_constraint),
1959        'dropout': self.dropout,
1960        'recurrent_dropout': self.recurrent_dropout,
1961        'implementation': self.implementation,
1962        'reset_after': self.reset_after
1963    }
1964    config.update(_config_for_enable_caching_device(self))
1965    base_config = super(GRUCell, self).get_config()
1966    return dict(list(base_config.items()) + list(config.items()))
1967
1968  def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
1969    return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
1970
1971
1972@keras_export(v1=['keras.layers.GRU'])
1973class GRU(RNN):
1974  """Gated Recurrent Unit - Cho et al. 2014.
1975
1976  There are two variants. The default one is based on 1406.1078v3 and
1977  has reset gate applied to hidden state before matrix multiplication. The
1978  other one is based on original 1406.1078v1 and has the order reversed.
1979
1980  The second variant is compatible with CuDNNGRU (GPU-only) and allows
1981  inference on CPU. Thus it has separate biases for `kernel` and
1982  `recurrent_kernel`. Use `'reset_after'=True` and
1983  `recurrent_activation='sigmoid'`.
1984
1985  Args:
1986    units: Positive integer, dimensionality of the output space.
1987    activation: Activation function to use.
1988      Default: hyperbolic tangent (`tanh`).
1989      If you pass `None`, no activation is applied
1990      (ie. "linear" activation: `a(x) = x`).
1991    recurrent_activation: Activation function to use
1992      for the recurrent step.
1993      Default: hard sigmoid (`hard_sigmoid`).
1994      If you pass `None`, no activation is applied
1995      (ie. "linear" activation: `a(x) = x`).
1996    use_bias: Boolean, whether the layer uses a bias vector.
1997    kernel_initializer: Initializer for the `kernel` weights matrix,
1998      used for the linear transformation of the inputs.
1999    recurrent_initializer: Initializer for the `recurrent_kernel`
2000      weights matrix, used for the linear transformation of the recurrent state.
2001    bias_initializer: Initializer for the bias vector.
2002    kernel_regularizer: Regularizer function applied to
2003      the `kernel` weights matrix.
2004    recurrent_regularizer: Regularizer function applied to
2005      the `recurrent_kernel` weights matrix.
2006    bias_regularizer: Regularizer function applied to the bias vector.
2007    activity_regularizer: Regularizer function applied to
2008      the output of the layer (its "activation")..
2009    kernel_constraint: Constraint function applied to
2010      the `kernel` weights matrix.
2011    recurrent_constraint: Constraint function applied to
2012      the `recurrent_kernel` weights matrix.
2013    bias_constraint: Constraint function applied to the bias vector.
2014    dropout: Float between 0 and 1.
2015      Fraction of the units to drop for
2016      the linear transformation of the inputs.
2017    recurrent_dropout: Float between 0 and 1.
2018      Fraction of the units to drop for
2019      the linear transformation of the recurrent state.
2020    return_sequences: Boolean. Whether to return the last output
2021      in the output sequence, or the full sequence.
2022    return_state: Boolean. Whether to return the last state
2023      in addition to the output.
2024    go_backwards: Boolean (default False).
2025      If True, process the input sequence backwards and return the
2026      reversed sequence.
2027    stateful: Boolean (default False). If True, the last state
2028      for each sample at index i in a batch will be used as initial
2029      state for the sample of index i in the following batch.
2030    unroll: Boolean (default False).
2031      If True, the network will be unrolled,
2032      else a symbolic loop will be used.
2033      Unrolling can speed-up a RNN,
2034      although it tends to be more memory-intensive.
2035      Unrolling is only suitable for short sequences.
2036    time_major: The shape format of the `inputs` and `outputs` tensors.
2037      If True, the inputs and outputs will be in shape
2038      `(timesteps, batch, ...)`, whereas in the False case, it will be
2039      `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
2040      efficient because it avoids transposes at the beginning and end of the
2041      RNN calculation. However, most TensorFlow data is batch-major, so by
2042      default this function accepts input and emits output in batch-major
2043      form.
2044    reset_after: GRU convention (whether to apply reset gate after or
2045      before matrix multiplication). False = "before" (default),
2046      True = "after" (CuDNN compatible).
2047
2048  Call arguments:
2049    inputs: A 3D tensor.
2050    mask: Binary tensor of shape `(samples, timesteps)` indicating whether
2051      a given timestep should be masked. An individual `True` entry indicates
2052      that the corresponding timestep should be utilized, while a `False`
2053      entry indicates that the corresponding timestep should be ignored.
2054    training: Python boolean indicating whether the layer should behave in
2055      training mode or in inference mode. This argument is passed to the cell
2056      when calling it. This is only relevant if `dropout` or
2057      `recurrent_dropout` is used.
2058    initial_state: List of initial state tensors to be passed to the first
2059      call of the cell.
2060  """
2061
2062  def __init__(self,
2063               units,
2064               activation='tanh',
2065               recurrent_activation='hard_sigmoid',
2066               use_bias=True,
2067               kernel_initializer='glorot_uniform',
2068               recurrent_initializer='orthogonal',
2069               bias_initializer='zeros',
2070               kernel_regularizer=None,
2071               recurrent_regularizer=None,
2072               bias_regularizer=None,
2073               activity_regularizer=None,
2074               kernel_constraint=None,
2075               recurrent_constraint=None,
2076               bias_constraint=None,
2077               dropout=0.,
2078               recurrent_dropout=0.,
2079               return_sequences=False,
2080               return_state=False,
2081               go_backwards=False,
2082               stateful=False,
2083               unroll=False,
2084               reset_after=False,
2085               **kwargs):
2086    implementation = kwargs.pop('implementation', 1)
2087    if implementation == 0:
2088      logging.warning('`implementation=0` has been deprecated, '
2089                      'and now defaults to `implementation=1`.'
2090                      'Please update your layer call.')
2091    if 'enable_caching_device' in kwargs:
2092      cell_kwargs = {'enable_caching_device':
2093                     kwargs.pop('enable_caching_device')}
2094    else:
2095      cell_kwargs = {}
2096    cell = GRUCell(
2097        units,
2098        activation=activation,
2099        recurrent_activation=recurrent_activation,
2100        use_bias=use_bias,
2101        kernel_initializer=kernel_initializer,
2102        recurrent_initializer=recurrent_initializer,
2103        bias_initializer=bias_initializer,
2104        kernel_regularizer=kernel_regularizer,
2105        recurrent_regularizer=recurrent_regularizer,
2106        bias_regularizer=bias_regularizer,
2107        kernel_constraint=kernel_constraint,
2108        recurrent_constraint=recurrent_constraint,
2109        bias_constraint=bias_constraint,
2110        dropout=dropout,
2111        recurrent_dropout=recurrent_dropout,
2112        implementation=implementation,
2113        reset_after=reset_after,
2114        dtype=kwargs.get('dtype'),
2115        trainable=kwargs.get('trainable', True),
2116        **cell_kwargs)
2117    super(GRU, self).__init__(
2118        cell,
2119        return_sequences=return_sequences,
2120        return_state=return_state,
2121        go_backwards=go_backwards,
2122        stateful=stateful,
2123        unroll=unroll,
2124        **kwargs)
2125    self.activity_regularizer = regularizers.get(activity_regularizer)
2126    self.input_spec = [InputSpec(ndim=3)]
2127
2128  def call(self, inputs, mask=None, training=None, initial_state=None):
2129    return super(GRU, self).call(
2130        inputs, mask=mask, training=training, initial_state=initial_state)
2131
2132  @property
2133  def units(self):
2134    return self.cell.units
2135
2136  @property
2137  def activation(self):
2138    return self.cell.activation
2139
2140  @property
2141  def recurrent_activation(self):
2142    return self.cell.recurrent_activation
2143
2144  @property
2145  def use_bias(self):
2146    return self.cell.use_bias
2147
2148  @property
2149  def kernel_initializer(self):
2150    return self.cell.kernel_initializer
2151
2152  @property
2153  def recurrent_initializer(self):
2154    return self.cell.recurrent_initializer
2155
2156  @property
2157  def bias_initializer(self):
2158    return self.cell.bias_initializer
2159
2160  @property
2161  def kernel_regularizer(self):
2162    return self.cell.kernel_regularizer
2163
2164  @property
2165  def recurrent_regularizer(self):
2166    return self.cell.recurrent_regularizer
2167
2168  @property
2169  def bias_regularizer(self):
2170    return self.cell.bias_regularizer
2171
2172  @property
2173  def kernel_constraint(self):
2174    return self.cell.kernel_constraint
2175
2176  @property
2177  def recurrent_constraint(self):
2178    return self.cell.recurrent_constraint
2179
2180  @property
2181  def bias_constraint(self):
2182    return self.cell.bias_constraint
2183
2184  @property
2185  def dropout(self):
2186    return self.cell.dropout
2187
2188  @property
2189  def recurrent_dropout(self):
2190    return self.cell.recurrent_dropout
2191
2192  @property
2193  def implementation(self):
2194    return self.cell.implementation
2195
2196  @property
2197  def reset_after(self):
2198    return self.cell.reset_after
2199
2200  def get_config(self):
2201    config = {
2202        'units':
2203            self.units,
2204        'activation':
2205            activations.serialize(self.activation),
2206        'recurrent_activation':
2207            activations.serialize(self.recurrent_activation),
2208        'use_bias':
2209            self.use_bias,
2210        'kernel_initializer':
2211            initializers.serialize(self.kernel_initializer),
2212        'recurrent_initializer':
2213            initializers.serialize(self.recurrent_initializer),
2214        'bias_initializer':
2215            initializers.serialize(self.bias_initializer),
2216        'kernel_regularizer':
2217            regularizers.serialize(self.kernel_regularizer),
2218        'recurrent_regularizer':
2219            regularizers.serialize(self.recurrent_regularizer),
2220        'bias_regularizer':
2221            regularizers.serialize(self.bias_regularizer),
2222        'activity_regularizer':
2223            regularizers.serialize(self.activity_regularizer),
2224        'kernel_constraint':
2225            constraints.serialize(self.kernel_constraint),
2226        'recurrent_constraint':
2227            constraints.serialize(self.recurrent_constraint),
2228        'bias_constraint':
2229            constraints.serialize(self.bias_constraint),
2230        'dropout':
2231            self.dropout,
2232        'recurrent_dropout':
2233            self.recurrent_dropout,
2234        'implementation':
2235            self.implementation,
2236        'reset_after':
2237            self.reset_after
2238    }
2239    config.update(_config_for_enable_caching_device(self.cell))
2240    base_config = super(GRU, self).get_config()
2241    del base_config['cell']
2242    return dict(list(base_config.items()) + list(config.items()))
2243
2244  @classmethod
2245  def from_config(cls, config):
2246    if 'implementation' in config and config['implementation'] == 0:
2247      config['implementation'] = 1
2248    return cls(**config)
2249
2250
2251@keras_export(v1=['keras.layers.LSTMCell'])
2252class LSTMCell(DropoutRNNCellMixin, Layer):
2253  """Cell class for the LSTM layer.
2254
2255  Args:
2256    units: Positive integer, dimensionality of the output space.
2257    activation: Activation function to use.
2258      Default: hyperbolic tangent (`tanh`).
2259      If you pass `None`, no activation is applied
2260      (ie. "linear" activation: `a(x) = x`).
2261    recurrent_activation: Activation function to use
2262      for the recurrent step.
2263      Default: hard sigmoid (`hard_sigmoid`).
2264      If you pass `None`, no activation is applied
2265      (ie. "linear" activation: `a(x) = x`).
2266    use_bias: Boolean, whether the layer uses a bias vector.
2267    kernel_initializer: Initializer for the `kernel` weights matrix,
2268      used for the linear transformation of the inputs.
2269    recurrent_initializer: Initializer for the `recurrent_kernel`
2270      weights matrix,
2271      used for the linear transformation of the recurrent state.
2272    bias_initializer: Initializer for the bias vector.
2273    unit_forget_bias: Boolean.
2274      If True, add 1 to the bias of the forget gate at initialization.
2275      Setting it to true will also force `bias_initializer="zeros"`.
2276      This is recommended in [Jozefowicz et al., 2015](
2277        http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
2278    kernel_regularizer: Regularizer function applied to
2279      the `kernel` weights matrix.
2280    recurrent_regularizer: Regularizer function applied to
2281      the `recurrent_kernel` weights matrix.
2282    bias_regularizer: Regularizer function applied to the bias vector.
2283    kernel_constraint: Constraint function applied to
2284      the `kernel` weights matrix.
2285    recurrent_constraint: Constraint function applied to
2286      the `recurrent_kernel` weights matrix.
2287    bias_constraint: Constraint function applied to the bias vector.
2288    dropout: Float between 0 and 1.
2289      Fraction of the units to drop for
2290      the linear transformation of the inputs.
2291    recurrent_dropout: Float between 0 and 1.
2292      Fraction of the units to drop for
2293      the linear transformation of the recurrent state.
2294
2295  Call arguments:
2296    inputs: A 2D tensor.
2297    states: List of state tensors corresponding to the previous timestep.
2298    training: Python boolean indicating whether the layer should behave in
2299      training mode or in inference mode. Only relevant when `dropout` or
2300      `recurrent_dropout` is used.
2301  """
2302
2303  def __init__(self,
2304               units,
2305               activation='tanh',
2306               recurrent_activation='hard_sigmoid',
2307               use_bias=True,
2308               kernel_initializer='glorot_uniform',
2309               recurrent_initializer='orthogonal',
2310               bias_initializer='zeros',
2311               unit_forget_bias=True,
2312               kernel_regularizer=None,
2313               recurrent_regularizer=None,
2314               bias_regularizer=None,
2315               kernel_constraint=None,
2316               recurrent_constraint=None,
2317               bias_constraint=None,
2318               dropout=0.,
2319               recurrent_dropout=0.,
2320               **kwargs):
2321    if units < 0:
2322      raise ValueError(f'Received an invalid value for units, expected '
2323                       f'a positive integer, got {units}.')
2324    # By default use cached variable under v2 mode, see b/143699808.
2325    if ops.executing_eagerly_outside_functions():
2326      self._enable_caching_device = kwargs.pop('enable_caching_device', True)
2327    else:
2328      self._enable_caching_device = kwargs.pop('enable_caching_device', False)
2329    super(LSTMCell, self).__init__(**kwargs)
2330    self.units = units
2331    self.activation = activations.get(activation)
2332    self.recurrent_activation = activations.get(recurrent_activation)
2333    self.use_bias = use_bias
2334
2335    self.kernel_initializer = initializers.get(kernel_initializer)
2336    self.recurrent_initializer = initializers.get(recurrent_initializer)
2337    self.bias_initializer = initializers.get(bias_initializer)
2338    self.unit_forget_bias = unit_forget_bias
2339
2340    self.kernel_regularizer = regularizers.get(kernel_regularizer)
2341    self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
2342    self.bias_regularizer = regularizers.get(bias_regularizer)
2343
2344    self.kernel_constraint = constraints.get(kernel_constraint)
2345    self.recurrent_constraint = constraints.get(recurrent_constraint)
2346    self.bias_constraint = constraints.get(bias_constraint)
2347
2348    self.dropout = min(1., max(0., dropout))
2349    self.recurrent_dropout = min(1., max(0., recurrent_dropout))
2350    implementation = kwargs.pop('implementation', 1)
2351    if self.recurrent_dropout != 0 and implementation != 1:
2352      logging.debug(RECURRENT_DROPOUT_WARNING_MSG)
2353      self.implementation = 1
2354    else:
2355      self.implementation = implementation
2356    self.state_size = [self.units, self.units]
2357    self.output_size = self.units
2358
2359  @tf_utils.shape_type_conversion
2360  def build(self, input_shape):
2361    default_caching_device = _caching_device(self)
2362    input_dim = input_shape[-1]
2363    self.kernel = self.add_weight(
2364        shape=(input_dim, self.units * 4),
2365        name='kernel',
2366        initializer=self.kernel_initializer,
2367        regularizer=self.kernel_regularizer,
2368        constraint=self.kernel_constraint,
2369        caching_device=default_caching_device)
2370    self.recurrent_kernel = self.add_weight(
2371        shape=(self.units, self.units * 4),
2372        name='recurrent_kernel',
2373        initializer=self.recurrent_initializer,
2374        regularizer=self.recurrent_regularizer,
2375        constraint=self.recurrent_constraint,
2376        caching_device=default_caching_device)
2377
2378    if self.use_bias:
2379      if self.unit_forget_bias:
2380
2381        def bias_initializer(_, *args, **kwargs):
2382          return backend.concatenate([
2383              self.bias_initializer((self.units,), *args, **kwargs),
2384              initializers.get('ones')((self.units,), *args, **kwargs),
2385              self.bias_initializer((self.units * 2,), *args, **kwargs),
2386          ])
2387      else:
2388        bias_initializer = self.bias_initializer
2389      self.bias = self.add_weight(
2390          shape=(self.units * 4,),
2391          name='bias',
2392          initializer=bias_initializer,
2393          regularizer=self.bias_regularizer,
2394          constraint=self.bias_constraint,
2395          caching_device=default_caching_device)
2396    else:
2397      self.bias = None
2398    self.built = True
2399
2400  def _compute_carry_and_output(self, x, h_tm1, c_tm1):
2401    """Computes carry and output using split kernels."""
2402    x_i, x_f, x_c, x_o = x
2403    h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1
2404    i = self.recurrent_activation(
2405        x_i + backend.dot(h_tm1_i, self.recurrent_kernel[:, :self.units]))
2406    f = self.recurrent_activation(x_f + backend.dot(
2407        h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2]))
2408    c = f * c_tm1 + i * self.activation(x_c + backend.dot(
2409        h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3]))
2410    o = self.recurrent_activation(
2411        x_o + backend.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:]))
2412    return c, o
2413
2414  def _compute_carry_and_output_fused(self, z, c_tm1):
2415    """Computes carry and output using fused kernels."""
2416    z0, z1, z2, z3 = z
2417    i = self.recurrent_activation(z0)
2418    f = self.recurrent_activation(z1)
2419    c = f * c_tm1 + i * self.activation(z2)
2420    o = self.recurrent_activation(z3)
2421    return c, o
2422
2423  def call(self, inputs, states, training=None):
2424    h_tm1 = states[0]  # previous memory state
2425    c_tm1 = states[1]  # previous carry state
2426
2427    dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4)
2428    rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
2429        h_tm1, training, count=4)
2430
2431    if self.implementation == 1:
2432      if 0 < self.dropout < 1.:
2433        inputs_i = inputs * dp_mask[0]
2434        inputs_f = inputs * dp_mask[1]
2435        inputs_c = inputs * dp_mask[2]
2436        inputs_o = inputs * dp_mask[3]
2437      else:
2438        inputs_i = inputs
2439        inputs_f = inputs
2440        inputs_c = inputs
2441        inputs_o = inputs
2442      k_i, k_f, k_c, k_o = array_ops.split(
2443          self.kernel, num_or_size_splits=4, axis=1)
2444      x_i = backend.dot(inputs_i, k_i)
2445      x_f = backend.dot(inputs_f, k_f)
2446      x_c = backend.dot(inputs_c, k_c)
2447      x_o = backend.dot(inputs_o, k_o)
2448      if self.use_bias:
2449        b_i, b_f, b_c, b_o = array_ops.split(
2450            self.bias, num_or_size_splits=4, axis=0)
2451        x_i = backend.bias_add(x_i, b_i)
2452        x_f = backend.bias_add(x_f, b_f)
2453        x_c = backend.bias_add(x_c, b_c)
2454        x_o = backend.bias_add(x_o, b_o)
2455
2456      if 0 < self.recurrent_dropout < 1.:
2457        h_tm1_i = h_tm1 * rec_dp_mask[0]
2458        h_tm1_f = h_tm1 * rec_dp_mask[1]
2459        h_tm1_c = h_tm1 * rec_dp_mask[2]
2460        h_tm1_o = h_tm1 * rec_dp_mask[3]
2461      else:
2462        h_tm1_i = h_tm1
2463        h_tm1_f = h_tm1
2464        h_tm1_c = h_tm1
2465        h_tm1_o = h_tm1
2466      x = (x_i, x_f, x_c, x_o)
2467      h_tm1 = (h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o)
2468      c, o = self._compute_carry_and_output(x, h_tm1, c_tm1)
2469    else:
2470      if 0. < self.dropout < 1.:
2471        inputs = inputs * dp_mask[0]
2472      z = backend.dot(inputs, self.kernel)
2473      z += backend.dot(h_tm1, self.recurrent_kernel)
2474      if self.use_bias:
2475        z = backend.bias_add(z, self.bias)
2476
2477      z = array_ops.split(z, num_or_size_splits=4, axis=1)
2478      c, o = self._compute_carry_and_output_fused(z, c_tm1)
2479
2480    h = o * self.activation(c)
2481    return h, [h, c]
2482
2483  def get_config(self):
2484    config = {
2485        'units':
2486            self.units,
2487        'activation':
2488            activations.serialize(self.activation),
2489        'recurrent_activation':
2490            activations.serialize(self.recurrent_activation),
2491        'use_bias':
2492            self.use_bias,
2493        'kernel_initializer':
2494            initializers.serialize(self.kernel_initializer),
2495        'recurrent_initializer':
2496            initializers.serialize(self.recurrent_initializer),
2497        'bias_initializer':
2498            initializers.serialize(self.bias_initializer),
2499        'unit_forget_bias':
2500            self.unit_forget_bias,
2501        'kernel_regularizer':
2502            regularizers.serialize(self.kernel_regularizer),
2503        'recurrent_regularizer':
2504            regularizers.serialize(self.recurrent_regularizer),
2505        'bias_regularizer':
2506            regularizers.serialize(self.bias_regularizer),
2507        'kernel_constraint':
2508            constraints.serialize(self.kernel_constraint),
2509        'recurrent_constraint':
2510            constraints.serialize(self.recurrent_constraint),
2511        'bias_constraint':
2512            constraints.serialize(self.bias_constraint),
2513        'dropout':
2514            self.dropout,
2515        'recurrent_dropout':
2516            self.recurrent_dropout,
2517        'implementation':
2518            self.implementation
2519    }
2520    config.update(_config_for_enable_caching_device(self))
2521    base_config = super(LSTMCell, self).get_config()
2522    return dict(list(base_config.items()) + list(config.items()))
2523
2524  def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
2525    return list(_generate_zero_filled_state_for_cell(
2526        self, inputs, batch_size, dtype))
2527
2528
2529@keras_export('keras.experimental.PeepholeLSTMCell')
2530class PeepholeLSTMCell(LSTMCell):
2531  """Equivalent to LSTMCell class but adds peephole connections.
2532
2533  Peephole connections allow the gates to utilize the previous internal state as
2534  well as the previous hidden state (which is what LSTMCell is limited to).
2535  This allows PeepholeLSTMCell to better learn precise timings over LSTMCell.
2536
2537  From [Gers et al., 2002](
2538    http://www.jmlr.org/papers/volume3/gers02a/gers02a.pdf):
2539
2540  "We find that LSTM augmented by 'peephole connections' from its internal
2541  cells to its multiplicative gates can learn the fine distinction between
2542  sequences of spikes spaced either 50 or 49 time steps apart without the help
2543  of any short training exemplars."
2544
2545  The peephole implementation is based on:
2546
2547  [Sak et al., 2014](https://research.google.com/pubs/archive/43905.pdf)
2548
2549  Example:
2550
2551  ```python
2552  # Create 2 PeepholeLSTMCells
2553  peephole_lstm_cells = [PeepholeLSTMCell(size) for size in [128, 256]]
2554  # Create a layer composed sequentially of the peephole LSTM cells.
2555  layer = RNN(peephole_lstm_cells)
2556  input = keras.Input((timesteps, input_dim))
2557  output = layer(input)
2558  ```
2559  """
2560
2561  def __init__(self,
2562               units,
2563               activation='tanh',
2564               recurrent_activation='hard_sigmoid',
2565               use_bias=True,
2566               kernel_initializer='glorot_uniform',
2567               recurrent_initializer='orthogonal',
2568               bias_initializer='zeros',
2569               unit_forget_bias=True,
2570               kernel_regularizer=None,
2571               recurrent_regularizer=None,
2572               bias_regularizer=None,
2573               kernel_constraint=None,
2574               recurrent_constraint=None,
2575               bias_constraint=None,
2576               dropout=0.,
2577               recurrent_dropout=0.,
2578               **kwargs):
2579    warnings.warn('`tf.keras.experimental.PeepholeLSTMCell` is deprecated '
2580                  'and will be removed in a future version. '
2581                  'Please use tensorflow_addons.rnn.PeepholeLSTMCell '
2582                  'instead.')
2583    super(PeepholeLSTMCell, self).__init__(
2584        units=units,
2585        activation=activation,
2586        recurrent_activation=recurrent_activation,
2587        use_bias=use_bias,
2588        kernel_initializer=kernel_initializer,
2589        recurrent_initializer=recurrent_initializer,
2590        bias_initializer=bias_initializer,
2591        unit_forget_bias=unit_forget_bias,
2592        kernel_regularizer=kernel_regularizer,
2593        recurrent_regularizer=recurrent_regularizer,
2594        bias_regularizer=bias_regularizer,
2595        kernel_constraint=kernel_constraint,
2596        recurrent_constraint=recurrent_constraint,
2597        bias_constraint=bias_constraint,
2598        dropout=dropout,
2599        recurrent_dropout=recurrent_dropout,
2600        implementation=kwargs.pop('implementation', 1),
2601        **kwargs)
2602
2603  def build(self, input_shape):
2604    super(PeepholeLSTMCell, self).build(input_shape)
2605    # The following are the weight matrices for the peephole connections. These
2606    # are multiplied with the previous internal state during the computation of
2607    # carry and output.
2608    self.input_gate_peephole_weights = self.add_weight(
2609        shape=(self.units,),
2610        name='input_gate_peephole_weights',
2611        initializer=self.kernel_initializer)
2612    self.forget_gate_peephole_weights = self.add_weight(
2613        shape=(self.units,),
2614        name='forget_gate_peephole_weights',
2615        initializer=self.kernel_initializer)
2616    self.output_gate_peephole_weights = self.add_weight(
2617        shape=(self.units,),
2618        name='output_gate_peephole_weights',
2619        initializer=self.kernel_initializer)
2620
2621  def _compute_carry_and_output(self, x, h_tm1, c_tm1):
2622    x_i, x_f, x_c, x_o = x
2623    h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1
2624    i = self.recurrent_activation(
2625        x_i + backend.dot(h_tm1_i, self.recurrent_kernel[:, :self.units]) +
2626        self.input_gate_peephole_weights * c_tm1)
2627    f = self.recurrent_activation(x_f + backend.dot(
2628        h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2]) +
2629                                  self.forget_gate_peephole_weights * c_tm1)
2630    c = f * c_tm1 + i * self.activation(x_c + backend.dot(
2631        h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3]))
2632    o = self.recurrent_activation(
2633        x_o + backend.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:]) +
2634        self.output_gate_peephole_weights * c)
2635    return c, o
2636
2637  def _compute_carry_and_output_fused(self, z, c_tm1):
2638    z0, z1, z2, z3 = z
2639    i = self.recurrent_activation(z0 +
2640                                  self.input_gate_peephole_weights * c_tm1)
2641    f = self.recurrent_activation(z1 +
2642                                  self.forget_gate_peephole_weights * c_tm1)
2643    c = f * c_tm1 + i * self.activation(z2)
2644    o = self.recurrent_activation(z3 + self.output_gate_peephole_weights * c)
2645    return c, o
2646
2647
2648@keras_export(v1=['keras.layers.LSTM'])
2649class LSTM(RNN):
2650  """Long Short-Term Memory layer - Hochreiter 1997.
2651
2652   Note that this cell is not optimized for performance on GPU. Please use
2653  `tf.compat.v1.keras.layers.CuDNNLSTM` for better performance on GPU.
2654
2655  Args:
2656    units: Positive integer, dimensionality of the output space.
2657    activation: Activation function to use.
2658      Default: hyperbolic tangent (`tanh`).
2659      If you pass `None`, no activation is applied
2660      (ie. "linear" activation: `a(x) = x`).
2661    recurrent_activation: Activation function to use
2662      for the recurrent step.
2663      Default: hard sigmoid (`hard_sigmoid`).
2664      If you pass `None`, no activation is applied
2665      (ie. "linear" activation: `a(x) = x`).
2666    use_bias: Boolean, whether the layer uses a bias vector.
2667    kernel_initializer: Initializer for the `kernel` weights matrix,
2668      used for the linear transformation of the inputs..
2669    recurrent_initializer: Initializer for the `recurrent_kernel`
2670      weights matrix,
2671      used for the linear transformation of the recurrent state.
2672    bias_initializer: Initializer for the bias vector.
2673    unit_forget_bias: Boolean.
2674      If True, add 1 to the bias of the forget gate at initialization.
2675      Setting it to true will also force `bias_initializer="zeros"`.
2676      This is recommended in [Jozefowicz et al., 2015](
2677        http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf).
2678    kernel_regularizer: Regularizer function applied to
2679      the `kernel` weights matrix.
2680    recurrent_regularizer: Regularizer function applied to
2681      the `recurrent_kernel` weights matrix.
2682    bias_regularizer: Regularizer function applied to the bias vector.
2683    activity_regularizer: Regularizer function applied to
2684      the output of the layer (its "activation").
2685    kernel_constraint: Constraint function applied to
2686      the `kernel` weights matrix.
2687    recurrent_constraint: Constraint function applied to
2688      the `recurrent_kernel` weights matrix.
2689    bias_constraint: Constraint function applied to the bias vector.
2690    dropout: Float between 0 and 1.
2691      Fraction of the units to drop for
2692      the linear transformation of the inputs.
2693    recurrent_dropout: Float between 0 and 1.
2694      Fraction of the units to drop for
2695      the linear transformation of the recurrent state.
2696    return_sequences: Boolean. Whether to return the last output.
2697      in the output sequence, or the full sequence.
2698    return_state: Boolean. Whether to return the last state
2699      in addition to the output.
2700    go_backwards: Boolean (default False).
2701      If True, process the input sequence backwards and return the
2702      reversed sequence.
2703    stateful: Boolean (default False). If True, the last state
2704      for each sample at index i in a batch will be used as initial
2705      state for the sample of index i in the following batch.
2706    unroll: Boolean (default False).
2707      If True, the network will be unrolled,
2708      else a symbolic loop will be used.
2709      Unrolling can speed-up a RNN,
2710      although it tends to be more memory-intensive.
2711      Unrolling is only suitable for short sequences.
2712    time_major: The shape format of the `inputs` and `outputs` tensors.
2713      If True, the inputs and outputs will be in shape
2714      `(timesteps, batch, ...)`, whereas in the False case, it will be
2715      `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
2716      efficient because it avoids transposes at the beginning and end of the
2717      RNN calculation. However, most TensorFlow data is batch-major, so by
2718      default this function accepts input and emits output in batch-major
2719      form.
2720
2721  Call arguments:
2722    inputs: A 3D tensor.
2723    mask: Binary tensor of shape `(samples, timesteps)` indicating whether
2724      a given timestep should be masked. An individual `True` entry indicates
2725      that the corresponding timestep should be utilized, while a `False`
2726      entry indicates that the corresponding timestep should be ignored.
2727    training: Python boolean indicating whether the layer should behave in
2728      training mode or in inference mode. This argument is passed to the cell
2729      when calling it. This is only relevant if `dropout` or
2730      `recurrent_dropout` is used.
2731    initial_state: List of initial state tensors to be passed to the first
2732      call of the cell.
2733  """
2734
2735  def __init__(self,
2736               units,
2737               activation='tanh',
2738               recurrent_activation='hard_sigmoid',
2739               use_bias=True,
2740               kernel_initializer='glorot_uniform',
2741               recurrent_initializer='orthogonal',
2742               bias_initializer='zeros',
2743               unit_forget_bias=True,
2744               kernel_regularizer=None,
2745               recurrent_regularizer=None,
2746               bias_regularizer=None,
2747               activity_regularizer=None,
2748               kernel_constraint=None,
2749               recurrent_constraint=None,
2750               bias_constraint=None,
2751               dropout=0.,
2752               recurrent_dropout=0.,
2753               return_sequences=False,
2754               return_state=False,
2755               go_backwards=False,
2756               stateful=False,
2757               unroll=False,
2758               **kwargs):
2759    implementation = kwargs.pop('implementation', 1)
2760    if implementation == 0:
2761      logging.warning('`implementation=0` has been deprecated, '
2762                      'and now defaults to `implementation=1`.'
2763                      'Please update your layer call.')
2764    if 'enable_caching_device' in kwargs:
2765      cell_kwargs = {'enable_caching_device':
2766                     kwargs.pop('enable_caching_device')}
2767    else:
2768      cell_kwargs = {}
2769    cell = LSTMCell(
2770        units,
2771        activation=activation,
2772        recurrent_activation=recurrent_activation,
2773        use_bias=use_bias,
2774        kernel_initializer=kernel_initializer,
2775        recurrent_initializer=recurrent_initializer,
2776        unit_forget_bias=unit_forget_bias,
2777        bias_initializer=bias_initializer,
2778        kernel_regularizer=kernel_regularizer,
2779        recurrent_regularizer=recurrent_regularizer,
2780        bias_regularizer=bias_regularizer,
2781        kernel_constraint=kernel_constraint,
2782        recurrent_constraint=recurrent_constraint,
2783        bias_constraint=bias_constraint,
2784        dropout=dropout,
2785        recurrent_dropout=recurrent_dropout,
2786        implementation=implementation,
2787        dtype=kwargs.get('dtype'),
2788        trainable=kwargs.get('trainable', True),
2789        **cell_kwargs)
2790    super(LSTM, self).__init__(
2791        cell,
2792        return_sequences=return_sequences,
2793        return_state=return_state,
2794        go_backwards=go_backwards,
2795        stateful=stateful,
2796        unroll=unroll,
2797        **kwargs)
2798    self.activity_regularizer = regularizers.get(activity_regularizer)
2799    self.input_spec = [InputSpec(ndim=3)]
2800
2801  def call(self, inputs, mask=None, training=None, initial_state=None):
2802    return super(LSTM, self).call(
2803        inputs, mask=mask, training=training, initial_state=initial_state)
2804
2805  @property
2806  def units(self):
2807    return self.cell.units
2808
2809  @property
2810  def activation(self):
2811    return self.cell.activation
2812
2813  @property
2814  def recurrent_activation(self):
2815    return self.cell.recurrent_activation
2816
2817  @property
2818  def use_bias(self):
2819    return self.cell.use_bias
2820
2821  @property
2822  def kernel_initializer(self):
2823    return self.cell.kernel_initializer
2824
2825  @property
2826  def recurrent_initializer(self):
2827    return self.cell.recurrent_initializer
2828
2829  @property
2830  def bias_initializer(self):
2831    return self.cell.bias_initializer
2832
2833  @property
2834  def unit_forget_bias(self):
2835    return self.cell.unit_forget_bias
2836
2837  @property
2838  def kernel_regularizer(self):
2839    return self.cell.kernel_regularizer
2840
2841  @property
2842  def recurrent_regularizer(self):
2843    return self.cell.recurrent_regularizer
2844
2845  @property
2846  def bias_regularizer(self):
2847    return self.cell.bias_regularizer
2848
2849  @property
2850  def kernel_constraint(self):
2851    return self.cell.kernel_constraint
2852
2853  @property
2854  def recurrent_constraint(self):
2855    return self.cell.recurrent_constraint
2856
2857  @property
2858  def bias_constraint(self):
2859    return self.cell.bias_constraint
2860
2861  @property
2862  def dropout(self):
2863    return self.cell.dropout
2864
2865  @property
2866  def recurrent_dropout(self):
2867    return self.cell.recurrent_dropout
2868
2869  @property
2870  def implementation(self):
2871    return self.cell.implementation
2872
2873  def get_config(self):
2874    config = {
2875        'units':
2876            self.units,
2877        'activation':
2878            activations.serialize(self.activation),
2879        'recurrent_activation':
2880            activations.serialize(self.recurrent_activation),
2881        'use_bias':
2882            self.use_bias,
2883        'kernel_initializer':
2884            initializers.serialize(self.kernel_initializer),
2885        'recurrent_initializer':
2886            initializers.serialize(self.recurrent_initializer),
2887        'bias_initializer':
2888            initializers.serialize(self.bias_initializer),
2889        'unit_forget_bias':
2890            self.unit_forget_bias,
2891        'kernel_regularizer':
2892            regularizers.serialize(self.kernel_regularizer),
2893        'recurrent_regularizer':
2894            regularizers.serialize(self.recurrent_regularizer),
2895        'bias_regularizer':
2896            regularizers.serialize(self.bias_regularizer),
2897        'activity_regularizer':
2898            regularizers.serialize(self.activity_regularizer),
2899        'kernel_constraint':
2900            constraints.serialize(self.kernel_constraint),
2901        'recurrent_constraint':
2902            constraints.serialize(self.recurrent_constraint),
2903        'bias_constraint':
2904            constraints.serialize(self.bias_constraint),
2905        'dropout':
2906            self.dropout,
2907        'recurrent_dropout':
2908            self.recurrent_dropout,
2909        'implementation':
2910            self.implementation
2911    }
2912    config.update(_config_for_enable_caching_device(self.cell))
2913    base_config = super(LSTM, self).get_config()
2914    del base_config['cell']
2915    return dict(list(base_config.items()) + list(config.items()))
2916
2917  @classmethod
2918  def from_config(cls, config):
2919    if 'implementation' in config and config['implementation'] == 0:
2920      config['implementation'] = 1
2921    return cls(**config)
2922
2923
2924def _generate_dropout_mask(ones, rate, training=None, count=1):
2925  def dropped_inputs():
2926    return backend.dropout(ones, rate)
2927
2928  if count > 1:
2929    return [
2930        backend.in_train_phase(dropped_inputs, ones, training=training)
2931        for _ in range(count)
2932    ]
2933  return backend.in_train_phase(dropped_inputs, ones, training=training)
2934
2935
2936def _standardize_args(inputs, initial_state, constants, num_constants):
2937  """Standardizes `__call__` to a single list of tensor inputs.
2938
2939  When running a model loaded from a file, the input tensors
2940  `initial_state` and `constants` can be passed to `RNN.__call__()` as part
2941  of `inputs` instead of by the dedicated keyword arguments. This method
2942  makes sure the arguments are separated and that `initial_state` and
2943  `constants` are lists of tensors (or None).
2944
2945  Args:
2946    inputs: Tensor or list/tuple of tensors. which may include constants
2947      and initial states. In that case `num_constant` must be specified.
2948    initial_state: Tensor or list of tensors or None, initial states.
2949    constants: Tensor or list of tensors or None, constant tensors.
2950    num_constants: Expected number of constants (if constants are passed as
2951      part of the `inputs` list.
2952
2953  Returns:
2954    inputs: Single tensor or tuple of tensors.
2955    initial_state: List of tensors or None.
2956    constants: List of tensors or None.
2957  """
2958  if isinstance(inputs, list):
2959    # There are several situations here:
2960    # In the graph mode, __call__ will be only called once. The initial_state
2961    # and constants could be in inputs (from file loading).
2962    # In the eager mode, __call__ will be called twice, once during
2963    # rnn_layer(inputs=input_t, constants=c_t, ...), and second time will be
2964    # model.fit/train_on_batch/predict with real np data. In the second case,
2965    # the inputs will contain initial_state and constants as eager tensor.
2966    #
2967    # For either case, the real input is the first item in the list, which
2968    # could be a nested structure itself. Then followed by initial_states, which
2969    # could be a list of items, or list of list if the initial_state is complex
2970    # structure, and finally followed by constants which is a flat list.
2971    assert initial_state is None and constants is None
2972    if num_constants:
2973      constants = inputs[-num_constants:]
2974      inputs = inputs[:-num_constants]
2975    if len(inputs) > 1:
2976      initial_state = inputs[1:]
2977      inputs = inputs[:1]
2978
2979    if len(inputs) > 1:
2980      inputs = tuple(inputs)
2981    else:
2982      inputs = inputs[0]
2983
2984  def to_list_or_none(x):
2985    if x is None or isinstance(x, list):
2986      return x
2987    if isinstance(x, tuple):
2988      return list(x)
2989    return [x]
2990
2991  initial_state = to_list_or_none(initial_state)
2992  constants = to_list_or_none(constants)
2993
2994  return inputs, initial_state, constants
2995
2996
2997def _is_multiple_state(state_size):
2998  """Check whether the state_size contains multiple states."""
2999  return (hasattr(state_size, '__len__') and
3000          not isinstance(state_size, tensor_shape.TensorShape))
3001
3002
3003def _generate_zero_filled_state_for_cell(cell, inputs, batch_size, dtype):
3004  if inputs is not None:
3005    batch_size = array_ops.shape(inputs)[0]
3006    dtype = inputs.dtype
3007  return _generate_zero_filled_state(batch_size, cell.state_size, dtype)
3008
3009
3010def _generate_zero_filled_state(batch_size_tensor, state_size, dtype):
3011  """Generate a zero filled tensor with shape [batch_size, state_size]."""
3012  if batch_size_tensor is None or dtype is None:
3013    raise ValueError(
3014        'batch_size and dtype cannot be None while constructing initial state: '
3015        'batch_size={}, dtype={}'.format(batch_size_tensor, dtype))
3016
3017  def create_zeros(unnested_state_size):
3018    flat_dims = tensor_shape.TensorShape(unnested_state_size).as_list()
3019    init_state_size = [batch_size_tensor] + flat_dims
3020    return array_ops.zeros(init_state_size, dtype=dtype)
3021
3022  if nest.is_nested(state_size):
3023    return nest.map_structure(create_zeros, state_size)
3024  else:
3025    return create_zeros(state_size)
3026
3027
3028def _caching_device(rnn_cell):
3029  """Returns the caching device for the RNN variable.
3030
3031  This is useful for distributed training, when variable is not located as same
3032  device as the training worker. By enabling the device cache, this allows
3033  worker to read the variable once and cache locally, rather than read it every
3034  time step from remote when it is needed.
3035
3036  Note that this is assuming the variable that cell needs for each time step is
3037  having the same value in the forward path, and only gets updated in the
3038  backprop. It is true for all the default cells (SimpleRNN, GRU, LSTM). If the
3039  cell body relies on any variable that gets updated every time step, then
3040  caching device will cause it to read the stall value.
3041
3042  Args:
3043    rnn_cell: the rnn cell instance.
3044  """
3045  if context.executing_eagerly():
3046    # caching_device is not supported in eager mode.
3047    return None
3048  if not getattr(rnn_cell, '_enable_caching_device', False):
3049    return None
3050  # Don't set a caching device when running in a loop, since it is possible that
3051  # train steps could be wrapped in a tf.while_loop. In that scenario caching
3052  # prevents forward computations in loop iterations from re-reading the
3053  # updated weights.
3054  if control_flow_util.IsInWhileLoop(ops.get_default_graph()):
3055    logging.warning(
3056        'Variable read device caching has been disabled because the '
3057        'RNN is in tf.while_loop loop context, which will cause '
3058        'reading stalled value in forward path. This could slow down '
3059        'the training due to duplicated variable reads. Please '
3060        'consider updating your code to remove tf.while_loop if possible.')
3061    return None
3062  if (rnn_cell._dtype_policy.compute_dtype !=
3063      rnn_cell._dtype_policy.variable_dtype):
3064    logging.warning(
3065        'Variable read device caching has been disabled since it '
3066        'doesn\'t work with the mixed precision API. This is '
3067        'likely to cause a slowdown for RNN training due to '
3068        'duplicated read of variable for each timestep, which '
3069        'will be significant in a multi remote worker setting. '
3070        'Please consider disabling mixed precision API if '
3071        'the performance has been affected.')
3072    return None
3073  # Cache the value on the device that access the variable.
3074  return lambda op: op.device
3075
3076
3077def _config_for_enable_caching_device(rnn_cell):
3078  """Return the dict config for RNN cell wrt to enable_caching_device field.
3079
3080  Since enable_caching_device is a internal implementation detail for speed up
3081  the RNN variable read when running on the multi remote worker setting, we
3082  don't want this config to be serialized constantly in the JSON. We will only
3083  serialize this field when a none default value is used to create the cell.
3084  Args:
3085    rnn_cell: the RNN cell for serialize.
3086
3087  Returns:
3088    A dict which contains the JSON config for enable_caching_device value or
3089    empty dict if the enable_caching_device value is same as the default value.
3090  """
3091  default_enable_caching_device = ops.executing_eagerly_outside_functions()
3092  if rnn_cell._enable_caching_device != default_enable_caching_device:
3093    return {'enable_caching_device': rnn_cell._enable_caching_device}
3094  return {}
3095