xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/ctc_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""CTC (Connectionist Temporal Classification) Operations."""
16
17import uuid
18
19from tensorflow.python.eager import context
20from tensorflow.python.eager import function as function_eager
21
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import device
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import function
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import sparse_tensor
28from tensorflow.python.framework import tensor_shape
29
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import custom_gradient
32from tensorflow.python.ops import functional_ops
33from tensorflow.python.ops import gen_ctc_ops
34from tensorflow.python.ops import inplace_ops
35from tensorflow.python.ops import linalg_ops
36from tensorflow.python.ops import map_fn
37from tensorflow.python.ops import math_ops
38from tensorflow.python.ops import nn_ops
39from tensorflow.python.ops import sparse_ops
40from tensorflow.python.ops.nn_grad import _BroadcastMul
41from tensorflow.python.util import deprecation
42from tensorflow.python.util import dispatch
43from tensorflow.python.util import nest
44from tensorflow.python.util.tf_export import tf_export
45
46_DEFUN_API_NAME_ATTRIBUTE = "api_implements"
47_DEFUN_DEVICE_ATTRIBUTE = "api_preferred_device"
48_CPU_DEVICE_NAME = "CPU"
49_GPU_DEVICE_NAME = "GPU"
50
51
52def _get_context_device_type():
53  """Parse the current context and return the device type, eg CPU/GPU."""
54  current_device = context.context().device_name
55  if current_device is None:
56    return None
57  return device.DeviceSpec.from_string(current_device).device_type
58
59
60def _generate_defun_backend(unique_api_name, preferred_device, func):
61  function_attributes = {
62      _DEFUN_API_NAME_ATTRIBUTE: unique_api_name,
63      _DEFUN_DEVICE_ATTRIBUTE: preferred_device,
64  }
65  return function_eager.defun_with_attributes(
66      func=func, attributes=function_attributes, autograph=False)
67
68# pylint: disable=protected-access, invalid-name
69@tf_export(v1=["nn.ctc_loss"])
70@dispatch.add_dispatch_support
71def ctc_loss(labels,
72             inputs=None,
73             sequence_length=None,
74             preprocess_collapse_repeated=False,
75             ctc_merge_repeated=True,
76             ignore_longer_outputs_than_inputs=False,
77             time_major=True,
78             logits=None):
79  """Computes the CTC (Connectionist Temporal Classification) Loss.
80
81  This op implements the CTC loss as presented in (Graves et al., 2006).
82
83  Input requirements:
84
85  ```
86  sequence_length(b) <= time for all b
87
88  max(labels.indices(labels.indices[:, 1] == b, 2))
89    <= sequence_length(b) for all b.
90  ```
91
92  Notes:
93
94  This class performs the softmax operation for you, so inputs should
95  be e.g. linear projections of outputs by an LSTM.
96
97  The `inputs` Tensor's innermost dimension size, `num_classes`, represents
98  `num_labels + 1` classes, where num_labels is the number of true labels, and
99  the largest value `(num_classes - 1)` is reserved for the blank label.
100
101  For example, for a vocabulary containing 3 labels `[a, b, c]`,
102  `num_classes = 4` and the labels indexing is `{a: 0, b: 1, c: 2, blank: 3}`.
103
104  Regarding the arguments `preprocess_collapse_repeated` and
105  `ctc_merge_repeated`:
106
107  If `preprocess_collapse_repeated` is True, then a preprocessing step runs
108  before loss calculation, wherein repeated labels passed to the loss
109  are merged into single labels.  This is useful if the training labels come
110  from, e.g., forced alignments and therefore have unnecessary repetitions.
111
112  If `ctc_merge_repeated` is set False, then deep within the CTC calculation,
113  repeated non-blank labels will not be merged and are interpreted
114  as individual labels.  This is a simplified (non-standard) version of CTC.
115
116  Here is a table of the (roughly) expected first order behavior:
117
118  * `preprocess_collapse_repeated=False`, `ctc_merge_repeated=True`
119
120    Classical CTC behavior: Outputs true repeated classes with blanks in
121    between, and can also output repeated classes with no blanks in
122    between that need to be collapsed by the decoder.
123
124  * `preprocess_collapse_repeated=True`, `ctc_merge_repeated=False`
125
126    Never learns to output repeated classes, as they are collapsed
127    in the input labels before training.
128
129  * `preprocess_collapse_repeated=False`, `ctc_merge_repeated=False`
130
131    Outputs repeated classes with blanks in between, but generally does not
132    require the decoder to collapse/merge repeated classes.
133
134  * `preprocess_collapse_repeated=True`, `ctc_merge_repeated=True`
135
136    Untested.  Very likely will not learn to output repeated classes.
137
138  The `ignore_longer_outputs_than_inputs` option allows to specify the behavior
139  of the CTCLoss when dealing with sequences that have longer outputs than
140  inputs. If true, the CTCLoss will simply return zero gradient for those
141  items, otherwise an InvalidArgument error is returned, stopping training.
142
143  Args:
144    labels: An `int32` `SparseTensor`.
145      `labels.indices[i, :] == [b, t]` means `labels.values[i]` stores the id
146        for (batch b, time t). `labels.values[i]` must take on values in `[0,
147        num_labels)`. See `core/ops/ctc_ops.cc` for more details.
148    inputs: 3-D `float` `Tensor`.
149      If time_major == False, this will be a `Tensor` shaped: `[batch_size,
150        max_time, num_classes]`.
151      If time_major == True (default), this will be a `Tensor` shaped:
152        `[max_time, batch_size, num_classes]`. The logits.
153    sequence_length: 1-D `int32` vector, size `[batch_size]`. The sequence
154      lengths.
155    preprocess_collapse_repeated: Boolean.  Default: False. If True, repeated
156      labels are collapsed prior to the CTC calculation.
157    ctc_merge_repeated: Boolean.  Default: True.
158    ignore_longer_outputs_than_inputs: Boolean. Default: False. If True,
159      sequences with longer outputs than inputs will be ignored.
160    time_major: The shape format of the `inputs` Tensors. If True, these
161      `Tensors` must be shaped `[max_time, batch_size, num_classes]`. If False,
162      these `Tensors` must be shaped `[batch_size, max_time, num_classes]`.
163      Using `time_major = True` (default) is a bit more efficient because it
164      avoids transposes at the beginning of the ctc_loss calculation.  However,
165      most TensorFlow data is batch-major, so by this function also accepts
166      inputs in batch-major form.
167    logits: Alias for inputs.
168
169  Returns:
170    A 1-D `float` `Tensor`, size `[batch]`, containing the negative log
171      probabilities.
172
173  Raises:
174    TypeError: if labels is not a `SparseTensor`.
175
176  References:
177      Connectionist Temporal Classification - Labeling Unsegmented Sequence Data
178      with Recurrent Neural Networks:
179        [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891)
180        ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
181  """
182  return _ctc_loss_impl(
183      labels,
184      inputs,
185      sequence_length,
186      preprocess_collapse_repeated,
187      ctc_merge_repeated,
188      ignore_longer_outputs_than_inputs,
189      time_major,
190      logits,
191      use_cudnn=False)
192
193
194def _ctc_loss_impl(labels,
195                   inputs=None,
196                   sequence_length=None,
197                   preprocess_collapse_repeated=False,
198                   ctc_merge_repeated=True,
199                   ignore_longer_outputs_than_inputs=False,
200                   time_major=True,
201                   logits=None,
202                   use_cudnn=False):
203  # Helper function of ctc_loss with one additional param:
204  # use_cudnn: A bool to enable cuDNN CTC loss operation. If true, the blank
205  #   index has to be 0.
206
207  # The second, third, etc output tensors contain the gradients.  We use it in
208  # _CTCLossGrad() below.
209  if not isinstance(labels, sparse_tensor.SparseTensor):
210    raise TypeError("Expected argument `labels` to be a SparseTensor. "
211                    f"Received labels={labels} of type: "
212                    f"{type(labels).__name__}")
213
214  # For internal calculations, we transpose to [time, batch, num_classes]
215  inputs = deprecation.deprecated_argument_lookup("logits", logits, "inputs",
216                                                  inputs)
217
218  inputs = ops.convert_to_tensor(inputs, name="logits")
219  if not time_major:
220    inputs = array_ops.transpose(inputs, [1, 0, 2])  # (B,T,N) => (T,B,N)
221
222  orig_dtype = inputs.dtype
223  if orig_dtype in (dtypes.float16, dtypes.bfloat16):
224    inputs = math_ops.cast(inputs, dtypes.float32)
225
226  # gen_ctc_ops.ctc_loss_v2 differs from gen_ctc_ops.ctc_loss. v2 assumes the
227  # blank index to be 0, but v1 views it as the last index.
228  if use_cudnn:
229    ctc_loss_func = gen_ctc_ops.ctc_loss_v2
230  else:
231    ctc_loss_func = gen_ctc_ops.ctc_loss
232
233  loss, _ = ctc_loss_func(
234      inputs,
235      labels.indices,
236      labels.values,
237      sequence_length,
238      preprocess_collapse_repeated=preprocess_collapse_repeated,
239      ctc_merge_repeated=ctc_merge_repeated,
240      ignore_longer_outputs_than_inputs=ignore_longer_outputs_than_inputs)
241
242  if orig_dtype in (dtypes.float16, dtypes.bfloat16):
243    loss = math_ops.cast(loss, orig_dtype)
244
245  return loss
246
247# pylint: disable=unused-argument
248def _CTCLossGradImpl(op, grad_loss, _):
249  # Outputs are: loss, grad
250  #
251  # Currently there is no way to take the second derivative of this op
252  # due to the fused implementation's interaction with tf.gradients(),
253  # so we make sure we prevent silently incorrect results by raising
254  # an error if the second derivative is requested via prevent_gradient.
255  grad_without_gradient = array_ops.prevent_gradient(
256      op.outputs[1],
257      message="Currently there is no way to take the second "
258      " derivative of ctc_loss due to the fused implementation's interaction "
259      " with tf.gradients()")
260  # Return gradient for inputs and None for
261  # labels_indices, labels_values and sequence_length
262  return [_BroadcastMul(grad_loss, grad_without_gradient), None, None, None]
263
264
265# pylint: disable=unused-argument
266@ops.RegisterGradient("CTCLoss")
267def _CTCLossGrad(op, grad_loss, _):
268  """The derivative provided by CTC Loss.
269
270  Args:
271     op: the CTCLoss op.
272     grad_loss: The backprop for cost.
273
274  Returns:
275     The CTC Loss gradient.
276  """
277  return _CTCLossGradImpl(op, grad_loss, _)
278
279
280# pylint: disable=unused-argument
281@ops.RegisterGradient("CTCLossV2")
282def _CTCLossV2Grad(op, grad_loss, _):
283  """The derivative provided by CTC Loss V2.
284
285  Args:
286     op: the CTCLossV2 op.
287     grad_loss: The backprop for cost.
288
289  Returns:
290     The CTC Loss V2 gradient.
291  """
292  return _CTCLossGradImpl(op, grad_loss, _)
293
294
295@tf_export("nn.ctc_greedy_decoder")
296@dispatch.add_dispatch_support
297def ctc_greedy_decoder(inputs,
298                       sequence_length,
299                       merge_repeated=True,
300                       blank_index=None):
301  """Performs greedy decoding on the logits given in input (best path).
302
303  Given a tensor as `inputs`, the `blank_index` parameter defines the class
304  index of the blank symbol.
305
306  For example:
307
308  If `blank_index` is equal to 1:
309
310  >>> inf = float("inf")
311  >>> logits = tf.constant([[[   0., -inf, -inf],
312  ...                        [ -2.3, -inf, -0.1]],
313  ...                       [[ -inf, -0.5, -inf],
314  ...                        [ -inf, -inf, -0.1]],
315  ...                       [[ -inf, -inf, -inf],
316  ...                        [ -0.1, -inf, -2.3]]])
317  >>> seq_lens = tf.constant([2, 3])
318  >>> outputs = tf.nn.ctc_greedy_decoder(
319  ...     logits,
320  ...     seq_lens,
321  ...     blank_index=1)
322
323  Notes:
324
325  - Unlike `ctc_beam_search_decoder`, `ctc_greedy_decoder` considers blanks
326    as regular elements when computing the probability of a sequence.
327  - Default `blank_index` is `(num_classes - 1)`, unless overriden.
328
329  If `merge_repeated` is `True`, merge repeated classes in output.
330  This means that if consecutive logits' maximum indices are the same,
331  only the first of these is emitted.  The sequence `A B B * B * B` (where '*'
332  is the blank label) becomes
333
334    * `A B B B` if `merge_repeated=True`.
335    * `A B B B B` if `merge_repeated=False`.
336
337  Args:
338    inputs: 3-D `float` `Tensor` sized `[max_time, batch_size, num_classes]`.
339      The logits.
340    sequence_length: 1-D `int32` vector containing sequence lengths, having size
341      `[batch_size]`.
342    merge_repeated: Boolean.  Default: True.
343    blank_index: (Optional). Default: `num_classes - 1`. Define the class index
344      to use for the blank label. Negative values will start from num_classes,
345      ie, -1 will reproduce the ctc_greedy_decoder behavior of using
346      num_classes - 1 for the blank symbol, which corresponds to the default.
347
348  Returns:
349    A tuple `(decoded, neg_sum_logits)` where
350
351    decoded: A single-element list. `decoded[0]`
352      is an `SparseTensor` containing the decoded outputs s.t.:
353
354      `decoded.indices`: Indices matrix `(total_decoded_outputs, 2)`.
355        The rows store: `[batch, time]`.
356
357      `decoded.values`: Values vector, size `(total_decoded_outputs)`.
358        The vector stores the decoded classes.
359
360      `decoded.dense_shape`: Shape vector, size `(2)`.
361        The shape values are: `[batch_size, max_decoded_length]`
362
363    neg_sum_logits: A `float` matrix `(batch_size x 1)` containing, for the
364        sequence found, the negative of the sum of the greatest logit at each
365        timeframe.
366  """
367
368  outputs = gen_ctc_ops.ctc_greedy_decoder(
369      inputs,
370      sequence_length,
371      merge_repeated=merge_repeated,
372      blank_index=blank_index)
373  (decoded_ix, decoded_val, decoded_shape, log_probabilities) = outputs
374  return ([sparse_tensor.SparseTensor(decoded_ix, decoded_val,
375                                      decoded_shape)], log_probabilities)
376
377
378@tf_export(v1=["nn.ctc_beam_search_decoder"])
379@dispatch.add_dispatch_support
380def ctc_beam_search_decoder(inputs,
381                            sequence_length,
382                            beam_width=100,
383                            top_paths=1,
384                            merge_repeated=True):
385  """Performs beam search decoding on the logits given in input.
386
387  **Note** Although in general greedy search is a special case of beam-search
388  with `top_paths=1` and `beam_width=1`, `ctc_beam_search_decoder` differs
389  from `ctc_greedy_decoder` in the treatment of blanks when computing the
390  probability of a sequence:
391    - `ctc_beam_search_decoder` treats blanks as sequence termination
392    - `ctc_greedy_decoder` treats blanks as regular elements
393
394  If `merge_repeated` is `True`, merge repeated classes in the output beams.
395  This means that if consecutive entries in a beam are the same,
396  only the first of these is emitted.  That is, when the sequence is
397  `A B B * B * B` (where '*' is the blank label), the return value is:
398
399    * `A B` if `merge_repeated = True`.
400    * `A B B B` if `merge_repeated = False`.
401
402  Args:
403    inputs: 3-D `float` `Tensor`, size `[max_time x batch_size x num_classes]`.
404      The logits.
405    sequence_length: 1-D `int32` vector containing sequence lengths, having size
406      `[batch_size]`.
407    beam_width: An int scalar >= 0 (beam search beam width).
408    top_paths: An int scalar >= 0, <= beam_width (controls output size).
409    merge_repeated: Boolean.  Default: True.
410
411  Returns:
412    A tuple `(decoded, log_probabilities)` where
413
414    decoded: A list of length top_paths, where `decoded[j]`
415      is a `SparseTensor` containing the decoded outputs:
416
417      `decoded[j].indices`: Indices matrix `(total_decoded_outputs[j] x 2)`
418        The rows store: [batch, time].
419
420      `decoded[j].values`: Values vector, size `(total_decoded_outputs[j])`.
421        The vector stores the decoded classes for beam j.
422
423      `decoded[j].dense_shape`: Shape vector, size `(2)`.
424        The shape values are: `[batch_size, max_decoded_length[j]]`.
425
426    log_probability: A `float` matrix `(batch_size x top_paths)` containing
427        sequence log-probabilities.
428  """
429
430  decoded_ixs, decoded_vals, decoded_shapes, log_probabilities = (
431      gen_ctc_ops.ctc_beam_search_decoder(
432          inputs,
433          sequence_length,
434          beam_width=beam_width,
435          top_paths=top_paths,
436          merge_repeated=merge_repeated))
437
438  return ([
439      sparse_tensor.SparseTensor(ix, val, shape)
440      for (ix, val, shape) in zip(decoded_ixs, decoded_vals, decoded_shapes)
441  ], log_probabilities)
442
443
444@tf_export("nn.ctc_beam_search_decoder", v1=["nn.ctc_beam_search_decoder_v2"])
445@dispatch.add_dispatch_support
446def ctc_beam_search_decoder_v2(inputs,
447                               sequence_length,
448                               beam_width=100,
449                               top_paths=1):
450  """Performs beam search decoding on the logits given in input.
451
452  **Note** Although in general greedy search is a special case of beam-search
453  with `top_paths=1` and `beam_width=1`, `ctc_beam_search_decoder` differs
454  from `ctc_greedy_decoder` in the treatment of blanks when computing the
455  probability of a sequence:
456    - `ctc_beam_search_decoder` treats blanks as sequence termination
457    - `ctc_greedy_decoder` treats blanks as regular elements
458
459  Args:
460    inputs: 3-D `float` `Tensor`, size `[max_time, batch_size, num_classes]`.
461      The logits.
462    sequence_length: 1-D `int32` vector containing sequence lengths, having size
463      `[batch_size]`.
464    beam_width: An int scalar >= 0 (beam search beam width).
465    top_paths: An int scalar >= 0, <= beam_width (controls output size).
466
467  Returns:
468    A tuple `(decoded, log_probabilities)` where
469
470    decoded: A list of length top_paths, where `decoded[j]`
471      is a `SparseTensor` containing the decoded outputs:
472
473      `decoded[j].indices`: Indices matrix `[total_decoded_outputs[j], 2]`;
474        The rows store: `[batch, time]`.
475
476      `decoded[j].values`: Values vector, size `[total_decoded_outputs[j]]`.
477        The vector stores the decoded classes for beam `j`.
478
479      `decoded[j].dense_shape`: Shape vector, size `(2)`.
480        The shape values are: `[batch_size, max_decoded_length[j]]`.
481
482    log_probability: A `float` matrix `[batch_size, top_paths]` containing
483        sequence log-probabilities.
484  """
485
486  # Note, merge_repeated is an invalid optimization that is removed from the
487  # public API: it returns low probability paths.
488  return ctc_beam_search_decoder(
489      inputs,
490      sequence_length=sequence_length,
491      beam_width=beam_width,
492      top_paths=top_paths,
493      merge_repeated=False)
494
495
496ops.NotDifferentiable("CTCGreedyDecoder")
497ops.NotDifferentiable("CTCBeamSearchDecoder")
498
499
500def _ctc_state_trans(label_seq):
501  """Compute CTC alignment model transition matrix.
502
503  Args:
504    label_seq: tensor of shape [batch_size, max_seq_length]
505
506  Returns:
507    tensor of shape [batch_size, states, states] with a state transition matrix
508    computed for each sequence of the batch.
509  """
510
511  with ops.name_scope("ctc_state_trans"):
512    label_seq = ops.convert_to_tensor(label_seq, name="label_seq")
513    batch_size = _get_dim(label_seq, 0)
514    num_labels = _get_dim(label_seq, 1)
515
516    num_label_states = num_labels + 1
517    num_states = 2 * num_label_states
518
519    label_states = math_ops.range(num_label_states)
520    blank_states = label_states + num_label_states
521
522    # Start state to first label.
523    start_to_label = [[1, 0]]
524
525    # Blank to label transitions.
526    blank_to_label = array_ops.stack([label_states[1:], blank_states[:-1]], 1)
527
528    # Label to blank transitions.
529    label_to_blank = array_ops.stack([blank_states, label_states], 1)
530
531    # Scatter transitions that don't depend on sequence.
532    indices = array_ops.concat([start_to_label, blank_to_label, label_to_blank],
533                               0)
534    values = array_ops.ones([_get_dim(indices, 0)])
535    trans = array_ops.scatter_nd(
536        indices, values, shape=[num_states, num_states])
537    trans += linalg_ops.eye(num_states)  # Self-loops.
538
539    # Label to label transitions. Disallow transitions between repeated labels
540    # with no blank state in between.
541    batch_idx = array_ops.zeros_like(label_states[2:])
542    indices = array_ops.stack([batch_idx, label_states[2:], label_states[1:-1]],
543                              1)
544    indices = array_ops.tile(
545        array_ops.expand_dims(indices, 0), [batch_size, 1, 1])
546    batch_idx = array_ops.expand_dims(math_ops.range(batch_size), 1) * [1, 0, 0]
547    indices += array_ops.expand_dims(batch_idx, 1)
548    repeats = math_ops.equal(label_seq[:, :-1], label_seq[:, 1:])
549    values = 1.0 - math_ops.cast(repeats, dtypes.float32)
550    batched_shape = [batch_size, num_states, num_states]
551    label_to_label = array_ops.scatter_nd(indices, values, batched_shape)
552
553    return array_ops.expand_dims(trans, 0) + label_to_label
554
555
556def ctc_state_log_probs(seq_lengths, max_seq_length):
557  """Computes CTC alignment initial and final state log probabilities.
558
559  Create the initial/final state values directly as log values to avoid
560  having to take a float64 log on tpu (which does not exist).
561
562  Args:
563    seq_lengths: int tensor of shape [batch_size], seq lengths in the batch.
564    max_seq_length: int, max sequence length possible.
565
566  Returns:
567    initial_state_log_probs, final_state_log_probs
568  """
569
570  batch_size = _get_dim(seq_lengths, 0)
571  num_label_states = max_seq_length + 1
572  num_duration_states = 2
573  num_states = num_duration_states * num_label_states
574  log_0 = math_ops.cast(
575      math_ops.log(math_ops.cast(0, dtypes.float64) + 1e-307), dtypes.float32)
576
577  initial_state_log_probs = array_ops.one_hot(
578      indices=array_ops.zeros([batch_size], dtype=dtypes.int32),
579      depth=num_states,
580      on_value=0.0,
581      off_value=log_0,
582      axis=1)
583
584  label_final_state_mask = array_ops.one_hot(
585      seq_lengths, depth=num_label_states, axis=0)
586  duration_final_state_mask = array_ops.ones(
587      [num_duration_states, 1, batch_size])
588  final_state_mask = duration_final_state_mask * label_final_state_mask
589  final_state_log_probs = (1.0 - final_state_mask) * log_0
590  final_state_log_probs = array_ops.reshape(final_state_log_probs,
591                                            [num_states, batch_size])
592
593  return initial_state_log_probs, array_ops.transpose(final_state_log_probs)
594
595
596def _ilabel_to_state(labels, num_labels, ilabel_log_probs):
597  """Project ilabel log probs to state log probs."""
598
599  num_label_states = _get_dim(labels, 1)
600  blank = ilabel_log_probs[:, :, :1]
601  blank = array_ops.tile(blank, [1, 1, num_label_states + 1])
602  one_hot = array_ops.one_hot(labels, depth=num_labels)
603  one_hot = array_ops.expand_dims(one_hot, axis=0)
604  ilabel_log_probs = array_ops.expand_dims(ilabel_log_probs, axis=2)
605  state_log_probs = math_ops.reduce_sum(ilabel_log_probs * one_hot, axis=3)
606  state_log_probs = array_ops.concat([state_log_probs, blank], axis=2)
607  return array_ops.pad(
608      state_log_probs, [[0, 0], [0, 0], [1, 0]],
609      constant_values=math_ops.log(0.0))
610
611
612def _state_to_olabel(labels, num_labels, states):
613  """Sum state log probs to ilabel log probs."""
614
615  num_label_states = _get_dim(labels, 1) + 1
616  label_states = states[:, :, 1:num_label_states]
617  blank_states = states[:, :, num_label_states:]
618  one_hot = array_ops.one_hot(
619      labels - 1,
620      depth=(num_labels - 1),
621      on_value=0.0,
622      off_value=math_ops.log(0.0))
623  one_hot = array_ops.expand_dims(one_hot, axis=0)
624  label_states = array_ops.expand_dims(label_states, axis=3)
625  label_olabels = math_ops.reduce_logsumexp(label_states + one_hot, axis=2)
626  blank_olabels = math_ops.reduce_logsumexp(blank_states, axis=2, keepdims=True)
627  return array_ops.concat([blank_olabels, label_olabels], axis=-1)
628
629
630# pylint: disable=redefined-outer-name
631def _state_to_olabel_unique(labels, num_labels, states, unique):
632  """Sum state log probs to ilabel log probs using unique label indices."""
633
634  num_label_states = _get_dim(labels, 1) + 1
635  label_states = states[:, :, 1:num_label_states]
636  blank_states = states[:, :, num_label_states:]
637
638  unique_y, unique_idx = unique
639  mul_reduce = _sum_states(unique_idx, label_states)
640
641  num_frames = _get_dim(states, 0)
642  batch_size = _get_dim(states, 1)
643  num_states = num_label_states - 1
644  batch_state_major = array_ops.transpose(mul_reduce, perm=[1, 2, 0])
645  batch_state_major = array_ops.reshape(batch_state_major,
646                                        [batch_size * num_states, num_frames])
647  batch_offset = math_ops.range(batch_size, dtype=unique_y.dtype) * num_labels
648  indices = unique_y + array_ops.expand_dims(batch_offset, axis=-1)
649  indices = array_ops.reshape(indices, [-1, 1])
650  scatter = array_ops.scatter_nd(
651      indices=indices,
652      updates=batch_state_major,
653      shape=[batch_size * num_labels, num_frames])
654  scatter = array_ops.reshape(scatter, [batch_size, num_labels, num_frames])
655
656  mask = array_ops.ones_like(batch_state_major, dtype=dtypes.bool)
657  mask = array_ops.scatter_nd(
658      indices=indices,
659      updates=mask,
660      shape=[batch_size * num_labels, num_frames])
661  mask = array_ops.reshape(mask, [batch_size, num_labels, num_frames])
662
663  scatter = array_ops.where(
664      mask, scatter,
665      array_ops.fill(array_ops.shape(scatter), math_ops.log(0.0)))
666
667  label_olabels = array_ops.transpose(scatter, [2, 0, 1])
668  label_olabels = label_olabels[:, :, 1:]
669
670  blank_olabels = math_ops.reduce_logsumexp(blank_states, axis=2, keepdims=True)
671
672  return array_ops.concat([blank_olabels, label_olabels], axis=-1)
673
674
675def ctc_loss_and_grad(logits, labels, label_length, logit_length, unique=None):
676  """Computes the CTC loss and gradients.
677
678  Most users will want fwd_bwd.ctc_loss
679
680  This function returns the computed gradient, it does not have a gradient
681  of its own defined.
682
683  Args:
684    logits: tensor of shape [frames, batch_size, num_labels]
685    labels: tensor of shape [batch_size, max_label_seq_length]
686    label_length: tensor of shape [batch_size] Length of reference label
687      sequence in labels.
688    logit_length: tensor of shape [batch_size] Length of input sequence in
689      logits.
690    unique: (optional) unique label indices as computed by unique(labels) If
691      supplied, enables an implementation that is faster and more memory
692      efficient on TPU.
693
694  Returns:
695    loss: tensor of shape [batch_size]
696    gradient: tensor of shape [frames, batch_size, num_labels]
697  """
698
699  num_labels = _get_dim(logits, 2)
700  max_label_seq_length = _get_dim(labels, 1)
701
702  ilabel_log_probs = nn_ops.log_softmax(logits)
703  state_log_probs = _ilabel_to_state(labels, num_labels, ilabel_log_probs)
704  state_trans_probs = _ctc_state_trans(labels)
705  initial_state_log_probs, final_state_log_probs = ctc_state_log_probs(
706      label_length, max_label_seq_length)
707  fwd_bwd_log_probs, log_likelihood = _forward_backward_log(
708      state_trans_log_probs=math_ops.log(state_trans_probs),
709      initial_state_log_probs=initial_state_log_probs,
710      final_state_log_probs=final_state_log_probs,
711      observed_log_probs=state_log_probs,
712      sequence_length=logit_length)
713
714  if unique:
715    olabel_log_probs = _state_to_olabel_unique(labels, num_labels,
716                                               fwd_bwd_log_probs, unique)
717  else:
718    olabel_log_probs = _state_to_olabel(labels, num_labels, fwd_bwd_log_probs)
719
720  grad = math_ops.exp(ilabel_log_probs) - math_ops.exp(olabel_log_probs)
721
722  # Applies the sequence mask for the gradient. It is enough to appply the mask
723  # only for ilabel_log_probs because olabel_log_probs already consider the
724  # mask. However, it is just safe and clean to apply it for the gradient.
725  max_logit_length = _get_dim(logits, 0)
726  logit_mask = array_ops.sequence_mask(logit_length, max_logit_length,
727                                       dtypes.float32)
728  logit_mask = array_ops.transpose(logit_mask, perm=[1, 0])
729  logit_mask = array_ops.expand_dims(logit_mask, axis=2)
730  grad *= logit_mask
731
732  loss = -log_likelihood
733  return loss, grad
734
735
736def _ctc_loss_grad(op, grad_loss, _):
737  grad = op.outputs[1]
738  grad = [array_ops.reshape(grad_loss, [1, -1, 1]) * grad]
739  grad += [None] * (len(op.inputs) - len(grad))
740  return grad
741
742
743def _ctc_loss_op_standard(labels, logits, logit_length, logits_time_major,
744                          blank_index):
745  part_before = logits[:, :, :blank_index]
746  part_after = logits[:, :, blank_index + 1:]
747  part_blank = logits[:, :, blank_index:blank_index + 1]
748  logits = array_ops.concat([part_before, part_after, part_blank], axis=2)
749  labels = sparse_tensor.SparseTensor(
750      labels.indices,
751      array_ops.where(labels.values < blank_index, labels.values,
752                      labels.values - 1), labels.dense_shape)
753  return _ctc_loss_impl(
754      labels=labels,
755      inputs=logits,
756      sequence_length=logit_length,
757      time_major=logits_time_major,
758      use_cudnn=False)
759
760
761def _ctc_loss_op_cudnn(labels, logits, logit_length, logits_time_major,
762                       blank_index):
763  part_before = logits[:, :, :blank_index]
764  part_after = logits[:, :, blank_index + 1:]
765  part_blank = logits[:, :, blank_index:blank_index + 1]
766  logits = array_ops.concat([part_blank, part_before, part_after], axis=2)
767  labels = sparse_tensor.SparseTensor(
768      labels.indices,
769      array_ops.where(labels.values < blank_index, labels.values + 1,
770                      labels.values), labels.dense_shape)
771  return _ctc_loss_impl(
772      labels=labels,
773      inputs=logits,
774      sequence_length=logit_length,
775      time_major=logits_time_major,
776      use_cudnn=True)
777
778
779def _ctc_loss_shape(op):
780  return [op.inputs[2].get_shape(), op.inputs[0].get_shape()]
781
782
783# pylint: disable=protected-access, invalid-name
784@tf_export(v1=["nn.ctc_loss_v2"])
785@dispatch.add_dispatch_support
786def ctc_loss_v2(labels,
787                logits,
788                label_length,
789                logit_length,
790                logits_time_major=True,
791                unique=None,
792                blank_index=None,
793                name=None):
794  """Computes CTC (Connectionist Temporal Classification) loss.
795
796  This op implements the CTC loss as presented in (Graves et al., 2006).
797
798  Notes:
799
800  - Same as the "Classic CTC" in TensorFlow 1.x's tf.compat.v1.nn.ctc_loss
801    setting of preprocess_collapse_repeated=False, ctc_merge_repeated=True
802  - Labels may be supplied as either a dense, zero-padded tensor with a
803    vector of label sequence lengths OR as a SparseTensor.
804  - On TPU and GPU: Only dense padded labels are supported.
805  - On CPU: Caller may use SparseTensor or dense padded labels but calling with
806    a SparseTensor will be significantly faster.
807  - Default blank label is 0 rather num_classes - 1, unless overridden by
808    blank_index.
809
810  Args:
811    labels: tensor of shape [batch_size, max_label_seq_length] or SparseTensor
812    logits: tensor of shape [frames, batch_size, num_labels], if
813      logits_time_major == False, shape is [batch_size, frames, num_labels].
814    label_length: tensor of shape [batch_size], None if labels is SparseTensor
815      Length of reference label sequence in labels.
816    logit_length: tensor of shape [batch_size] Length of input sequence in
817      logits.
818    logits_time_major: (optional) If True (default), logits is shaped [time,
819      batch, logits]. If False, shape is [batch, time, logits]
820    unique: (optional) Unique label indices as computed by
821      ctc_unique_labels(labels).  If supplied, enable a faster, memory efficient
822      implementation on TPU.
823    blank_index: (optional) Set the class index to use for the blank label.
824      Negative values will start from num_classes, ie, -1 will reproduce the
825      ctc_loss behavior of using num_classes - 1 for the blank symbol. There is
826      some memory/performance overhead to switching from the default of 0 as an
827      additional shifted copy of the logits may be created.
828    name: A name for this `Op`. Defaults to "ctc_loss_dense".
829
830  Returns:
831    loss: tensor of shape [batch_size], negative log probabilities.
832
833  References:
834      Connectionist Temporal Classification - Labeling Unsegmented Sequence Data
835      with Recurrent Neural Networks:
836        [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891)
837        ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
838  """
839  if isinstance(labels, sparse_tensor.SparseTensor):
840    if blank_index is None:
841      raise ValueError(
842          "Argument `blank_index` must be provided when labels is a "
843          "SparseTensor.")
844
845    if blank_index < 0:
846      blank_index += _get_dim(logits, 2)
847
848    if blank_index != _get_dim(logits, 2) - 1:
849      logits = array_ops.concat([
850          logits[:, :, :blank_index],
851          logits[:, :, blank_index + 1:],
852          logits[:, :, blank_index:blank_index + 1],
853      ],
854                                axis=2)
855      labels = sparse_tensor.SparseTensor(
856          labels.indices,
857          array_ops.where(labels.values < blank_index, labels.values,
858                          labels.values - 1), labels.dense_shape)
859
860    return ctc_loss(
861        labels=labels,
862        inputs=logits,
863        sequence_length=logit_length,
864        time_major=logits_time_major)
865
866  if blank_index is None:
867    blank_index = 0
868
869  return ctc_loss_dense(
870      labels=labels,
871      logits=logits,
872      label_length=label_length,
873      logit_length=logit_length,
874      logits_time_major=logits_time_major,
875      unique=unique,
876      blank_index=blank_index,
877      name=name)
878
879
880@tf_export("nn.ctc_loss", v1=[])
881@dispatch.add_dispatch_support
882def ctc_loss_v3(labels,
883                logits,
884                label_length,
885                logit_length,
886                logits_time_major=True,
887                unique=None,
888                blank_index=None,
889                name=None):
890  """Computes CTC (Connectionist Temporal Classification) loss.
891
892  This op implements the CTC loss as presented in (Graves et al., 2006).
893
894  Notes:
895
896  - Same as the "Classic CTC" in TensorFlow 1.x's tf.compat.v1.nn.ctc_loss
897    setting of preprocess_collapse_repeated=False, ctc_merge_repeated=True
898  - Labels may be supplied as either a dense, zero-padded tensor with a
899    vector of label sequence lengths OR as a SparseTensor.
900  - On TPU and GPU: Only dense padded labels are supported.
901  - On CPU: Caller may use SparseTensor or dense padded labels but calling with
902    a SparseTensor will be significantly faster.
903  - Default blank label is 0 rather num_classes - 1, unless overridden by
904    blank_index.
905
906  Args:
907    labels: tensor of shape [batch_size, max_label_seq_length] or SparseTensor
908    logits: tensor of shape [frames, batch_size, num_labels], if
909      logits_time_major == False, shape is [batch_size, frames, num_labels].
910    label_length: tensor of shape [batch_size], None if labels is SparseTensor
911      Length of reference label sequence in labels.
912    logit_length: tensor of shape [batch_size] Length of input sequence in
913      logits.
914    logits_time_major: (optional) If True (default), logits is shaped [time,
915      batch, logits]. If False, shape is [batch, time, logits]
916    unique: (optional) Unique label indices as computed by
917      ctc_unique_labels(labels).  If supplied, enable a faster, memory efficient
918      implementation on TPU.
919    blank_index: (optional) Set the class index to use for the blank label.
920      Negative values will start from num_classes, ie, -1 will reproduce the
921      ctc_loss behavior of using num_classes - 1 for the blank symbol. There is
922      some memory/performance overhead to switching from the default of 0 as an
923      additional shifted copy of the logits may be created.
924    name: A name for this `Op`. Defaults to "ctc_loss_dense".
925
926  Returns:
927    loss: tensor of shape [batch_size], negative log probabilities.
928
929  References:
930      Connectionist Temporal Classification - Labeling Unsegmented Sequence Data
931      with Recurrent Neural Networks:
932        [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891)
933        ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
934  """
935  if isinstance(labels, sparse_tensor.SparseTensor):
936    if blank_index is None:
937      raise ValueError(
938          "Argument `blank_index` must be provided when labels is a "
939          "SparseTensor.")
940
941    if blank_index < 0:
942      blank_index += _get_dim(logits, 2)
943
944    logits = ops.convert_to_tensor(logits, name="logits")
945
946    params = {
947        "labels": labels,
948        "logits": logits,
949        "logit_length": logit_length,
950        "logits_time_major": logits_time_major,
951        "blank_index": blank_index
952    }
953
954    if context.executing_eagerly():
955      device_type = _get_context_device_type()
956      can_use_gpu = (
957          # Either user specified GPU or unspecified but GPU is available.
958          (device_type == _GPU_DEVICE_NAME or
959           (device_type is None and context.num_gpus() > 0)))
960      # Under eager context, check the device placement and prefer the
961      if can_use_gpu:
962        res = _ctc_loss_op_cudnn(**params)
963      else:
964        res = _ctc_loss_op_standard(**params)
965    else:
966      api_name = "ctc_loss_" + str(uuid.uuid4())
967      ctc_loss_op_standard = _generate_defun_backend(api_name, _CPU_DEVICE_NAME,
968                                                     _ctc_loss_op_standard)
969      ctc_loss_op_cudnn = _generate_defun_backend(api_name, _GPU_DEVICE_NAME,
970                                                  _ctc_loss_op_cudnn)
971      res = ctc_loss_op_standard(**params)
972      function_eager.register(ctc_loss_op_cudnn, **params)
973    return res
974
975  if blank_index is None:
976    blank_index = 0
977
978  return ctc_loss_dense(
979      labels=labels,
980      logits=logits,
981      label_length=label_length,
982      logit_length=logit_length,
983      logits_time_major=logits_time_major,
984      unique=unique,
985      blank_index=blank_index,
986      name=name)
987
988
989def ctc_loss_dense(labels,
990                   logits,
991                   label_length,
992                   logit_length,
993                   logits_time_major=True,
994                   unique=None,
995                   blank_index=0,
996                   name=None):
997  """Computes CTC (Connectionist Temporal Classification) loss.
998
999  This op implements the CTC loss as presented in (Graves et al., 2006),
1000  using the batched forward backward algorithm described in (Sim et al., 2017).
1001
1002  Notes:
1003    Significant differences from tf.compat.v1.nn.ctc_loss:
1004      Supports GPU and TPU (tf.compat.v1.nn.ctc_loss supports CPU only):
1005        For batched operations, GPU and TPU are significantly faster than using
1006        ctc_loss on CPU.
1007        This implementation runs on CPU, but significantly slower than ctc_loss.
1008      Blank label is 0 rather num_classes - 1, unless overridden by blank_index.
1009      Logits and labels are dense arrays with padding rather than SparseTensor.
1010      The only mode supported is the same as:
1011        preprocess_collapse_repeated=False, ctc_merge_repeated=True
1012        To collapse labels, the caller can preprocess label sequence first.
1013
1014    The dense implementation supports both CPU, GPU and TPU. A fast path is
1015    provided that significantly improves memory use for large vocabulary if the
1016    caller preprocesses label sequences to get unique label indices on the CPU
1017    (eg. in the data input pipeline) using ctc_ops.unique and simplifies this in
1018    the optional "unique" kwarg. This is especially useful for TPU and GPU but
1019    also works with if used on CPU.
1020
1021  Args:
1022    labels: tensor of shape [batch_size, max_label_seq_length]
1023    logits: tensor of shape [frames, batch_size, num_labels], if
1024      logits_time_major == False, shape is [batch_size, frames, num_labels].
1025    label_length: tensor of shape [batch_size] Length of reference label
1026      sequence in labels.
1027    logit_length: tensor of shape [batch_size] Length of input sequence in
1028      logits.
1029    logits_time_major: (optional) If True (default), logits is shaped [time,
1030      batch, logits]. If False, shape is [batch, time, logits]
1031    unique: (optional) Unique label indices as computed by unique(labels). If
1032      supplied, enable a faster, memory efficient implementation on TPU.
1033    blank_index: (optional) Set the class index to use for the blank label.
1034      Negative values will start from num_classes, ie, -1 will reproduce the
1035      ctc_loss behavior of using num_classes - 1 for the blank symbol. There is
1036      some memory/performance overhead to switching from the default of 0 as an
1037      additional shifted copy of the logits may be created.
1038    name: A name for this `Op`. Defaults to "ctc_loss_dense".
1039
1040  Returns:
1041    loss: tensor of shape [batch_size], negative log probabilities.
1042
1043  References:
1044      Connectionist Temporal Classification - Labeling Unsegmented Sequence Data
1045      with Recurrent Neural Networks:
1046        [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891)
1047        ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
1048      Improving the efficiency of forward-backward algorithm using batched
1049      computation in TensorFlow:
1050        [Sim et al., 2017](https://ieeexplore.ieee.org/document/8268944)
1051        ([pdf](http://bacchiani.net/resume/papers/ASRU2017.pdf))
1052  """
1053
1054  with ops.name_scope(name, "ctc_loss_dense",
1055                      [logits, labels, label_length, logit_length]):
1056    logits = ops.convert_to_tensor(logits, name="logits")
1057    labels = ops.convert_to_tensor(labels, name="labels")
1058    label_length = ops.convert_to_tensor(label_length, name="label_length")
1059    logit_length = ops.convert_to_tensor(logit_length, name="logit_length")
1060
1061    orig_dtype = logits.dtype
1062    if orig_dtype in (dtypes.float16, dtypes.bfloat16):
1063      logits = math_ops.cast(logits, dtypes.float32)
1064
1065    if not logits_time_major:
1066      logits = array_ops.transpose(logits, perm=[1, 0, 2])
1067
1068    if blank_index != 0:
1069      if blank_index < 0:
1070        blank_index += _get_dim(logits, 2)
1071      logits = array_ops.concat([
1072          logits[:, :, blank_index:blank_index + 1],
1073          logits[:, :, :blank_index],
1074          logits[:, :, blank_index + 1:],
1075      ],
1076                                axis=2)
1077      labels = array_ops.where(labels < blank_index, labels + 1, labels)
1078
1079    args = [logits, labels, label_length, logit_length]
1080
1081    if unique:
1082      unique_y, unique_idx = unique
1083      if blank_index != 0:
1084        unique_y = array_ops.where(unique_y < blank_index, unique_y + 1,
1085                                   unique_y)
1086        label_mask_len = math_ops.reduce_max(unique_idx, axis=1) + 1
1087        max_label_length = _get_dim(unique_y, 1)
1088        label_mask = array_ops.sequence_mask(label_mask_len, max_label_length)
1089        unique_y = array_ops.where(label_mask, unique_y,
1090                                   array_ops.zeros_like(unique_y))
1091      args.extend([unique_y, unique_idx])
1092
1093    @custom_gradient.custom_gradient
1094    def compute_ctc_loss(logits_t, labels_t, label_length_t, logit_length_t,
1095                         *unique_t):
1096      """Compute CTC loss."""
1097      logits_t.set_shape(logits.shape)
1098      labels_t.set_shape(labels.shape)
1099      label_length_t.set_shape(label_length.shape)
1100      logit_length_t.set_shape(logit_length.shape)
1101      kwargs = dict(
1102          logits=logits_t,
1103          labels=labels_t,
1104          label_length=label_length_t,
1105          logit_length=logit_length_t)
1106      if unique_t:
1107        kwargs["unique"] = unique_t
1108      result = ctc_loss_and_grad(**kwargs)
1109      def grad(grad_loss):
1110        grad = [array_ops.reshape(grad_loss, [1, -1, 1]) * result[1]]
1111        grad += [None] * (len(args) - len(grad))
1112        return grad
1113
1114      return result[0], grad
1115
1116    loss = compute_ctc_loss(*args)
1117    if orig_dtype in (dtypes.float16, dtypes.bfloat16):
1118      loss = math_ops.cast(loss, orig_dtype)
1119    return loss
1120
1121
1122@tf_export("nn.collapse_repeated")
1123@dispatch.add_dispatch_support
1124def collapse_repeated(labels, seq_length, name=None):
1125  """Merge repeated labels into single labels.
1126
1127  Args:
1128    labels: Tensor of shape [batch, max value in seq_length]
1129    seq_length: Tensor of shape [batch], sequence length of each batch element.
1130    name: A name for this `Op`. Defaults to "collapse_repeated_labels".
1131
1132  Returns:
1133    A tuple `(collapsed_labels, new_seq_length)` where
1134
1135    collapsed_labels: Tensor of shape [batch, max_seq_length] with repeated
1136    labels collapsed and padded to max_seq_length, eg:
1137    `[[A, A, B, B, A], [A, B, C, D, E]] => [[A, B, A, 0, 0], [A, B, C, D, E]]`
1138
1139    new_seq_length: int tensor of shape [batch] with new sequence lengths.
1140  """
1141
1142  with ops.name_scope(name, "collapse_repeated_labels", [labels, seq_length]):
1143    labels = ops.convert_to_tensor(labels, name="labels")
1144    seq_length = ops.convert_to_tensor(seq_length, name="seq_length")
1145
1146    # Mask labels that don't equal previous label.
1147    label_mask = array_ops.concat([
1148        array_ops.ones_like(labels[:, :1], dtypes.bool),
1149        math_ops.not_equal(labels[:, 1:], labels[:, :-1])
1150    ],
1151                                  axis=1)
1152
1153    # Filter labels that aren't in the original sequence.
1154    maxlen = _get_dim(labels, 1)
1155    seq_mask = array_ops.sequence_mask(seq_length, maxlen=maxlen)
1156    label_mask = math_ops.logical_and(label_mask, seq_mask)
1157
1158    # Count masks for new sequence lengths.
1159    new_seq_len = math_ops.reduce_sum(
1160        math_ops.cast(label_mask, dtypes.int32), axis=1)
1161
1162    # Mask indexes based on sequence length mask.
1163    new_maxlen = math_ops.reduce_max(new_seq_len)
1164    idx_mask = array_ops.sequence_mask(new_seq_len, maxlen=new_maxlen)
1165
1166    # Flatten everything and mask out labels to keep and sparse indices.
1167    flat_labels = array_ops.reshape(labels, [-1])
1168    flat_label_mask = array_ops.reshape(label_mask, [-1])
1169    flat_idx_mask = array_ops.reshape(idx_mask, [-1])
1170    idx = math_ops.range(_get_dim(flat_idx_mask, 0))
1171
1172    # Scatter to flat shape.
1173    flat = array_ops.scatter_nd(
1174        indices=array_ops.expand_dims(
1175            array_ops.boolean_mask(idx, flat_idx_mask), axis=1),
1176        updates=array_ops.boolean_mask(flat_labels, flat_label_mask),
1177        shape=array_ops.shape(flat_idx_mask))
1178
1179    # Reshape back to square batch.
1180    batch_size = _get_dim(labels, 0)
1181    new_shape = [batch_size, new_maxlen]
1182    return (array_ops.reshape(flat, new_shape),
1183            math_ops.cast(new_seq_len, seq_length.dtype))
1184
1185
1186def dense_labels_to_sparse(dense, length):
1187  """Convert dense labels with sequence lengths to sparse tensor.
1188
1189  Args:
1190    dense: tensor of shape [batch, max_length]
1191    length: int tensor of shape [batch] The length of each sequence in dense.
1192
1193  Returns:
1194    tf.sparse.SparseTensor with values only for the valid elements of sequences.
1195  """
1196
1197  flat_values = array_ops.reshape(dense, [-1])
1198  flat_indices = math_ops.range(
1199      array_ops.shape(flat_values, out_type=dtypes.int64)[0])
1200  mask = array_ops.sequence_mask(length, maxlen=array_ops.shape(dense)[1])
1201  flat_mask = array_ops.reshape(mask, [-1])
1202  indices = array_ops.expand_dims(
1203      array_ops.boolean_mask(flat_indices, flat_mask), 1)
1204  values = array_ops.boolean_mask(flat_values, flat_mask)
1205  sparse = sparse_tensor.SparseTensor(
1206      indices=indices,
1207      values=math_ops.cast(values, dtypes.int32),
1208      dense_shape=array_ops.shape(flat_values, out_type=dtypes.int64))
1209  reshaped = sparse_ops.sparse_reshape(sparse, array_ops.shape(dense))
1210  max_length = math_ops.reduce_max(length)
1211  return sparse_tensor.SparseTensor(
1212      indices=reshaped.indices,
1213      values=reshaped.values,
1214      dense_shape=[
1215          math_ops.cast(reshaped.dense_shape[0], dtypes.int64),
1216          math_ops.cast(max_length, dtypes.int64)
1217      ])
1218
1219
1220@tf_export("nn.ctc_unique_labels")
1221@dispatch.add_dispatch_support
1222def ctc_unique_labels(labels, name=None):
1223  """Get unique labels and indices for batched labels for `tf.nn.ctc_loss`.
1224
1225  For use with `tf.nn.ctc_loss` optional argument `unique`: This op can be
1226  used to preprocess labels in input pipeline to for better speed/memory use
1227  computing the ctc loss on TPU.
1228
1229  Example:
1230    ctc_unique_labels([[3, 4, 4, 3]]) ->
1231      unique labels padded with 0: [[3, 4, 0, 0]]
1232      indices of original labels in unique: [0, 1, 1, 0]
1233
1234  Args:
1235    labels: tensor of shape [batch_size, max_label_length] padded with 0.
1236    name: A name for this `Op`. Defaults to "ctc_unique_labels".
1237
1238  Returns:
1239    tuple of
1240      - unique labels, tensor of shape `[batch_size, max_label_length]`
1241      - indices into unique labels, shape `[batch_size, max_label_length]`
1242  """
1243
1244  with ops.name_scope(name, "ctc_unique_labels", [labels]):
1245    labels = ops.convert_to_tensor(labels, name="labels")
1246
1247    def _unique(x):
1248      u = array_ops.unique(x)
1249      y = array_ops.pad(u.y, [[0, _get_dim(u.idx, 0) - _get_dim(u.y, 0)]])
1250      y = math_ops.cast(y, dtypes.int64)
1251      return [y, u.idx]
1252
1253    return map_fn.map_fn(_unique, labels, dtype=[dtypes.int64, dtypes.int32])
1254
1255
1256def _sum_states(idx, states):
1257  """Take logsumexp for each unique state out of all label states.
1258
1259  Args:
1260    idx: tensor of shape [batch, label_length] For each sequence, indices into a
1261      set of unique labels as computed by calling unique.
1262    states: tensor of shape [frames, batch, label_length] Log probabilities for
1263      each label state.
1264
1265  Returns:
1266    tensor of shape [frames, batch_size, label_length], log probabilities summed
1267      for each unique label of the sequence.
1268  """
1269
1270  with ops.name_scope("sum_states"):
1271    idx = ops.convert_to_tensor(idx, name="idx")
1272    num_states = _get_dim(states, 2)
1273    states = array_ops.expand_dims(states, axis=2)
1274    one_hot = array_ops.one_hot(
1275        idx,
1276        depth=num_states,
1277        on_value=0.0,
1278        off_value=math_ops.log(0.0),
1279        axis=1)
1280    return math_ops.reduce_logsumexp(states + one_hot, axis=-1)
1281
1282
1283def _forward_backward_log(state_trans_log_probs, initial_state_log_probs,
1284                          final_state_log_probs, observed_log_probs,
1285                          sequence_length):
1286  """Forward-backward algorithm computed in log domain.
1287
1288  Args:
1289    state_trans_log_probs: tensor of shape [states, states] or if different
1290      transition matrix per batch [batch_size, states, states]
1291    initial_state_log_probs: tensor of shape [batch_size, states]
1292    final_state_log_probs: tensor of shape [batch_size, states]
1293    observed_log_probs: tensor of shape [frames, batch_size, states]
1294    sequence_length: tensor of shape [batch_size]
1295
1296  Returns:
1297    forward backward log probabilities: tensor of shape [frames, batch, states]
1298    log_likelihood: tensor of shape [batch_size]
1299
1300  Raises:
1301    ValueError: If state_trans_log_probs has unknown or incorrect rank.
1302  """
1303
1304  if state_trans_log_probs.shape.ndims == 2:
1305    perm = [1, 0]
1306  elif state_trans_log_probs.shape.ndims == 3:
1307    perm = [0, 2, 1]
1308  else:
1309    raise ValueError(
1310        "Rank of argument `state_trans_log_probs` must be known and equal to "
1311        f"2 or 3. Received state_trans_log_probs={state_trans_log_probs} of "
1312        f"rank {state_trans_log_probs.shape.ndims}")
1313
1314  bwd_state_trans_log_probs = array_ops.transpose(state_trans_log_probs, perm)
1315  batch_size = _get_dim(observed_log_probs, 1)
1316
1317  def _forward(state_log_prob, obs_log_prob):
1318    state_log_prob = array_ops.expand_dims(state_log_prob, axis=1)  # Broadcast.
1319    state_log_prob += state_trans_log_probs
1320    state_log_prob = math_ops.reduce_logsumexp(state_log_prob, axis=-1)
1321    state_log_prob += obs_log_prob
1322    log_prob_sum = math_ops.reduce_logsumexp(
1323        state_log_prob, axis=-1, keepdims=True)
1324    state_log_prob -= log_prob_sum
1325    return state_log_prob
1326
1327  fwd = _scan(
1328      _forward, observed_log_probs, initial_state_log_probs, inclusive=True)
1329
1330  def _backward(accs, elems):
1331    """Calculate log probs and cumulative sum masked for sequence length."""
1332    state_log_prob, cum_log_sum = accs
1333    obs_log_prob, mask = elems
1334    state_log_prob += obs_log_prob
1335    state_log_prob = array_ops.expand_dims(state_log_prob, axis=1)  # Broadcast.
1336    state_log_prob += bwd_state_trans_log_probs
1337    state_log_prob = math_ops.reduce_logsumexp(state_log_prob, axis=-1)
1338
1339    log_prob_sum = math_ops.reduce_logsumexp(
1340        state_log_prob, axis=-1, keepdims=True)
1341    state_log_prob -= log_prob_sum
1342
1343    cum_log_sum += array_ops.squeeze(log_prob_sum, axis=[-1]) * mask
1344    batched_mask = array_ops.expand_dims(mask, axis=1)
1345    out = state_log_prob * batched_mask
1346    out += final_state_log_probs * (1.0 - batched_mask)
1347    return out, cum_log_sum
1348
1349  zero_log_sum = array_ops.zeros([batch_size])
1350  maxlen = _get_dim(observed_log_probs, 0)
1351  mask = array_ops.sequence_mask(sequence_length, maxlen, dtypes.float32)
1352  mask = array_ops.transpose(mask, perm=[1, 0])
1353
1354  bwd, cum_log_sum = _scan(
1355      _backward, (observed_log_probs, mask),
1356      (final_state_log_probs, zero_log_sum),
1357      reverse=True,
1358      inclusive=True)
1359
1360  fwd_bwd_log_probs = fwd[1:] + bwd[1:]
1361  fwd_bwd_log_probs_sum = math_ops.reduce_logsumexp(
1362      fwd_bwd_log_probs, axis=2, keepdims=True)
1363  fwd_bwd_log_probs -= fwd_bwd_log_probs_sum
1364  fwd_bwd_log_probs += math_ops.log(array_ops.expand_dims(mask, axis=2))
1365
1366  log_likelihood = bwd[0, :, 0] + cum_log_sum[0]
1367
1368  return fwd_bwd_log_probs, log_likelihood
1369
1370
1371# TODO(tombagby): This is currently faster for the ctc implementation than using
1372# functional_ops.scan, but could be replaced by that or something similar if
1373# things change.
1374def _scan(fn, elems, initial, reverse=False, inclusive=False, final_only=False):
1375  """Repeatedly applies callable `fn` to a sequence of elements.
1376
1377  Implemented by functional_ops.While, tpu friendly, no gradient.
1378
1379  This is similar to functional_ops.scan but significantly faster on tpu/gpu
1380  for the forward backward use case.
1381
1382  Examples:
1383    scan(lambda a, e: a + e, [1.0, 2.0, 3.0], 1.0) => [2.0, 4.0, 7.0]
1384
1385    Multiple accumulators:
1386      scan(lambda a, e: (a[0] + e, a[1] * e), [1.0, 2.0, 3.0], (0.0, 1.0))
1387
1388    Multiple inputs:
1389      scan(lambda a, e: a + (e[0] * e[1]), (elems1, elems2), 0.0)
1390
1391  Args:
1392    fn: callable, fn(accumulators, element) return new accumulator values. The
1393      (possibly nested) sequence of accumulators is the same as `initial` and
1394      the return value must have the same structure.
1395    elems: A (possibly nested) tensor which will be unpacked along the first
1396      dimension. The resulting slices will be the second argument to fn. The
1397      first dimension of all nested input tensors must be the same.
1398    initial: A tensor or (possibly nested) sequence of tensors with initial
1399      values for the accumulators.
1400    reverse: (optional) True enables scan and output elems in reverse order.
1401    inclusive: (optional) True includes the initial accumulator values in the
1402      output. Length of output will be len(elem sequence) + 1. Not meaningful if
1403      final_only is True.
1404    final_only: (optional) When True, return only the final accumulated values,
1405      not the concatenation of accumulated values for each input.
1406
1407  Returns:
1408    A (possibly nested) sequence of tensors with the results of applying fn
1409    to tensors unpacked from elems and previous accumulator values.
1410  """
1411
1412  flat_elems = [ops.convert_to_tensor(x) for x in nest.flatten(elems)]
1413  num_elems = array_ops.shape(flat_elems[0])[0]
1414  pack_elems = lambda x: nest.pack_sequence_as(structure=elems, flat_sequence=x)
1415  flat_initial = [ops.convert_to_tensor(x) for x in nest.flatten(initial)]
1416  pack = lambda x: nest.pack_sequence_as(structure=initial, flat_sequence=x)
1417  accum_dtypes = [x.dtype for x in flat_initial]
1418  num_accums = len(flat_initial)
1419
1420  # Types for counter, [outputs], [accumulators] loop arguments.
1421  if final_only:
1422    loop_dtypes = [dtypes.int32, dtypes.int32] + accum_dtypes
1423  else:
1424    loop_dtypes = [dtypes.int32, dtypes.int32] + accum_dtypes + accum_dtypes
1425
1426  # TODO(tombagby): Update to tfe.defun
1427  def cond(i, num_elems, *args):
1428    del args
1429    return i >= 0 if reverse else i < num_elems
1430
1431  # The loop *args are [output tensors] + [accumulator tensors] which must
1432  # be paired. Each output corresponds to one accumulator.
1433  def body(i, num_elems, *args):
1434    """Loop body."""
1435    i.set_shape([])
1436    if final_only:
1437      accum = args
1438    else:
1439      out, accum = args[:num_accums], args[num_accums:]
1440    slices = [array_ops.gather(e, i) for e in flat_elems]
1441    accum = fn(pack(accum), pack_elems(slices))
1442    flat_accum = nest.flatten(accum)
1443    if final_only:
1444      new_out = []
1445    else:
1446      update_i = i + 1 if inclusive and not reverse else i
1447      new_out = [
1448          inplace_ops.alias_inplace_update(x, update_i, y)
1449          for x, y in zip(out, flat_accum)
1450      ]
1451    i = i - 1 if reverse else i + 1
1452    return [i, num_elems] + new_out + flat_accum
1453
1454  init_i = (
1455      array_ops.shape(flat_elems[0])[0] -
1456      1 if reverse else constant_op.constant(0, dtype=dtypes.int32))
1457  outputs = []
1458  if not final_only:
1459    num_outputs = array_ops.shape(flat_elems[0])[0] + (1 if inclusive else 0)
1460    for initial_accum in flat_initial:
1461      out_shape = array_ops.concat(
1462          [[num_outputs], array_ops.shape(initial_accum)], 0)
1463      out = inplace_ops.empty(out_shape, dtype=initial_accum.dtype, init=True)
1464      if inclusive:
1465        out = inplace_ops.alias_inplace_add(out, init_i + (1 if reverse else 0),
1466                                            initial_accum)
1467      outputs.append(out)
1468  loop_in = [init_i, num_elems] + outputs + flat_initial
1469  hostmem = [
1470      i for i, x in enumerate(loop_in)
1471      if x.dtype.base_dtype in (dtypes.int32, dtypes.int64)
1472  ]
1473
1474  if context.executing_eagerly():
1475    loop_results = loop_in
1476    while cond(*loop_results):
1477      loop_results = body(*loop_results)
1478  else:
1479    # TODO(tombagby): Update to while_v2.
1480    cond = function.Defun(*loop_dtypes)(cond)
1481    body = function.Defun(*loop_dtypes)(body)
1482    loop_results = functional_ops.While(loop_in, cond, body, hostmem=hostmem)
1483  out = loop_results[2:num_accums + 2]
1484  return pack(out)
1485
1486
1487def _get_dim(tensor, i):
1488  """Get value of tensor shape[i] preferring static value if available."""
1489  return tensor_shape.dimension_value(
1490      tensor.shape[i]) or array_ops.shape(tensor)[i]
1491