Giter Club home page Giter Club logo

lmtool-fwp's Introduction

PyTorch Language Modeling Toolkit (for Fast Weight Programmers)

This repository contains the official code used for language modeling experiments in the paper(s):

More generally, this can be used as a language modeling toolkit in PyTorch to experiment with:

  • Standard Transformers

  • Transformer-XL

  • Fast Weight Programmers with different update rules and linear attention functions:

    • Update rules: "sum" and our "delta" rule (as proposed in our paper; Sec 4.2)
    • Linear attention functions: "ELU-based" linear attention, "FAVOR+", "deterministic parameter-free projection (DPFP)"

    e.g. some combinations result in well known models:

Fast Weight Implementations

This repositiory contains two implementations of fast weights.

While we only used the cuda implementation for all our final experiments (faster/much better GPU utilization), torch.autograd.Function version can be useful for a quick prototyping with new extensions.

Requirements

This toolkit requires PyTorch torch and Ninja ninja (to compile the cuda kernels).

The experiments for the paper were conducted with Python 3.6 and PyTorch 1.4.0 (note on Aug 24, 2023: the code also works with Python 3.11 and PyTorch 2.0.1+cu117).

More recent versions of PyTorch are not yet well supported by this toolkit which still uses torch.nn.DataParallel for multi-GPU training. If you really need to use a more recent version of PyTorch, check the documentation to use torch.nn.parallel.DistributedDataParallel instead. We will hopefully fix this soon, but we cannot tell exactly when.

The toolkit supports Weights & Biases for monitoring jobs. If you use it, also install wandb.

Acknowledgements

This reposity contains many lines of code taken and adapted from the following sources:

  • This reposity was originally forked from the official implementation of Transformer-XL kimiyoung/transformer-xl. The code for Transformer-XL and standard Transformer models, as well as basic functionality needed for language modeling (including adaptive input and output embeddings) and data preparation (WikiText-103, enwik8, ...) is from the corresponding repository.
  • For Performers, helper functions from lucidrains/performer-pytorch are used.
  • For cuda implementations of our fast weight programmers with the delta rule:
    • Code from idiap/fast-transformers is used with minor changes for the sum update rule.
    • We modified it to implement our update rule. See comments in code for exact locations and modifications.

General Instructions

Please check files under example_scripts for general instructions and examples to train and evaluate models.

BibTex

@inproceedings{schlag2021linear,
      title={Linear Transformers Are Secretly Fast Weight Programmers}, 
      author={Imanol Schlag and Kazuki Irie and J\"urgen Schmidhuber},
      booktitle={Proc. Int. Conf. on Machine Learning (ICML)},
      address = {Virtual only},
      month = jul,
      year={2021}
}
@article{irie2021going,
      title={Going Beyond Linear Transformers with Recurrent Fast Weight Programmers}, 
      author={Kazuki Irie and Imanol Schlag and R\'obert Csord\'as and J\"urgen Schmidhuber},
      journal={Preprint arXiv:2106.06295},
      year={2021}
}

Links

lmtool-fwp's People

Contributors

ischlag avatar kazuki-irie avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

lmtool-fwp's Issues

cannot run example with python 3.10 and pytorch1.11 in a virtual environment

I cannot run the provided example. The code depends on a very old version of pytorch (1.4), I had to install 1.11, and tried to run it inside a python virtual environment. This however seems to be preventing the import of the jit compiled module, and I get instead these errors:

     Loading extension module fast_weight_forward... 
Traceback (most recent call last):
  File "<...redacted>/lmtool-fwp/src/train.py", line 18, in <module>
    from model_main import MemTransformerLM
  File "<...redacted>/lmtool-fwp/src/model_main.py", line 19, in <module>
    from utils.cuda_fast_weight_layer import CudaFastWeightLinearTransformerLayer
  File "<...redacted>/lmtool-fwp/src/utils/cuda_fast_weight_layer.py", line 9, in <module>
    from utils.fast_fast_weight import fast_weight_delta
  File "<...redacted>/lmtool-fwp/src/utils/fast_fast_weight/__init__.py", line 17, in <module>
    mod_causal_dot_product_cuda = load(
  File "<...redacted>/lmtool-fwp/env/lib/python3.10/site-packages/torch/utils/cpp_extension.py", line 1144, in load
    return _jit_compile(
  File "<...redacted>/lmtool-fwp/env/lib/python3.10/site-packages/torch/utils/cpp_extension.py", line 1382, in _jit_compile
    return _import_module_from_library(name, build_directory, is_python_module)
  File "<...redacted>/lmtool-fwp/env/lib/python3.10/site-packages/torch/utils/cpp_extension.py", line 1776, in _import_module_from_library
    assert isinstance(spec.loader, importlib.abc.Loader)
AttributeError: module 'importlib' has no attribute 'abc'. Did you mean: '_abc'?

Do you have any insight on how to proceed?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.