Giter Club home page Giter Club logo

long-range-arena's Introduction

Long-Range Arena (LRA: pronounced ELRA).

Long-range arena is an effort toward systematic evaluation of efficient transformer models. The project aims at establishing benchmark tasks/dtasets using which we can evaluate transformer-based models in a systematic way, by assessing their generalization power, computational efficiency, memory foot-print, etc.

Long-range arena also implements different variants of Transformer models in JAX, using Flax.

This first initial release includes the benchmarks for the paper "Long Range Arena: A benchmark for Efficient Transformers.

Currently we have released all the necessary code to get started and run our benchmarks on vanilla Transformers.

V2 release

Update We have released the xformer models used in our experiments.

We are working on a 2nd update that will release more models and baselines for this benchmark suite. Stay tuned.

Please see below for more examples on how to get started.

Our experiments

Current leaderboard results of all xformer results on our benchmark results. (as of 8th November 2020)

Model ListOps Text Retrieval Image Path Path-X Avg
Local Att 15.82 52.98 53.39 41.46 66.63 FAIL 46.06
Linear Trans. 16.13 65.90 53.09 42.34 75.30 FAIL 50.55
Reformer 37.27 56.10 53.40 38.07 68.50 FAIL 50.67
Sparse Trans. 17.07 63.58 59.59 44.24 71.71 FAIL 51.24
Sinkhorn Trans. 33.67 61.20 53.83 41.23 67.45 FAIL 51.29
Linformer 35.70 53.94 52.27 38.56 76.34 FAIL 51.36
Performer 18.01 65.40 53.82 42.77 77.05 FAIL 51.41
Synthesizer 36.99 61.68 54.67 41.61 69.45 FAIL 52.88
Longformer 35.63 62.85 56.89 42.22 69.71 FAIL 53.46
Transformer 36.37 64.27 57.46 42.44 71.40 FAIL 54.39
BigBird 36.05 64.02 59.29 40.83 74.87 FAIL 55.01

Public External Entries

We list the entries of other papers and submissions that used our LRA benchmark.

Model ListOps Text Retrieval Image Path Path-X Avg
IGLOO 39.23 82 75.5 47.0 67.50 NA 62.25
TLB 37.05 81.88 76.91 57.51 79.06 FAIL 66.48

IGLOO Submissions (by Vsevolod Sourkov) - https://github.com/redna11/lra-igloo
TLB (Temporal Latent Bottleneck) - transformer_tlb

Citation

If you find out work useful, please cite our paper at:

@inproceedings{
tay2021long,
title={Long Range Arena : A Benchmark for Efficient Transformers },
author={Yi Tay and Mostafa Dehghani and Samira Abnar and Yikang Shen and Dara Bahri and Philip Pham and Jinfeng Rao and Liu Yang and Sebastian Ruder and Donald Metzler},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=qVyeW-grC2k}
}

**Note: Please also cite the original sources of these datasets! **

Adding results to the leaderboard.

Please send the link of the paper (arxiv, or published) to the Yi Tay or Mostafa Dehghani (emails in paper) to include your new results to the leaderboard. Just like above, we will add results to the external submission part of the leaderboard. This is so that we do not encourage hill-climbing on the leaderboard but rather meaningful side by side comparisons.

A note on evaluation and comparisons

Meaningful Comparisons

We intend for your benchmark to act as a tool and suite for inspecting model behaviour. As such, if you're running a new setup and you have tuned hparams, do consider running all the other models.

Apples-to Apples setting

This setting is for folks who want to compare with our published results directly.

The default hyperparameter setup (each benchmark should have a config file now). You are not allowed to change hyperparameters such as embedding size, hidden dimensions, number of layers of the new model.

The new model should be within at best 10% larger in terms of parameters compared to the base Transformer model in the provided config file.

Free-for-all Setting

You are allowed to run any model size and change any hyperparameter of the model. However, in the end, you'll not be allowed to report results from our leaderboard because they are no longer comparable. You can choose to rerun models from our library in a comparable setting.

Adding benchmarks or models to this suite

If you develop or could benefit from an extensive array of xformer baselines, please feel free to let us know if you're interested in building new benchmarks. We welcome contributions for new or older models that are not covered in the existing suite.

What if I find a better config for an existing model?

In this paper, we did not prioritize doing hparam sweeps. If you happen to find an implementation related issue or a better hparam that allows a model to do better on a certain task, do send a PR (or a new config file) and we will run the model again internally and report new results for the existing model.

I have a new Xyzformer, how do we add this to the benchmark.

The official results are only for code that have been verified and run in our codebase. We report all external submissions as external. Either submit a PR, an email showing us how to run your model in our codebase and we will update the results accordingly. (Note due to bandwidth constraints this process will take a substantial amount of time).

Example Usage

To run a task, run the train.py file in the corresponding task directory. (please see how to obtain the data for certain tasks if applicable).

PYTHONPATH="$(pwd)":"$PYTHON_PATH" python lra_benchmarks/listops/train.py \
      --config=lra_benchmarks/listops/configs/transformer_base.py \
      --model_dir=/tmp/listops \
      --task_name=basic \
      --data_dir=$HOME/lra_data/listops/

Dataset Setup

This section describes the methods to obtain the datasets and run the tasks in LRA.

To download the datasets, please download it from gs://long-range-arena/lra_release. If permissions fail, you may download the entire gziped file at https://storage.googleapis.com/long-range-arena/lra_release.gz.

ListOps

This task can be found at /listops. The datasets used in our experiments can be found at these google cloud buckets and are in TSV format.

If you would like to go to longer/shorter sequence lengths, we also support generating your own split, run the following comment:

PYTHONPATH="$(pwd)":"$PYTHON_PATH" python lra_benchmarks/data/listops.py -- \
  --output_dir=$HOME/lra_data/listops/

Text Classification

This task can be found at /text_classification. No action is required because this task is already found in tensorflow datasets. The code should run as it is.

Document Retrieval

Please download the dataset at (http://aan.how/download/). Please download the train/test/dev splits from our google cloud bucket. Unfortunately, we were not able to re-distribute this datasets and are only releasing the ids in the format label paper1_id paper2_id. You may download the data from the original source and extract the textual data.

Pixel-level Image Classification

This task can be found at /image. No action is required because this task is already found in tensorflow datasets. It should work out of the box.

Pathfinder

Please see the ./data directory, where the TFDS builder for the pathfinder dataset can be found. We generated different datasets for pathfinder task, with different levels of difficulty using the script provided here. You can find information about the parameters used for generatinng the data in the TFDS builder code in ./data/pathfinder. We are preparing the exact data splits for release at the moment.

Disclaimer

This is not an official Google product.

long-range-arena's People

Contributors

cifkao avatar guyd1995 avatar jnhwkim avatar mostafadehghani avatar niteshbharadwaj avatar ppham27 avatar yazdanbakhsh avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

long-range-arena's Issues

Perceiver on LRA

Did anybody try out the Perceiver model on LRA?
It should be fairly easy to fit into this codebase, as both are written in jax
I'd be super curious to see the results it'll produce

Cannot reproduce the results for cifar10

the same question is issued in
https://github.com/google-research/long-range-arena/issues/16

but for the vallina transformer, using the hyperparameters in the reply
i cannot reproduce the results in the LRA paper ( 2 V100/32GB card)
nor does the default settings in the code do
can you kindly share a ready version of hyperparameters

  • my result, test in step: 34999, loss: 1.8261, acc: 0.3633
  • paper result: 42.44 (same with the latest V2 release result)

my hyperparameters from the issue 16:
batch_size: 256
checkpoint_freq: 17500
eval_frequency: 175
factors: constant * linear_warmup * cosine_decay
grad_clip_norm: null
learning_rate: 0.0005
model:
attention_dropout_rate: 0.2
classifier_pool: CLS
dropout_rate: 0.3
emb_dim: 128
learn_pos_emb: true
mlp_dim: 128
num_heads: 1
num_layers: 1
qkv_dim: 64
model_type: transformer
num_eval_steps: 39
num_train_steps: 35000
random_seed: 0
restore_checkpoints: true
save_checkpoints: true
steps_per_cycle: 35000
trial: 0
warmup: 175
weight_decay: 0.0

Dataset for the matching task

Hi,

Is there available link to download the Document Retrieval dataset? The provided link in ReadMe seems broken.

Thanks

Is there a pytorch equivalent of this implementation?

I searched for a PyTorch implementation of this benchmark and came across several papers that used it as a benchmark for their methods. However, I was unable to find a standardized version. If you have any leads, I would appreciate it. Thank you!

Nested Structure in ListOps

Hi,

According to a previous issue #6, the bracketed parentheses '[', ']' should be preserved during tokenization, resulting to the following set of unique tokens

( ']', '1', '4', '7', '9', '8', '2', '5', '3', '0', '6', '[MIN', '[SM', '[MAX', '[MED']).

However, based on your code (input_pipeline.py), the following set of tokens is used to encode the input

('SM', '2', 'MAX', '8', '5', '0', '4', '1', '7', '6', '3', 'MIN', '9', 'MED').

Evidently, if you run

encoder.decode(x)

where x is an encoded input sample, the result does not contain the bracketed parentheses '[', ']. For instance, the following is the (truncated) outcome of decoding the first encoded sample in the training set

MIN 1 MED 8 0 1 3 0 6 1 7 MED SM 4 MED SM 2 9 MED 9 9 2 MIN 8 3 5 4 8 5 6 2 6 0 5 8 2 6 8 9 8 3 4 SM 3 SM MIN MIN MED 5 8 7 9 1 7 1 8 8 MAX 2 5 7 1 1 0 MAX 6 2 5 0 MIN 2 2 4 8 7 1 MIN 1 SM 3 6 4 0 5 5 MED 6 3 8 0 5 6 4 3 7 8 0 6 8 8 3 MAX 5 1 3 5 MED 3 SM 3 2 6 MAX 2 MIN 8 MAX 8 2 5 8 7 1 5 9 0 0 9 MAX 8 1 MAX 5 9 4 1 1 3 2 3 2 2 MIN 1 9 0 2 9 3 8 MAX ....

So I was wondering why the bracketed parentheses '[', ']' are ignored? This way the nested structure is no longer preserved.

ValueError: Non-hashable static arguments are not supported

I am trying to train the Transformer on the listops task using the command from the readme, but I get the following error:

I1201 18:31:38.370242 23342722882432 input_pipeline.py:75] Finished processing vocab size=14                                                                                                                
Traceback (most recent call last):                                                                                                                                                                          
  File "lra_benchmarks/listops/train.py", line 288, in <module>                                                                                                                                             
    app.run(main)                                                                                                                                                                                           
  File "/ids-cluster-storage/storage/ocifka/venv/lra/lib/python3.7/site-packages/absl/app.py", line 300, in run                                                                                             
    _run_main(main, args)                                                                                                                                                                                   
  File "/ids-cluster-storage/storage/ocifka/venv/lra/lib/python3.7/site-packages/absl/app.py", line 251, in _run_main                                                                                       
    sys.exit(main(argv))                                                                                                                                                                                    
  File "lra_benchmarks/listops/train.py", line 193, in main                                                                                                                                                 
    model_kwargs)                                                                                                                                                                                           
jax._src.traceback_util.FilteredStackTrace: ValueError: Non-hashable static arguments are not supported. An error occured while trying to hash an object of type <class 'dict'>, {'vocab_size': 16, 'emb_dim
': 512, 'num_heads': 8, 'num_layers': 6, 'qkv_dim': 512, 'mlp_dim': 2048, 'max_len': 2000, 'classifier': True, 'num_classes': 10}. The error was:                                                           
TypeError: unhashable type: 'dict'                                                                                                                                                                          
                                                                                                                                                                                                            
The stack trace above excludes JAX-internal frames.                                                                                                                                                         
The following is the original exception that occurred, unmodified.                                                                                                                                          
                                                                                                                                                                                                            
--------------------                                                                                                                                                                                        
                                                                                                                                                                                                            
The above exception was the direct cause of the following exception:                                                                                                                                        
                                                                                                                                                                                                            
Traceback (most recent call last):                                                                                                                                                                          
  File "lra_benchmarks/listops/train.py", line 288, in <module>                                                                                                                                             
    app.run(main)                                                                                                                                                                                           
  File "/ids-cluster-storage/storage/ocifka/venv/lra/lib/python3.7/site-packages/absl/app.py", line 300, in run                                                                                             
    _run_main(main, args)                                                                                                                                                                                   
  File "/ids-cluster-storage/storage/ocifka/venv/lra/lib/python3.7/site-packages/absl/app.py", line 251, in _run_main                                                                                       
    sys.exit(main(argv))                                                                                                                                                                                    
  File "lra_benchmarks/listops/train.py", line 193, in main                                                                                                                                                 
    model_kwargs)                                                                                                                                                                                           
  File "/ids-cluster-storage/storage/ocifka/venv/lra/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 133, in reraise_with_filtered_traceback                                                  
    return fun(*args, **kwargs)                                                                                                                                                                             
  File "/ids-cluster-storage/storage/ocifka/venv/lra/lib/python3.7/site-packages/jax/api.py", line 371, in f_jitted                                                                                         
    return cpp_jitted_f(*args, **kwargs)                                                                                                                                                                    
ValueError: Non-hashable static arguments are not supported. An error occured while trying to hash an object of type <class 'dict'>, {'vocab_size': 16, 'emb_dim': 512, 'num_heads': 8, 'num_layers': 6, 'qk
v_dim': 512, 'mlp_dim': 2048, 'max_len': 2000, 'classifier': True, 'num_classes': 10}. The error was:
TypeError: unhashable type: 'dict'

I have jax==0.2.6 and jaxlib==0.1.57+cuda101.

Is it possible to include instructions on how to run it on GPUs

This code seems to be using 4x4 TPUs, but since I don't have access to TPUs, I wonder if you could release instructions on how to replicater the results on GPUs, which would make this code more accessible for people without abundant computation resources.

Request about cuda version when using GPUs

Hi, Thanks for your code.

I try to run this code on GPUs. But I do suffer a lot in the environment setting.

Can you provide the version of CUDA, jaxlib, jax, flax, when using GPUs for training?

Linear Transformer code base

Hello!

Thank you for your work on long-range-arena it's impressive!
My name is Maksim Zubkov, and I am working on the improvement of the Linear Transformer proposed in the paper Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention (Angelos Katharopoulos et al, 2020). I want to compare the results of the vanilla Linear transformer with the one achieved by the model I proposed. In this regard, it would be very useful for my research to get access to the code base of the Linear transformer you used in the LRA. Due to this fact, I have the following questions:

  1. How soon do you plan to publish the code of the Linear Transformer used in the experiments?
  2. Have you somehow changed the mechanism proposed in the original repo? Or you basically change attention_fn in nn.SelfAttention?

Best regards, Maksim

Problem training listops on GPU

Hello, I'm running into a strange issue training models (performer, bigbird, longformer) on listops.
The model works fine on CPU, but on GPU (either on one or multiple v100 16GB) it crashes with a strange error.
This doesn't happen with any other dataset I've tried in the benchmark. The error is:

E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_blas.cc:226] failed to create cublas handle: CUBLAS_STATUS_NOT_INITIALIZED
F external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc:113] Check failed: stream->parent()->GetBlasGemmAlgorithms(&algorithms)

The dataset was created with the included script:
PYTHONPATH="$(pwd)":"$PYTHON_PATH" python lra_benchmarks/data/listops.py --output_dir=lra_data/listops
Any idea what may be causing this?

Computing required attention span

Hi,

I just have a general clarification question: for the required attention span mentioned in the Long Range Arena paper, do you calculate the distance from the attended keys to only the last query token in a sequence? In other words, the maximum possible distance is always 1K (2K, or 4K, respectively, depending on the tasks). Also, I am wondering how you deal with query tokens that are in the middle of a sequence.

Thank you.

Error when run document retrival

Hi, thanks for the great code,
I am having some issues when trying to run the document retrieval tasks.
I got following issue when trying to run matching/train.py using the base transformer network:

Traceback (most recent call last): File "lra_benchmarks/matching/train.py", line 320, in <module> app.run(main) File "/mnt/lustre/share/spring/conda_envs/miniconda3/envs/s0.3.4/lib/python3.6/site-packages/absl/app.py", line 303, in run _run_main(main, args) File "/mnt/lustre/share/spring/conda_envs/miniconda3/envs/s0.3.4/lib/python3.6/site-packages/absl/app.py", line 251, in _run_main sys.exit(main(argv)) File "lra_benchmarks/matching/train.py", line 197, in main init_rng, input_shape) File "/mnt/lustre/sunweixuan/long-range-arena/lra_benchmarks/utils/train_utils.py", line 52, in get_model *create_model_args) TypeError: create_model() missing 1 required positional argument: 'input2_shape'

It seems that it took two 'input_shape' then modified? see here:
093bfc6.

Then, when I input two 'input_shape1' and 'input_shape2', above issue is solved but I got a new error:

Traceback (most recent call last): File "lra_benchmarks/matching/train.py", line 321, in <module> app.run(main) File "/mnt/lustre/share/spring/conda_envs/miniconda3/envs/s0.3.4/lib/python3.6/site-packages/absl/app.py", line 303, in run _run_main(main, args) File "/mnt/lustre/share/spring/conda_envs/miniconda3/envs/s0.3.4/lib/python3.6/site-packages/absl/app.py", line 251, in _run_main sys.exit(main(argv)) File "lra_benchmarks/matching/train.py", line 198, in main init_rng, input_shape, input_shape) File "/mnt/lustre/sunweixuan/long-range-arena/lra_benchmarks/utils/train_utils.py", line 52, in get_model *create_model_args) File "lra_benchmarks/matching/train.py", line 71, in create_model return _create_model(key) File "lra_benchmarks/matching/train.py", line 67, in _create_model (input2_shape, jnp.float32)]) File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/nn/base.py", line 238, in wrapper return super_fn(*args, **kwargs) File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/nn/base.py", line 238, in wrapper return super_fn(*args, **kwargs) File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/nn/base.py", line 536, in init_by_shape return jax_utils.partial_eval_by_shape(lazy_init, input_specs) File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/jax_utils.py", line 116, in partial_eval_by_shape _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals) File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/jax_utils.py", line 110, in <lambda> f = lambda *inputs: fn(*inputs, *args, **kwargs) File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/nn/base.py", line 533, in lazy_init return init_fn() File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/nn/base.py", line 527, in init_fn return cls.init(_rng, *(inputs + args), name=name, **kwargs) File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/nn/base.py", line 238, in wrapper return super_fn(*args, **kwargs) File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/nn/base.py", line 238, in wrapper return super_fn(*args, **kwargs) File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/nn/base.py", line 489, in init y = instance.apply(*args, **kwargs) TypeError: apply() got multiple values for argument 'vocab_size'

Could you let me know what's the possible solution?

Text Classification Data

Couldn't find the text classification data in the gz file. Is it missing or goes under a different name?

Thanks,
Guy

Configs for Image Classification (cifar10)

Thanks for the great work!
I have a question regarding the hyperparams for training cifar10. I used the setting in this repo and replaced several hyperparams (eg n_layers n_heads etc) with the ones reported in the paper, but the best testing acc I got was 0.36:

import ml_collections

NUM_EPOCHS = 200
TRAIN_EXAMPLES = 45000
VALID_EXAMPLES = 10000

def get_config():
"""Get the default hyperparameter configuration."""
config = ml_collections.ConfigDict()
config.batch_size = 256
config.eval_frequency = TRAIN_EXAMPLES // config.batch_size
config.num_train_steps = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS
config.num_eval_steps = VALID_EXAMPLES // config.batch_size
config.weight_decay = 0.
config.grad_clip_norm = None

config.save_checkpoints = True
config.restore_checkpoints = False
config.checkpoint_freq = (TRAIN_EXAMPLES //
config.batch_size) * NUM_EPOCHS // 2
config.random_seed = 0

config.learning_rate = .0005 (if using 0.01 from the paper, the loss is not going down)
config.factors = 'constant * linear_warmup * cosine_decay'
config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1
config.steps_per_cycle = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS

model params

config.model = ml_collections.ConfigDict()
config.model.emb_dim = 32
config.model.num_heads = 4
config.model.num_layers = 3
config.model.qkv_dim = 64
config.model.mlp_dim = 128
config.model.dropout_rate = 0.3
config.model.attention_dropout_rate = 0.2
config.model.classifier_pool = 'CLS'
config.model.learn_pos_emb = True

config.trial = 0 # dummy for repeated runs.
return config

Could you point out which params I could adjust to match the accuracy (this is for full attention).

Different hyper-parameters used for different models in image task.

Hi,

I found that different hyper-parameters (number of layers, dimension, etc.) are used for different models.
Can you clarify how the baselines are compared?

For example,
https://github.com/google-research/long-range-arena/blob/main/lra_benchmarks/image/configs/cifar10/longformer_base.py

config.model_type = "longformer"
config.model.num_layers = 4
config.model.emb_dim = 128
config.model.qkv_dim = 64
config.model.mlp_dim = 128
config.model.num_heads = 4
config.model.classifier_pool = "MEAN"

https://github.com/google-research/long-range-arena/blob/main/lra_benchmarks/image/configs/cifar10/performer_base.py

config.model_type = "performer"
config.model.num_layers = 1
config.model.emb_dim = 128
config.model.qkv_dim = 64
config.model.mlp_dim = 128
config.model.num_heads = 8
config.model.classifier_pool = "CLS"

https://github.com/google-research/long-range-arena/blob/main/lra_benchmarks/image/configs/cifar10/reformer_base.py

config.model_type = "reformer"
config.model.num_layers = 4
config.model.emb_dim = 64
config.model.qkv_dim = 32
config.model.mlp_dim = 64
config.model.num_heads = 8
config.model.classifier_pool = "CLS"

bracket in listops task?

Hello, thank you for sharing a great benchmark!
I'm focusing on 'listops' benchmark, with the provided codes and hyperparameters.

The paper says the maximum length of the sequence in this task is 2K, but it seems the code excluded all of the brackets in the sequence.

With the consideration of brackets ( '(', ')', '[', ']'), the maximum length becomes 6K which becomes larger sequence compared to the mentioned length in the paper.

Don't we have to consider such brackets as the input in this task?

Thank you.

Tied weights for ListOps Transformer

Hi,

According to the configuration file the transformer model (with vanilla softmax attention) used in the ListOps task is weight tied (for both self-attention module and the feedforward network). Does that also apply to the other transformer variants listed in the comparison? For example, is Linear Transformer also weight-tied? I couldn't find anything relevant in the paper.

Many thanks in advance!

AAN dataset crashing when loading .tsv file

Did anyone else have issues loading the AAN dataset into memory? In particular when I load the .tsv file into memory, it crashes :/ I used several different instances on Google Cloud, with varying amount of memory, up to 170G, 24 cpus, but it still crashed. I feel like I am missing something. Here's my snippet of code that crashes the instance every time.

from datasets import DatasetDict, Value, load_dataset
...

        dataset = load_dataset(
            "csv",
            data_files={
                "train": str(self.data_dir / "new_aan_pairs.train.tsv"),  # 8G file
                "val": str(self.data_dir / "new_aan_pairs.eval.tsv"),
                "test": str(self.data_dir / "new_aan_pairs.test.tsv"),
            },
            delimiter="\t",
            column_names=["label", "input1_id", "input2_id", "text1", "text2"],
            keep_in_memory=True,

Confusion Regarding Hyperparameters

As I outlined this Github issue, the hyper parameters reported in your article, github issue and config files all differ. Do you plan to fix these inconsistencies? If not, could you please add a disclaimer to your article clarifying that hyperparameters are not shared?

My concern is that reviewers in the future will require experimental results on LRA. But if the hyperparameters are not shared, researchers may end up writing long appendix disclaimers describing how they dealt with missing/inconsistent hyperparameters.

Please note that this ALREADY happened in the FNet paper.

image

I am concerned these hyperparameter inconsistencies may diminish the positive impact LRA could have on Transformer research, potentially causing more harm than it does good.

Pretrained models

Have the pretrained models that were used to report the accuracy been released? If so, could someone please direct me to where I can find them?

Q's on Performer & Text Classification

Thanks for the great work. I had a couple questions when trying to reproduce the Performer on the Byte Level Text Classification:

  1. What Kernel Function are you using? (Softmax approximation or Relu?)
  2. I found the training to be very instable. Do you take the final model after 20K steps or do you take the best checkpoint?
  3. With the learning rate scheduler you use, the learning rate is 0 if the first step is 0 isn't it? Shouldn't you instead start your training loop with for step in range(1, X) at https://github.com/google-research/long-range-arena/blob/main/lra_benchmarks/text_classification/train.py

Looking forward to the implementations of the other models, thanks!

Are you interested in publishing to huggingface/datasets ?

It's a little bit hard for Pytorch users to evaluate their models on the benchmark.

Are you willing to import your datasets to huggingface/datasets ?
There are detailed steps about how to add a dataset. (https://huggingface.co/docs/datasets/add_dataset.html), and it shouldn't be hard since you can refer to the processing scripts of other datasets.

If this benchmark can be imported to huggingface/datasets, which then provides use for Numpy/Pandas/PyTorch/TensorFlow/JAX, I believe it will become more accessible and prevailed.

Publish number of parameters for each task

Hello,

you mention: "The new model should be within at best 10% larger in terms of parameters compared to the base Transformer model in the provided config file"

Do you publish what those baseline number of params are respectively for each task?

Thanks

Linear transformer performance

Hello again!

Thank you for your work and for open-sourcing the codebase! I tried to run a listops experiment with Linear transformer and got results on the test that did not correspond to results proposed in the paper:

{"accuracy": 0.27000001072883606, "loss": 2.579371690750122, "perplexity": 13.188848495483398}

The model config was absolutely identical to the one used in transformer and the only thing I changed in the lra_benchmark/listops/train.py, was the following lines:

  if model_type == 'transformer':
    model = create_model(init_rng, transformer.TransformerEncoder, input_shape,
                         model_kwargs)
  elif model_type == 'linear_transformer':
      model = create_model(init_rng, linear_transformer.LinearTransformerEncoder, input_shape,
                           model_kwargs)
  else:
    raise ValueError('Model type not supported')

The experiments were run inside a docker nvidia/cuda:11.0-cudnn8-devel-ubuntu18.04 on 4 Tesla T4 with requirements:

jax>=0.2.4
flax>=0.2.2
ml-collections>=0.1.0
tensorboard>=2.3.0
tensorflow>=2.3.1
tensorflow-datasets>=4.0.1

and

pip install --upgrade jaxlib==0.1.65+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

I can share the full Dockerfile if needed

typo in readme code block

PYTHONPATH="$(pwd)":"$PYTHON_PATH" python lra_benchmarks/listops/train.py \ --config=lra_benchmarks/listops/configs/transformer_base.py \ --model_dir=/tmp/listops \ --task_name=basic \ --data_dir=$HOME/lra_data/listops/

PYTHON_PATH should be changed to PYTHONPATH

bug in Pathfinder-128 dataset

Hi,

I trained a baseline CNN model on the 4 provided pathfinder datasets (32x32, 64x64, 128x128, 256x256) and it achieved good results on 32x32, 64x64, and 256x256, but only random guessing on 128x128. This seems to indicate that the 128x128 dataset has a bug. In more detail:

I extracted the provided .gz which is organized like this:

lra_release/
  listops-1000/
  pathfinder128/
  lra_release/
      listops-1000/
      tsv_data/
      pathfinder32/
      pathfinder64/
      pathfinder128/
      pathfinder256/

I created a dataset for each of the 4 pathfinderXX folders using the same processing. More specifically, for each dataset I used the 200k "curve length 14" examples with a 160k/20k/20k split.
I trained a basic ResNet18 model on each of these datasets; the pre-processing transforms and model were exactly the same in all cases.
This ResNet achieved 80+% validation/test accuracy on each of the 32x32, 64x64, and 256x256 datasets. On the 128x128 dataset, after training for many epochs it achieved over 95% train accuracy, but never achieved more than random guessing on validation/test.

I have no idea what is causing the issue, but this seems like the data for Pathfinder128 must be different somehow. Any leads on the issue will be greatly appreciated.

Text Classification Configuration: Paper vs Code

Hi, for the byte-level document classification task there seems to be a discrepancy between the paper (see Appendix 1.2) and the config file in the repository.

Paper

6 layers, 8 heads, 512 hidden dimensions, d=2048 for positional FFN

Code

config.emb_dim = 256
config.num_heads = 4
config.num_layers = 4
config.qkv_dim = 256
config.mlp_dim = 1024

Could you please resolve this?

This is also the case for other tasks, e.g. Image Classification

jax report that "No GPU/TPU found, falling back to CPU"

in the requirement.txt it requires jax>=0.2.4
when incurring this problem as in the title, the jax github homepages says for both support of GPU and CPU,
one needs to install as

pip install --upgrade jax jaxlib==0.1.67+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
The jaxlib version must correspond to the version of the existing CUDA installation you want to use

but that jaxlib repo link does not contain any version >= 0.2.4,
actually the newest version there is 0.1.67

How do you install the environment ?

validation on IMDB

Hi, could you tell me how you decide the best model trained on the IMDB dataset since there is no validation set? Do you use the last checkpoint for testing? Thanks!

The best checkpoint of Transformer

Hi,

Thanks for sharing this wonderful work! Could you please also share the finetuned checkpoint of the vanilla Transformer for all the tasks? Thank you very much!

Quadratic Longformer suspicion

Hi!
I am now checking the Longformer implementation and It seems that nn.attention.dot_product_attention() (with attention pattern passed through the bias parameter) does all the heavy lifting.
But in the nn.attention.dot_product_attention() or a newer version linen.attention.dot_product_attention() the multiplication goes first and you only use the mask after the multiplication. So you still have the quadratic computation.
Can you explain, please, how you bypass the quadratic computation?

Serious bugs in the ListOps task

I have discovered two serious bugs in the ListOps task, which unfortunately mean that the task is completely broken as far as I can tell.

  1. The input pipeline uses the tfds.deprecated.text.Tokenizer(), which ignores non-alphanumeric characters by default. This means an input like [MAX 4 3 [MIN 2 3 ] 1 0 [MEDIAN 1 5 8 9 2]] will actually get encoded as MAX 4 3 MIN 2 3 1 0 MEDIAN 1 5 8 9 2, making the task impossible to solve.
  2. The token counting in the data generation script is incorrect. With --max-length=2000, the generated sequences are clearly much longer than 2000 tokens (if encoded correctly), and will only get truncated to this length during training, leading again to an impossible task.

ModuleNotFoundError: No module named 'flax.deprecated'

Hello! Thanks for offering this benchmark!
However, when I use the following commands

PYTHONPATH="$(pwd)":"$PYTHON_PATH" python lra_benchmarks/listops/train.py \
      --config=lra_benchmarks/listops/configs/transformer_base.py \
      --model_dir=/tmp/listops \
      --task_name=basic \
      --data_dir=$HOME/lra_data/listops/

I got this message

Traceback (most recent call last):
  File "lra_benchmarks/listops/train.py", line 28, in <module>
    from flax.deprecated import nn
ModuleNotFoundError: No module named 'flax.deprecated'

Is there any solutions?

Hyperparameters of each task to reproduce table 1 in paper

Hi,

I am looking for the hyperparameters of each task to reproduce table 1.
The hyperparameters in base_*_config.py do not match what is reported in Appendix, but when I tried using the hyperparameters reported in the paper, the result does not seem to match the scores in table 1.
Could you provide the hyperparameters you used to produce table 1?

Thanks!

Pathfinder task

Could you please specify which pathfinder task is used in the paper? I'm assuming it's pathfinder32, but which difficulty?

Also, the task is broken. There is no way to specify the path to the data, and the pipeline code tries to reference _PATHFINER_TFDS_PATH (note the typo), which is never defined (even without the typo).

ListOps performance

On running the ListOps task as-is from the repo, I got a validation performance similar to that reported in the paper but the test performance on results.json is very low:

{"accuracy": 0.17500001192092896, "loss": 3.032956123352051, "perplexity": 20.758506774902344}

I saw that the code is saving the model from the last checkpoint as compared to the model with the best validation performance. Could you detail the evaluation setup used in the paper i.e. for the paper do you evaluate the model from the last checkpoint of from the best validation checkpoint?

Thank you very much! :-)

Script for computing memory consumption

Hello again!

I succeeded to run your code on my devices, but now I am struggling with computing memory consumption. Could you please opensource script for computing it?

Pathfinder task cannot converge.

I try to run pathfinder32 based on this dataset and run 5 times, 3 out of 5 cannot converge, and the loss keep 0.6933 until the end, but 2 of them can converge normally, and get final acc of 75%(bigbird). It is pretty random. Then I try different models(performer), and it never converge again. But the cifar10 task, which using the same train code with pathfinder32, converge all the times. It's the problem of the dataset?

I0831 15:32:43.578928 140327465420608 train.py:276] eval in step: 16224, loss: 0.6931, acc: 0.5017
I0831 15:33:00.912956 140327465420608 train.py:242] train in step: 16536, loss: 0.6932, acc: 0.5011
I0831 15:33:02.813938 140327465420608 train.py:276] eval in step: 16536, loss: 0.6932, acc: 0.4983
I0831 15:33:21.293757 140327465420608 train.py:242] train in step: 16848, loss: 0.6931, acc: 0.5018
I0831 15:33:23.183998 140327465420608 train.py:276] eval in step: 16848, loss: 0.6932, acc: 0.4983
I0831 15:33:41.210031 140327465420608 train.py:242] train in step: 17160, loss: 0.6932, acc: 0.4997
I0831 15:33:43.294295 140327465420608 train.py:276] eval in step: 17160, loss: 0.6931, acc: 0.4983

How to run test

This might be a dumb question, but it seems that there's only a train.py which trains and prints validation stats. How to test the model to get numbers comparable to accuracy numbers in the table?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.