Giter Club home page Giter Club logo

ssnt_loss's Introduction

ssnt-loss

A pure PyTorch implementation of the loss described in "Online Segment to Segment Neural Transduction" https://arxiv.org/abs/1609.08194.

Usage

There are two versions, a normal version and a memory efficient version. They should give the same output, please inform me if they don't.

def ssnt_loss_mem(
    log_probs: Tensor,
    targets: Tensor,
    source_lengths: Tensor,
    target_lengths: Tensor,
    emit_logits: Optional[Tensor] = None,
    emit_probs: Optional[Tensor] = None,
    neg_inf: float = -1e4,
    reduction="none",
    fastemit_lambda=0
):
    """The memory efficient implementation concatenates along the targets
    dimension to reduce wasted computation on padding positions.

    N is the minibatch size
    T is the maximum number of output labels
    S is the maximum number of input frames
    V is the vocabulary of labels.
    T_flat is the summation of lengths of all output labels

    Assuming the original tensor is of (N, T, ...), then it should be reduced to
    (T_flat, ...). This can be obtained by using a target mask.
    For example:
        >>> target_mask = targets.ne(pad)   # (B, T)
        >>> targets = targets[target_mask]  # (T_flat,)
        >>> log_probs = log_probs[target_mask]  # (T_flat, S, V)

    Args:
        log_probs (Tensor): (T_flat, S, V) Word prediction log-probs, should be output of log_softmax.
        targets (Tensor): (T_flat,) target labels for all samples in the minibatch.
        source_lengths (Tensor): (N,) Length of the source frames for each sample in the minibatch.
        target_lengths (Tensor): (N,) Length of the target labels for each sample in the minibatch.
        emit_logits, emit_probs (Tensor, optional): (T_flat, S) Emission logits (before sigmoid) or
            probs (after sigmoid). If both are provided, logits is used.
        neg_inf (float, optional): The constant representing -inf used for masking.
            Default: -1e4
        reduction (string, optional): Specifies reduction. suppoerts mean / sum.
            Default: None.
        fastemit_lambda (float, optional): Scale the emission gradient of emission paths to
            encourage low latency. https://arxiv.org/pdf/2010.11148.pdf
            Default: 0
    """

Minimal example

python example.py

Note

ℹī¸ This is a WIP project. the implementation is still being tested.

  • This implementation is based on the parallelized cumsum and cumprod operations proposed in monotonic attention. Since the alignments in SSNT and monotonic attention is almost identical, we can infer that the forward variable alpha(i,j) of the SSNT can be computed similarly.
  • Run test by python test.py (requires pip install expecttest).
  • Feel free to contact me if there are bugs in the code.

Reference

ssnt_loss's People

Contributors

george0828zhang avatar

Stargazers

 avatar  avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤ī¸ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.