1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""CTC (Connectionist Temporal Classification) Operations.""" 16 17import uuid 18 19from tensorflow.python.eager import context 20from tensorflow.python.eager import function as function_eager 21 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import device 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import function 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import sparse_tensor 28from tensorflow.python.framework import tensor_shape 29 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import custom_gradient 32from tensorflow.python.ops import functional_ops 33from tensorflow.python.ops import gen_ctc_ops 34from tensorflow.python.ops import inplace_ops 35from tensorflow.python.ops import linalg_ops 36from tensorflow.python.ops import map_fn 37from tensorflow.python.ops import math_ops 38from tensorflow.python.ops import nn_ops 39from tensorflow.python.ops import sparse_ops 40from tensorflow.python.ops.nn_grad import _BroadcastMul 41from tensorflow.python.util import deprecation 42from tensorflow.python.util import dispatch 43from tensorflow.python.util import nest 44from tensorflow.python.util.tf_export import tf_export 45 46_DEFUN_API_NAME_ATTRIBUTE = "api_implements" 47_DEFUN_DEVICE_ATTRIBUTE = "api_preferred_device" 48_CPU_DEVICE_NAME = "CPU" 49_GPU_DEVICE_NAME = "GPU" 50 51 52def _get_context_device_type(): 53 """Parse the current context and return the device type, eg CPU/GPU.""" 54 current_device = context.context().device_name 55 if current_device is None: 56 return None 57 return device.DeviceSpec.from_string(current_device).device_type 58 59 60def _generate_defun_backend(unique_api_name, preferred_device, func): 61 function_attributes = { 62 _DEFUN_API_NAME_ATTRIBUTE: unique_api_name, 63 _DEFUN_DEVICE_ATTRIBUTE: preferred_device, 64 } 65 return function_eager.defun_with_attributes( 66 func=func, attributes=function_attributes, autograph=False) 67 68# pylint: disable=protected-access, invalid-name 69@tf_export(v1=["nn.ctc_loss"]) 70@dispatch.add_dispatch_support 71def ctc_loss(labels, 72 inputs=None, 73 sequence_length=None, 74 preprocess_collapse_repeated=False, 75 ctc_merge_repeated=True, 76 ignore_longer_outputs_than_inputs=False, 77 time_major=True, 78 logits=None): 79 """Computes the CTC (Connectionist Temporal Classification) Loss. 80 81 This op implements the CTC loss as presented in (Graves et al., 2006). 82 83 Input requirements: 84 85 ``` 86 sequence_length(b) <= time for all b 87 88 max(labels.indices(labels.indices[:, 1] == b, 2)) 89 <= sequence_length(b) for all b. 90 ``` 91 92 Notes: 93 94 This class performs the softmax operation for you, so inputs should 95 be e.g. linear projections of outputs by an LSTM. 96 97 The `inputs` Tensor's innermost dimension size, `num_classes`, represents 98 `num_labels + 1` classes, where num_labels is the number of true labels, and 99 the largest value `(num_classes - 1)` is reserved for the blank label. 100 101 For example, for a vocabulary containing 3 labels `[a, b, c]`, 102 `num_classes = 4` and the labels indexing is `{a: 0, b: 1, c: 2, blank: 3}`. 103 104 Regarding the arguments `preprocess_collapse_repeated` and 105 `ctc_merge_repeated`: 106 107 If `preprocess_collapse_repeated` is True, then a preprocessing step runs 108 before loss calculation, wherein repeated labels passed to the loss 109 are merged into single labels. This is useful if the training labels come 110 from, e.g., forced alignments and therefore have unnecessary repetitions. 111 112 If `ctc_merge_repeated` is set False, then deep within the CTC calculation, 113 repeated non-blank labels will not be merged and are interpreted 114 as individual labels. This is a simplified (non-standard) version of CTC. 115 116 Here is a table of the (roughly) expected first order behavior: 117 118 * `preprocess_collapse_repeated=False`, `ctc_merge_repeated=True` 119 120 Classical CTC behavior: Outputs true repeated classes with blanks in 121 between, and can also output repeated classes with no blanks in 122 between that need to be collapsed by the decoder. 123 124 * `preprocess_collapse_repeated=True`, `ctc_merge_repeated=False` 125 126 Never learns to output repeated classes, as they are collapsed 127 in the input labels before training. 128 129 * `preprocess_collapse_repeated=False`, `ctc_merge_repeated=False` 130 131 Outputs repeated classes with blanks in between, but generally does not 132 require the decoder to collapse/merge repeated classes. 133 134 * `preprocess_collapse_repeated=True`, `ctc_merge_repeated=True` 135 136 Untested. Very likely will not learn to output repeated classes. 137 138 The `ignore_longer_outputs_than_inputs` option allows to specify the behavior 139 of the CTCLoss when dealing with sequences that have longer outputs than 140 inputs. If true, the CTCLoss will simply return zero gradient for those 141 items, otherwise an InvalidArgument error is returned, stopping training. 142 143 Args: 144 labels: An `int32` `SparseTensor`. 145 `labels.indices[i, :] == [b, t]` means `labels.values[i]` stores the id 146 for (batch b, time t). `labels.values[i]` must take on values in `[0, 147 num_labels)`. See `core/ops/ctc_ops.cc` for more details. 148 inputs: 3-D `float` `Tensor`. 149 If time_major == False, this will be a `Tensor` shaped: `[batch_size, 150 max_time, num_classes]`. 151 If time_major == True (default), this will be a `Tensor` shaped: 152 `[max_time, batch_size, num_classes]`. The logits. 153 sequence_length: 1-D `int32` vector, size `[batch_size]`. The sequence 154 lengths. 155 preprocess_collapse_repeated: Boolean. Default: False. If True, repeated 156 labels are collapsed prior to the CTC calculation. 157 ctc_merge_repeated: Boolean. Default: True. 158 ignore_longer_outputs_than_inputs: Boolean. Default: False. If True, 159 sequences with longer outputs than inputs will be ignored. 160 time_major: The shape format of the `inputs` Tensors. If True, these 161 `Tensors` must be shaped `[max_time, batch_size, num_classes]`. If False, 162 these `Tensors` must be shaped `[batch_size, max_time, num_classes]`. 163 Using `time_major = True` (default) is a bit more efficient because it 164 avoids transposes at the beginning of the ctc_loss calculation. However, 165 most TensorFlow data is batch-major, so by this function also accepts 166 inputs in batch-major form. 167 logits: Alias for inputs. 168 169 Returns: 170 A 1-D `float` `Tensor`, size `[batch]`, containing the negative log 171 probabilities. 172 173 Raises: 174 TypeError: if labels is not a `SparseTensor`. 175 176 References: 177 Connectionist Temporal Classification - Labeling Unsegmented Sequence Data 178 with Recurrent Neural Networks: 179 [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891) 180 ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf)) 181 """ 182 return _ctc_loss_impl( 183 labels, 184 inputs, 185 sequence_length, 186 preprocess_collapse_repeated, 187 ctc_merge_repeated, 188 ignore_longer_outputs_than_inputs, 189 time_major, 190 logits, 191 use_cudnn=False) 192 193 194def _ctc_loss_impl(labels, 195 inputs=None, 196 sequence_length=None, 197 preprocess_collapse_repeated=False, 198 ctc_merge_repeated=True, 199 ignore_longer_outputs_than_inputs=False, 200 time_major=True, 201 logits=None, 202 use_cudnn=False): 203 # Helper function of ctc_loss with one additional param: 204 # use_cudnn: A bool to enable cuDNN CTC loss operation. If true, the blank 205 # index has to be 0. 206 207 # The second, third, etc output tensors contain the gradients. We use it in 208 # _CTCLossGrad() below. 209 if not isinstance(labels, sparse_tensor.SparseTensor): 210 raise TypeError("Expected argument `labels` to be a SparseTensor. " 211 f"Received labels={labels} of type: " 212 f"{type(labels).__name__}") 213 214 # For internal calculations, we transpose to [time, batch, num_classes] 215 inputs = deprecation.deprecated_argument_lookup("logits", logits, "inputs", 216 inputs) 217 218 inputs = ops.convert_to_tensor(inputs, name="logits") 219 if not time_major: 220 inputs = array_ops.transpose(inputs, [1, 0, 2]) # (B,T,N) => (T,B,N) 221 222 orig_dtype = inputs.dtype 223 if orig_dtype in (dtypes.float16, dtypes.bfloat16): 224 inputs = math_ops.cast(inputs, dtypes.float32) 225 226 # gen_ctc_ops.ctc_loss_v2 differs from gen_ctc_ops.ctc_loss. v2 assumes the 227 # blank index to be 0, but v1 views it as the last index. 228 if use_cudnn: 229 ctc_loss_func = gen_ctc_ops.ctc_loss_v2 230 else: 231 ctc_loss_func = gen_ctc_ops.ctc_loss 232 233 loss, _ = ctc_loss_func( 234 inputs, 235 labels.indices, 236 labels.values, 237 sequence_length, 238 preprocess_collapse_repeated=preprocess_collapse_repeated, 239 ctc_merge_repeated=ctc_merge_repeated, 240 ignore_longer_outputs_than_inputs=ignore_longer_outputs_than_inputs) 241 242 if orig_dtype in (dtypes.float16, dtypes.bfloat16): 243 loss = math_ops.cast(loss, orig_dtype) 244 245 return loss 246 247# pylint: disable=unused-argument 248def _CTCLossGradImpl(op, grad_loss, _): 249 # Outputs are: loss, grad 250 # 251 # Currently there is no way to take the second derivative of this op 252 # due to the fused implementation's interaction with tf.gradients(), 253 # so we make sure we prevent silently incorrect results by raising 254 # an error if the second derivative is requested via prevent_gradient. 255 grad_without_gradient = array_ops.prevent_gradient( 256 op.outputs[1], 257 message="Currently there is no way to take the second " 258 " derivative of ctc_loss due to the fused implementation's interaction " 259 " with tf.gradients()") 260 # Return gradient for inputs and None for 261 # labels_indices, labels_values and sequence_length 262 return [_BroadcastMul(grad_loss, grad_without_gradient), None, None, None] 263 264 265# pylint: disable=unused-argument 266@ops.RegisterGradient("CTCLoss") 267def _CTCLossGrad(op, grad_loss, _): 268 """The derivative provided by CTC Loss. 269 270 Args: 271 op: the CTCLoss op. 272 grad_loss: The backprop for cost. 273 274 Returns: 275 The CTC Loss gradient. 276 """ 277 return _CTCLossGradImpl(op, grad_loss, _) 278 279 280# pylint: disable=unused-argument 281@ops.RegisterGradient("CTCLossV2") 282def _CTCLossV2Grad(op, grad_loss, _): 283 """The derivative provided by CTC Loss V2. 284 285 Args: 286 op: the CTCLossV2 op. 287 grad_loss: The backprop for cost. 288 289 Returns: 290 The CTC Loss V2 gradient. 291 """ 292 return _CTCLossGradImpl(op, grad_loss, _) 293 294 295@tf_export("nn.ctc_greedy_decoder") 296@dispatch.add_dispatch_support 297def ctc_greedy_decoder(inputs, 298 sequence_length, 299 merge_repeated=True, 300 blank_index=None): 301 """Performs greedy decoding on the logits given in input (best path). 302 303 Given a tensor as `inputs`, the `blank_index` parameter defines the class 304 index of the blank symbol. 305 306 For example: 307 308 If `blank_index` is equal to 1: 309 310 >>> inf = float("inf") 311 >>> logits = tf.constant([[[ 0., -inf, -inf], 312 ... [ -2.3, -inf, -0.1]], 313 ... [[ -inf, -0.5, -inf], 314 ... [ -inf, -inf, -0.1]], 315 ... [[ -inf, -inf, -inf], 316 ... [ -0.1, -inf, -2.3]]]) 317 >>> seq_lens = tf.constant([2, 3]) 318 >>> outputs = tf.nn.ctc_greedy_decoder( 319 ... logits, 320 ... seq_lens, 321 ... blank_index=1) 322 323 Notes: 324 325 - Unlike `ctc_beam_search_decoder`, `ctc_greedy_decoder` considers blanks 326 as regular elements when computing the probability of a sequence. 327 - Default `blank_index` is `(num_classes - 1)`, unless overriden. 328 329 If `merge_repeated` is `True`, merge repeated classes in output. 330 This means that if consecutive logits' maximum indices are the same, 331 only the first of these is emitted. The sequence `A B B * B * B` (where '*' 332 is the blank label) becomes 333 334 * `A B B B` if `merge_repeated=True`. 335 * `A B B B B` if `merge_repeated=False`. 336 337 Args: 338 inputs: 3-D `float` `Tensor` sized `[max_time, batch_size, num_classes]`. 339 The logits. 340 sequence_length: 1-D `int32` vector containing sequence lengths, having size 341 `[batch_size]`. 342 merge_repeated: Boolean. Default: True. 343 blank_index: (Optional). Default: `num_classes - 1`. Define the class index 344 to use for the blank label. Negative values will start from num_classes, 345 ie, -1 will reproduce the ctc_greedy_decoder behavior of using 346 num_classes - 1 for the blank symbol, which corresponds to the default. 347 348 Returns: 349 A tuple `(decoded, neg_sum_logits)` where 350 351 decoded: A single-element list. `decoded[0]` 352 is an `SparseTensor` containing the decoded outputs s.t.: 353 354 `decoded.indices`: Indices matrix `(total_decoded_outputs, 2)`. 355 The rows store: `[batch, time]`. 356 357 `decoded.values`: Values vector, size `(total_decoded_outputs)`. 358 The vector stores the decoded classes. 359 360 `decoded.dense_shape`: Shape vector, size `(2)`. 361 The shape values are: `[batch_size, max_decoded_length]` 362 363 neg_sum_logits: A `float` matrix `(batch_size x 1)` containing, for the 364 sequence found, the negative of the sum of the greatest logit at each 365 timeframe. 366 """ 367 368 outputs = gen_ctc_ops.ctc_greedy_decoder( 369 inputs, 370 sequence_length, 371 merge_repeated=merge_repeated, 372 blank_index=blank_index) 373 (decoded_ix, decoded_val, decoded_shape, log_probabilities) = outputs 374 return ([sparse_tensor.SparseTensor(decoded_ix, decoded_val, 375 decoded_shape)], log_probabilities) 376 377 378@tf_export(v1=["nn.ctc_beam_search_decoder"]) 379@dispatch.add_dispatch_support 380def ctc_beam_search_decoder(inputs, 381 sequence_length, 382 beam_width=100, 383 top_paths=1, 384 merge_repeated=True): 385 """Performs beam search decoding on the logits given in input. 386 387 **Note** Although in general greedy search is a special case of beam-search 388 with `top_paths=1` and `beam_width=1`, `ctc_beam_search_decoder` differs 389 from `ctc_greedy_decoder` in the treatment of blanks when computing the 390 probability of a sequence: 391 - `ctc_beam_search_decoder` treats blanks as sequence termination 392 - `ctc_greedy_decoder` treats blanks as regular elements 393 394 If `merge_repeated` is `True`, merge repeated classes in the output beams. 395 This means that if consecutive entries in a beam are the same, 396 only the first of these is emitted. That is, when the sequence is 397 `A B B * B * B` (where '*' is the blank label), the return value is: 398 399 * `A B` if `merge_repeated = True`. 400 * `A B B B` if `merge_repeated = False`. 401 402 Args: 403 inputs: 3-D `float` `Tensor`, size `[max_time x batch_size x num_classes]`. 404 The logits. 405 sequence_length: 1-D `int32` vector containing sequence lengths, having size 406 `[batch_size]`. 407 beam_width: An int scalar >= 0 (beam search beam width). 408 top_paths: An int scalar >= 0, <= beam_width (controls output size). 409 merge_repeated: Boolean. Default: True. 410 411 Returns: 412 A tuple `(decoded, log_probabilities)` where 413 414 decoded: A list of length top_paths, where `decoded[j]` 415 is a `SparseTensor` containing the decoded outputs: 416 417 `decoded[j].indices`: Indices matrix `(total_decoded_outputs[j] x 2)` 418 The rows store: [batch, time]. 419 420 `decoded[j].values`: Values vector, size `(total_decoded_outputs[j])`. 421 The vector stores the decoded classes for beam j. 422 423 `decoded[j].dense_shape`: Shape vector, size `(2)`. 424 The shape values are: `[batch_size, max_decoded_length[j]]`. 425 426 log_probability: A `float` matrix `(batch_size x top_paths)` containing 427 sequence log-probabilities. 428 """ 429 430 decoded_ixs, decoded_vals, decoded_shapes, log_probabilities = ( 431 gen_ctc_ops.ctc_beam_search_decoder( 432 inputs, 433 sequence_length, 434 beam_width=beam_width, 435 top_paths=top_paths, 436 merge_repeated=merge_repeated)) 437 438 return ([ 439 sparse_tensor.SparseTensor(ix, val, shape) 440 for (ix, val, shape) in zip(decoded_ixs, decoded_vals, decoded_shapes) 441 ], log_probabilities) 442 443 444@tf_export("nn.ctc_beam_search_decoder", v1=["nn.ctc_beam_search_decoder_v2"]) 445@dispatch.add_dispatch_support 446def ctc_beam_search_decoder_v2(inputs, 447 sequence_length, 448 beam_width=100, 449 top_paths=1): 450 """Performs beam search decoding on the logits given in input. 451 452 **Note** Although in general greedy search is a special case of beam-search 453 with `top_paths=1` and `beam_width=1`, `ctc_beam_search_decoder` differs 454 from `ctc_greedy_decoder` in the treatment of blanks when computing the 455 probability of a sequence: 456 - `ctc_beam_search_decoder` treats blanks as sequence termination 457 - `ctc_greedy_decoder` treats blanks as regular elements 458 459 Args: 460 inputs: 3-D `float` `Tensor`, size `[max_time, batch_size, num_classes]`. 461 The logits. 462 sequence_length: 1-D `int32` vector containing sequence lengths, having size 463 `[batch_size]`. 464 beam_width: An int scalar >= 0 (beam search beam width). 465 top_paths: An int scalar >= 0, <= beam_width (controls output size). 466 467 Returns: 468 A tuple `(decoded, log_probabilities)` where 469 470 decoded: A list of length top_paths, where `decoded[j]` 471 is a `SparseTensor` containing the decoded outputs: 472 473 `decoded[j].indices`: Indices matrix `[total_decoded_outputs[j], 2]`; 474 The rows store: `[batch, time]`. 475 476 `decoded[j].values`: Values vector, size `[total_decoded_outputs[j]]`. 477 The vector stores the decoded classes for beam `j`. 478 479 `decoded[j].dense_shape`: Shape vector, size `(2)`. 480 The shape values are: `[batch_size, max_decoded_length[j]]`. 481 482 log_probability: A `float` matrix `[batch_size, top_paths]` containing 483 sequence log-probabilities. 484 """ 485 486 # Note, merge_repeated is an invalid optimization that is removed from the 487 # public API: it returns low probability paths. 488 return ctc_beam_search_decoder( 489 inputs, 490 sequence_length=sequence_length, 491 beam_width=beam_width, 492 top_paths=top_paths, 493 merge_repeated=False) 494 495 496ops.NotDifferentiable("CTCGreedyDecoder") 497ops.NotDifferentiable("CTCBeamSearchDecoder") 498 499 500def _ctc_state_trans(label_seq): 501 """Compute CTC alignment model transition matrix. 502 503 Args: 504 label_seq: tensor of shape [batch_size, max_seq_length] 505 506 Returns: 507 tensor of shape [batch_size, states, states] with a state transition matrix 508 computed for each sequence of the batch. 509 """ 510 511 with ops.name_scope("ctc_state_trans"): 512 label_seq = ops.convert_to_tensor(label_seq, name="label_seq") 513 batch_size = _get_dim(label_seq, 0) 514 num_labels = _get_dim(label_seq, 1) 515 516 num_label_states = num_labels + 1 517 num_states = 2 * num_label_states 518 519 label_states = math_ops.range(num_label_states) 520 blank_states = label_states + num_label_states 521 522 # Start state to first label. 523 start_to_label = [[1, 0]] 524 525 # Blank to label transitions. 526 blank_to_label = array_ops.stack([label_states[1:], blank_states[:-1]], 1) 527 528 # Label to blank transitions. 529 label_to_blank = array_ops.stack([blank_states, label_states], 1) 530 531 # Scatter transitions that don't depend on sequence. 532 indices = array_ops.concat([start_to_label, blank_to_label, label_to_blank], 533 0) 534 values = array_ops.ones([_get_dim(indices, 0)]) 535 trans = array_ops.scatter_nd( 536 indices, values, shape=[num_states, num_states]) 537 trans += linalg_ops.eye(num_states) # Self-loops. 538 539 # Label to label transitions. Disallow transitions between repeated labels 540 # with no blank state in between. 541 batch_idx = array_ops.zeros_like(label_states[2:]) 542 indices = array_ops.stack([batch_idx, label_states[2:], label_states[1:-1]], 543 1) 544 indices = array_ops.tile( 545 array_ops.expand_dims(indices, 0), [batch_size, 1, 1]) 546 batch_idx = array_ops.expand_dims(math_ops.range(batch_size), 1) * [1, 0, 0] 547 indices += array_ops.expand_dims(batch_idx, 1) 548 repeats = math_ops.equal(label_seq[:, :-1], label_seq[:, 1:]) 549 values = 1.0 - math_ops.cast(repeats, dtypes.float32) 550 batched_shape = [batch_size, num_states, num_states] 551 label_to_label = array_ops.scatter_nd(indices, values, batched_shape) 552 553 return array_ops.expand_dims(trans, 0) + label_to_label 554 555 556def ctc_state_log_probs(seq_lengths, max_seq_length): 557 """Computes CTC alignment initial and final state log probabilities. 558 559 Create the initial/final state values directly as log values to avoid 560 having to take a float64 log on tpu (which does not exist). 561 562 Args: 563 seq_lengths: int tensor of shape [batch_size], seq lengths in the batch. 564 max_seq_length: int, max sequence length possible. 565 566 Returns: 567 initial_state_log_probs, final_state_log_probs 568 """ 569 570 batch_size = _get_dim(seq_lengths, 0) 571 num_label_states = max_seq_length + 1 572 num_duration_states = 2 573 num_states = num_duration_states * num_label_states 574 log_0 = math_ops.cast( 575 math_ops.log(math_ops.cast(0, dtypes.float64) + 1e-307), dtypes.float32) 576 577 initial_state_log_probs = array_ops.one_hot( 578 indices=array_ops.zeros([batch_size], dtype=dtypes.int32), 579 depth=num_states, 580 on_value=0.0, 581 off_value=log_0, 582 axis=1) 583 584 label_final_state_mask = array_ops.one_hot( 585 seq_lengths, depth=num_label_states, axis=0) 586 duration_final_state_mask = array_ops.ones( 587 [num_duration_states, 1, batch_size]) 588 final_state_mask = duration_final_state_mask * label_final_state_mask 589 final_state_log_probs = (1.0 - final_state_mask) * log_0 590 final_state_log_probs = array_ops.reshape(final_state_log_probs, 591 [num_states, batch_size]) 592 593 return initial_state_log_probs, array_ops.transpose(final_state_log_probs) 594 595 596def _ilabel_to_state(labels, num_labels, ilabel_log_probs): 597 """Project ilabel log probs to state log probs.""" 598 599 num_label_states = _get_dim(labels, 1) 600 blank = ilabel_log_probs[:, :, :1] 601 blank = array_ops.tile(blank, [1, 1, num_label_states + 1]) 602 one_hot = array_ops.one_hot(labels, depth=num_labels) 603 one_hot = array_ops.expand_dims(one_hot, axis=0) 604 ilabel_log_probs = array_ops.expand_dims(ilabel_log_probs, axis=2) 605 state_log_probs = math_ops.reduce_sum(ilabel_log_probs * one_hot, axis=3) 606 state_log_probs = array_ops.concat([state_log_probs, blank], axis=2) 607 return array_ops.pad( 608 state_log_probs, [[0, 0], [0, 0], [1, 0]], 609 constant_values=math_ops.log(0.0)) 610 611 612def _state_to_olabel(labels, num_labels, states): 613 """Sum state log probs to ilabel log probs.""" 614 615 num_label_states = _get_dim(labels, 1) + 1 616 label_states = states[:, :, 1:num_label_states] 617 blank_states = states[:, :, num_label_states:] 618 one_hot = array_ops.one_hot( 619 labels - 1, 620 depth=(num_labels - 1), 621 on_value=0.0, 622 off_value=math_ops.log(0.0)) 623 one_hot = array_ops.expand_dims(one_hot, axis=0) 624 label_states = array_ops.expand_dims(label_states, axis=3) 625 label_olabels = math_ops.reduce_logsumexp(label_states + one_hot, axis=2) 626 blank_olabels = math_ops.reduce_logsumexp(blank_states, axis=2, keepdims=True) 627 return array_ops.concat([blank_olabels, label_olabels], axis=-1) 628 629 630# pylint: disable=redefined-outer-name 631def _state_to_olabel_unique(labels, num_labels, states, unique): 632 """Sum state log probs to ilabel log probs using unique label indices.""" 633 634 num_label_states = _get_dim(labels, 1) + 1 635 label_states = states[:, :, 1:num_label_states] 636 blank_states = states[:, :, num_label_states:] 637 638 unique_y, unique_idx = unique 639 mul_reduce = _sum_states(unique_idx, label_states) 640 641 num_frames = _get_dim(states, 0) 642 batch_size = _get_dim(states, 1) 643 num_states = num_label_states - 1 644 batch_state_major = array_ops.transpose(mul_reduce, perm=[1, 2, 0]) 645 batch_state_major = array_ops.reshape(batch_state_major, 646 [batch_size * num_states, num_frames]) 647 batch_offset = math_ops.range(batch_size, dtype=unique_y.dtype) * num_labels 648 indices = unique_y + array_ops.expand_dims(batch_offset, axis=-1) 649 indices = array_ops.reshape(indices, [-1, 1]) 650 scatter = array_ops.scatter_nd( 651 indices=indices, 652 updates=batch_state_major, 653 shape=[batch_size * num_labels, num_frames]) 654 scatter = array_ops.reshape(scatter, [batch_size, num_labels, num_frames]) 655 656 mask = array_ops.ones_like(batch_state_major, dtype=dtypes.bool) 657 mask = array_ops.scatter_nd( 658 indices=indices, 659 updates=mask, 660 shape=[batch_size * num_labels, num_frames]) 661 mask = array_ops.reshape(mask, [batch_size, num_labels, num_frames]) 662 663 scatter = array_ops.where( 664 mask, scatter, 665 array_ops.fill(array_ops.shape(scatter), math_ops.log(0.0))) 666 667 label_olabels = array_ops.transpose(scatter, [2, 0, 1]) 668 label_olabels = label_olabels[:, :, 1:] 669 670 blank_olabels = math_ops.reduce_logsumexp(blank_states, axis=2, keepdims=True) 671 672 return array_ops.concat([blank_olabels, label_olabels], axis=-1) 673 674 675def ctc_loss_and_grad(logits, labels, label_length, logit_length, unique=None): 676 """Computes the CTC loss and gradients. 677 678 Most users will want fwd_bwd.ctc_loss 679 680 This function returns the computed gradient, it does not have a gradient 681 of its own defined. 682 683 Args: 684 logits: tensor of shape [frames, batch_size, num_labels] 685 labels: tensor of shape [batch_size, max_label_seq_length] 686 label_length: tensor of shape [batch_size] Length of reference label 687 sequence in labels. 688 logit_length: tensor of shape [batch_size] Length of input sequence in 689 logits. 690 unique: (optional) unique label indices as computed by unique(labels) If 691 supplied, enables an implementation that is faster and more memory 692 efficient on TPU. 693 694 Returns: 695 loss: tensor of shape [batch_size] 696 gradient: tensor of shape [frames, batch_size, num_labels] 697 """ 698 699 num_labels = _get_dim(logits, 2) 700 max_label_seq_length = _get_dim(labels, 1) 701 702 ilabel_log_probs = nn_ops.log_softmax(logits) 703 state_log_probs = _ilabel_to_state(labels, num_labels, ilabel_log_probs) 704 state_trans_probs = _ctc_state_trans(labels) 705 initial_state_log_probs, final_state_log_probs = ctc_state_log_probs( 706 label_length, max_label_seq_length) 707 fwd_bwd_log_probs, log_likelihood = _forward_backward_log( 708 state_trans_log_probs=math_ops.log(state_trans_probs), 709 initial_state_log_probs=initial_state_log_probs, 710 final_state_log_probs=final_state_log_probs, 711 observed_log_probs=state_log_probs, 712 sequence_length=logit_length) 713 714 if unique: 715 olabel_log_probs = _state_to_olabel_unique(labels, num_labels, 716 fwd_bwd_log_probs, unique) 717 else: 718 olabel_log_probs = _state_to_olabel(labels, num_labels, fwd_bwd_log_probs) 719 720 grad = math_ops.exp(ilabel_log_probs) - math_ops.exp(olabel_log_probs) 721 722 # Applies the sequence mask for the gradient. It is enough to appply the mask 723 # only for ilabel_log_probs because olabel_log_probs already consider the 724 # mask. However, it is just safe and clean to apply it for the gradient. 725 max_logit_length = _get_dim(logits, 0) 726 logit_mask = array_ops.sequence_mask(logit_length, max_logit_length, 727 dtypes.float32) 728 logit_mask = array_ops.transpose(logit_mask, perm=[1, 0]) 729 logit_mask = array_ops.expand_dims(logit_mask, axis=2) 730 grad *= logit_mask 731 732 loss = -log_likelihood 733 return loss, grad 734 735 736def _ctc_loss_grad(op, grad_loss, _): 737 grad = op.outputs[1] 738 grad = [array_ops.reshape(grad_loss, [1, -1, 1]) * grad] 739 grad += [None] * (len(op.inputs) - len(grad)) 740 return grad 741 742 743def _ctc_loss_op_standard(labels, logits, logit_length, logits_time_major, 744 blank_index): 745 part_before = logits[:, :, :blank_index] 746 part_after = logits[:, :, blank_index + 1:] 747 part_blank = logits[:, :, blank_index:blank_index + 1] 748 logits = array_ops.concat([part_before, part_after, part_blank], axis=2) 749 labels = sparse_tensor.SparseTensor( 750 labels.indices, 751 array_ops.where(labels.values < blank_index, labels.values, 752 labels.values - 1), labels.dense_shape) 753 return _ctc_loss_impl( 754 labels=labels, 755 inputs=logits, 756 sequence_length=logit_length, 757 time_major=logits_time_major, 758 use_cudnn=False) 759 760 761def _ctc_loss_op_cudnn(labels, logits, logit_length, logits_time_major, 762 blank_index): 763 part_before = logits[:, :, :blank_index] 764 part_after = logits[:, :, blank_index + 1:] 765 part_blank = logits[:, :, blank_index:blank_index + 1] 766 logits = array_ops.concat([part_blank, part_before, part_after], axis=2) 767 labels = sparse_tensor.SparseTensor( 768 labels.indices, 769 array_ops.where(labels.values < blank_index, labels.values + 1, 770 labels.values), labels.dense_shape) 771 return _ctc_loss_impl( 772 labels=labels, 773 inputs=logits, 774 sequence_length=logit_length, 775 time_major=logits_time_major, 776 use_cudnn=True) 777 778 779def _ctc_loss_shape(op): 780 return [op.inputs[2].get_shape(), op.inputs[0].get_shape()] 781 782 783# pylint: disable=protected-access, invalid-name 784@tf_export(v1=["nn.ctc_loss_v2"]) 785@dispatch.add_dispatch_support 786def ctc_loss_v2(labels, 787 logits, 788 label_length, 789 logit_length, 790 logits_time_major=True, 791 unique=None, 792 blank_index=None, 793 name=None): 794 """Computes CTC (Connectionist Temporal Classification) loss. 795 796 This op implements the CTC loss as presented in (Graves et al., 2006). 797 798 Notes: 799 800 - Same as the "Classic CTC" in TensorFlow 1.x's tf.compat.v1.nn.ctc_loss 801 setting of preprocess_collapse_repeated=False, ctc_merge_repeated=True 802 - Labels may be supplied as either a dense, zero-padded tensor with a 803 vector of label sequence lengths OR as a SparseTensor. 804 - On TPU and GPU: Only dense padded labels are supported. 805 - On CPU: Caller may use SparseTensor or dense padded labels but calling with 806 a SparseTensor will be significantly faster. 807 - Default blank label is 0 rather num_classes - 1, unless overridden by 808 blank_index. 809 810 Args: 811 labels: tensor of shape [batch_size, max_label_seq_length] or SparseTensor 812 logits: tensor of shape [frames, batch_size, num_labels], if 813 logits_time_major == False, shape is [batch_size, frames, num_labels]. 814 label_length: tensor of shape [batch_size], None if labels is SparseTensor 815 Length of reference label sequence in labels. 816 logit_length: tensor of shape [batch_size] Length of input sequence in 817 logits. 818 logits_time_major: (optional) If True (default), logits is shaped [time, 819 batch, logits]. If False, shape is [batch, time, logits] 820 unique: (optional) Unique label indices as computed by 821 ctc_unique_labels(labels). If supplied, enable a faster, memory efficient 822 implementation on TPU. 823 blank_index: (optional) Set the class index to use for the blank label. 824 Negative values will start from num_classes, ie, -1 will reproduce the 825 ctc_loss behavior of using num_classes - 1 for the blank symbol. There is 826 some memory/performance overhead to switching from the default of 0 as an 827 additional shifted copy of the logits may be created. 828 name: A name for this `Op`. Defaults to "ctc_loss_dense". 829 830 Returns: 831 loss: tensor of shape [batch_size], negative log probabilities. 832 833 References: 834 Connectionist Temporal Classification - Labeling Unsegmented Sequence Data 835 with Recurrent Neural Networks: 836 [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891) 837 ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf)) 838 """ 839 if isinstance(labels, sparse_tensor.SparseTensor): 840 if blank_index is None: 841 raise ValueError( 842 "Argument `blank_index` must be provided when labels is a " 843 "SparseTensor.") 844 845 if blank_index < 0: 846 blank_index += _get_dim(logits, 2) 847 848 if blank_index != _get_dim(logits, 2) - 1: 849 logits = array_ops.concat([ 850 logits[:, :, :blank_index], 851 logits[:, :, blank_index + 1:], 852 logits[:, :, blank_index:blank_index + 1], 853 ], 854 axis=2) 855 labels = sparse_tensor.SparseTensor( 856 labels.indices, 857 array_ops.where(labels.values < blank_index, labels.values, 858 labels.values - 1), labels.dense_shape) 859 860 return ctc_loss( 861 labels=labels, 862 inputs=logits, 863 sequence_length=logit_length, 864 time_major=logits_time_major) 865 866 if blank_index is None: 867 blank_index = 0 868 869 return ctc_loss_dense( 870 labels=labels, 871 logits=logits, 872 label_length=label_length, 873 logit_length=logit_length, 874 logits_time_major=logits_time_major, 875 unique=unique, 876 blank_index=blank_index, 877 name=name) 878 879 880@tf_export("nn.ctc_loss", v1=[]) 881@dispatch.add_dispatch_support 882def ctc_loss_v3(labels, 883 logits, 884 label_length, 885 logit_length, 886 logits_time_major=True, 887 unique=None, 888 blank_index=None, 889 name=None): 890 """Computes CTC (Connectionist Temporal Classification) loss. 891 892 This op implements the CTC loss as presented in (Graves et al., 2006). 893 894 Notes: 895 896 - Same as the "Classic CTC" in TensorFlow 1.x's tf.compat.v1.nn.ctc_loss 897 setting of preprocess_collapse_repeated=False, ctc_merge_repeated=True 898 - Labels may be supplied as either a dense, zero-padded tensor with a 899 vector of label sequence lengths OR as a SparseTensor. 900 - On TPU and GPU: Only dense padded labels are supported. 901 - On CPU: Caller may use SparseTensor or dense padded labels but calling with 902 a SparseTensor will be significantly faster. 903 - Default blank label is 0 rather num_classes - 1, unless overridden by 904 blank_index. 905 906 Args: 907 labels: tensor of shape [batch_size, max_label_seq_length] or SparseTensor 908 logits: tensor of shape [frames, batch_size, num_labels], if 909 logits_time_major == False, shape is [batch_size, frames, num_labels]. 910 label_length: tensor of shape [batch_size], None if labels is SparseTensor 911 Length of reference label sequence in labels. 912 logit_length: tensor of shape [batch_size] Length of input sequence in 913 logits. 914 logits_time_major: (optional) If True (default), logits is shaped [time, 915 batch, logits]. If False, shape is [batch, time, logits] 916 unique: (optional) Unique label indices as computed by 917 ctc_unique_labels(labels). If supplied, enable a faster, memory efficient 918 implementation on TPU. 919 blank_index: (optional) Set the class index to use for the blank label. 920 Negative values will start from num_classes, ie, -1 will reproduce the 921 ctc_loss behavior of using num_classes - 1 for the blank symbol. There is 922 some memory/performance overhead to switching from the default of 0 as an 923 additional shifted copy of the logits may be created. 924 name: A name for this `Op`. Defaults to "ctc_loss_dense". 925 926 Returns: 927 loss: tensor of shape [batch_size], negative log probabilities. 928 929 References: 930 Connectionist Temporal Classification - Labeling Unsegmented Sequence Data 931 with Recurrent Neural Networks: 932 [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891) 933 ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf)) 934 """ 935 if isinstance(labels, sparse_tensor.SparseTensor): 936 if blank_index is None: 937 raise ValueError( 938 "Argument `blank_index` must be provided when labels is a " 939 "SparseTensor.") 940 941 if blank_index < 0: 942 blank_index += _get_dim(logits, 2) 943 944 logits = ops.convert_to_tensor(logits, name="logits") 945 946 params = { 947 "labels": labels, 948 "logits": logits, 949 "logit_length": logit_length, 950 "logits_time_major": logits_time_major, 951 "blank_index": blank_index 952 } 953 954 if context.executing_eagerly(): 955 device_type = _get_context_device_type() 956 can_use_gpu = ( 957 # Either user specified GPU or unspecified but GPU is available. 958 (device_type == _GPU_DEVICE_NAME or 959 (device_type is None and context.num_gpus() > 0))) 960 # Under eager context, check the device placement and prefer the 961 if can_use_gpu: 962 res = _ctc_loss_op_cudnn(**params) 963 else: 964 res = _ctc_loss_op_standard(**params) 965 else: 966 api_name = "ctc_loss_" + str(uuid.uuid4()) 967 ctc_loss_op_standard = _generate_defun_backend(api_name, _CPU_DEVICE_NAME, 968 _ctc_loss_op_standard) 969 ctc_loss_op_cudnn = _generate_defun_backend(api_name, _GPU_DEVICE_NAME, 970 _ctc_loss_op_cudnn) 971 res = ctc_loss_op_standard(**params) 972 function_eager.register(ctc_loss_op_cudnn, **params) 973 return res 974 975 if blank_index is None: 976 blank_index = 0 977 978 return ctc_loss_dense( 979 labels=labels, 980 logits=logits, 981 label_length=label_length, 982 logit_length=logit_length, 983 logits_time_major=logits_time_major, 984 unique=unique, 985 blank_index=blank_index, 986 name=name) 987 988 989def ctc_loss_dense(labels, 990 logits, 991 label_length, 992 logit_length, 993 logits_time_major=True, 994 unique=None, 995 blank_index=0, 996 name=None): 997 """Computes CTC (Connectionist Temporal Classification) loss. 998 999 This op implements the CTC loss as presented in (Graves et al., 2006), 1000 using the batched forward backward algorithm described in (Sim et al., 2017). 1001 1002 Notes: 1003 Significant differences from tf.compat.v1.nn.ctc_loss: 1004 Supports GPU and TPU (tf.compat.v1.nn.ctc_loss supports CPU only): 1005 For batched operations, GPU and TPU are significantly faster than using 1006 ctc_loss on CPU. 1007 This implementation runs on CPU, but significantly slower than ctc_loss. 1008 Blank label is 0 rather num_classes - 1, unless overridden by blank_index. 1009 Logits and labels are dense arrays with padding rather than SparseTensor. 1010 The only mode supported is the same as: 1011 preprocess_collapse_repeated=False, ctc_merge_repeated=True 1012 To collapse labels, the caller can preprocess label sequence first. 1013 1014 The dense implementation supports both CPU, GPU and TPU. A fast path is 1015 provided that significantly improves memory use for large vocabulary if the 1016 caller preprocesses label sequences to get unique label indices on the CPU 1017 (eg. in the data input pipeline) using ctc_ops.unique and simplifies this in 1018 the optional "unique" kwarg. This is especially useful for TPU and GPU but 1019 also works with if used on CPU. 1020 1021 Args: 1022 labels: tensor of shape [batch_size, max_label_seq_length] 1023 logits: tensor of shape [frames, batch_size, num_labels], if 1024 logits_time_major == False, shape is [batch_size, frames, num_labels]. 1025 label_length: tensor of shape [batch_size] Length of reference label 1026 sequence in labels. 1027 logit_length: tensor of shape [batch_size] Length of input sequence in 1028 logits. 1029 logits_time_major: (optional) If True (default), logits is shaped [time, 1030 batch, logits]. If False, shape is [batch, time, logits] 1031 unique: (optional) Unique label indices as computed by unique(labels). If 1032 supplied, enable a faster, memory efficient implementation on TPU. 1033 blank_index: (optional) Set the class index to use for the blank label. 1034 Negative values will start from num_classes, ie, -1 will reproduce the 1035 ctc_loss behavior of using num_classes - 1 for the blank symbol. There is 1036 some memory/performance overhead to switching from the default of 0 as an 1037 additional shifted copy of the logits may be created. 1038 name: A name for this `Op`. Defaults to "ctc_loss_dense". 1039 1040 Returns: 1041 loss: tensor of shape [batch_size], negative log probabilities. 1042 1043 References: 1044 Connectionist Temporal Classification - Labeling Unsegmented Sequence Data 1045 with Recurrent Neural Networks: 1046 [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891) 1047 ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf)) 1048 Improving the efficiency of forward-backward algorithm using batched 1049 computation in TensorFlow: 1050 [Sim et al., 2017](https://ieeexplore.ieee.org/document/8268944) 1051 ([pdf](http://bacchiani.net/resume/papers/ASRU2017.pdf)) 1052 """ 1053 1054 with ops.name_scope(name, "ctc_loss_dense", 1055 [logits, labels, label_length, logit_length]): 1056 logits = ops.convert_to_tensor(logits, name="logits") 1057 labels = ops.convert_to_tensor(labels, name="labels") 1058 label_length = ops.convert_to_tensor(label_length, name="label_length") 1059 logit_length = ops.convert_to_tensor(logit_length, name="logit_length") 1060 1061 orig_dtype = logits.dtype 1062 if orig_dtype in (dtypes.float16, dtypes.bfloat16): 1063 logits = math_ops.cast(logits, dtypes.float32) 1064 1065 if not logits_time_major: 1066 logits = array_ops.transpose(logits, perm=[1, 0, 2]) 1067 1068 if blank_index != 0: 1069 if blank_index < 0: 1070 blank_index += _get_dim(logits, 2) 1071 logits = array_ops.concat([ 1072 logits[:, :, blank_index:blank_index + 1], 1073 logits[:, :, :blank_index], 1074 logits[:, :, blank_index + 1:], 1075 ], 1076 axis=2) 1077 labels = array_ops.where(labels < blank_index, labels + 1, labels) 1078 1079 args = [logits, labels, label_length, logit_length] 1080 1081 if unique: 1082 unique_y, unique_idx = unique 1083 if blank_index != 0: 1084 unique_y = array_ops.where(unique_y < blank_index, unique_y + 1, 1085 unique_y) 1086 label_mask_len = math_ops.reduce_max(unique_idx, axis=1) + 1 1087 max_label_length = _get_dim(unique_y, 1) 1088 label_mask = array_ops.sequence_mask(label_mask_len, max_label_length) 1089 unique_y = array_ops.where(label_mask, unique_y, 1090 array_ops.zeros_like(unique_y)) 1091 args.extend([unique_y, unique_idx]) 1092 1093 @custom_gradient.custom_gradient 1094 def compute_ctc_loss(logits_t, labels_t, label_length_t, logit_length_t, 1095 *unique_t): 1096 """Compute CTC loss.""" 1097 logits_t.set_shape(logits.shape) 1098 labels_t.set_shape(labels.shape) 1099 label_length_t.set_shape(label_length.shape) 1100 logit_length_t.set_shape(logit_length.shape) 1101 kwargs = dict( 1102 logits=logits_t, 1103 labels=labels_t, 1104 label_length=label_length_t, 1105 logit_length=logit_length_t) 1106 if unique_t: 1107 kwargs["unique"] = unique_t 1108 result = ctc_loss_and_grad(**kwargs) 1109 def grad(grad_loss): 1110 grad = [array_ops.reshape(grad_loss, [1, -1, 1]) * result[1]] 1111 grad += [None] * (len(args) - len(grad)) 1112 return grad 1113 1114 return result[0], grad 1115 1116 loss = compute_ctc_loss(*args) 1117 if orig_dtype in (dtypes.float16, dtypes.bfloat16): 1118 loss = math_ops.cast(loss, orig_dtype) 1119 return loss 1120 1121 1122@tf_export("nn.collapse_repeated") 1123@dispatch.add_dispatch_support 1124def collapse_repeated(labels, seq_length, name=None): 1125 """Merge repeated labels into single labels. 1126 1127 Args: 1128 labels: Tensor of shape [batch, max value in seq_length] 1129 seq_length: Tensor of shape [batch], sequence length of each batch element. 1130 name: A name for this `Op`. Defaults to "collapse_repeated_labels". 1131 1132 Returns: 1133 A tuple `(collapsed_labels, new_seq_length)` where 1134 1135 collapsed_labels: Tensor of shape [batch, max_seq_length] with repeated 1136 labels collapsed and padded to max_seq_length, eg: 1137 `[[A, A, B, B, A], [A, B, C, D, E]] => [[A, B, A, 0, 0], [A, B, C, D, E]]` 1138 1139 new_seq_length: int tensor of shape [batch] with new sequence lengths. 1140 """ 1141 1142 with ops.name_scope(name, "collapse_repeated_labels", [labels, seq_length]): 1143 labels = ops.convert_to_tensor(labels, name="labels") 1144 seq_length = ops.convert_to_tensor(seq_length, name="seq_length") 1145 1146 # Mask labels that don't equal previous label. 1147 label_mask = array_ops.concat([ 1148 array_ops.ones_like(labels[:, :1], dtypes.bool), 1149 math_ops.not_equal(labels[:, 1:], labels[:, :-1]) 1150 ], 1151 axis=1) 1152 1153 # Filter labels that aren't in the original sequence. 1154 maxlen = _get_dim(labels, 1) 1155 seq_mask = array_ops.sequence_mask(seq_length, maxlen=maxlen) 1156 label_mask = math_ops.logical_and(label_mask, seq_mask) 1157 1158 # Count masks for new sequence lengths. 1159 new_seq_len = math_ops.reduce_sum( 1160 math_ops.cast(label_mask, dtypes.int32), axis=1) 1161 1162 # Mask indexes based on sequence length mask. 1163 new_maxlen = math_ops.reduce_max(new_seq_len) 1164 idx_mask = array_ops.sequence_mask(new_seq_len, maxlen=new_maxlen) 1165 1166 # Flatten everything and mask out labels to keep and sparse indices. 1167 flat_labels = array_ops.reshape(labels, [-1]) 1168 flat_label_mask = array_ops.reshape(label_mask, [-1]) 1169 flat_idx_mask = array_ops.reshape(idx_mask, [-1]) 1170 idx = math_ops.range(_get_dim(flat_idx_mask, 0)) 1171 1172 # Scatter to flat shape. 1173 flat = array_ops.scatter_nd( 1174 indices=array_ops.expand_dims( 1175 array_ops.boolean_mask(idx, flat_idx_mask), axis=1), 1176 updates=array_ops.boolean_mask(flat_labels, flat_label_mask), 1177 shape=array_ops.shape(flat_idx_mask)) 1178 1179 # Reshape back to square batch. 1180 batch_size = _get_dim(labels, 0) 1181 new_shape = [batch_size, new_maxlen] 1182 return (array_ops.reshape(flat, new_shape), 1183 math_ops.cast(new_seq_len, seq_length.dtype)) 1184 1185 1186def dense_labels_to_sparse(dense, length): 1187 """Convert dense labels with sequence lengths to sparse tensor. 1188 1189 Args: 1190 dense: tensor of shape [batch, max_length] 1191 length: int tensor of shape [batch] The length of each sequence in dense. 1192 1193 Returns: 1194 tf.sparse.SparseTensor with values only for the valid elements of sequences. 1195 """ 1196 1197 flat_values = array_ops.reshape(dense, [-1]) 1198 flat_indices = math_ops.range( 1199 array_ops.shape(flat_values, out_type=dtypes.int64)[0]) 1200 mask = array_ops.sequence_mask(length, maxlen=array_ops.shape(dense)[1]) 1201 flat_mask = array_ops.reshape(mask, [-1]) 1202 indices = array_ops.expand_dims( 1203 array_ops.boolean_mask(flat_indices, flat_mask), 1) 1204 values = array_ops.boolean_mask(flat_values, flat_mask) 1205 sparse = sparse_tensor.SparseTensor( 1206 indices=indices, 1207 values=math_ops.cast(values, dtypes.int32), 1208 dense_shape=array_ops.shape(flat_values, out_type=dtypes.int64)) 1209 reshaped = sparse_ops.sparse_reshape(sparse, array_ops.shape(dense)) 1210 max_length = math_ops.reduce_max(length) 1211 return sparse_tensor.SparseTensor( 1212 indices=reshaped.indices, 1213 values=reshaped.values, 1214 dense_shape=[ 1215 math_ops.cast(reshaped.dense_shape[0], dtypes.int64), 1216 math_ops.cast(max_length, dtypes.int64) 1217 ]) 1218 1219 1220@tf_export("nn.ctc_unique_labels") 1221@dispatch.add_dispatch_support 1222def ctc_unique_labels(labels, name=None): 1223 """Get unique labels and indices for batched labels for `tf.nn.ctc_loss`. 1224 1225 For use with `tf.nn.ctc_loss` optional argument `unique`: This op can be 1226 used to preprocess labels in input pipeline to for better speed/memory use 1227 computing the ctc loss on TPU. 1228 1229 Example: 1230 ctc_unique_labels([[3, 4, 4, 3]]) -> 1231 unique labels padded with 0: [[3, 4, 0, 0]] 1232 indices of original labels in unique: [0, 1, 1, 0] 1233 1234 Args: 1235 labels: tensor of shape [batch_size, max_label_length] padded with 0. 1236 name: A name for this `Op`. Defaults to "ctc_unique_labels". 1237 1238 Returns: 1239 tuple of 1240 - unique labels, tensor of shape `[batch_size, max_label_length]` 1241 - indices into unique labels, shape `[batch_size, max_label_length]` 1242 """ 1243 1244 with ops.name_scope(name, "ctc_unique_labels", [labels]): 1245 labels = ops.convert_to_tensor(labels, name="labels") 1246 1247 def _unique(x): 1248 u = array_ops.unique(x) 1249 y = array_ops.pad(u.y, [[0, _get_dim(u.idx, 0) - _get_dim(u.y, 0)]]) 1250 y = math_ops.cast(y, dtypes.int64) 1251 return [y, u.idx] 1252 1253 return map_fn.map_fn(_unique, labels, dtype=[dtypes.int64, dtypes.int32]) 1254 1255 1256def _sum_states(idx, states): 1257 """Take logsumexp for each unique state out of all label states. 1258 1259 Args: 1260 idx: tensor of shape [batch, label_length] For each sequence, indices into a 1261 set of unique labels as computed by calling unique. 1262 states: tensor of shape [frames, batch, label_length] Log probabilities for 1263 each label state. 1264 1265 Returns: 1266 tensor of shape [frames, batch_size, label_length], log probabilities summed 1267 for each unique label of the sequence. 1268 """ 1269 1270 with ops.name_scope("sum_states"): 1271 idx = ops.convert_to_tensor(idx, name="idx") 1272 num_states = _get_dim(states, 2) 1273 states = array_ops.expand_dims(states, axis=2) 1274 one_hot = array_ops.one_hot( 1275 idx, 1276 depth=num_states, 1277 on_value=0.0, 1278 off_value=math_ops.log(0.0), 1279 axis=1) 1280 return math_ops.reduce_logsumexp(states + one_hot, axis=-1) 1281 1282 1283def _forward_backward_log(state_trans_log_probs, initial_state_log_probs, 1284 final_state_log_probs, observed_log_probs, 1285 sequence_length): 1286 """Forward-backward algorithm computed in log domain. 1287 1288 Args: 1289 state_trans_log_probs: tensor of shape [states, states] or if different 1290 transition matrix per batch [batch_size, states, states] 1291 initial_state_log_probs: tensor of shape [batch_size, states] 1292 final_state_log_probs: tensor of shape [batch_size, states] 1293 observed_log_probs: tensor of shape [frames, batch_size, states] 1294 sequence_length: tensor of shape [batch_size] 1295 1296 Returns: 1297 forward backward log probabilities: tensor of shape [frames, batch, states] 1298 log_likelihood: tensor of shape [batch_size] 1299 1300 Raises: 1301 ValueError: If state_trans_log_probs has unknown or incorrect rank. 1302 """ 1303 1304 if state_trans_log_probs.shape.ndims == 2: 1305 perm = [1, 0] 1306 elif state_trans_log_probs.shape.ndims == 3: 1307 perm = [0, 2, 1] 1308 else: 1309 raise ValueError( 1310 "Rank of argument `state_trans_log_probs` must be known and equal to " 1311 f"2 or 3. Received state_trans_log_probs={state_trans_log_probs} of " 1312 f"rank {state_trans_log_probs.shape.ndims}") 1313 1314 bwd_state_trans_log_probs = array_ops.transpose(state_trans_log_probs, perm) 1315 batch_size = _get_dim(observed_log_probs, 1) 1316 1317 def _forward(state_log_prob, obs_log_prob): 1318 state_log_prob = array_ops.expand_dims(state_log_prob, axis=1) # Broadcast. 1319 state_log_prob += state_trans_log_probs 1320 state_log_prob = math_ops.reduce_logsumexp(state_log_prob, axis=-1) 1321 state_log_prob += obs_log_prob 1322 log_prob_sum = math_ops.reduce_logsumexp( 1323 state_log_prob, axis=-1, keepdims=True) 1324 state_log_prob -= log_prob_sum 1325 return state_log_prob 1326 1327 fwd = _scan( 1328 _forward, observed_log_probs, initial_state_log_probs, inclusive=True) 1329 1330 def _backward(accs, elems): 1331 """Calculate log probs and cumulative sum masked for sequence length.""" 1332 state_log_prob, cum_log_sum = accs 1333 obs_log_prob, mask = elems 1334 state_log_prob += obs_log_prob 1335 state_log_prob = array_ops.expand_dims(state_log_prob, axis=1) # Broadcast. 1336 state_log_prob += bwd_state_trans_log_probs 1337 state_log_prob = math_ops.reduce_logsumexp(state_log_prob, axis=-1) 1338 1339 log_prob_sum = math_ops.reduce_logsumexp( 1340 state_log_prob, axis=-1, keepdims=True) 1341 state_log_prob -= log_prob_sum 1342 1343 cum_log_sum += array_ops.squeeze(log_prob_sum, axis=[-1]) * mask 1344 batched_mask = array_ops.expand_dims(mask, axis=1) 1345 out = state_log_prob * batched_mask 1346 out += final_state_log_probs * (1.0 - batched_mask) 1347 return out, cum_log_sum 1348 1349 zero_log_sum = array_ops.zeros([batch_size]) 1350 maxlen = _get_dim(observed_log_probs, 0) 1351 mask = array_ops.sequence_mask(sequence_length, maxlen, dtypes.float32) 1352 mask = array_ops.transpose(mask, perm=[1, 0]) 1353 1354 bwd, cum_log_sum = _scan( 1355 _backward, (observed_log_probs, mask), 1356 (final_state_log_probs, zero_log_sum), 1357 reverse=True, 1358 inclusive=True) 1359 1360 fwd_bwd_log_probs = fwd[1:] + bwd[1:] 1361 fwd_bwd_log_probs_sum = math_ops.reduce_logsumexp( 1362 fwd_bwd_log_probs, axis=2, keepdims=True) 1363 fwd_bwd_log_probs -= fwd_bwd_log_probs_sum 1364 fwd_bwd_log_probs += math_ops.log(array_ops.expand_dims(mask, axis=2)) 1365 1366 log_likelihood = bwd[0, :, 0] + cum_log_sum[0] 1367 1368 return fwd_bwd_log_probs, log_likelihood 1369 1370 1371# TODO(tombagby): This is currently faster for the ctc implementation than using 1372# functional_ops.scan, but could be replaced by that or something similar if 1373# things change. 1374def _scan(fn, elems, initial, reverse=False, inclusive=False, final_only=False): 1375 """Repeatedly applies callable `fn` to a sequence of elements. 1376 1377 Implemented by functional_ops.While, tpu friendly, no gradient. 1378 1379 This is similar to functional_ops.scan but significantly faster on tpu/gpu 1380 for the forward backward use case. 1381 1382 Examples: 1383 scan(lambda a, e: a + e, [1.0, 2.0, 3.0], 1.0) => [2.0, 4.0, 7.0] 1384 1385 Multiple accumulators: 1386 scan(lambda a, e: (a[0] + e, a[1] * e), [1.0, 2.0, 3.0], (0.0, 1.0)) 1387 1388 Multiple inputs: 1389 scan(lambda a, e: a + (e[0] * e[1]), (elems1, elems2), 0.0) 1390 1391 Args: 1392 fn: callable, fn(accumulators, element) return new accumulator values. The 1393 (possibly nested) sequence of accumulators is the same as `initial` and 1394 the return value must have the same structure. 1395 elems: A (possibly nested) tensor which will be unpacked along the first 1396 dimension. The resulting slices will be the second argument to fn. The 1397 first dimension of all nested input tensors must be the same. 1398 initial: A tensor or (possibly nested) sequence of tensors with initial 1399 values for the accumulators. 1400 reverse: (optional) True enables scan and output elems in reverse order. 1401 inclusive: (optional) True includes the initial accumulator values in the 1402 output. Length of output will be len(elem sequence) + 1. Not meaningful if 1403 final_only is True. 1404 final_only: (optional) When True, return only the final accumulated values, 1405 not the concatenation of accumulated values for each input. 1406 1407 Returns: 1408 A (possibly nested) sequence of tensors with the results of applying fn 1409 to tensors unpacked from elems and previous accumulator values. 1410 """ 1411 1412 flat_elems = [ops.convert_to_tensor(x) for x in nest.flatten(elems)] 1413 num_elems = array_ops.shape(flat_elems[0])[0] 1414 pack_elems = lambda x: nest.pack_sequence_as(structure=elems, flat_sequence=x) 1415 flat_initial = [ops.convert_to_tensor(x) for x in nest.flatten(initial)] 1416 pack = lambda x: nest.pack_sequence_as(structure=initial, flat_sequence=x) 1417 accum_dtypes = [x.dtype for x in flat_initial] 1418 num_accums = len(flat_initial) 1419 1420 # Types for counter, [outputs], [accumulators] loop arguments. 1421 if final_only: 1422 loop_dtypes = [dtypes.int32, dtypes.int32] + accum_dtypes 1423 else: 1424 loop_dtypes = [dtypes.int32, dtypes.int32] + accum_dtypes + accum_dtypes 1425 1426 # TODO(tombagby): Update to tfe.defun 1427 def cond(i, num_elems, *args): 1428 del args 1429 return i >= 0 if reverse else i < num_elems 1430 1431 # The loop *args are [output tensors] + [accumulator tensors] which must 1432 # be paired. Each output corresponds to one accumulator. 1433 def body(i, num_elems, *args): 1434 """Loop body.""" 1435 i.set_shape([]) 1436 if final_only: 1437 accum = args 1438 else: 1439 out, accum = args[:num_accums], args[num_accums:] 1440 slices = [array_ops.gather(e, i) for e in flat_elems] 1441 accum = fn(pack(accum), pack_elems(slices)) 1442 flat_accum = nest.flatten(accum) 1443 if final_only: 1444 new_out = [] 1445 else: 1446 update_i = i + 1 if inclusive and not reverse else i 1447 new_out = [ 1448 inplace_ops.alias_inplace_update(x, update_i, y) 1449 for x, y in zip(out, flat_accum) 1450 ] 1451 i = i - 1 if reverse else i + 1 1452 return [i, num_elems] + new_out + flat_accum 1453 1454 init_i = ( 1455 array_ops.shape(flat_elems[0])[0] - 1456 1 if reverse else constant_op.constant(0, dtype=dtypes.int32)) 1457 outputs = [] 1458 if not final_only: 1459 num_outputs = array_ops.shape(flat_elems[0])[0] + (1 if inclusive else 0) 1460 for initial_accum in flat_initial: 1461 out_shape = array_ops.concat( 1462 [[num_outputs], array_ops.shape(initial_accum)], 0) 1463 out = inplace_ops.empty(out_shape, dtype=initial_accum.dtype, init=True) 1464 if inclusive: 1465 out = inplace_ops.alias_inplace_add(out, init_i + (1 if reverse else 0), 1466 initial_accum) 1467 outputs.append(out) 1468 loop_in = [init_i, num_elems] + outputs + flat_initial 1469 hostmem = [ 1470 i for i, x in enumerate(loop_in) 1471 if x.dtype.base_dtype in (dtypes.int32, dtypes.int64) 1472 ] 1473 1474 if context.executing_eagerly(): 1475 loop_results = loop_in 1476 while cond(*loop_results): 1477 loop_results = body(*loop_results) 1478 else: 1479 # TODO(tombagby): Update to while_v2. 1480 cond = function.Defun(*loop_dtypes)(cond) 1481 body = function.Defun(*loop_dtypes)(body) 1482 loop_results = functional_ops.While(loop_in, cond, body, hostmem=hostmem) 1483 out = loop_results[2:num_accums + 2] 1484 return pack(out) 1485 1486 1487def _get_dim(tensor, i): 1488 """Get value of tensor shape[i] preferring static value if available.""" 1489 return tensor_shape.dimension_value( 1490 tensor.shape[i]) or array_ops.shape(tensor)[i] 1491