| # Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """CTC (Connectionist Temporal Classification) Operations.""" |
| |
| import uuid |
| |
| from tensorflow.python.eager import context |
| from tensorflow.python.eager import def_function |
| |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import device |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import function |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import sparse_tensor |
| from tensorflow.python.framework import tensor_shape |
| |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import array_ops_stack |
| from tensorflow.python.ops import custom_gradient |
| from tensorflow.python.ops import functional_ops |
| from tensorflow.python.ops import gen_ctc_ops |
| from tensorflow.python.ops import inplace_ops |
| from tensorflow.python.ops import linalg_ops |
| from tensorflow.python.ops import map_fn |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import nn_ops |
| from tensorflow.python.ops import sparse_ops |
| from tensorflow.python.ops.nn_grad import _BroadcastMul |
| from tensorflow.python.util import deprecation |
| from tensorflow.python.util import dispatch |
| from tensorflow.python.util import nest |
| from tensorflow.python.util.tf_export import tf_export |
| |
| _DEFUN_API_NAME_ATTRIBUTE = "api_implements" |
| _DEFUN_DEVICE_ATTRIBUTE = "api_preferred_device" |
| _CPU_DEVICE_NAME = "CPU" |
| _GPU_DEVICE_NAME = "GPU" |
| |
| |
| def _get_context_device_type(): |
| """Parses the current context and returns the device type, eg CPU/GPU.""" |
| current_device = context.context().device_name |
| if current_device is None: |
| return None |
| return device.DeviceSpec.from_string(current_device).device_type |
| |
| |
| def _generate_defun_backend(unique_api_name, preferred_device, func): |
| function_attributes = { |
| _DEFUN_API_NAME_ATTRIBUTE: unique_api_name, |
| _DEFUN_DEVICE_ATTRIBUTE: preferred_device, |
| } |
| return def_function.function( |
| func=func, experimental_attributes=function_attributes, autograph=False) |
| |
| # pylint: disable=protected-access, invalid-name |
| @tf_export(v1=["nn.ctc_loss"]) |
| @dispatch.add_dispatch_support |
| def ctc_loss(labels, |
| inputs=None, |
| sequence_length=None, |
| preprocess_collapse_repeated=False, |
| ctc_merge_repeated=True, |
| ignore_longer_outputs_than_inputs=False, |
| time_major=True, |
| logits=None): |
| """Computes the CTC (Connectionist Temporal Classification) Loss. |
| |
| This op implements the CTC loss as presented in (Graves et al., 2006). |
| |
| Input requirements: |
| |
| ``` |
| sequence_length(b) <= time for all b |
| |
| max(labels.indices(labels.indices[:, 1] == b, 2)) |
| <= sequence_length(b) for all b. |
| ``` |
| |
| Notes: |
| |
| This class performs the softmax operation for you, so inputs should |
| be e.g. linear projections of outputs by an LSTM. |
| |
| The `inputs` Tensor's innermost dimension size, `num_classes`, represents |
| `num_labels + 1` classes, where num_labels is the number of true labels, and |
| the largest value `(num_classes - 1)` is reserved for the blank label. |
| |
| For example, for a vocabulary containing 3 labels `[a, b, c]`, |
| `num_classes = 4` and the labels indexing is `{a: 0, b: 1, c: 2, blank: 3}`. |
| |
| Regarding the arguments `preprocess_collapse_repeated` and |
| `ctc_merge_repeated`: |
| |
| If `preprocess_collapse_repeated` is True, then a preprocessing step runs |
| before loss calculation, wherein repeated labels passed to the loss |
| are merged into single labels. This is useful if the training labels come |
| from, e.g., forced alignments and therefore have unnecessary repetitions. |
| |
| If `ctc_merge_repeated` is set False, then deep within the CTC calculation, |
| repeated non-blank labels will not be merged and are interpreted |
| as individual labels. This is a simplified (non-standard) version of CTC. |
| |
| Here is a table of the (roughly) expected first order behavior: |
| |
| * `preprocess_collapse_repeated=False`, `ctc_merge_repeated=True` |
| |
| Classical CTC behavior: Outputs true repeated classes with blanks in |
| between, and can also output repeated classes with no blanks in |
| between that need to be collapsed by the decoder. |
| |
| * `preprocess_collapse_repeated=True`, `ctc_merge_repeated=False` |
| |
| Never learns to output repeated classes, as they are collapsed |
| in the input labels before training. |
| |
| * `preprocess_collapse_repeated=False`, `ctc_merge_repeated=False` |
| |
| Outputs repeated classes with blanks in between, but generally does not |
| require the decoder to collapse/merge repeated classes. |
| |
| * `preprocess_collapse_repeated=True`, `ctc_merge_repeated=True` |
| |
| Untested. Very likely will not learn to output repeated classes. |
| |
| The `ignore_longer_outputs_than_inputs` option allows to specify the behavior |
| of the CTCLoss when dealing with sequences that have longer outputs than |
| inputs. If true, the CTCLoss will simply return zero gradient for those |
| items, otherwise an InvalidArgument error is returned, stopping training. |
| |
| Args: |
| labels: An `int32` `SparseTensor`. |
| `labels.indices[i, :] == [b, t]` means `labels.values[i]` stores the id |
| for (batch b, time t). `labels.values[i]` must take on values in `[0, |
| num_labels)`. See `core/ops/ctc_ops.cc` for more details. |
| inputs: 3-D `float` `Tensor`. |
| If time_major == False, this will be a `Tensor` shaped: `[batch_size, |
| max_time, num_classes]`. |
| If time_major == True (default), this will be a `Tensor` shaped: |
| `[max_time, batch_size, num_classes]`. The logits. |
| sequence_length: 1-D `int32` vector, size `[batch_size]`. The sequence |
| lengths. |
| preprocess_collapse_repeated: Boolean. Default: False. If True, repeated |
| labels are collapsed prior to the CTC calculation. |
| ctc_merge_repeated: Boolean. Default: True. |
| ignore_longer_outputs_than_inputs: Boolean. Default: False. If True, |
| sequences with longer outputs than inputs will be ignored. |
| time_major: The shape format of the `inputs` Tensors. If True, these |
| `Tensors` must be shaped `[max_time, batch_size, num_classes]`. If False, |
| these `Tensors` must be shaped `[batch_size, max_time, num_classes]`. |
| Using `time_major = True` (default) is a bit more efficient because it |
| avoids transposes at the beginning of the ctc_loss calculation. However, |
| most TensorFlow data is batch-major, so by this function also accepts |
| inputs in batch-major form. |
| logits: Alias for inputs. |
| |
| Returns: |
| A 1-D `float` `Tensor`, size `[batch]`, containing the negative log |
| probabilities. |
| |
| Raises: |
| TypeError: if labels is not a `SparseTensor`. |
| |
| References: |
| Connectionist Temporal Classification - Labeling Unsegmented Sequence Data |
| with Recurrent Neural Networks: |
| [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891) |
| ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf)) |
| """ |
| return _ctc_loss_impl( |
| labels, |
| inputs, |
| sequence_length, |
| preprocess_collapse_repeated, |
| ctc_merge_repeated, |
| ignore_longer_outputs_than_inputs, |
| time_major, |
| logits, |
| use_cudnn=False) |
| |
| |
| def _ctc_loss_impl(labels, |
| inputs=None, |
| sequence_length=None, |
| preprocess_collapse_repeated=False, |
| ctc_merge_repeated=True, |
| ignore_longer_outputs_than_inputs=False, |
| time_major=True, |
| logits=None, |
| use_cudnn=False): |
| # Helper function of ctc_loss with one additional param: |
| # use_cudnn: A bool to enable cuDNN CTC loss operation. If true, the blank |
| # index has to be 0. |
| |
| # The second, third, etc output tensors contain the gradients. We use it in |
| # _CTCLossGrad() below. |
| if not isinstance(labels, sparse_tensor.SparseTensor): |
| raise TypeError("Expected argument `labels` to be a SparseTensor. " |
| f"Received labels={labels} of type: " |
| f"{type(labels).__name__}") |
| |
| # For internal calculations, we transpose to [time, batch, num_classes] |
| inputs = deprecation.deprecated_argument_lookup("logits", logits, "inputs", |
| inputs) |
| |
| inputs = ops.convert_to_tensor(inputs, name="logits") |
| if not time_major: |
| inputs = array_ops.transpose(inputs, [1, 0, 2]) # (B,T,N) => (T,B,N) |
| |
| orig_dtype = inputs.dtype |
| if orig_dtype in (dtypes.float16, dtypes.bfloat16): |
| inputs = math_ops.cast(inputs, dtypes.float32) |
| |
| # gen_ctc_ops.ctc_loss_v2 differs from gen_ctc_ops.ctc_loss. v2 assumes the |
| # blank index to be 0, but v1 views it as the last index. |
| if use_cudnn: |
| ctc_loss_func = gen_ctc_ops.ctc_loss_v2 |
| else: |
| ctc_loss_func = gen_ctc_ops.ctc_loss |
| |
| loss, _ = ctc_loss_func( |
| inputs, |
| labels.indices, |
| labels.values, |
| sequence_length, |
| preprocess_collapse_repeated=preprocess_collapse_repeated, |
| ctc_merge_repeated=ctc_merge_repeated, |
| ignore_longer_outputs_than_inputs=ignore_longer_outputs_than_inputs) |
| |
| if orig_dtype in (dtypes.float16, dtypes.bfloat16): |
| loss = math_ops.cast(loss, orig_dtype) |
| |
| return loss |
| |
| # pylint: disable=unused-argument |
| def _CTCLossGradImpl(op, grad_loss, _): |
| # Outputs are: loss, grad |
| # |
| # Currently there is no way to take the second derivative of this op |
| # due to the fused implementation's interaction with tf.gradients(), |
| # so we make sure we prevent silently incorrect results by raising |
| # an error if the second derivative is requested via prevent_gradient. |
| grad_without_gradient = array_ops.prevent_gradient( |
| op.outputs[1], |
| message="Currently there is no way to take the second " |
| " derivative of ctc_loss due to the fused implementation's interaction " |
| " with tf.gradients()") |
| # Return gradient for inputs and None for |
| # labels_indices, labels_values and sequence_length |
| return [_BroadcastMul(grad_loss, grad_without_gradient), None, None, None] |
| |
| |
| # pylint: disable=unused-argument |
| @ops.RegisterGradient("CTCLoss") |
| def _CTCLossGrad(op, grad_loss, _): |
| """The derivative provided by CTC Loss. |
| |
| Args: |
| op: the CTCLoss op. |
| grad_loss: The backprop for cost. |
| |
| Returns: |
| The CTC Loss gradient. |
| """ |
| return _CTCLossGradImpl(op, grad_loss, _) |
| |
| |
| # pylint: disable=unused-argument |
| @ops.RegisterGradient("CTCLossV2") |
| def _CTCLossV2Grad(op, grad_loss, _): |
| """The derivative provided by CTC Loss V2. |
| |
| Args: |
| op: the CTCLossV2 op. |
| grad_loss: The backprop for cost. |
| |
| Returns: |
| The CTC Loss V2 gradient. |
| """ |
| return _CTCLossGradImpl(op, grad_loss, _) |
| |
| |
| @tf_export("nn.ctc_greedy_decoder") |
| @dispatch.add_dispatch_support |
| def ctc_greedy_decoder(inputs, |
| sequence_length, |
| merge_repeated=True, |
| blank_index=None): |
| """Performs greedy decoding on the logits given in input (best path). |
| |
| Given a tensor as `inputs`, the `blank_index` parameter defines the class |
| index of the blank symbol. |
| |
| For example: |
| |
| If `blank_index` is equal to 1: |
| |
| >>> inf = float("inf") |
| >>> logits = tf.constant([[[ 0., -inf, -inf], |
| ... [ -2.3, -inf, -0.1]], |
| ... [[ -inf, -0.5, -inf], |
| ... [ -inf, -inf, -0.1]], |
| ... [[ -inf, -inf, -inf], |
| ... [ -0.1, -inf, -2.3]]]) |
| >>> seq_lens = tf.constant([2, 3]) |
| >>> outputs = tf.nn.ctc_greedy_decoder( |
| ... logits, |
| ... seq_lens, |
| ... blank_index=1) |
| |
| Notes: |
| |
| - Unlike `ctc_beam_search_decoder`, `ctc_greedy_decoder` considers blanks |
| as regular elements when computing the probability of a sequence. |
| - Default `blank_index` is `(num_classes - 1)`, unless overriden. |
| |
| If `merge_repeated` is `True`, merge repeated classes in output. |
| This means that if consecutive logits' maximum indices are the same, |
| only the first of these is emitted. The sequence `A B B * B * B` (where '*' |
| is the blank label) becomes |
| |
| * `A B B B` if `merge_repeated=True`. |
| * `A B B B B` if `merge_repeated=False`. |
| |
| Args: |
| inputs: 3-D `float` `Tensor` sized `[max_time, batch_size, num_classes]`. |
| The logits. |
| sequence_length: 1-D `int32` vector containing sequence lengths, having size |
| `[batch_size]`. |
| merge_repeated: Boolean. Default: True. |
| blank_index: (Optional). Default: `num_classes - 1`. Define the class index |
| to use for the blank label. Negative values will start from num_classes, |
| ie, -1 will reproduce the ctc_greedy_decoder behavior of using |
| num_classes - 1 for the blank symbol, which corresponds to the default. |
| |
| Returns: |
| A tuple `(decoded, neg_sum_logits)` where |
| |
| decoded: A single-element list. `decoded[0]` |
| is an `SparseTensor` containing the decoded outputs s.t.: |
| |
| `decoded.indices`: Indices matrix `(total_decoded_outputs, 2)`. |
| The rows store: `[batch, time]`. |
| |
| `decoded.values`: Values vector, size `(total_decoded_outputs)`. |
| The vector stores the decoded classes. |
| |
| `decoded.dense_shape`: Shape vector, size `(2)`. |
| The shape values are: `[batch_size, max_decoded_length]` |
| |
| neg_sum_logits: A `float` matrix `(batch_size x 1)` containing, for the |
| sequence found, the negative of the sum of the greatest logit at each |
| timeframe. |
| """ |
| |
| outputs = gen_ctc_ops.ctc_greedy_decoder( |
| inputs, |
| sequence_length, |
| merge_repeated=merge_repeated, |
| blank_index=blank_index) |
| (decoded_ix, decoded_val, decoded_shape, log_probabilities) = outputs |
| return ([sparse_tensor.SparseTensor(decoded_ix, decoded_val, |
| decoded_shape)], log_probabilities) |
| |
| |
| @tf_export(v1=["nn.ctc_beam_search_decoder"]) |
| @dispatch.add_dispatch_support |
| def ctc_beam_search_decoder(inputs, |
| sequence_length, |
| beam_width=100, |
| top_paths=1, |
| merge_repeated=True): |
| """Performs beam search decoding on the logits given in input. |
| |
| **Note** Although in general greedy search is a special case of beam-search |
| with `top_paths=1` and `beam_width=1`, `ctc_beam_search_decoder` differs |
| from `ctc_greedy_decoder` in the treatment of blanks when computing the |
| probability of a sequence: |
| - `ctc_beam_search_decoder` treats blanks as sequence termination |
| - `ctc_greedy_decoder` treats blanks as regular elements |
| |
| If `merge_repeated` is `True`, merge repeated classes in the output beams. |
| This means that if consecutive entries in a beam are the same, |
| only the first of these is emitted. That is, when the sequence is |
| `A B B * B * B` (where '*' is the blank label), the return value is: |
| |
| * `A B` if `merge_repeated = True`. |
| * `A B B B` if `merge_repeated = False`. |
| |
| Args: |
| inputs: 3-D `float` `Tensor`, size `[max_time x batch_size x num_classes]`. |
| The logits. |
| sequence_length: 1-D `int32` vector containing sequence lengths, having size |
| `[batch_size]`. |
| beam_width: An int scalar >= 0 (beam search beam width). |
| top_paths: An int scalar >= 0, <= beam_width (controls output size). |
| merge_repeated: Boolean. Default: True. |
| |
| Returns: |
| A tuple `(decoded, log_probabilities)` where |
| |
| decoded: A list of length top_paths, where `decoded[j]` |
| is a `SparseTensor` containing the decoded outputs: |
| |
| `decoded[j].indices`: Indices matrix `(total_decoded_outputs[j] x 2)` |
| The rows store: [batch, time]. |
| |
| `decoded[j].values`: Values vector, size `(total_decoded_outputs[j])`. |
| The vector stores the decoded classes for beam j. |
| |
| `decoded[j].dense_shape`: Shape vector, size `(2)`. |
| The shape values are: `[batch_size, max_decoded_length[j]]`. |
| |
| log_probability: A `float` matrix `(batch_size x top_paths)` containing |
| sequence log-probabilities. |
| """ |
| |
| decoded_ixs, decoded_vals, decoded_shapes, log_probabilities = ( |
| gen_ctc_ops.ctc_beam_search_decoder( |
| inputs, |
| sequence_length, |
| beam_width=beam_width, |
| top_paths=top_paths, |
| merge_repeated=merge_repeated)) |
| |
| return ([ |
| sparse_tensor.SparseTensor(ix, val, shape) |
| for (ix, val, shape) in zip(decoded_ixs, decoded_vals, decoded_shapes) |
| ], log_probabilities) |
| |
| |
| @tf_export("nn.ctc_beam_search_decoder", v1=["nn.ctc_beam_search_decoder_v2"]) |
| @dispatch.add_dispatch_support |
| def ctc_beam_search_decoder_v2(inputs, |
| sequence_length, |
| beam_width=100, |
| top_paths=1): |
| """Performs beam search decoding on the logits given in input. |
| |
| **Note** Although in general greedy search is a special case of beam-search |
| with `top_paths=1` and `beam_width=1`, `ctc_beam_search_decoder` differs |
| from `ctc_greedy_decoder` in the treatment of blanks when computing the |
| probability of a sequence: |
| - `ctc_beam_search_decoder` treats blanks as sequence termination |
| - `ctc_greedy_decoder` treats blanks as regular elements |
| |
| Args: |
| inputs: 3-D `float` `Tensor`, size `[max_time, batch_size, num_classes]`. |
| The logits. |
| sequence_length: 1-D `int32` vector containing sequence lengths, having size |
| `[batch_size]`. |
| beam_width: An int scalar >= 0 (beam search beam width). |
| top_paths: An int scalar >= 0, <= beam_width (controls output size). |
| |
| Returns: |
| A tuple `(decoded, log_probabilities)` where |
| |
| decoded: A list of length top_paths, where `decoded[j]` |
| is a `SparseTensor` containing the decoded outputs: |
| |
| `decoded[j].indices`: Indices matrix `[total_decoded_outputs[j], 2]`; |
| The rows store: `[batch, time]`. |
| |
| `decoded[j].values`: Values vector, size `[total_decoded_outputs[j]]`. |
| The vector stores the decoded classes for beam `j`. |
| |
| `decoded[j].dense_shape`: Shape vector, size `(2)`. |
| The shape values are: `[batch_size, max_decoded_length[j]]`. |
| |
| log_probability: A `float` matrix `[batch_size, top_paths]` containing |
| sequence log-probabilities. |
| """ |
| |
| # Note, merge_repeated is an invalid optimization that is removed from the |
| # public API: it returns low probability paths. |
| return ctc_beam_search_decoder( |
| inputs, |
| sequence_length=sequence_length, |
| beam_width=beam_width, |
| top_paths=top_paths, |
| merge_repeated=False) |
| |
| |
| ops.NotDifferentiable("CTCGreedyDecoder") |
| ops.NotDifferentiable("CTCBeamSearchDecoder") |
| |
| |
| def _ctc_state_trans(label_seq): |
| """Computes CTC alignment model transition matrix. |
| |
| Args: |
| label_seq: tensor of shape [batch_size, max_seq_length] |
| |
| Returns: |
| tensor of shape [batch_size, states, states] with a state transition matrix |
| computed for each sequence of the batch. |
| """ |
| |
| with ops.name_scope("ctc_state_trans"): |
| label_seq = ops.convert_to_tensor(label_seq, name="label_seq") |
| batch_size = _get_dim(label_seq, 0) |
| num_labels = _get_dim(label_seq, 1) |
| |
| num_label_states = num_labels + 1 |
| num_states = 2 * num_label_states |
| |
| label_states = math_ops.range(num_label_states) |
| blank_states = label_states + num_label_states |
| |
| # Start state to first label. |
| start_to_label = [[1, 0]] |
| |
| # Blank to label transitions. |
| blank_to_label = array_ops_stack.stack( |
| [label_states[1:], blank_states[:-1]], 1) |
| |
| # Label to blank transitions. |
| label_to_blank = array_ops_stack.stack([blank_states, label_states], 1) |
| |
| # Scatter transitions that don't depend on sequence. |
| indices = array_ops.concat([start_to_label, blank_to_label, label_to_blank], |
| 0) |
| values = array_ops.ones([_get_dim(indices, 0)]) |
| trans = array_ops.scatter_nd( |
| indices, values, shape=[num_states, num_states]) |
| trans += linalg_ops.eye(num_states) # Self-loops. |
| |
| # Label to label transitions. Disallow transitions between repeated labels |
| # with no blank state in between. |
| batch_idx = array_ops.zeros_like(label_states[2:]) |
| indices = array_ops_stack.stack( |
| [batch_idx, label_states[2:], label_states[1:-1]], 1) |
| indices = array_ops.tile( |
| array_ops.expand_dims(indices, 0), [batch_size, 1, 1]) |
| batch_idx = array_ops.expand_dims(math_ops.range(batch_size), 1) * [1, 0, 0] |
| indices += array_ops.expand_dims(batch_idx, 1) |
| repeats = math_ops.equal(label_seq[:, :-1], label_seq[:, 1:]) |
| values = 1.0 - math_ops.cast(repeats, dtypes.float32) |
| batched_shape = [batch_size, num_states, num_states] |
| label_to_label = array_ops.scatter_nd(indices, values, batched_shape) |
| |
| return array_ops.expand_dims(trans, 0) + label_to_label |
| |
| |
| def ctc_state_log_probs(seq_lengths, max_seq_length): |
| """Computes CTC alignment initial and final state log probabilities. |
| |
| Create the initial/final state values directly as log values to avoid |
| having to take a float64 log on tpu (which does not exist). |
| |
| Args: |
| seq_lengths: int tensor of shape [batch_size], seq lengths in the batch. |
| max_seq_length: int, max sequence length possible. |
| |
| Returns: |
| initial_state_log_probs, final_state_log_probs |
| """ |
| |
| batch_size = _get_dim(seq_lengths, 0) |
| num_label_states = max_seq_length + 1 |
| num_duration_states = 2 |
| num_states = num_duration_states * num_label_states |
| log_0 = math_ops.cast( |
| math_ops.log(math_ops.cast(0, dtypes.float64) + 1e-307), dtypes.float32) |
| |
| initial_state_log_probs = array_ops.one_hot( |
| indices=array_ops.zeros([batch_size], dtype=dtypes.int32), |
| depth=num_states, |
| on_value=0.0, |
| off_value=log_0, |
| axis=1) |
| |
| label_final_state_mask = array_ops.one_hot( |
| seq_lengths, depth=num_label_states, axis=0) |
| duration_final_state_mask = array_ops.ones( |
| [num_duration_states, 1, batch_size]) |
| final_state_mask = duration_final_state_mask * label_final_state_mask |
| final_state_log_probs = (1.0 - final_state_mask) * log_0 |
| final_state_log_probs = array_ops.reshape(final_state_log_probs, |
| [num_states, batch_size]) |
| |
| return initial_state_log_probs, array_ops.transpose(final_state_log_probs) |
| |
| |
| def _ilabel_to_state(labels, num_labels, ilabel_log_probs): |
| """Project ilabel log probs to state log probs.""" |
| |
| num_label_states = _get_dim(labels, 1) |
| blank = ilabel_log_probs[:, :, :1] |
| blank = array_ops.tile(blank, [1, 1, num_label_states + 1]) |
| one_hot = array_ops.one_hot(labels, depth=num_labels) |
| one_hot = array_ops.expand_dims(one_hot, axis=0) |
| ilabel_log_probs = array_ops.expand_dims(ilabel_log_probs, axis=2) |
| state_log_probs = math_ops.reduce_sum(ilabel_log_probs * one_hot, axis=3) |
| state_log_probs = array_ops.concat([state_log_probs, blank], axis=2) |
| return array_ops.pad( |
| state_log_probs, [[0, 0], [0, 0], [1, 0]], |
| constant_values=math_ops.log(0.0)) |
| |
| |
| def _state_to_olabel(labels, num_labels, states): |
| """Sum state log probs to ilabel log probs.""" |
| |
| num_label_states = _get_dim(labels, 1) + 1 |
| label_states = states[:, :, 1:num_label_states] |
| blank_states = states[:, :, num_label_states:] |
| one_hot = array_ops.one_hot( |
| labels - 1, |
| depth=(num_labels - 1), |
| on_value=0.0, |
| off_value=math_ops.log(0.0)) |
| one_hot = array_ops.expand_dims(one_hot, axis=0) |
| label_states = array_ops.expand_dims(label_states, axis=3) |
| label_olabels = math_ops.reduce_logsumexp(label_states + one_hot, axis=2) |
| blank_olabels = math_ops.reduce_logsumexp(blank_states, axis=2, keepdims=True) |
| return array_ops.concat([blank_olabels, label_olabels], axis=-1) |
| |
| |
| # pylint: disable=redefined-outer-name |
| def _state_to_olabel_unique(labels, num_labels, states, unique): |
| """Sum state log probs to ilabel log probs using unique label indices.""" |
| |
| num_label_states = _get_dim(labels, 1) + 1 |
| label_states = states[:, :, 1:num_label_states] |
| blank_states = states[:, :, num_label_states:] |
| |
| unique_y, unique_idx = unique |
| mul_reduce = _sum_states(unique_idx, label_states) |
| |
| num_frames = _get_dim(states, 0) |
| batch_size = _get_dim(states, 1) |
| num_states = num_label_states - 1 |
| batch_state_major = array_ops.transpose(mul_reduce, perm=[1, 2, 0]) |
| batch_state_major = array_ops.reshape(batch_state_major, |
| [batch_size * num_states, num_frames]) |
| batch_offset = math_ops.range(batch_size, dtype=unique_y.dtype) * num_labels |
| indices = unique_y + array_ops.expand_dims(batch_offset, axis=-1) |
| indices = array_ops.reshape(indices, [-1, 1]) |
| scatter = array_ops.scatter_nd( |
| indices=indices, |
| updates=batch_state_major, |
| shape=[batch_size * num_labels, num_frames]) |
| scatter = array_ops.reshape(scatter, [batch_size, num_labels, num_frames]) |
| |
| mask = array_ops.ones_like(batch_state_major, dtype=dtypes.bool) |
| mask = array_ops.scatter_nd( |
| indices=indices, |
| updates=mask, |
| shape=[batch_size * num_labels, num_frames]) |
| mask = array_ops.reshape(mask, [batch_size, num_labels, num_frames]) |
| |
| scatter = array_ops.where( |
| mask, scatter, |
| array_ops.fill(array_ops.shape(scatter), math_ops.log(0.0))) |
| |
| label_olabels = array_ops.transpose(scatter, [2, 0, 1]) |
| label_olabels = label_olabels[:, :, 1:] |
| |
| blank_olabels = math_ops.reduce_logsumexp(blank_states, axis=2, keepdims=True) |
| |
| return array_ops.concat([blank_olabels, label_olabels], axis=-1) |
| |
| |
| def ctc_loss_and_grad(logits, labels, label_length, logit_length, unique=None): |
| """Computes the CTC loss and gradients. |
| |
| Most users will want fwd_bwd.ctc_loss |
| |
| This function returns the computed gradient, it does not have a gradient |
| of its own defined. |
| |
| Args: |
| logits: tensor of shape [frames, batch_size, num_labels] |
| labels: tensor of shape [batch_size, max_label_seq_length] |
| label_length: tensor of shape [batch_size] Length of reference label |
| sequence in labels. |
| logit_length: tensor of shape [batch_size] Length of input sequence in |
| logits. |
| unique: (optional) unique label indices as computed by unique(labels) If |
| supplied, enables an implementation that is faster and more memory |
| efficient on TPU. |
| |
| Returns: |
| loss: tensor of shape [batch_size] |
| gradient: tensor of shape [frames, batch_size, num_labels] |
| """ |
| |
| num_labels = _get_dim(logits, 2) |
| max_label_seq_length = _get_dim(labels, 1) |
| |
| ilabel_log_probs = nn_ops.log_softmax(logits) |
| state_log_probs = _ilabel_to_state(labels, num_labels, ilabel_log_probs) |
| state_trans_probs = _ctc_state_trans(labels) |
| initial_state_log_probs, final_state_log_probs = ctc_state_log_probs( |
| label_length, max_label_seq_length) |
| fwd_bwd_log_probs, log_likelihood = _forward_backward_log( |
| state_trans_log_probs=math_ops.log(state_trans_probs), |
| initial_state_log_probs=initial_state_log_probs, |
| final_state_log_probs=final_state_log_probs, |
| observed_log_probs=state_log_probs, |
| sequence_length=logit_length) |
| |
| if unique: |
| olabel_log_probs = _state_to_olabel_unique(labels, num_labels, |
| fwd_bwd_log_probs, unique) |
| else: |
| olabel_log_probs = _state_to_olabel(labels, num_labels, fwd_bwd_log_probs) |
| |
| grad = math_ops.exp(ilabel_log_probs) - math_ops.exp(olabel_log_probs) |
| |
| # Applies the sequence mask for the gradient. It is enough to appply the mask |
| # only for ilabel_log_probs because olabel_log_probs already consider the |
| # mask. However, it is just safe and clean to apply it for the gradient. |
| max_logit_length = _get_dim(logits, 0) |
| logit_mask = array_ops.sequence_mask(logit_length, max_logit_length, |
| dtypes.float32) |
| logit_mask = array_ops.transpose(logit_mask, perm=[1, 0]) |
| logit_mask = array_ops.expand_dims(logit_mask, axis=2) |
| grad *= logit_mask |
| |
| loss = -log_likelihood |
| return loss, grad |
| |
| |
| def _ctc_loss_grad(op, grad_loss, _): |
| grad = op.outputs[1] |
| grad = [array_ops.reshape(grad_loss, [1, -1, 1]) * grad] |
| grad += [None] * (len(op.inputs) - len(grad)) |
| return grad |
| |
| |
| def _ctc_loss_op_standard(labels, logits, logit_length, logits_time_major, |
| blank_index): |
| part_before = logits[:, :, :blank_index] |
| part_after = logits[:, :, blank_index + 1:] |
| part_blank = logits[:, :, blank_index:blank_index + 1] |
| logits = array_ops.concat([part_before, part_after, part_blank], axis=2) |
| labels = sparse_tensor.SparseTensor( |
| labels.indices, |
| array_ops.where(labels.values < blank_index, labels.values, |
| labels.values - 1), labels.dense_shape) |
| return _ctc_loss_impl( |
| labels=labels, |
| inputs=logits, |
| sequence_length=logit_length, |
| time_major=logits_time_major, |
| use_cudnn=False) |
| |
| |
| def _ctc_loss_op_cudnn(labels, logits, logit_length, logits_time_major, |
| blank_index): |
| part_before = logits[:, :, :blank_index] |
| part_after = logits[:, :, blank_index + 1:] |
| part_blank = logits[:, :, blank_index:blank_index + 1] |
| logits = array_ops.concat([part_blank, part_before, part_after], axis=2) |
| labels = sparse_tensor.SparseTensor( |
| labels.indices, |
| array_ops.where(labels.values < blank_index, labels.values + 1, |
| labels.values), labels.dense_shape) |
| return _ctc_loss_impl( |
| labels=labels, |
| inputs=logits, |
| sequence_length=logit_length, |
| time_major=logits_time_major, |
| use_cudnn=True) |
| |
| |
| def _ctc_loss_shape(op): |
| return [op.inputs[2].get_shape(), op.inputs[0].get_shape()] |
| |
| |
| # pylint: disable=protected-access, invalid-name |
| @tf_export(v1=["nn.ctc_loss_v2"]) |
| @dispatch.add_dispatch_support |
| def ctc_loss_v2(labels, |
| logits, |
| label_length, |
| logit_length, |
| logits_time_major=True, |
| unique=None, |
| blank_index=None, |
| name=None): |
| """Computes CTC (Connectionist Temporal Classification) loss. |
| |
| This op implements the CTC loss as presented in (Graves et al., 2006). |
| |
| Notes: |
| |
| - Same as the "Classic CTC" in TensorFlow 1.x's tf.compat.v1.nn.ctc_loss |
| setting of preprocess_collapse_repeated=False, ctc_merge_repeated=True |
| - Labels may be supplied as either a dense, zero-padded tensor with a |
| vector of label sequence lengths OR as a SparseTensor. |
| - On TPU and GPU: Only dense padded labels are supported. |
| - On CPU: Caller may use SparseTensor or dense padded labels but calling with |
| a SparseTensor will be significantly faster. |
| - Default blank label is 0 rather num_classes - 1, unless overridden by |
| blank_index. |
| |
| Args: |
| labels: tensor of shape [batch_size, max_label_seq_length] or SparseTensor |
| logits: tensor of shape [frames, batch_size, num_labels], if |
| logits_time_major == False, shape is [batch_size, frames, num_labels]. |
| label_length: tensor of shape [batch_size], None if labels is SparseTensor |
| Length of reference label sequence in labels. |
| logit_length: tensor of shape [batch_size] Length of input sequence in |
| logits. |
| logits_time_major: (optional) If True (default), logits is shaped [time, |
| batch, logits]. If False, shape is [batch, time, logits] |
| unique: (optional) Unique label indices as computed by |
| ctc_unique_labels(labels). If supplied, enable a faster, memory efficient |
| implementation on TPU. |
| blank_index: (optional) Set the class index to use for the blank label. |
| Negative values will start from num_classes, ie, -1 will reproduce the |
| ctc_loss behavior of using num_classes - 1 for the blank symbol. There is |
| some memory/performance overhead to switching from the default of 0 as an |
| additional shifted copy of the logits may be created. |
| name: A name for this `Op`. Defaults to "ctc_loss_dense". |
| |
| Returns: |
| loss: tensor of shape [batch_size], negative log probabilities. |
| |
| References: |
| Connectionist Temporal Classification - Labeling Unsegmented Sequence Data |
| with Recurrent Neural Networks: |
| [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891) |
| ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf)) |
| """ |
| if isinstance(labels, sparse_tensor.SparseTensor): |
| if blank_index is None: |
| raise ValueError( |
| "Argument `blank_index` must be provided when labels is a " |
| "SparseTensor.") |
| |
| if blank_index < 0: |
| blank_index += _get_dim(logits, 2) |
| |
| if blank_index != _get_dim(logits, 2) - 1: |
| logits = array_ops.concat([ |
| logits[:, :, :blank_index], |
| logits[:, :, blank_index + 1:], |
| logits[:, :, blank_index:blank_index + 1], |
| ], |
| axis=2) |
| labels = sparse_tensor.SparseTensor( |
| labels.indices, |
| array_ops.where(labels.values < blank_index, labels.values, |
| labels.values - 1), labels.dense_shape) |
| |
| return ctc_loss( |
| labels=labels, |
| inputs=logits, |
| sequence_length=logit_length, |
| time_major=logits_time_major) |
| |
| if blank_index is None: |
| blank_index = 0 |
| |
| return ctc_loss_dense( |
| labels=labels, |
| logits=logits, |
| label_length=label_length, |
| logit_length=logit_length, |
| logits_time_major=logits_time_major, |
| unique=unique, |
| blank_index=blank_index, |
| name=name) |
| |
| |
| @tf_export("nn.ctc_loss", v1=[]) |
| @dispatch.add_dispatch_support |
| def ctc_loss_v3(labels, |
| logits, |
| label_length, |
| logit_length, |
| logits_time_major=True, |
| unique=None, |
| blank_index=None, |
| name=None): |
| """Computes CTC (Connectionist Temporal Classification) loss. |
| |
| This op implements the CTC loss as presented in |
| [Graves et al., 2006](https://www.cs.toronto.edu/~graves/icml_2006.pdf) |
| |
| Connectionist temporal classification (CTC) is a type of neural network output |
| and associated scoring function, for training recurrent neural networks (RNNs) |
| such as LSTM networks to tackle sequence problems where the timing is |
| variable. It can be used for tasks like on-line handwriting recognition or |
| recognizing phones in speech audio. CTC refers to the outputs and scoring, and |
| is independent of the underlying neural network structure. |
| |
| Notes: |
| |
| - This class performs the softmax operation for you, so `logits` should be |
| e.g. linear projections of outputs by an LSTM. |
| - Outputs true repeated classes with blanks in between, and can also output |
| repeated classes with no blanks in between that need to be collapsed by the |
| decoder. |
| - `labels` may be supplied as either a dense, zero-padded `Tensor` with a |
| vector of label sequence lengths OR as a `SparseTensor`. |
| - On TPU: Only dense padded `labels` are supported. |
| - On CPU and GPU: Caller may use `SparseTensor` or dense padded `labels` |
| but calling with a `SparseTensor` will be significantly faster. |
| - Default blank label is `0` instead of `num_labels - 1` (where `num_labels` |
| is the innermost dimension size of `logits`), unless overridden by |
| `blank_index`. |
| |
| >>> tf.random.set_seed(50) |
| >>> batch_size = 8 |
| >>> num_labels = 6 |
| >>> max_label_length = 5 |
| >>> num_frames = 12 |
| >>> labels = tf.random.uniform([batch_size, max_label_length], |
| ... minval=1, maxval=num_labels, dtype=tf.int64) |
| >>> logits = tf.random.uniform([num_frames, batch_size, num_labels]) |
| >>> label_length = tf.random.uniform([batch_size], minval=2, |
| ... maxval=max_label_length, dtype=tf.int64) |
| >>> label_mask = tf.sequence_mask(label_length, maxlen=max_label_length, |
| ... dtype=label_length.dtype) |
| >>> labels *= label_mask |
| >>> logit_length = [num_frames] * batch_size |
| >>> with tf.GradientTape() as t: |
| ... t.watch(logits) |
| ... ref_loss = tf.nn.ctc_loss( |
| ... labels=labels, |
| ... logits=logits, |
| ... label_length=label_length, |
| ... logit_length=logit_length, |
| ... blank_index=0) |
| >>> ref_grad = t.gradient(ref_loss, logits) |
| |
| Args: |
| labels: `Tensor` of shape `[batch_size, max_label_seq_length]` or |
| `SparseTensor`. |
| logits: `Tensor` of shape `[frames, batch_size, num_labels]`. If |
| `logits_time_major == False`, shape is `[batch_size, frames, num_labels]`. |
| label_length: `Tensor` of shape `[batch_size]`. None, if `labels` is a |
| `SparseTensor`. Length of reference label sequence in `labels`. |
| logit_length: `Tensor` of shape `[batch_size]`. Length of input sequence in |
| `logits`. |
| logits_time_major: (optional) If True (default), `logits` is shaped [frames, |
| batch_size, num_labels]. If False, shape is |
| `[batch_size, frames, num_labels]`. |
| unique: (optional) Unique label indices as computed by |
| `ctc_unique_labels(labels)`. If supplied, enable a faster, memory |
| efficient implementation on TPU. |
| blank_index: (optional) Set the class index to use for the blank label. |
| Negative values will start from `num_labels`, ie, `-1` will reproduce the |
| ctc_loss behavior of using `num_labels - 1` for the blank symbol. There is |
| some memory/performance overhead to switching from the default of 0 as an |
| additional shifted copy of `logits` may be created. |
| name: A name for this `Op`. Defaults to "ctc_loss_dense". |
| |
| Returns: |
| loss: A 1-D `float Tensor` of shape `[batch_size]`, containing negative log |
| probabilities. |
| |
| Raises: |
| ValueError: Argument `blank_index` must be provided when `labels` is a |
| `SparseTensor`. |
| |
| References: |
| Connectionist Temporal Classification - Labeling Unsegmented Sequence Data |
| with Recurrent Neural Networks: |
| [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891) |
| ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf)) |
| |
| https://en.wikipedia.org/wiki/Connectionist_temporal_classification |
| """ |
| if isinstance(labels, sparse_tensor.SparseTensor): |
| if blank_index is None: |
| raise ValueError( |
| "Argument `blank_index` must be provided when `labels` is a " |
| "`SparseTensor`.") |
| |
| if blank_index < 0: |
| blank_index += _get_dim(logits, 2) |
| |
| logits = ops.convert_to_tensor(logits, name="logits") |
| |
| params = { |
| "labels": labels, |
| "logits": logits, |
| "logit_length": logit_length, |
| "logits_time_major": logits_time_major, |
| "blank_index": blank_index |
| } |
| |
| if context.executing_eagerly(): |
| device_type = _get_context_device_type() |
| can_use_gpu = ( |
| # Either user specified GPU or unspecified but GPU is available. |
| (device_type == _GPU_DEVICE_NAME or |
| (device_type is None and context.num_gpus() > 0))) |
| # Under eager context, check the device placement and prefer the |
| if can_use_gpu: |
| res = _ctc_loss_op_cudnn(**params) |
| else: |
| res = _ctc_loss_op_standard(**params) |
| else: |
| api_name = "ctc_loss_" + str(uuid.uuid4()) |
| ctc_loss_op_standard = _generate_defun_backend(api_name, _CPU_DEVICE_NAME, |
| _ctc_loss_op_standard) |
| ctc_loss_op_cudnn = _generate_defun_backend(api_name, _GPU_DEVICE_NAME, |
| _ctc_loss_op_cudnn) |
| res = ctc_loss_op_standard(**params) |
| concrete_func = ctc_loss_op_cudnn.get_concrete_function(**params) |
| concrete_func.add_to_graph() |
| concrete_func.add_gradient_functions_to_graph() |
| return res |
| |
| if blank_index is None: |
| blank_index = 0 |
| |
| return ctc_loss_dense( |
| labels=labels, |
| logits=logits, |
| label_length=label_length, |
| logit_length=logit_length, |
| logits_time_major=logits_time_major, |
| unique=unique, |
| blank_index=blank_index, |
| name=name) |
| |
| |
| def ctc_loss_dense(labels, |
| logits, |
| label_length, |
| logit_length, |
| logits_time_major=True, |
| unique=None, |
| blank_index=0, |
| name=None): |
| """Computes CTC (Connectionist Temporal Classification) loss. |
| |
| This op implements the CTC loss as presented in (Graves et al., 2006), |
| using the batched forward backward algorithm described in (Sim et al., 2017). |
| |
| Notes: |
| Significant differences from `tf.compat.v1.nn.ctc_loss`: |
| Supports GPU and TPU (`tf.compat.v1.nn.ctc_loss` supports CPU only): |
| For batched operations, GPU and TPU are significantly faster than using |
| `ctc_loss` on CPU. |
| This implementation runs on CPU, but significantly slower than ctc_loss. |
| Blank label is 0 rather num_classes - 1, unless overridden by blank_index. |
| Logits and labels are dense arrays with padding rather than SparseTensor. |
| The only mode supported is the same as: |
| preprocess_collapse_repeated=False, ctc_merge_repeated=True |
| To collapse labels, the caller can preprocess label sequence first. |
| |
| The dense implementation supports both CPU, GPU and TPU. A fast path is |
| provided that significantly improves memory use for large vocabulary if the |
| caller preprocesses label sequences to get unique label indices on the CPU |
| (eg. in the data input pipeline) using ctc_ops.unique and simplifies this in |
| the optional "unique" kwarg. This is especially useful for TPU and GPU but |
| also works with if used on CPU. |
| |
| Args: |
| labels: tensor of shape [batch_size, max_label_seq_length] |
| logits: tensor of shape [frames, batch_size, num_labels], if |
| logits_time_major == False, shape is [batch_size, frames, num_labels]. |
| label_length: tensor of shape [batch_size] Length of reference label |
| sequence in labels. |
| logit_length: tensor of shape [batch_size] Length of input sequence in |
| logits. |
| logits_time_major: (optional) If True (default), logits is shaped [time, |
| batch, logits]. If False, shape is [batch, time, logits] |
| unique: (optional) Unique label indices as computed by unique(labels). If |
| supplied, enable a faster, memory efficient implementation on TPU. |
| blank_index: (optional) Set the class index to use for the blank label. |
| Negative values will start from num_classes, ie, -1 will reproduce the |
| ctc_loss behavior of using num_classes - 1 for the blank symbol. There is |
| some memory/performance overhead to switching from the default of 0 as an |
| additional shifted copy of the logits may be created. |
| name: A name for this `Op`. Defaults to "ctc_loss_dense". |
| |
| Returns: |
| loss: tensor of shape [batch_size], negative log probabilities. |
| |
| References: |
| Connectionist Temporal Classification - Labeling Unsegmented Sequence Data |
| with Recurrent Neural Networks: |
| [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891) |
| ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf)) |
| Improving the efficiency of forward-backward algorithm using batched |
| computation in TensorFlow: |
| [Sim et al., 2017](https://ieeexplore.ieee.org/document/8268944) |
| ([pdf](http://bacchiani.net/resume/papers/ASRU2017.pdf)) |
| """ |
| |
| with ops.name_scope(name, "ctc_loss_dense", |
| [logits, labels, label_length, logit_length]): |
| logits = ops.convert_to_tensor(logits, name="logits") |
| labels = ops.convert_to_tensor(labels, name="labels") |
| label_length = ops.convert_to_tensor(label_length, name="label_length") |
| logit_length = ops.convert_to_tensor(logit_length, name="logit_length") |
| |
| orig_dtype = logits.dtype |
| if orig_dtype in (dtypes.float16, dtypes.bfloat16): |
| logits = math_ops.cast(logits, dtypes.float32) |
| |
| if not logits_time_major: |
| logits = array_ops.transpose(logits, perm=[1, 0, 2]) |
| |
| if blank_index != 0: |
| if blank_index < 0: |
| blank_index += _get_dim(logits, 2) |
| logits = array_ops.concat([ |
| logits[:, :, blank_index:blank_index + 1], |
| logits[:, :, :blank_index], |
| logits[:, :, blank_index + 1:], |
| ], |
| axis=2) |
| labels = array_ops.where(labels < blank_index, labels + 1, labels) |
| |
| args = [logits, labels, label_length, logit_length] |
| |
| if unique: |
| unique_y, unique_idx = unique |
| if blank_index != 0: |
| unique_y = array_ops.where(unique_y < blank_index, unique_y + 1, |
| unique_y) |
| label_mask_len = math_ops.reduce_max(unique_idx, axis=1) + 1 |
| max_label_length = _get_dim(unique_y, 1) |
| label_mask = array_ops.sequence_mask(label_mask_len, max_label_length) |
| unique_y = array_ops.where(label_mask, unique_y, |
| array_ops.zeros_like(unique_y)) |
| args.extend([unique_y, unique_idx]) |
| |
| @custom_gradient.custom_gradient |
| def compute_ctc_loss(logits_t, labels_t, label_length_t, logit_length_t, |
| *unique_t): |
| """Compute CTC loss.""" |
| logits_t.set_shape(logits.shape) |
| labels_t.set_shape(labels.shape) |
| label_length_t.set_shape(label_length.shape) |
| logit_length_t.set_shape(logit_length.shape) |
| kwargs = dict( |
| logits=logits_t, |
| labels=labels_t, |
| label_length=label_length_t, |
| logit_length=logit_length_t) |
| if unique_t: |
| kwargs["unique"] = unique_t |
| result = ctc_loss_and_grad(**kwargs) |
| def grad(grad_loss): |
| grad = [array_ops.reshape(grad_loss, [1, -1, 1]) * result[1]] |
| grad += [None] * (len(args) - len(grad)) |
| return grad |
| |
| return result[0], grad |
| |
| loss = compute_ctc_loss(*args) |
| if orig_dtype in (dtypes.float16, dtypes.bfloat16): |
| loss = math_ops.cast(loss, orig_dtype) |
| return loss |
| |
| |
| @tf_export("nn.collapse_repeated") |
| @dispatch.add_dispatch_support |
| def collapse_repeated(labels, seq_length, name=None): |
| """Merge repeated labels into single labels. |
| |
| Args: |
| labels: Tensor of shape [batch, max value in seq_length] |
| seq_length: Tensor of shape [batch], sequence length of each batch element. |
| name: A name for this `Op`. Defaults to "collapse_repeated_labels". |
| |
| Returns: |
| A tuple `(collapsed_labels, new_seq_length)` where |
| |
| collapsed_labels: Tensor of shape [batch, max_seq_length] with repeated |
| labels collapsed and padded to max_seq_length, eg: |
| `[[A, A, B, B, A], [A, B, C, D, E]] => [[A, B, A, 0, 0], [A, B, C, D, E]]` |
| |
| new_seq_length: int tensor of shape [batch] with new sequence lengths. |
| """ |
| |
| with ops.name_scope(name, "collapse_repeated_labels", [labels, seq_length]): |
| labels = ops.convert_to_tensor(labels, name="labels") |
| seq_length = ops.convert_to_tensor(seq_length, name="seq_length") |
| |
| # Mask labels that don't equal previous label. |
| label_mask = array_ops.concat([ |
| array_ops.ones_like(labels[:, :1], dtypes.bool), |
| math_ops.not_equal(labels[:, 1:], labels[:, :-1]) |
| ], |
| axis=1) |
| |
| # Filter labels that aren't in the original sequence. |
| maxlen = _get_dim(labels, 1) |
| seq_mask = array_ops.sequence_mask(seq_length, maxlen=maxlen) |
| label_mask = math_ops.logical_and(label_mask, seq_mask) |
| |
| # Count masks for new sequence lengths. |
| new_seq_len = math_ops.reduce_sum( |
| math_ops.cast(label_mask, dtypes.int32), axis=1) |
| |
| # Mask indexes based on sequence length mask. |
| new_maxlen = math_ops.reduce_max(new_seq_len) |
| idx_mask = array_ops.sequence_mask(new_seq_len, maxlen=new_maxlen) |
| |
| # Flatten everything and mask out labels to keep and sparse indices. |
| flat_labels = array_ops.reshape(labels, [-1]) |
| flat_label_mask = array_ops.reshape(label_mask, [-1]) |
| flat_idx_mask = array_ops.reshape(idx_mask, [-1]) |
| idx = math_ops.range(_get_dim(flat_idx_mask, 0)) |
| |
| # Scatter to flat shape. |
| flat = array_ops.scatter_nd( |
| indices=array_ops.expand_dims( |
| array_ops.boolean_mask(idx, flat_idx_mask), axis=1), |
| updates=array_ops.boolean_mask(flat_labels, flat_label_mask), |
| shape=array_ops.shape(flat_idx_mask)) |
| |
| # Reshape back to square batch. |
| batch_size = _get_dim(labels, 0) |
| new_shape = [batch_size, new_maxlen] |
| return (array_ops.reshape(flat, new_shape), |
| math_ops.cast(new_seq_len, seq_length.dtype)) |
| |
| |
| def dense_labels_to_sparse(dense, length): |
| """Convert dense labels with sequence lengths to sparse tensor. |
| |
| Args: |
| dense: tensor of shape [batch, max_length] |
| length: int tensor of shape [batch] The length of each sequence in dense. |
| |
| Returns: |
| tf.sparse.SparseTensor with values only for the valid elements of sequences. |
| """ |
| |
| flat_values = array_ops.reshape(dense, [-1]) |
| flat_indices = math_ops.range( |
| array_ops.shape(flat_values, out_type=dtypes.int64)[0]) |
| mask = array_ops.sequence_mask(length, maxlen=array_ops.shape(dense)[1]) |
| flat_mask = array_ops.reshape(mask, [-1]) |
| indices = array_ops.expand_dims( |
| array_ops.boolean_mask(flat_indices, flat_mask), 1) |
| values = array_ops.boolean_mask(flat_values, flat_mask) |
| sparse = sparse_tensor.SparseTensor( |
| indices=indices, |
| values=math_ops.cast(values, dtypes.int32), |
| dense_shape=array_ops.shape(flat_values, out_type=dtypes.int64)) |
| reshaped = sparse_ops.sparse_reshape(sparse, array_ops.shape(dense)) |
| max_length = math_ops.reduce_max(length) |
| return sparse_tensor.SparseTensor( |
| indices=reshaped.indices, |
| values=reshaped.values, |
| dense_shape=[ |
| math_ops.cast(reshaped.dense_shape[0], dtypes.int64), |
| math_ops.cast(max_length, dtypes.int64) |
| ]) |
| |
| |
| @tf_export("nn.ctc_unique_labels") |
| @dispatch.add_dispatch_support |
| def ctc_unique_labels(labels, name=None): |
| """Get unique labels and indices for batched labels for `tf.nn.ctc_loss`. |
| |
| For use with `tf.nn.ctc_loss` optional argument `unique`: This op can be |
| used to preprocess labels in input pipeline to for better speed/memory use |
| computing the ctc loss on TPU. |
| |
| Example: |
| ctc_unique_labels([[3, 4, 4, 3]]) -> |
| unique labels padded with 0: [[3, 4, 0, 0]] |
| indices of original labels in unique: [0, 1, 1, 0] |
| |
| Args: |
| labels: tensor of shape [batch_size, max_label_length] padded with 0. |
| name: A name for this `Op`. Defaults to "ctc_unique_labels". |
| |
| Returns: |
| tuple of |
| - unique labels, tensor of shape `[batch_size, max_label_length]` |
| - indices into unique labels, shape `[batch_size, max_label_length]` |
| """ |
| |
| with ops.name_scope(name, "ctc_unique_labels", [labels]): |
| labels = ops.convert_to_tensor(labels, name="labels") |
| |
| def _unique(x): |
| u = array_ops.unique(x) |
| y = array_ops.pad(u.y, [[0, _get_dim(u.idx, 0) - _get_dim(u.y, 0)]]) |
| y = math_ops.cast(y, dtypes.int64) |
| return [y, u.idx] |
| |
| return map_fn.map_fn(_unique, labels, dtype=[dtypes.int64, dtypes.int32]) |
| |
| |
| def _sum_states(idx, states): |
| """Take logsumexp for each unique state out of all label states. |
| |
| Args: |
| idx: tensor of shape [batch, label_length] For each sequence, indices into a |
| set of unique labels as computed by calling unique. |
| states: tensor of shape [frames, batch, label_length] Log probabilities for |
| each label state. |
| |
| Returns: |
| tensor of shape [frames, batch_size, label_length], log probabilities summed |
| for each unique label of the sequence. |
| """ |
| |
| with ops.name_scope("sum_states"): |
| idx = ops.convert_to_tensor(idx, name="idx") |
| num_states = _get_dim(states, 2) |
| states = array_ops.expand_dims(states, axis=2) |
| one_hot = array_ops.one_hot( |
| idx, |
| depth=num_states, |
| on_value=0.0, |
| off_value=math_ops.log(0.0), |
| axis=1) |
| return math_ops.reduce_logsumexp(states + one_hot, axis=-1) |
| |
| |
| def _forward_backward_log(state_trans_log_probs, initial_state_log_probs, |
| final_state_log_probs, observed_log_probs, |
| sequence_length): |
| """Forward-backward algorithm computed in log domain. |
| |
| Args: |
| state_trans_log_probs: tensor of shape [states, states] or if different |
| transition matrix per batch [batch_size, states, states] |
| initial_state_log_probs: tensor of shape [batch_size, states] |
| final_state_log_probs: tensor of shape [batch_size, states] |
| observed_log_probs: tensor of shape [frames, batch_size, states] |
| sequence_length: tensor of shape [batch_size] |
| |
| Returns: |
| forward backward log probabilities: tensor of shape [frames, batch, states] |
| log_likelihood: tensor of shape [batch_size] |
| |
| Raises: |
| ValueError: If state_trans_log_probs has unknown or incorrect rank. |
| """ |
| |
| if state_trans_log_probs.shape.ndims == 2: |
| perm = [1, 0] |
| elif state_trans_log_probs.shape.ndims == 3: |
| perm = [0, 2, 1] |
| else: |
| raise ValueError( |
| "Rank of argument `state_trans_log_probs` must be known and equal to " |
| f"2 or 3. Received state_trans_log_probs={state_trans_log_probs} of " |
| f"rank {state_trans_log_probs.shape.ndims}") |
| |
| bwd_state_trans_log_probs = array_ops.transpose(state_trans_log_probs, perm) |
| batch_size = _get_dim(observed_log_probs, 1) |
| |
| def _forward(state_log_prob, obs_log_prob): |
| state_log_prob = array_ops.expand_dims(state_log_prob, axis=1) # Broadcast. |
| state_log_prob += state_trans_log_probs |
| state_log_prob = math_ops.reduce_logsumexp(state_log_prob, axis=-1) |
| state_log_prob += obs_log_prob |
| log_prob_sum = math_ops.reduce_logsumexp( |
| state_log_prob, axis=-1, keepdims=True) |
| state_log_prob -= log_prob_sum |
| return state_log_prob |
| |
| fwd = _scan( |
| _forward, observed_log_probs, initial_state_log_probs, inclusive=True) |
| |
| def _backward(accs, elems): |
| """Calculate log probs and cumulative sum masked for sequence length.""" |
| state_log_prob, cum_log_sum = accs |
| obs_log_prob, mask = elems |
| state_log_prob += obs_log_prob |
| state_log_prob = array_ops.expand_dims(state_log_prob, axis=1) # Broadcast. |
| state_log_prob += bwd_state_trans_log_probs |
| state_log_prob = math_ops.reduce_logsumexp(state_log_prob, axis=-1) |
| |
| log_prob_sum = math_ops.reduce_logsumexp( |
| state_log_prob, axis=-1, keepdims=True) |
| state_log_prob -= log_prob_sum |
| |
| cum_log_sum += array_ops.squeeze(log_prob_sum, axis=[-1]) * mask |
| batched_mask = array_ops.expand_dims(mask, axis=1) |
| out = state_log_prob * batched_mask |
| out += final_state_log_probs * (1.0 - batched_mask) |
| return out, cum_log_sum |
| |
| zero_log_sum = array_ops.zeros([batch_size]) |
| maxlen = _get_dim(observed_log_probs, 0) |
| mask = array_ops.sequence_mask(sequence_length, maxlen, dtypes.float32) |
| mask = array_ops.transpose(mask, perm=[1, 0]) |
| |
| bwd, cum_log_sum = _scan( |
| _backward, (observed_log_probs, mask), |
| (final_state_log_probs, zero_log_sum), |
| reverse=True, |
| inclusive=True) |
| |
| fwd_bwd_log_probs = fwd[1:] + bwd[1:] |
| fwd_bwd_log_probs_sum = math_ops.reduce_logsumexp( |
| fwd_bwd_log_probs, axis=2, keepdims=True) |
| fwd_bwd_log_probs -= fwd_bwd_log_probs_sum |
| fwd_bwd_log_probs += math_ops.log(array_ops.expand_dims(mask, axis=2)) |
| |
| log_likelihood = bwd[0, :, 0] + cum_log_sum[0] |
| |
| return fwd_bwd_log_probs, log_likelihood |
| |
| |
| # TODO(tombagby): This is currently faster for the ctc implementation than using |
| # functional_ops.scan, but could be replaced by that or something similar if |
| # things change. |
| def _scan(fn, elems, initial, reverse=False, inclusive=False, final_only=False): |
| """Repeatedly applies callable `fn` to a sequence of elements. |
| |
| Implemented by functional_ops.While, tpu friendly, no gradient. |
| |
| This is similar to functional_ops.scan but significantly faster on tpu/gpu |
| for the forward backward use case. |
| |
| Examples: |
| scan(lambda a, e: a + e, [1.0, 2.0, 3.0], 1.0) => [2.0, 4.0, 7.0] |
| |
| Multiple accumulators: |
| scan(lambda a, e: (a[0] + e, a[1] * e), [1.0, 2.0, 3.0], (0.0, 1.0)) |
| |
| Multiple inputs: |
| scan(lambda a, e: a + (e[0] * e[1]), (elems1, elems2), 0.0) |
| |
| Args: |
| fn: callable, fn(accumulators, element) return new accumulator values. The |
| (possibly nested) sequence of accumulators is the same as `initial` and |
| the return value must have the same structure. |
| elems: A (possibly nested) tensor which will be unpacked along the first |
| dimension. The resulting slices will be the second argument to fn. The |
| first dimension of all nested input tensors must be the same. |
| initial: A tensor or (possibly nested) sequence of tensors with initial |
| values for the accumulators. |
| reverse: (optional) True enables scan and output elems in reverse order. |
| inclusive: (optional) True includes the initial accumulator values in the |
| output. Length of output will be len(elem sequence) + 1. Not meaningful if |
| final_only is True. |
| final_only: (optional) When True, return only the final accumulated values, |
| not the concatenation of accumulated values for each input. |
| |
| Returns: |
| A (possibly nested) sequence of tensors with the results of applying fn |
| to tensors unpacked from elems and previous accumulator values. |
| """ |
| |
| flat_elems = [ops.convert_to_tensor(x) for x in nest.flatten(elems)] |
| num_elems = array_ops.shape(flat_elems[0])[0] |
| pack_elems = lambda x: nest.pack_sequence_as(structure=elems, flat_sequence=x) |
| flat_initial = [ops.convert_to_tensor(x) for x in nest.flatten(initial)] |
| pack = lambda x: nest.pack_sequence_as(structure=initial, flat_sequence=x) |
| accum_dtypes = [x.dtype for x in flat_initial] |
| num_accums = len(flat_initial) |
| |
| # Types for counter, [outputs], [accumulators] loop arguments. |
| if final_only: |
| loop_dtypes = [dtypes.int32, dtypes.int32] + accum_dtypes |
| else: |
| loop_dtypes = [dtypes.int32, dtypes.int32] + accum_dtypes + accum_dtypes |
| |
| # TODO(tombagby): Update to tfe.defun |
| def cond(i, num_elems, *args): |
| del args |
| return i >= 0 if reverse else i < num_elems |
| |
| # The loop *args are [output tensors] + [accumulator tensors] which must |
| # be paired. Each output corresponds to one accumulator. |
| def body(i, num_elems, *args): |
| """Loop body.""" |
| i.set_shape([]) |
| if final_only: |
| accum = args |
| else: |
| out, accum = args[:num_accums], args[num_accums:] |
| slices = [array_ops.gather(e, i) for e in flat_elems] |
| accum = fn(pack(accum), pack_elems(slices)) |
| flat_accum = nest.flatten(accum) |
| if final_only: |
| new_out = [] |
| else: |
| update_i = i + 1 if inclusive and not reverse else i |
| new_out = [ |
| inplace_ops.alias_inplace_update(x, update_i, y) |
| for x, y in zip(out, flat_accum) |
| ] |
| i = i - 1 if reverse else i + 1 |
| return [i, num_elems] + new_out + flat_accum |
| |
| init_i = ( |
| array_ops.shape(flat_elems[0])[0] - |
| 1 if reverse else constant_op.constant(0, dtype=dtypes.int32)) |
| outputs = [] |
| if not final_only: |
| num_outputs = array_ops.shape(flat_elems[0])[0] + (1 if inclusive else 0) |
| for initial_accum in flat_initial: |
| out_shape = array_ops.concat( |
| [[num_outputs], array_ops.shape(initial_accum)], 0) |
| out = inplace_ops.empty(out_shape, dtype=initial_accum.dtype, init=True) |
| if inclusive: |
| out = inplace_ops.alias_inplace_add(out, init_i + (1 if reverse else 0), |
| initial_accum) |
| outputs.append(out) |
| loop_in = [init_i, num_elems] + outputs + flat_initial |
| hostmem = [ |
| i for i, x in enumerate(loop_in) |
| if x.dtype.base_dtype in (dtypes.int32, dtypes.int64) |
| ] |
| |
| if context.executing_eagerly(): |
| loop_results = loop_in |
| while cond(*loop_results): |
| loop_results = body(*loop_results) |
| else: |
| # TODO(tombagby): Update to while_v2. |
| cond = function.Defun(*loop_dtypes)(cond) |
| body = function.Defun(*loop_dtypes)(body) |
| loop_results = functional_ops.While(loop_in, cond, body, hostmem=hostmem) |
| out = loop_results[2:num_accums + 2] |
| return pack(out) |
| |
| |
| def _get_dim(tensor, i): |
| """Get value of tensor shape[i] preferring static value if available.""" |
| return tensor_shape.dimension_value( |
| tensor.shape[i]) or array_ops.shape(tensor)[i] |