xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/layers/recurrent_v2.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=g-classes-have-attributes
16"""Recurrent layers for TF 2."""
17
18import uuid
19
20from tensorflow.python.eager import context
21from tensorflow.python.eager import function
22from tensorflow.python.eager.context import get_device_name
23from tensorflow.python.framework import config
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import device
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.keras import activations
29from tensorflow.python.keras import backend
30from tensorflow.python.keras.engine.input_spec import InputSpec
31from tensorflow.python.keras.layers import recurrent
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import control_flow_ops
34from tensorflow.python.ops import gen_cudnn_rnn_ops
35from tensorflow.python.ops import math_ops
36from tensorflow.python.ops import nn
37from tensorflow.python.ops import state_ops
38from tensorflow.python.ops import variables
39from tensorflow.python.platform import sysconfig
40from tensorflow.python.platform import tf_logging as logging
41from tensorflow.python.util.tf_export import keras_export
42
43
44# The following string constants are used by Defun approach for unified backend
45# of LSTM and GRU.
46_FUNCTION_API_NAME_ATTRIBUTE = 'api_implements'
47_FUNCTION_DEVICE_ATTRIBUTE = 'api_preferred_device'
48_CPU_DEVICE_NAME = 'CPU'
49_GPU_DEVICE_NAME = 'GPU'
50
51# The following number constants are used to represent the runtime of the defun
52# backend function. Since the CPU/GPU implementation are mathematically same, we
53# need some signal for the function to indicate which function is executed. This
54# is for testing purpose to verify the correctness of swapping backend function.
55_RUNTIME_UNKNOWN = 0
56_RUNTIME_CPU = 1
57_RUNTIME_GPU = 2
58
59_CUDNN_AVAILABLE_MSG = 'Layer %s will use cuDNN kernels when running on GPU.'
60_CUDNN_NOT_AVAILABLE_MSG = ('Layer %s will not use cuDNN kernels since it '
61                            'doesn\'t meet the criteria. It will '
62                            'use a generic GPU kernel as fallback when running '
63                            'on GPU.')
64
65
66def _use_new_code():
67  return False
68
69
70# TODO(b/169707691): The wrapper can be removed if TFLite doesn't need to rely
71# on supportive attributes from LSTM/GRU.
72class _DefunWrapper(object):
73  """A wrapper with no deep copy of the Defun in LSTM/GRU layer."""
74
75  def __init__(self, time_major, go_backwards, layer_name):
76    self.time_major = time_major
77    self.go_backwards = go_backwards
78    self.layer_name = layer_name
79    if self.layer_name not in ['lstm', 'gru']:
80      raise ValueError('Defun wrapper only applies to LSTM and GRU layer, '
81                       'but given {}'.format(self.layer_name))
82    # The first two attributes are added to support TFLite use case.
83    supportive_attributes = {
84        'time_major': self.time_major,
85        'go_backwards': self.go_backwards,
86        _FUNCTION_API_NAME_ATTRIBUTE: self.layer_name + '_' + str(uuid.uuid4())
87    }
88    if self.layer_name == 'lstm':
89      layer_func = lstm_with_backend_selection
90    else:
91      layer_func = gru_with_backend_selection
92
93    self.defun_layer = function.defun_with_attributes(
94        layer_func,
95        attributes=supportive_attributes,
96        autograph=False)
97
98  def __deepcopy__(self, memo):
99    new_wrapper = type(self)(
100        self.time_major, self.go_backwards, self.layer_name)
101    memo[id(self)] = new_wrapper
102    return new_wrapper
103
104
105@keras_export('keras.layers.GRUCell', v1=[])
106class GRUCell(recurrent.GRUCell):
107  """Cell class for the GRU layer.
108
109  See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
110  for details about the usage of RNN API.
111
112  This class processes one step within the whole time sequence input, whereas
113  `tf.keras.layer.GRU` processes the whole sequence.
114
115  For example:
116
117  >>> inputs = tf.random.normal([32, 10, 8])
118  >>> rnn = tf.keras.layers.RNN(tf.keras.layers.GRUCell(4))
119  >>> output = rnn(inputs)
120  >>> print(output.shape)
121  (32, 4)
122  >>> rnn = tf.keras.layers.RNN(
123  ...    tf.keras.layers.GRUCell(4),
124  ...    return_sequences=True,
125  ...    return_state=True)
126  >>> whole_sequence_output, final_state = rnn(inputs)
127  >>> print(whole_sequence_output.shape)
128  (32, 10, 4)
129  >>> print(final_state.shape)
130  (32, 4)
131
132  Args:
133    units: Positive integer, dimensionality of the output space.
134    activation: Activation function to use. Default: hyperbolic tangent
135      (`tanh`). If you pass None, no activation is applied
136      (ie. "linear" activation: `a(x) = x`).
137    recurrent_activation: Activation function to use for the recurrent step.
138      Default: sigmoid (`sigmoid`). If you pass `None`, no activation is
139      applied (ie. "linear" activation: `a(x) = x`).
140    use_bias: Boolean, (default `True`), whether the layer uses a bias vector.
141    kernel_initializer: Initializer for the `kernel` weights matrix,
142      used for the linear transformation of the inputs. Default:
143      `glorot_uniform`.
144    recurrent_initializer: Initializer for the `recurrent_kernel`
145      weights matrix, used for the linear transformation of the recurrent state.
146      Default: `orthogonal`.
147    bias_initializer: Initializer for the bias vector. Default: `zeros`.
148    kernel_regularizer: Regularizer function applied to the `kernel` weights
149      matrix. Default: `None`.
150    recurrent_regularizer: Regularizer function applied to the
151      `recurrent_kernel` weights matrix. Default: `None`.
152    bias_regularizer: Regularizer function applied to the bias vector. Default:
153      `None`.
154    kernel_constraint: Constraint function applied to the `kernel` weights
155      matrix. Default: `None`.
156    recurrent_constraint: Constraint function applied to the `recurrent_kernel`
157      weights matrix. Default: `None`.
158    bias_constraint: Constraint function applied to the bias vector. Default:
159      `None`.
160    dropout: Float between 0 and 1. Fraction of the units to drop for the
161      linear transformation of the inputs. Default: 0.
162    recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for
163      the linear transformation of the recurrent state. Default: 0.
164    reset_after: GRU convention (whether to apply reset gate after or
165      before matrix multiplication). False = "before",
166      True = "after" (default and CuDNN compatible).
167
168  Call arguments:
169    inputs: A 2D tensor, with shape of `[batch, feature]`.
170    states: A 2D tensor with shape of `[batch, units]`, which is the state from
171      the previous time step. For timestep 0, the initial state provided by user
172      will be feed to cell.
173    training: Python boolean indicating whether the layer should behave in
174      training mode or in inference mode. Only relevant when `dropout` or
175      `recurrent_dropout` is used.
176  """
177
178  def __init__(self,
179               units,
180               activation='tanh',
181               recurrent_activation='sigmoid',
182               use_bias=True,
183               kernel_initializer='glorot_uniform',
184               recurrent_initializer='orthogonal',
185               bias_initializer='zeros',
186               kernel_regularizer=None,
187               recurrent_regularizer=None,
188               bias_regularizer=None,
189               kernel_constraint=None,
190               recurrent_constraint=None,
191               bias_constraint=None,
192               dropout=0.,
193               recurrent_dropout=0.,
194               reset_after=True,
195               **kwargs):
196    super(GRUCell, self).__init__(
197        units,
198        activation=activation,
199        recurrent_activation=recurrent_activation,
200        use_bias=use_bias,
201        kernel_initializer=kernel_initializer,
202        recurrent_initializer=recurrent_initializer,
203        bias_initializer=bias_initializer,
204        kernel_regularizer=kernel_regularizer,
205        recurrent_regularizer=recurrent_regularizer,
206        bias_regularizer=bias_regularizer,
207        kernel_constraint=kernel_constraint,
208        recurrent_constraint=recurrent_constraint,
209        bias_constraint=bias_constraint,
210        dropout=dropout,
211        recurrent_dropout=recurrent_dropout,
212        implementation=kwargs.pop('implementation', 2),
213        reset_after=reset_after,
214        **kwargs)
215
216
217@keras_export('keras.layers.GRU', v1=[])
218class GRU(recurrent.DropoutRNNCellMixin, recurrent.GRU):
219  """Gated Recurrent Unit - Cho et al. 2014.
220
221  See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
222  for details about the usage of RNN API.
223
224  Based on available runtime hardware and constraints, this layer
225  will choose different implementations (cuDNN-based or pure-TensorFlow)
226  to maximize the performance. If a GPU is available and all
227  the arguments to the layer meet the requirement of the CuDNN kernel
228  (see below for details), the layer will use a fast cuDNN implementation.
229
230  The requirements to use the cuDNN implementation are:
231
232  1. `activation` == `tanh`
233  2. `recurrent_activation` == `sigmoid`
234  3. `recurrent_dropout` == 0
235  4. `unroll` is `False`
236  5. `use_bias` is `True`
237  6. `reset_after` is `True`
238  7. Inputs, if use masking, are strictly right-padded.
239  8. Eager execution is enabled in the outermost context.
240
241  There are two variants of the GRU implementation. The default one is based on
242  [v3](https://arxiv.org/abs/1406.1078v3) and has reset gate applied to hidden
243  state before matrix multiplication. The other one is based on
244  [original](https://arxiv.org/abs/1406.1078v1) and has the order reversed.
245
246  The second variant is compatible with CuDNNGRU (GPU-only) and allows
247  inference on CPU. Thus it has separate biases for `kernel` and
248  `recurrent_kernel`. To use this variant, set `'reset_after'=True` and
249  `recurrent_activation='sigmoid'`.
250
251  For example:
252
253  >>> inputs = tf.random.normal([32, 10, 8])
254  >>> gru = tf.keras.layers.GRU(4)
255  >>> output = gru(inputs)
256  >>> print(output.shape)
257  (32, 4)
258  >>> gru = tf.keras.layers.GRU(4, return_sequences=True, return_state=True)
259  >>> whole_sequence_output, final_state = gru(inputs)
260  >>> print(whole_sequence_output.shape)
261  (32, 10, 4)
262  >>> print(final_state.shape)
263  (32, 4)
264
265  Args:
266    units: Positive integer, dimensionality of the output space.
267    activation: Activation function to use.
268      Default: hyperbolic tangent (`tanh`).
269      If you pass `None`, no activation is applied
270      (ie. "linear" activation: `a(x) = x`).
271    recurrent_activation: Activation function to use
272      for the recurrent step.
273      Default: sigmoid (`sigmoid`).
274      If you pass `None`, no activation is applied
275      (ie. "linear" activation: `a(x) = x`).
276    use_bias: Boolean, (default `True`), whether the layer uses a bias vector.
277    kernel_initializer: Initializer for the `kernel` weights matrix,
278      used for the linear transformation of the inputs. Default:
279      `glorot_uniform`.
280    recurrent_initializer: Initializer for the `recurrent_kernel`
281       weights matrix, used for the linear transformation of the recurrent
282       state. Default: `orthogonal`.
283    bias_initializer: Initializer for the bias vector. Default: `zeros`.
284    kernel_regularizer: Regularizer function applied to the `kernel` weights
285      matrix. Default: `None`.
286    recurrent_regularizer: Regularizer function applied to the
287      `recurrent_kernel` weights matrix. Default: `None`.
288    bias_regularizer: Regularizer function applied to the bias vector. Default:
289      `None`.
290    activity_regularizer: Regularizer function applied to the output of the
291      layer (its "activation"). Default: `None`.
292    kernel_constraint: Constraint function applied to the `kernel` weights
293      matrix. Default: `None`.
294    recurrent_constraint: Constraint function applied to the `recurrent_kernel`
295      weights matrix. Default: `None`.
296    bias_constraint: Constraint function applied to the bias vector. Default:
297      `None`.
298    dropout: Float between 0 and 1. Fraction of the units to drop for the linear
299      transformation of the inputs. Default: 0.
300    recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for
301      the linear transformation of the recurrent state. Default: 0.
302    return_sequences: Boolean. Whether to return the last output
303      in the output sequence, or the full sequence. Default: `False`.
304    return_state: Boolean. Whether to return the last state in addition to the
305      output. Default: `False`.
306    go_backwards: Boolean (default `False`).
307      If True, process the input sequence backwards and return the
308      reversed sequence.
309    stateful: Boolean (default False). If True, the last state
310      for each sample at index i in a batch will be used as initial
311      state for the sample of index i in the following batch.
312    unroll: Boolean (default False).
313      If True, the network will be unrolled,
314      else a symbolic loop will be used.
315      Unrolling can speed-up a RNN,
316      although it tends to be more memory-intensive.
317      Unrolling is only suitable for short sequences.
318    time_major: The shape format of the `inputs` and `outputs` tensors.
319      If True, the inputs and outputs will be in shape
320      `[timesteps, batch, feature]`, whereas in the False case, it will be
321      `[batch, timesteps, feature]`. Using `time_major = True` is a bit more
322      efficient because it avoids transposes at the beginning and end of the
323      RNN calculation. However, most TensorFlow data is batch-major, so by
324      default this function accepts input and emits output in batch-major
325      form.
326    reset_after: GRU convention (whether to apply reset gate after or
327      before matrix multiplication). False = "before",
328      True = "after" (default and CuDNN compatible).
329
330  Call arguments:
331    inputs: A 3D tensor, with shape `[batch, timesteps, feature]`.
332    mask: Binary tensor of shape `[samples, timesteps]` indicating whether
333      a given timestep should be masked  (optional, defaults to `None`).
334      An individual `True` entry indicates that the corresponding timestep
335      should be utilized, while a `False` entry indicates that the
336      corresponding timestep should be ignored.
337    training: Python boolean indicating whether the layer should behave in
338      training mode or in inference mode. This argument is passed to the cell
339      when calling it. This is only relevant if `dropout` or
340      `recurrent_dropout` is used  (optional, defaults to `None`).
341    initial_state: List of initial state tensors to be passed to the first
342      call of the cell  (optional, defaults to `None` which causes creation
343      of zero-filled initial state tensors).
344  """
345
346  def __init__(self,
347               units,
348               activation='tanh',
349               recurrent_activation='sigmoid',
350               use_bias=True,
351               kernel_initializer='glorot_uniform',
352               recurrent_initializer='orthogonal',
353               bias_initializer='zeros',
354               kernel_regularizer=None,
355               recurrent_regularizer=None,
356               bias_regularizer=None,
357               activity_regularizer=None,
358               kernel_constraint=None,
359               recurrent_constraint=None,
360               bias_constraint=None,
361               dropout=0.,
362               recurrent_dropout=0.,
363               return_sequences=False,
364               return_state=False,
365               go_backwards=False,
366               stateful=False,
367               unroll=False,
368               time_major=False,
369               reset_after=True,
370               **kwargs):
371    # return_runtime is a flag for testing, which shows the real backend
372    # implementation chosen by grappler in graph mode.
373    self._return_runtime = kwargs.pop('return_runtime', False)
374
375    super(GRU, self).__init__(
376        units,
377        activation=activation,
378        recurrent_activation=recurrent_activation,
379        use_bias=use_bias,
380        kernel_initializer=kernel_initializer,
381        recurrent_initializer=recurrent_initializer,
382        bias_initializer=bias_initializer,
383        kernel_regularizer=kernel_regularizer,
384        recurrent_regularizer=recurrent_regularizer,
385        bias_regularizer=bias_regularizer,
386        activity_regularizer=activity_regularizer,
387        kernel_constraint=kernel_constraint,
388        recurrent_constraint=recurrent_constraint,
389        bias_constraint=bias_constraint,
390        dropout=dropout,
391        recurrent_dropout=recurrent_dropout,
392        implementation=kwargs.pop('implementation', 2),
393        return_sequences=return_sequences,
394        return_state=return_state,
395        go_backwards=go_backwards,
396        stateful=stateful,
397        unroll=unroll,
398        time_major=time_major,
399        reset_after=reset_after,
400        **kwargs)
401    # GPU kernel uses following setting by default and not configurable.
402    self._could_use_gpu_kernel = (
403        self.activation in (activations.tanh, nn.tanh) and
404        self.recurrent_activation in (activations.sigmoid, nn.sigmoid) and
405        recurrent_dropout == 0 and not unroll and use_bias and
406        reset_after and ops.executing_eagerly_outside_functions())
407    if config.list_logical_devices('GPU'):
408      # Only show the message when there is GPU available, user will not care
409      # about the cuDNN if there isn't any GPU.
410      if self._could_use_gpu_kernel:
411        logging.debug(_CUDNN_AVAILABLE_MSG % self.name)
412      else:
413        logging.warning(_CUDNN_NOT_AVAILABLE_MSG % self.name)
414
415    if _use_new_code():
416      self._defun_wrapper = _DefunWrapper(time_major, go_backwards, 'gru')
417
418  def call(self, inputs, mask=None, training=None, initial_state=None):
419    # The input should be dense, padded with zeros. If a ragged input is fed
420    # into the layer, it is padded and the row lengths are used for masking.
421    inputs, row_lengths = backend.convert_inputs_if_ragged(inputs)
422    is_ragged_input = (row_lengths is not None)
423    self._validate_args_if_ragged(is_ragged_input, mask)
424
425    # GRU does not support constants. Ignore it during process.
426    inputs, initial_state, _ = self._process_inputs(inputs, initial_state, None)
427
428    if isinstance(mask, list):
429      mask = mask[0]
430
431    input_shape = backend.int_shape(inputs)
432    timesteps = input_shape[0] if self.time_major else input_shape[1]
433
434    # TODO(b/156447398) Investigate why the cuDNN kernel fails with ragged
435    # inputs.
436    if is_ragged_input or not self._could_use_gpu_kernel:
437      kwargs = {'training': training}
438      self._maybe_reset_cell_dropout_mask(self.cell)
439
440      def step(cell_inputs, cell_states):
441        return self.cell(cell_inputs, cell_states, **kwargs)
442
443      last_output, outputs, states = backend.rnn(
444          step,
445          inputs,
446          initial_state,
447          constants=None,
448          go_backwards=self.go_backwards,
449          mask=mask,
450          unroll=self.unroll,
451          input_length=row_lengths if row_lengths is not None else timesteps,
452          time_major=self.time_major,
453          zero_output_for_mask=self.zero_output_for_mask)
454      # This is a dummy tensor for testing purpose.
455      runtime = _runtime(_RUNTIME_UNKNOWN)
456    else:
457      last_output, outputs, runtime, states = self._defun_gru_call(
458          inputs, initial_state, training, mask, row_lengths)
459
460    if self.stateful:
461      updates = [state_ops.assign(self.states[0], states[0])]
462      self.add_update(updates)
463
464    if self.return_sequences:
465      output = backend.maybe_convert_to_ragged(
466          is_ragged_input, outputs, row_lengths, go_backwards=self.go_backwards)
467    else:
468      output = last_output
469
470    if self.return_state:
471      return [output] + list(states)
472    elif self._return_runtime:
473      return output, runtime
474    else:
475      return output
476
477  def _defun_gru_call(self, inputs, initial_state, training, mask,
478                      sequence_lengths):
479    # Use the new defun approach for backend implementation swap.
480    # Note that different implementations need to have same function
481    # signature, eg, the tensor parameters need to have same shape and dtypes.
482
483    self.reset_dropout_mask()
484    dropout_mask = self.get_dropout_mask_for_cell(inputs, training, count=3)
485    if dropout_mask is not None:
486      inputs = inputs * dropout_mask[0]
487
488    if _use_new_code():
489      gru_kwargs = {
490          'inputs': inputs,
491          'init_h': _read_variable_value(initial_state[0]),
492          'kernel': _read_variable_value(self.cell.kernel),
493          'recurrent_kernel': _read_variable_value(self.cell.recurrent_kernel),
494          'bias': _read_variable_value(self.cell.bias),
495          'mask': mask,
496          'time_major': self.time_major,
497          'go_backwards': self.go_backwards,
498          'sequence_lengths': sequence_lengths,
499          'zero_output_for_mask': self.zero_output_for_mask
500      }
501      (last_output, outputs, new_h,
502       runtime) = self._defun_wrapper.defun_layer(**gru_kwargs)
503    else:
504      gpu_gru_kwargs = {
505          'inputs': inputs,
506          'init_h': _read_variable_value(initial_state[0]),
507          'kernel': _read_variable_value(self.cell.kernel),
508          'recurrent_kernel': _read_variable_value(self.cell.recurrent_kernel),
509          'bias': _read_variable_value(self.cell.bias),
510          'mask': mask,
511          'time_major': self.time_major,
512          'go_backwards': self.go_backwards,
513          'sequence_lengths': sequence_lengths
514      }
515      normal_gru_kwargs = gpu_gru_kwargs.copy()
516      normal_gru_kwargs.update({
517          'zero_output_for_mask': self.zero_output_for_mask,
518      })
519
520      if context.executing_eagerly():
521        device_type = _get_context_device_type()
522        can_use_gpu = (
523            # Either user specified GPU or unspecified but GPU is available.
524            (device_type == _GPU_DEVICE_NAME or
525             (device_type is None and config.list_logical_devices('GPU'))) and
526            (mask is None or is_cudnn_supported_inputs(mask, self.time_major)))
527        # Under eager context, check the device placement and prefer the
528        if can_use_gpu:
529          last_output, outputs, new_h, runtime = gpu_gru(**gpu_gru_kwargs)
530        else:
531          last_output, outputs, new_h, runtime = standard_gru(
532              **normal_gru_kwargs)
533      else:
534        last_output, outputs, new_h, runtime = gru_with_backend_selection(
535            **normal_gru_kwargs)
536
537    states = [new_h]
538    return last_output, outputs, runtime, states
539
540
541def standard_gru(inputs, init_h, kernel, recurrent_kernel, bias, mask,
542                 time_major, go_backwards, sequence_lengths,
543                 zero_output_for_mask):
544  """GRU with standard kernel implementation.
545
546  This implementation can be run on all types of hardware.
547
548  This implementation lifts out all the layer weights and make them function
549  parameters. It has same number of tensor input params as the CuDNN
550  counterpart. The RNN step logic has been simplified, eg dropout and mask is
551  removed since CuDNN implementation does not support that.
552
553  Args:
554    inputs: Input tensor of GRU layer.
555    init_h: Initial state tensor for the cell output.
556    kernel: Weights for cell kernel.
557    recurrent_kernel: Weights for cell recurrent kernel.
558    bias: Weights for cell kernel bias and recurrent bias. The bias contains the
559      combined input_bias and recurrent_bias.
560    mask: Binary tensor of shape `(samples, timesteps)` indicating whether
561      a given timestep should be masked. An individual `True` entry indicates
562      that the corresponding timestep should be utilized, while a `False` entry
563      indicates that the corresponding timestep should be ignored.
564    time_major: Boolean, whether the inputs are in the format of
565      [time, batch, feature] or [batch, time, feature].
566    go_backwards: Boolean (default False). If True, process the input sequence
567      backwards and return the reversed sequence.
568    sequence_lengths: The lengths of all sequences coming from a variable length
569      input, such as ragged tensors. If the input has a fixed timestep size,
570      this should be None.
571    zero_output_for_mask: Boolean, whether to output zero for masked timestep.
572
573  Returns:
574    last_output: output tensor for the last timestep, which has shape
575      [batch, units].
576    outputs: output tensor for all timesteps, which has shape
577      [batch, time, units].
578    state_0: the cell output, which has same shape as init_h.
579    runtime: constant string tensor which indicate real runtime hardware. This
580      value is for testing purpose and should be used by user.
581  """
582  input_shape = backend.int_shape(inputs)
583  timesteps = input_shape[0] if time_major else input_shape[1]
584
585  input_bias, recurrent_bias = array_ops.unstack(bias)
586
587  def step(cell_inputs, cell_states):
588    """Step function that will be used by Keras RNN backend."""
589    h_tm1 = cell_states[0]
590
591    # inputs projected by all gate matrices at once
592    matrix_x = backend.dot(cell_inputs, kernel)
593    matrix_x = backend.bias_add(matrix_x, input_bias)
594
595    x_z, x_r, x_h = array_ops.split(matrix_x, 3, axis=1)
596
597    # hidden state projected by all gate matrices at once
598    matrix_inner = backend.dot(h_tm1, recurrent_kernel)
599    matrix_inner = backend.bias_add(matrix_inner, recurrent_bias)
600
601    recurrent_z, recurrent_r, recurrent_h = array_ops.split(matrix_inner, 3,
602                                                            axis=1)
603    z = nn.sigmoid(x_z + recurrent_z)
604    r = nn.sigmoid(x_r + recurrent_r)
605    hh = nn.tanh(x_h + r * recurrent_h)
606
607    # previous and candidate state mixed by update gate
608    h = z * h_tm1 + (1 - z) * hh
609    return h, [h]
610
611  last_output, outputs, new_states = backend.rnn(
612      step,
613      inputs, [init_h],
614      constants=None,
615      unroll=False,
616      time_major=time_major,
617      mask=mask,
618      go_backwards=go_backwards,
619      input_length=sequence_lengths
620      if sequence_lengths is not None else timesteps,
621      zero_output_for_mask=zero_output_for_mask)
622  return last_output, outputs, new_states[0], _runtime(_RUNTIME_CPU)
623
624
625def gpu_gru(inputs, init_h, kernel, recurrent_kernel, bias, mask, time_major,
626            go_backwards, sequence_lengths):
627  """GRU with CuDNN implementation which is only available for GPU."""
628  if not time_major and mask is None:
629    inputs = array_ops.transpose(inputs, perm=(1, 0, 2))
630    seq_axis, batch_axis = (0, 1)
631  else:
632    seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
633  # For init_h, cuDNN expects one more dim of num_layers before or after batch
634  # dim for time major or batch major inputs respectively
635  init_h = array_ops.expand_dims(init_h, axis=seq_axis)
636
637  weights = array_ops.split(kernel, 3, axis=1)
638  weights += array_ops.split(recurrent_kernel, 3, axis=1)
639  # Note that the bias was initialized as shape (2, 3 * units), flat it into
640  # (6 * units)
641  bias = array_ops.split(backend.flatten(bias), 6)
642
643  if sysconfig.get_build_info()['is_cuda_build']:
644    # Note that the gate order for CuDNN is different from the canonical format.
645    # canonical format is [z, r, h], whereas CuDNN is [r, z, h]. The swap need
646    # to be done for kernel, recurrent_kernel, input_bias, recurrent_bias.
647    # z is update gate weights.
648    # r is reset gate weights.
649    # h is output gate weights.
650    weights[0], weights[1] = weights[1], weights[0]
651    weights[3], weights[4] = weights[4], weights[3]
652    bias[0], bias[1] = bias[1], bias[0]
653    bias[3], bias[4] = bias[4], bias[3]
654
655  params = _canonical_to_params(
656      weights=weights,
657      biases=bias,
658      shape=constant_op.constant([-1]),
659      transpose_weights=True)
660
661  if mask is not None:
662    sequence_lengths = calculate_sequence_by_mask(mask, time_major)
663
664  if sequence_lengths is not None:
665    if go_backwards:
666      # Three reversals are required. E.g.,
667      # normal input = [1, 2, 3, 0, 0]  # where 0 need to be masked
668      # reversed_input_to_cudnn = [3, 2, 1, 0, 0]
669      # output_from_cudnn = [6, 5, 4, 0, 0]
670      # expected_output = [0, 0, 6, 5 ,4]
671      inputs = array_ops.reverse_sequence_v2(
672          inputs, sequence_lengths, seq_axis=seq_axis, batch_axis=batch_axis)
673    outputs, h, _, _, _ = gen_cudnn_rnn_ops.CudnnRNNV3(
674        input=inputs,
675        input_h=init_h,
676        input_c=0,
677        params=params,
678        is_training=True,
679        rnn_mode='gru',
680        sequence_lengths=sequence_lengths,
681        time_major=time_major)
682    if go_backwards:
683      outputs = array_ops.reverse_sequence_v2(
684          outputs, sequence_lengths, seq_axis=seq_axis, batch_axis=batch_axis)
685      outputs = array_ops.reverse(outputs, axis=[seq_axis])
686  else:
687    if go_backwards:
688      # Reverse axis 0 since the input is already convert to time major.
689      inputs = array_ops.reverse(inputs, axis=[0])
690    outputs, h, _, _ = gen_cudnn_rnn_ops.CudnnRNN(
691        input=inputs, input_h=init_h, input_c=0, params=params,
692        is_training=True, rnn_mode='gru')
693
694  last_output = outputs[-1]
695  if not time_major and mask is None:
696    outputs = array_ops.transpose(outputs, perm=[1, 0, 2])
697  h = array_ops.squeeze(h, axis=seq_axis)
698
699  # In the case of variable length input, the cudnn kernel will fill zeros for
700  # the output, whereas the default keras behavior is to bring over the previous
701  # output for t-1, so that in the return_sequence=False case, user can quickly
702  # get the final effect output instead just 0s at the last timestep.
703  # In order to mimic the default keras behavior, we copy the final h state as
704  # the last_output, since it is numerically same as the output.
705  if mask is not None:
706    last_output = h
707
708  return last_output, outputs, h, _runtime(_RUNTIME_GPU)
709
710
711def gru_with_backend_selection(inputs, init_h, kernel, recurrent_kernel, bias,
712                               mask, time_major, go_backwards, sequence_lengths,
713                               zero_output_for_mask):
714  """Call the GRU with optimized backend kernel selection.
715
716  Under the hood, this function will create two TF function, one with the most
717  generic kernel and can run on all device condition, and the second one with
718  CuDNN specific kernel, which can only run on GPU.
719
720  The first function will be called with normal_lstm_params, while the second
721  function is not called, but only registered in the graph. The Grappler will
722  do the proper graph rewrite and swap the optimized TF function based on the
723  device placement.
724
725  Args:
726    inputs: Input tensor of GRU layer.
727    init_h: Initial state tensor for the cell output.
728    kernel: Weights for cell kernel.
729    recurrent_kernel: Weights for cell recurrent kernel.
730    bias: Weights for cell kernel bias and recurrent bias. Only recurrent bias
731      is used in this case.
732    mask: Boolean tensor for mask out the steps within sequence.
733      An individual `True` entry indicates that the corresponding timestep
734      should be utilized, while a `False` entry indicates that the corresponding
735      timestep should be ignored.
736    time_major: Boolean, whether the inputs are in the format of
737      [time, batch, feature] or [batch, time, feature].
738    go_backwards: Boolean (default False). If True, process the input sequence
739      backwards and return the reversed sequence.
740    sequence_lengths: The lengths of all sequences coming from a variable length
741      input, such as ragged tensors. If the input has a fixed timestep size,
742      this should be None.
743    zero_output_for_mask: Boolean, whether to output zero for masked timestep.
744
745  Returns:
746    List of output tensors, same as standard_gru.
747  """
748  params = {
749      'inputs': inputs,
750      'init_h': init_h,
751      'kernel': kernel,
752      'recurrent_kernel': recurrent_kernel,
753      'bias': bias,
754      'mask': mask,
755      'time_major': time_major,
756      'go_backwards': go_backwards,
757      'sequence_lengths': sequence_lengths,
758      'zero_output_for_mask': zero_output_for_mask,
759  }
760
761  def gpu_gru_with_fallback(inputs, init_h, kernel, recurrent_kernel, bias,
762                            mask, time_major, go_backwards, sequence_lengths,
763                            zero_output_for_mask):
764    """Use CuDNN kernel when mask is none or strictly right padded."""
765    if mask is None:
766      return gpu_gru(
767          inputs=inputs,
768          init_h=init_h,
769          kernel=kernel,
770          recurrent_kernel=recurrent_kernel,
771          bias=bias,
772          mask=mask,
773          time_major=time_major,
774          go_backwards=go_backwards,
775          sequence_lengths=sequence_lengths)
776
777    def cudnn_gru_fn():
778      return gpu_gru(
779          inputs=inputs,
780          init_h=init_h,
781          kernel=kernel,
782          recurrent_kernel=recurrent_kernel,
783          bias=bias,
784          mask=mask,
785          time_major=time_major,
786          go_backwards=go_backwards,
787          sequence_lengths=sequence_lengths)
788
789    def standard_gru_fn():
790      return standard_gru(
791          inputs=inputs,
792          init_h=init_h,
793          kernel=kernel,
794          recurrent_kernel=recurrent_kernel,
795          bias=bias,
796          mask=mask,
797          time_major=time_major,
798          go_backwards=go_backwards,
799          sequence_lengths=sequence_lengths,
800          zero_output_for_mask=zero_output_for_mask)
801
802    return control_flow_ops.cond(
803        is_cudnn_supported_inputs(mask, time_major),
804        true_fn=cudnn_gru_fn,
805        false_fn=standard_gru_fn)
806
807  if _use_new_code():
808    # Chooses the implementation dynamically based on the running device.
809    (last_output, outputs, new_h,
810     runtime) = control_flow_ops.execute_fn_for_device(
811         {
812             _CPU_DEVICE_NAME: lambda: standard_gru(**params),
813             _GPU_DEVICE_NAME: lambda: gpu_gru_with_fallback(**params)
814         }, lambda: standard_gru(**params))
815  else:
816    # Each time a `tf.function` is called, we will give it a unique
817    # identifiable API name, so that Grappler won't get confused when it
818    # sees multiple GRU layers added into same graph, and it will be able
819    # to pair up the different implementations across them.
820    api_name = 'gru_' + str(uuid.uuid4())
821    supportive_attribute = {
822        'time_major': time_major,
823        'go_backwards': go_backwards,
824    }
825    defun_standard_gru = _generate_defun_backend(api_name, _CPU_DEVICE_NAME,
826                                                 standard_gru,
827                                                 supportive_attribute)
828    defun_gpu_gru = _generate_defun_backend(api_name, _GPU_DEVICE_NAME,
829                                            gpu_gru_with_fallback,
830                                            supportive_attribute)
831
832    # Call the normal GRU impl and register the CuDNN impl function. The
833    # grappler will kick in during session execution to optimize the graph.
834    last_output, outputs, new_h, runtime = defun_standard_gru(**params)
835    _function_register(defun_gpu_gru, **params)
836
837  return last_output, outputs, new_h, runtime
838
839
840@keras_export('keras.layers.LSTMCell', v1=[])
841class LSTMCell(recurrent.LSTMCell):
842  """Cell class for the LSTM layer.
843
844  See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
845  for details about the usage of RNN API.
846
847  This class processes one step within the whole time sequence input, whereas
848  `tf.keras.layer.LSTM` processes the whole sequence.
849
850  For example:
851
852  >>> inputs = tf.random.normal([32, 10, 8])
853  >>> rnn = tf.keras.layers.RNN(tf.keras.layers.LSTMCell(4))
854  >>> output = rnn(inputs)
855  >>> print(output.shape)
856  (32, 4)
857  >>> rnn = tf.keras.layers.RNN(
858  ...    tf.keras.layers.LSTMCell(4),
859  ...    return_sequences=True,
860  ...    return_state=True)
861  >>> whole_seq_output, final_memory_state, final_carry_state = rnn(inputs)
862  >>> print(whole_seq_output.shape)
863  (32, 10, 4)
864  >>> print(final_memory_state.shape)
865  (32, 4)
866  >>> print(final_carry_state.shape)
867  (32, 4)
868
869  Args:
870    units: Positive integer, dimensionality of the output space.
871    activation: Activation function to use. Default: hyperbolic tangent
872      (`tanh`). If you pass `None`, no activation is applied (ie. "linear"
873      activation: `a(x) = x`).
874    recurrent_activation: Activation function to use for the recurrent step.
875      Default: sigmoid (`sigmoid`). If you pass `None`, no activation is applied
876      (ie. "linear" activation: `a(x) = x`).
877    use_bias: Boolean, (default `True`), whether the layer uses a bias vector.
878    kernel_initializer: Initializer for the `kernel` weights matrix, used for
879      the linear transformation of the inputs. Default: `glorot_uniform`.
880    recurrent_initializer: Initializer for the `recurrent_kernel` weights
881      matrix, used for the linear transformation of the recurrent state.
882      Default: `orthogonal`.
883    bias_initializer: Initializer for the bias vector. Default: `zeros`.
884    unit_forget_bias: Boolean (default `True`). If True, add 1 to the bias of
885      the forget gate at initialization. Setting it to true will also force
886      `bias_initializer="zeros"`. This is recommended in [Jozefowicz et
887        al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
888    kernel_regularizer: Regularizer function applied to the `kernel` weights
889      matrix. Default: `None`.
890    recurrent_regularizer: Regularizer function applied to
891      the `recurrent_kernel` weights matrix. Default: `None`.
892    bias_regularizer: Regularizer function applied to the bias vector. Default:
893      `None`.
894    kernel_constraint: Constraint function applied to the `kernel` weights
895      matrix. Default: `None`.
896    recurrent_constraint: Constraint function applied to the `recurrent_kernel`
897      weights matrix. Default: `None`.
898    bias_constraint: Constraint function applied to the bias vector. Default:
899      `None`.
900    dropout: Float between 0 and 1. Fraction of the units to drop for the linear
901      transformation of the inputs. Default: 0.
902    recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for
903      the linear transformation of the recurrent state. Default: 0.
904
905  Call arguments:
906    inputs: A 2D tensor, with shape of `[batch, feature]`.
907    states: List of 2 tensors that corresponding to the cell's units. Both of
908      them have shape `[batch, units]`, the first tensor is the memory state
909      from previous time step, the second tensor is the carry state from
910      previous time step. For timestep 0, the initial state provided by user
911      will be feed to cell.
912    training: Python boolean indicating whether the layer should behave in
913      training mode or in inference mode. Only relevant when `dropout` or
914      `recurrent_dropout` is used.
915  """
916
917  def __init__(self,
918               units,
919               activation='tanh',
920               recurrent_activation='sigmoid',
921               use_bias=True,
922               kernel_initializer='glorot_uniform',
923               recurrent_initializer='orthogonal',
924               bias_initializer='zeros',
925               unit_forget_bias=True,
926               kernel_regularizer=None,
927               recurrent_regularizer=None,
928               bias_regularizer=None,
929               kernel_constraint=None,
930               recurrent_constraint=None,
931               bias_constraint=None,
932               dropout=0.,
933               recurrent_dropout=0.,
934               **kwargs):
935    super(LSTMCell, self).__init__(
936        units,
937        activation=activation,
938        recurrent_activation=recurrent_activation,
939        use_bias=use_bias,
940        kernel_initializer=kernel_initializer,
941        recurrent_initializer=recurrent_initializer,
942        bias_initializer=bias_initializer,
943        unit_forget_bias=unit_forget_bias,
944        kernel_regularizer=kernel_regularizer,
945        recurrent_regularizer=recurrent_regularizer,
946        bias_regularizer=bias_regularizer,
947        kernel_constraint=kernel_constraint,
948        recurrent_constraint=recurrent_constraint,
949        bias_constraint=bias_constraint,
950        dropout=dropout,
951        recurrent_dropout=recurrent_dropout,
952        implementation=kwargs.pop('implementation', 2),
953        **kwargs)
954
955
956@keras_export('keras.layers.LSTM', v1=[])
957class LSTM(recurrent.DropoutRNNCellMixin, recurrent.LSTM):
958  """Long Short-Term Memory layer - Hochreiter 1997.
959
960  See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
961  for details about the usage of RNN API.
962
963  Based on available runtime hardware and constraints, this layer
964  will choose different implementations (cuDNN-based or pure-TensorFlow)
965  to maximize the performance. If a GPU is available and all
966  the arguments to the layer meet the requirement of the CuDNN kernel
967  (see below for details), the layer will use a fast cuDNN implementation.
968
969  The requirements to use the cuDNN implementation are:
970
971  1. `activation` == `tanh`
972  2. `recurrent_activation` == `sigmoid`
973  3. `recurrent_dropout` == 0
974  4. `unroll` is `False`
975  5. `use_bias` is `True`
976  6. Inputs, if use masking, are strictly right-padded.
977  7. Eager execution is enabled in the outermost context.
978
979  For example:
980
981  >>> inputs = tf.random.normal([32, 10, 8])
982  >>> lstm = tf.keras.layers.LSTM(4)
983  >>> output = lstm(inputs)
984  >>> print(output.shape)
985  (32, 4)
986  >>> lstm = tf.keras.layers.LSTM(4, return_sequences=True, return_state=True)
987  >>> whole_seq_output, final_memory_state, final_carry_state = lstm(inputs)
988  >>> print(whole_seq_output.shape)
989  (32, 10, 4)
990  >>> print(final_memory_state.shape)
991  (32, 4)
992  >>> print(final_carry_state.shape)
993  (32, 4)
994
995  Args:
996    units: Positive integer, dimensionality of the output space.
997    activation: Activation function to use.
998      Default: hyperbolic tangent (`tanh`). If you pass `None`, no activation
999      is applied (ie. "linear" activation: `a(x) = x`).
1000    recurrent_activation: Activation function to use for the recurrent step.
1001      Default: sigmoid (`sigmoid`). If you pass `None`, no activation is
1002      applied (ie. "linear" activation: `a(x) = x`).
1003    use_bias: Boolean (default `True`), whether the layer uses a bias vector.
1004    kernel_initializer: Initializer for the `kernel` weights matrix, used for
1005      the linear transformation of the inputs. Default: `glorot_uniform`.
1006    recurrent_initializer: Initializer for the `recurrent_kernel` weights
1007      matrix, used for the linear transformation of the recurrent state.
1008      Default: `orthogonal`.
1009    bias_initializer: Initializer for the bias vector. Default: `zeros`.
1010    unit_forget_bias: Boolean (default `True`). If True, add 1 to the bias of
1011      the forget gate at initialization. Setting it to true will also force
1012      `bias_initializer="zeros"`. This is recommended in [Jozefowicz et
1013          al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf).
1014    kernel_regularizer: Regularizer function applied to the `kernel` weights
1015      matrix. Default: `None`.
1016    recurrent_regularizer: Regularizer function applied to the
1017      `recurrent_kernel` weights matrix. Default: `None`.
1018    bias_regularizer: Regularizer function applied to the bias vector. Default:
1019      `None`.
1020    activity_regularizer: Regularizer function applied to the output of the
1021      layer (its "activation"). Default: `None`.
1022    kernel_constraint: Constraint function applied to the `kernel` weights
1023      matrix. Default: `None`.
1024    recurrent_constraint: Constraint function applied to the `recurrent_kernel`
1025      weights matrix. Default: `None`.
1026    bias_constraint: Constraint function applied to the bias vector. Default:
1027      `None`.
1028    dropout: Float between 0 and 1. Fraction of the units to drop for the linear
1029      transformation of the inputs. Default: 0.
1030    recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for
1031      the linear transformation of the recurrent state. Default: 0.
1032    return_sequences: Boolean. Whether to return the last output. in the output
1033      sequence, or the full sequence. Default: `False`.
1034    return_state: Boolean. Whether to return the last state in addition to the
1035      output. Default: `False`.
1036    go_backwards: Boolean (default `False`). If True, process the input sequence
1037      backwards and return the reversed sequence.
1038    stateful: Boolean (default `False`). If True, the last state for each sample
1039      at index i in a batch will be used as initial state for the sample of
1040      index i in the following batch.
1041    time_major: The shape format of the `inputs` and `outputs` tensors.
1042      If True, the inputs and outputs will be in shape
1043      `[timesteps, batch, feature]`, whereas in the False case, it will be
1044      `[batch, timesteps, feature]`. Using `time_major = True` is a bit more
1045      efficient because it avoids transposes at the beginning and end of the
1046      RNN calculation. However, most TensorFlow data is batch-major, so by
1047      default this function accepts input and emits output in batch-major
1048      form.
1049    unroll: Boolean (default `False`). If True, the network will be unrolled,
1050      else a symbolic loop will be used. Unrolling can speed-up a RNN, although
1051      it tends to be more memory-intensive. Unrolling is only suitable for short
1052      sequences.
1053
1054  Call arguments:
1055    inputs: A 3D tensor with shape `[batch, timesteps, feature]`.
1056    mask: Binary tensor of shape `[batch, timesteps]` indicating whether
1057      a given timestep should be masked (optional, defaults to `None`).
1058      An individual `True` entry indicates that the corresponding timestep
1059      should be utilized, while a `False` entry indicates that the corresponding
1060      timestep should be ignored.
1061    training: Python boolean indicating whether the layer should behave in
1062      training mode or in inference mode. This argument is passed to the cell
1063      when calling it. This is only relevant if `dropout` or
1064      `recurrent_dropout` is used (optional, defaults to `None`).
1065    initial_state: List of initial state tensors to be passed to the first
1066      call of the cell (optional, defaults to `None` which causes creation
1067      of zero-filled initial state tensors).
1068  """
1069
1070  def __init__(self,
1071               units,
1072               activation='tanh',
1073               recurrent_activation='sigmoid',
1074               use_bias=True,
1075               kernel_initializer='glorot_uniform',
1076               recurrent_initializer='orthogonal',
1077               bias_initializer='zeros',
1078               unit_forget_bias=True,
1079               kernel_regularizer=None,
1080               recurrent_regularizer=None,
1081               bias_regularizer=None,
1082               activity_regularizer=None,
1083               kernel_constraint=None,
1084               recurrent_constraint=None,
1085               bias_constraint=None,
1086               dropout=0.,
1087               recurrent_dropout=0.,
1088               return_sequences=False,
1089               return_state=False,
1090               go_backwards=False,
1091               stateful=False,
1092               time_major=False,
1093               unroll=False,
1094               **kwargs):
1095    # return_runtime is a flag for testing, which shows the real backend
1096    # implementation chosen by grappler in graph mode.
1097    self.return_runtime = kwargs.pop('return_runtime', False)
1098
1099    super(LSTM, self).__init__(
1100        units,
1101        activation=activation,
1102        recurrent_activation=recurrent_activation,
1103        use_bias=use_bias,
1104        kernel_initializer=kernel_initializer,
1105        recurrent_initializer=recurrent_initializer,
1106        bias_initializer=bias_initializer,
1107        unit_forget_bias=unit_forget_bias,
1108        kernel_regularizer=kernel_regularizer,
1109        recurrent_regularizer=recurrent_regularizer,
1110        bias_regularizer=bias_regularizer,
1111        activity_regularizer=activity_regularizer,
1112        kernel_constraint=kernel_constraint,
1113        recurrent_constraint=recurrent_constraint,
1114        bias_constraint=bias_constraint,
1115        dropout=dropout,
1116        recurrent_dropout=recurrent_dropout,
1117        implementation=kwargs.pop('implementation', 2),
1118        return_sequences=return_sequences,
1119        return_state=return_state,
1120        go_backwards=go_backwards,
1121        stateful=stateful,
1122        time_major=time_major,
1123        unroll=unroll,
1124        **kwargs)
1125
1126    self.state_spec = [
1127        InputSpec(shape=(None, dim)) for dim in (self.units, self.units)
1128    ]
1129    self._could_use_gpu_kernel = (
1130        self.activation in (activations.tanh, nn.tanh) and
1131        self.recurrent_activation in (activations.sigmoid, nn.sigmoid) and
1132        recurrent_dropout == 0 and not unroll and use_bias and
1133        ops.executing_eagerly_outside_functions())
1134    if config.list_logical_devices('GPU'):
1135      # Only show the message when there is GPU available, user will not care
1136      # about the cuDNN if there isn't any GPU.
1137      if self._could_use_gpu_kernel:
1138        logging.debug(_CUDNN_AVAILABLE_MSG % self.name)
1139      else:
1140        logging.warning(_CUDNN_NOT_AVAILABLE_MSG % self.name)
1141
1142    if _use_new_code():
1143      self._defun_wrapper = _DefunWrapper(time_major, go_backwards, 'lstm')
1144
1145  def call(self, inputs, mask=None, training=None, initial_state=None):
1146    # The input should be dense, padded with zeros. If a ragged input is fed
1147    # into the layer, it is padded and the row lengths are used for masking.
1148    inputs, row_lengths = backend.convert_inputs_if_ragged(inputs)
1149    is_ragged_input = (row_lengths is not None)
1150    self._validate_args_if_ragged(is_ragged_input, mask)
1151
1152    # LSTM does not support constants. Ignore it during process.
1153    inputs, initial_state, _ = self._process_inputs(inputs, initial_state, None)
1154
1155    if isinstance(mask, list):
1156      mask = mask[0]
1157
1158    input_shape = backend.int_shape(inputs)
1159    timesteps = input_shape[0] if self.time_major else input_shape[1]
1160
1161    # TODO(b/156447398) Investigate why the cuDNN kernel fails with ragged
1162    # inputs.
1163    if is_ragged_input or not self._could_use_gpu_kernel:
1164      # Fall back to use the normal LSTM.
1165      kwargs = {'training': training}
1166      self._maybe_reset_cell_dropout_mask(self.cell)
1167
1168      def step(inputs, states):
1169        return self.cell(inputs, states, **kwargs)
1170
1171      last_output, outputs, states = backend.rnn(
1172          step,
1173          inputs,
1174          initial_state,
1175          constants=None,
1176          go_backwards=self.go_backwards,
1177          mask=mask,
1178          unroll=self.unroll,
1179          input_length=row_lengths if row_lengths is not None else timesteps,
1180          time_major=self.time_major,
1181          zero_output_for_mask=self.zero_output_for_mask)
1182      runtime = _runtime(_RUNTIME_UNKNOWN)
1183    else:
1184      # Use the new defun approach for backend implementation swap.
1185      # Note that different implementations need to have same function
1186      # signature, eg, the tensor parameters need to have same shape and dtypes.
1187      # Since the CuDNN has an extra set of bias, those bias will be passed to
1188      # both normal and CuDNN implementations.
1189      self.reset_dropout_mask()
1190      dropout_mask = self.get_dropout_mask_for_cell(inputs, training, count=4)
1191      if dropout_mask is not None:
1192        inputs = inputs * dropout_mask[0]
1193      if _use_new_code():
1194        lstm_kwargs = {
1195            'inputs':
1196                inputs,
1197            'init_h':
1198                _read_variable_value(initial_state[0]),
1199            'init_c':
1200                _read_variable_value(initial_state[1]),
1201            'kernel':
1202                _read_variable_value(self.cell.kernel),
1203            'recurrent_kernel':
1204                _read_variable_value(self.cell.recurrent_kernel),
1205            'bias':
1206                _read_variable_value(self.cell.bias),
1207            'mask':
1208                mask,
1209            'time_major':
1210                self.time_major,
1211            'go_backwards':
1212                self.go_backwards,
1213            'sequence_lengths':
1214                row_lengths,
1215            'zero_output_for_mask':
1216                self.zero_output_for_mask,
1217        }
1218        (last_output, outputs, new_h, new_c,
1219         runtime) = self._defun_wrapper.defun_layer(**lstm_kwargs)
1220      else:
1221        gpu_lstm_kwargs = {
1222            'inputs':
1223                inputs,
1224            'init_h':
1225                _read_variable_value(initial_state[0]),
1226            'init_c':
1227                _read_variable_value(initial_state[1]),
1228            'kernel':
1229                _read_variable_value(self.cell.kernel),
1230            'recurrent_kernel':
1231                _read_variable_value(self.cell.recurrent_kernel),
1232            'bias':
1233                _read_variable_value(self.cell.bias),
1234            'mask':
1235                mask,
1236            'time_major':
1237                self.time_major,
1238            'go_backwards':
1239                self.go_backwards,
1240            'sequence_lengths':
1241                row_lengths
1242        }
1243        normal_lstm_kwargs = gpu_lstm_kwargs.copy()
1244        normal_lstm_kwargs.update({
1245            'zero_output_for_mask': self.zero_output_for_mask,
1246        })
1247
1248        if context.executing_eagerly():
1249          device_type = _get_context_device_type()
1250          can_use_gpu = (
1251              # Either user specified GPU or unspecified but GPU is available.
1252              (device_type == _GPU_DEVICE_NAME or
1253               (device_type is None and config.list_logical_devices('GPU'))) and
1254              (mask is None or
1255               is_cudnn_supported_inputs(mask, self.time_major)))
1256          # Under eager context, check the device placement and prefer the
1257          # GPU implementation when GPU is available.
1258          if can_use_gpu:
1259            last_output, outputs, new_h, new_c, runtime = gpu_lstm(
1260                **gpu_lstm_kwargs)
1261          else:
1262            last_output, outputs, new_h, new_c, runtime = standard_lstm(
1263                **normal_lstm_kwargs)
1264        else:
1265          (last_output, outputs, new_h, new_c,
1266           runtime) = lstm_with_backend_selection(**normal_lstm_kwargs)
1267
1268      states = [new_h, new_c]
1269
1270    if self.stateful:
1271      updates = [
1272          state_ops.assign(self_state, state)
1273          for self_state, state in zip(self.states, states)
1274      ]
1275      self.add_update(updates)
1276
1277    if self.return_sequences:
1278      output = backend.maybe_convert_to_ragged(
1279          is_ragged_input, outputs, row_lengths, go_backwards=self.go_backwards)
1280    else:
1281      output = last_output
1282
1283    if self.return_state:
1284      return [output] + list(states)
1285    elif self.return_runtime:
1286      return output, runtime
1287    else:
1288      return output
1289
1290
1291def _canonical_to_params(weights, biases, shape, transpose_weights=False):
1292  """Utility function convert variable to CuDNN compatible parameter.
1293
1294  Note that Keras weights for kernels are different from the CuDNN format. Eg.:
1295
1296  ```
1297    Keras                 CuDNN
1298    [[0, 1, 2],  <--->  [[0, 2, 4],
1299     [3, 4, 5]]          [1, 3, 5]]
1300  ```
1301
1302  If the input weights need to be in a unified format, then set
1303  `transpose_weights=True` to convert the weights.
1304
1305  Args:
1306    weights: list of weights for the individual kernels and recurrent kernels.
1307    biases: list of biases for individual gate.
1308    shape: the shape for the converted variables that will be feed to CuDNN.
1309    transpose_weights: boolean, whether to transpose the weights.
1310
1311  Returns:
1312    The converted weights that can be feed to CuDNN ops as param.
1313  """
1314  def convert(w):
1315    return array_ops.transpose(w) if transpose_weights else w
1316
1317  weights = [array_ops.reshape(convert(x), shape) for x in weights]
1318  biases = [array_ops.reshape(x, shape) for x in biases]
1319  return array_ops.concat(weights + biases, axis=0)
1320
1321
1322def standard_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias,
1323                  mask, time_major, go_backwards, sequence_lengths,
1324                  zero_output_for_mask):
1325  """LSTM with standard kernel implementation.
1326
1327  This implementation can be run on all types for hardware.
1328
1329  This implementation lifts out all the layer weights and make them function
1330  parameters. It has same number of tensor input params as the CuDNN
1331  counterpart. The RNN step logic has been simplified, eg dropout and mask is
1332  removed since CuDNN implementation does not support that.
1333
1334  Note that the first half of the bias tensor should be ignored by this impl.
1335  The CuDNN impl need an extra set of input gate bias. In order to make the both
1336  function take same shape of parameter, that extra set of bias is also feed
1337  here.
1338
1339  Args:
1340    inputs: input tensor of LSTM layer.
1341    init_h: initial state tensor for the cell output.
1342    init_c: initial state tensor for the cell hidden state.
1343    kernel: weights for cell kernel.
1344    recurrent_kernel: weights for cell recurrent kernel.
1345    bias: weights for cell kernel bias and recurrent bias. Only recurrent bias
1346      is used in this case.
1347    mask: Boolean tensor for mask out the steps within sequence.
1348      An individual `True` entry indicates that the corresponding timestep
1349      should be utilized, while a `False` entry indicates that the corresponding
1350      timestep should be ignored.
1351    time_major: boolean, whether the inputs are in the format of
1352      [time, batch, feature] or [batch, time, feature].
1353    go_backwards: Boolean (default False). If True, process the input sequence
1354      backwards and return the reversed sequence.
1355    sequence_lengths: The lengths of all sequences coming from a variable length
1356      input, such as ragged tensors. If the input has a fixed timestep size,
1357      this should be None.
1358    zero_output_for_mask: Boolean, whether to output zero for masked timestep.
1359
1360  Returns:
1361    last_output: output tensor for the last timestep, which has shape
1362      [batch, units].
1363    outputs: output tensor for all timesteps, which has shape
1364      [batch, time, units].
1365    state_0: the cell output, which has same shape as init_h.
1366    state_1: the cell hidden state, which has same shape as init_c.
1367    runtime: constant string tensor which indicate real runtime hardware. This
1368      value is for testing purpose and should be used by user.
1369  """
1370  input_shape = backend.int_shape(inputs)
1371  timesteps = input_shape[0] if time_major else input_shape[1]
1372
1373  def step(cell_inputs, cell_states):
1374    """Step function that will be used by Keras RNN backend."""
1375    h_tm1 = cell_states[0]  # previous memory state
1376    c_tm1 = cell_states[1]  # previous carry state
1377
1378    z = backend.dot(cell_inputs, kernel)
1379    z += backend.dot(h_tm1, recurrent_kernel)
1380    z = backend.bias_add(z, bias)
1381
1382    z0, z1, z2, z3 = array_ops.split(z, 4, axis=1)
1383
1384    i = nn.sigmoid(z0)
1385    f = nn.sigmoid(z1)
1386    c = f * c_tm1 + i * nn.tanh(z2)
1387    o = nn.sigmoid(z3)
1388
1389    h = o * nn.tanh(c)
1390    return h, [h, c]
1391
1392  last_output, outputs, new_states = backend.rnn(
1393      step,
1394      inputs, [init_h, init_c],
1395      constants=None,
1396      unroll=False,
1397      time_major=time_major,
1398      mask=mask,
1399      go_backwards=go_backwards,
1400      input_length=(sequence_lengths
1401                    if sequence_lengths is not None else timesteps),
1402      zero_output_for_mask=zero_output_for_mask)
1403  return (last_output, outputs, new_states[0], new_states[1],
1404          _runtime(_RUNTIME_CPU))
1405
1406
1407def gpu_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask,
1408             time_major, go_backwards, sequence_lengths):
1409  """LSTM with either CuDNN or ROCm implementation which is only available for GPU.
1410
1411  Note that currently only right padded data is supported, or the result will be
1412  polluted by the unmasked data which should be filtered.
1413
1414  Args:
1415    inputs: Input tensor of LSTM layer.
1416    init_h: Initial state tensor for the cell output.
1417    init_c: Initial state tensor for the cell hidden state.
1418    kernel: Weights for cell kernel.
1419    recurrent_kernel: Weights for cell recurrent kernel.
1420    bias: Weights for cell kernel bias and recurrent bias. Only recurrent bias
1421      is used in this case.
1422    mask: Boolean tensor for mask out the steps within sequence.
1423      An individual `True` entry indicates that the corresponding timestep
1424      should be utilized, while a `False` entry indicates that the corresponding
1425      timestep should be ignored.
1426    time_major: Boolean, whether the inputs are in the format of [time, batch,
1427      feature] or [batch, time, feature].
1428    go_backwards: Boolean (default False). If True, process the input sequence
1429      backwards and return the reversed sequence.
1430    sequence_lengths: The lengths of all sequences coming from a variable length
1431      input, such as ragged tensors. If the input has a fixed timestep size,
1432      this should be None.
1433
1434  Returns:
1435    last_output: Output tensor for the last timestep, which has shape
1436      [batch, units].
1437    outputs: Output tensor for all timesteps, which has shape
1438      [batch, time, units].
1439    state_0: The cell output, which has same shape as init_h.
1440    state_1: The cell hidden state, which has same shape as init_c.
1441    runtime: Constant string tensor which indicate real runtime hardware. This
1442      value is for testing purpose and should not be used by user.
1443  """
1444  if not time_major and mask is None:
1445    inputs = array_ops.transpose(inputs, perm=(1, 0, 2))
1446    seq_axis, batch_axis = (0, 1)
1447  else:
1448    seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
1449  # For init_h and init_c, cuDNN expects one more dim of num_layers before or
1450  # after batch dim for time major or batch major inputs respectively
1451  init_h = array_ops.expand_dims(init_h, axis=seq_axis)
1452  init_c = array_ops.expand_dims(init_c, axis=seq_axis)
1453
1454  weights = array_ops.split(kernel, 4, axis=1)
1455  weights += array_ops.split(recurrent_kernel, 4, axis=1)
1456  # CuDNN has an extra set of bias for inputs, we disable them (setting to 0),
1457  # so that mathematically it is same as the canonical LSTM implementation.
1458  full_bias = array_ops.concat((array_ops.zeros_like(bias), bias), 0)
1459
1460  if sysconfig.get_build_info()['is_rocm_build']:
1461    # ROCm MIOpen's weight sequence for LSTM is different from both canonical
1462    # and Cudnn format
1463    # MIOpen: [i, f, o, c] Cudnn/Canonical: [i, f, c, o]
1464    # i is input gate weights.
1465    # f is forget gate weights.
1466    # o is output gate weights.
1467    # c is cell gate weights.
1468    weights = [weights[x] for x in (0, 1, 3, 2, 4, 5, 7, 6)]
1469    # full_bias is a tensor of shape (8*n,)
1470    full_bias = array_ops.split(full_bias, 8, axis=0)
1471    full_bias = [full_bias[x] for x in (0, 1, 3, 2, 4, 5, 7, 6)]
1472
1473  params = _canonical_to_params(
1474      weights=weights,
1475      biases=array_ops.split(full_bias, 8),
1476      shape=constant_op.constant([-1]),
1477      transpose_weights=True)
1478
1479  if mask is not None:
1480    sequence_lengths = calculate_sequence_by_mask(mask, time_major)
1481
1482  if sequence_lengths is not None:
1483    if go_backwards:
1484      # Three reversals are required. E.g.,
1485      # normal input = [1, 2, 3, 0, 0]  # where 0 need to be masked
1486      # reversed_input_to_cudnn = [3, 2, 1, 0, 0]
1487      # output_from_cudnn = [6, 5, 4, 0, 0]
1488      # expected_output = [0, 0, 6, 5 ,4]
1489      inputs = array_ops.reverse_sequence_v2(
1490          inputs, sequence_lengths, seq_axis=seq_axis, batch_axis=batch_axis)
1491    outputs, h, c, _, _ = gen_cudnn_rnn_ops.CudnnRNNV3(
1492        input=inputs,
1493        input_h=init_h,
1494        input_c=init_c,
1495        params=params,
1496        is_training=True,
1497        rnn_mode='lstm',
1498        sequence_lengths=sequence_lengths,
1499        time_major=time_major)
1500    if go_backwards:
1501      outputs = array_ops.reverse_sequence_v2(
1502          outputs, sequence_lengths, seq_axis=seq_axis, batch_axis=batch_axis)
1503      outputs = array_ops.reverse(outputs, axis=[seq_axis])
1504  else:
1505    # # Fill the array with shape [batch] with value of max timesteps.
1506    # sequence_length = array_ops.fill([array_ops.shape(inputs)[1]],
1507    #                                  array_ops.shape(inputs)[0])
1508    if go_backwards:
1509      # Reverse axis 0 since the input is already convert to time major.
1510      inputs = array_ops.reverse(inputs, axis=[0])
1511    outputs, h, c, _ = gen_cudnn_rnn_ops.CudnnRNN(
1512        input=inputs, input_h=init_h, input_c=init_c, params=params,
1513        is_training=True, rnn_mode='lstm')
1514
1515  last_output = outputs[-1]
1516  if not time_major and mask is None:
1517    outputs = array_ops.transpose(outputs, perm=[1, 0, 2])
1518  h = array_ops.squeeze(h, axis=seq_axis)
1519  c = array_ops.squeeze(c, axis=seq_axis)
1520
1521  # In the case of variable length input, the cudnn kernel will fill zeros for
1522  # the output, whereas the default keras behavior is to bring over the previous
1523  # output for t-1, so that in the return_sequence=False case, user can quickly
1524  # get the final effect output instead just 0s at the last timestep.
1525  # In order to mimic the default keras behavior, we copy the final h state as
1526  # the last_output, since it is numerically same as the output.
1527  if mask is not None:
1528    last_output = h
1529  return last_output, outputs, h, c, _runtime(_RUNTIME_GPU)
1530
1531
1532def lstm_with_backend_selection(inputs, init_h, init_c, kernel,
1533                                recurrent_kernel, bias, mask, time_major,
1534                                go_backwards, sequence_lengths,
1535                                zero_output_for_mask):
1536  """Call the LSTM with optimized backend kernel selection.
1537
1538  Under the hood, this function will create two TF function, one with the most
1539  generic kernel and can run on all device condition, and the second one with
1540  CuDNN specific kernel, which can only run on GPU.
1541
1542  The first function will be called with normal_lstm_params, while the second
1543  function is not called, but only registered in the graph. The Grappler will
1544  do the proper graph rewrite and swap the optimized TF function based on the
1545  device placement.
1546
1547  Args:
1548    inputs: Input tensor of LSTM layer.
1549    init_h: Initial state tensor for the cell output.
1550    init_c: Initial state tensor for the cell hidden state.
1551    kernel: Weights for cell kernel.
1552    recurrent_kernel: Weights for cell recurrent kernel.
1553    bias: Weights for cell kernel bias and recurrent bias. Only recurrent bias
1554      is used in this case.
1555    mask: Boolean tensor for mask out the steps within sequence.
1556      An individual `True` entry indicates that the corresponding timestep
1557      should be utilized, while a `False` entry indicates that the corresponding
1558      timestep should be ignored.
1559    time_major: Boolean, whether the inputs are in the format of
1560      [time, batch, feature] or [batch, time, feature].
1561    go_backwards: Boolean (default False). If True, process the input sequence
1562      backwards and return the reversed sequence.
1563    sequence_lengths: The lengths of all sequences coming from a variable length
1564      input, such as ragged tensors. If the input has a fixed timestep size,
1565      this should be None.
1566    zero_output_for_mask: Boolean, whether to output zero for masked timestep.
1567
1568  Returns:
1569    List of output tensors, same as standard_lstm.
1570  """
1571  params = {
1572      'inputs': inputs,
1573      'init_h': init_h,
1574      'init_c': init_c,
1575      'kernel': kernel,
1576      'recurrent_kernel': recurrent_kernel,
1577      'bias': bias,
1578      'mask': mask,
1579      'time_major': time_major,
1580      'go_backwards': go_backwards,
1581      'sequence_lengths': sequence_lengths,
1582      'zero_output_for_mask': zero_output_for_mask,
1583  }
1584
1585  def gpu_lstm_with_fallback(inputs, init_h, init_c, kernel, recurrent_kernel,
1586                             bias, mask, time_major, go_backwards,
1587                             sequence_lengths, zero_output_for_mask):
1588    """Use CuDNN kernel when mask is none or strictly right padded."""
1589    if mask is None:
1590      return gpu_lstm(
1591          inputs=inputs,
1592          init_h=init_h,
1593          init_c=init_c,
1594          kernel=kernel,
1595          recurrent_kernel=recurrent_kernel,
1596          bias=bias,
1597          mask=mask,
1598          time_major=time_major,
1599          go_backwards=go_backwards,
1600          sequence_lengths=sequence_lengths)
1601
1602    def cudnn_lstm_fn():
1603      return gpu_lstm(
1604          inputs=inputs,
1605          init_h=init_h,
1606          init_c=init_c,
1607          kernel=kernel,
1608          recurrent_kernel=recurrent_kernel,
1609          bias=bias,
1610          mask=mask,
1611          time_major=time_major,
1612          go_backwards=go_backwards,
1613          sequence_lengths=sequence_lengths)
1614
1615    def stardard_lstm_fn():
1616      return standard_lstm(
1617          inputs=inputs,
1618          init_h=init_h,
1619          init_c=init_c,
1620          kernel=kernel,
1621          recurrent_kernel=recurrent_kernel,
1622          bias=bias,
1623          mask=mask,
1624          time_major=time_major,
1625          go_backwards=go_backwards,
1626          sequence_lengths=sequence_lengths,
1627          zero_output_for_mask=zero_output_for_mask)
1628
1629    return control_flow_ops.cond(
1630        is_cudnn_supported_inputs(mask, time_major),
1631        true_fn=cudnn_lstm_fn,
1632        false_fn=stardard_lstm_fn)
1633
1634  if _use_new_code():
1635    # Chooses the implementation dynamically based on the running device.
1636    (last_output, outputs, new_h, new_c,
1637     runtime) = control_flow_ops.execute_fn_for_device(
1638         {
1639             _CPU_DEVICE_NAME: lambda: standard_lstm(**params),
1640             _GPU_DEVICE_NAME: lambda: gpu_lstm_with_fallback(**params)
1641         }, lambda: standard_lstm(**params))
1642  else:
1643    # Each time a `tf.function` is called, we will give it a unique
1644    # identifiable API name, so that Grappler won't get confused when it
1645    # sees multiple LSTM layers added into same graph, and it will be able
1646    # to pair up the different implementations across them.
1647    api_name = 'lstm_' + str(uuid.uuid4())
1648    supportive_attribute = {
1649        'time_major': time_major,
1650        'go_backwards': go_backwards,
1651    }
1652    defun_standard_lstm = _generate_defun_backend(api_name, _CPU_DEVICE_NAME,
1653                                                  standard_lstm,
1654                                                  supportive_attribute)
1655    defun_gpu_lstm = _generate_defun_backend(api_name, _GPU_DEVICE_NAME,
1656                                             gpu_lstm_with_fallback,
1657                                             supportive_attribute)
1658
1659    # Call the normal LSTM impl and register the CuDNN impl function. The
1660    # grappler will kick in during session execution to optimize the graph.
1661    last_output, outputs, new_h, new_c, runtime = defun_standard_lstm(**params)
1662    _function_register(defun_gpu_lstm, **params)
1663
1664  return last_output, outputs, new_h, new_c, runtime
1665
1666
1667def is_sequence_right_padded(mask):
1668  """Check the mask tensor and see if it right padded.
1669
1670  For CuDNN kernel, it uses the sequence length param to skip the tailing
1671  timestep. If the data is left padded, or not a strict right padding (has
1672  masked value in the middle of the sequence), then CuDNN kernel won't be work
1673  properly in those cases.
1674
1675  Left padded data: [[False, False, True, True, True]].
1676  Right padded data: [[True, True, True, False, False]].
1677  Mixture of mask/unmasked data: [[True, False, True, False, False]].
1678
1679  Note that for the mixed data example above, the actually data RNN should see
1680  are those 2 Trues (index 0 and 2), the index 1 False should be ignored and not
1681  pollute the internal states.
1682
1683  Args:
1684    mask: the Boolean tensor with shape [batch, timestep]
1685
1686  Returns:
1687    boolean scalar tensor, whether the mask is strictly right padded.
1688  """
1689  max_seq_length = array_ops.shape(mask)[1]
1690  count_of_true = math_ops.reduce_sum(math_ops.cast(mask, dtypes.int32), axis=1)
1691  right_padded_mask = array_ops.sequence_mask(
1692      count_of_true, maxlen=max_seq_length)
1693  return math_ops.reduce_all(math_ops.equal(mask, right_padded_mask))
1694
1695
1696def has_fully_masked_sequence(mask):
1697  # See https://github.com/tensorflow/tensorflow/issues/33148 for more details.
1698  # Cudnn kernel will error out if the input sequence contains any fully masked
1699  # data. We walk around this issue by rerouting the computation to standard
1700  # kernel, until the issue on cudnn side has been fixed.
1701  # For a fully masked sequence, it will contain all Falses. To make it easy to
1702  # check, we inverse the boolean, check if any of the sequence has all True.
1703  return math_ops.reduce_any(
1704      math_ops.reduce_all(
1705          math_ops.logical_not(mask),
1706          axis=1))
1707
1708
1709def is_cudnn_supported_inputs(mask, time_major):
1710  if time_major:
1711    mask = array_ops.transpose(mask)
1712
1713  return math_ops.logical_and(
1714      is_sequence_right_padded(mask),
1715      math_ops.logical_not(has_fully_masked_sequence(mask)))
1716
1717
1718def calculate_sequence_by_mask(mask, time_major):
1719  """Calculate the sequence length tensor (1-D) based on the masking tensor.
1720
1721  The masking tensor is a 2D boolean tensor with shape [batch, timestep]. For
1722  any timestep that should be masked, the corresponding field will be False.
1723  Consider the following example:
1724    a = [[True, True, False, False],
1725         [True, True, True, False]]
1726  It is a (2, 4) tensor, and the corresponding sequence length result should be
1727  1D tensor with value [2, 3]. Note that the masking tensor must be right
1728  padded that could be checked by, e.g., `is_sequence_right_padded()`.
1729
1730  Args:
1731    mask: Boolean tensor with shape [batch, timestep] or [timestep, batch] if
1732      time_major=True.
1733    time_major: Boolean, which indicates whether the mask is time major or batch
1734      major.
1735  Returns:
1736    sequence_length: 1D int32 tensor.
1737  """
1738  timestep_index = 0 if time_major else 1
1739  return math_ops.reduce_sum(math_ops.cast(mask, dtypes.int32),
1740                             axis=timestep_index)
1741
1742
1743def _generate_defun_backend(unique_api_name, preferred_device, func,
1744                            supportive_attributes):
1745  function_attributes = {
1746      _FUNCTION_API_NAME_ATTRIBUTE: unique_api_name,
1747      _FUNCTION_DEVICE_ATTRIBUTE: preferred_device,
1748  }
1749  function_attributes.update(supportive_attributes)
1750  return function.defun_with_attributes(func=func,
1751                                        attributes=function_attributes,
1752                                        autograph=False)
1753
1754
1755def _get_context_device_type():
1756  """Parse the current context and return the device type, eg CPU/GPU."""
1757  current_device = get_device_name()
1758  if current_device is None:
1759    return None
1760  return device.DeviceSpec.from_string(current_device).device_type
1761
1762
1763def _runtime(runtime_name):
1764  with ops.device('/cpu:0'):
1765    return constant_op.constant(
1766        runtime_name, dtype=dtypes.float32, name='runtime')
1767
1768
1769def _read_variable_value(v):
1770  """Read the value of a variable if it is variable."""
1771  if isinstance(v, variables.Variable):
1772    return v.read_value()
1773  return v
1774
1775
1776def _function_register(func, *args, **kwargs):
1777  """Register a specialization of a `Function` into the graph.
1778
1779  This won't actually call the function with the inputs, and only put the
1780  function definition into graph. Register function with different input param
1781  will result into multiple version of functions registered in graph.
1782
1783  Args:
1784    func: the `Function` instance that generated by a @defun
1785    *args: input arguments for the Python function.
1786    **kwargs: input keyword arguments for the Python function.
1787
1788  Returns:
1789    a `ConcreteFunction` object specialized to inputs and execution context.
1790
1791  Raises:
1792    ValueError: When the input function is not a defun wrapped python function.
1793  """
1794  concrete_func = func.get_concrete_function(*args, **kwargs)
1795  concrete_func.add_to_graph()
1796  concrete_func.add_gradient_functions_to_graph()
1797  return concrete_func
1798