Giter Club home page Giter Club logo

pytorch_beam_search's Introduction

PyTorch Beam Search

This library implements fully vectorized Beam Search, Greedy Search and Sampling for sequence models written in PyTorch. This is specially useful for tasks in Natural Language Processing, but can also be used for anything that requires generating a sequence from a sequence model.

Usage

A GPT-like character-level language model

from pytorch_beam_search import autoregressive

# Create vocabulary and examples
# Tokenize the way you need

corpus = list("abcdefghijklmnopqrstwxyz ")
# len(corpus) == 25
# An Index object represents a mapping from the vocabulary
# to integers (indices) to feed into the models
index = autoregressive.Index(corpus)
n_gram_size = 17  # 16 with an offset of 1 
n_grams = [corpus[i:n_gram_size + i] for i in range(len(corpus))[:-n_gram_size + 1]]

# Create tensor

X = index.text2tensor(n_grams)
# X.shape == (n_examples, len_examples) == (25 - 17 + 1 = 9, 17)

# Create and train the model

model = autoregressive.TransformerEncoder(index)  # just a PyTorch model
model.fit(X)  # basic method included

# Generate new predictions

new_examples = [list("new first"), list("new second")]
X_new = index.text2tensor(new_examples)
loss, error_rate = model.evaluate(X_new)  # basic method included
predictions, log_probabilities = autoregressive.beam_search(model, X_new)
# every element in predictions is the list of candidates for each example
output = [index.tensor2text(p) for p in predictions]
output

A Transformer character sequence-to-sequence model

from pytorch_beam_search import seq2seq

# Create vocabularies
# Tokenize the way you need

source = [list("abcdefghijkl"), list("mnopqrstwxyz")]
target = [list("ABCDEFGHIJKL"), list("MNOPQRSTWXYZ")]
# An Index object represents a mapping from the vocabulary
# to integers (indices) to feed into the models
source_index = seq2seq.Index(source)
target_index = seq2seq.Index(target)

# Create tensors

X = source_index.text2tensor(source)
Y = target_index.text2tensor(target)
# X.shape == (n_source_examples, len_source_examples) == (2, 11)
# Y.shape == (n_target_examples, len_target_examples) == (2, 12)

# Create and train the model

model = seq2seq.Transformer(source_index, target_index)  # just a PyTorch model
model.fit(X, Y, epochs=100)  # basic method included

# Generate new predictions

new_source = [list("new first in"), list("new second in")]
new_target = [list("new first out"), list("new second out")]
X_new = source_index.text2tensor(new_source)
Y_new = target_index.text2tensor(new_target)
loss, error_rate = model.evaluate(X_new, Y_new)  # basic method included
predictions, log_probabilities = seq2seq.beam_search(model, X_new)
output = [target_index.tensor2text(p) for p in predictions]
output

Features

Algorithms

  • The greedy_search function implements Greedy Search, which simply picks the most likely token at every step. This is the fastest and simplest algorithm, but can work well if the model is properly trained.
  • The sample function implements sampling from a sequence model, using the learned distribution at every step to build the output token by token. This is very useful to inspect what the model learned.
  • The beam_search function implements Beam Search, a form of pruned Breadth-First Search that expands a fixed number of the best candidates at every step. This is the slowest algorithm, but usually outperforms Greedy Search.

Models

  • The autoregressive module implements the search algorithms and some architectures for unsupervised models that learn to predict the next token in a sequence.
    • LSTM is a simple baseline/sanity check.
    • TransformerEncoder is a GPT -like model for state-of-the-art performance.
  • The seq2seq module implements the search algorithms and some architectures for supervised encoder-decoder models that learn how to map sequences to sequences.
    • LSTM is a sequence-to-sequence unidirectional LSTM model similar to the one in Cho et al., 2014, useful as a simple baseline/sanity check.
    • ReversingLSTM is a sequence-to-sequence unidirectional LSTM model that reverses the order of the tokens in the input, similar to the one in Sutskever et al., 2014. A bit more complex than LSTM but gives better performance.
    • Transformer is a standard Transformer model for state-of-the-art performance.

Installation

pip install pytorch_beam_search

Contribute

License

The project is licensed under the MIT License.

pytorch_beam_search's People

Contributors

jarobyte91 avatar risangbaskoro avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

pytorch_beam_search's Issues

torch version pinned to 1.8.1 is not optimal

The setup file pins the version of pytorch to >=1.8.1.
I don't think any of the APIs used require a very specific pytorch release (would 2.0.1 not work equally well?).
Pinning to >=1.8.1 is a bit too specific to support a wide use of this package in different repos (some of them may already be on torch ^2.0.0, for instance, and >=1.8.1 does not allow 2.0).

Would it be possible to release a version that is a bit less specific about the version of torch it can work with?

Such a release on pypi would be great.

Length Penalty with Beam Search

Hi, great work on the beam search implementation. I was going through your code to understand your implementation and noticed that you a) Don't terminate a beam when the EOS token is predicted by the language model, b) There is no length penalty for the probabilities.

Am I right in noticing these or have I missed something? I would be glad if you could give me an idea about how I can implement these two features. Thank you.

How to implement beam search from this libarary in the context of a chatbot model (link in the coments)

Hi there,

Here is the code for the chatbot: https://github.com/pytorch/tutorials/blob/master/beginner_source/chatbot_tutorial.py.

As you'll see there is a GreedySearchDecoder class. All I want to do is create a BeamSearchDecoder class using beam search from this library. Any ideas?

To be clear, the model is already trained, I'm just trying to evaluate it with a different search method, beam search.

Thank you very much.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.