Giter Club home page Giter Club logo

long-range-arena's Issues

Confusion Regarding Hyperparameters

As I outlined this Github issue, the hyper parameters reported in your article, github issue and config files all differ. Do you plan to fix these inconsistencies? If not, could you please add a disclaimer to your article clarifying that hyperparameters are not shared?

My concern is that reviewers in the future will require experimental results on LRA. But if the hyperparameters are not shared, researchers may end up writing long appendix disclaimers describing how they dealt with missing/inconsistent hyperparameters.

Please note that this ALREADY happened in the FNet paper.

image

I am concerned these hyperparameter inconsistencies may diminish the positive impact LRA could have on Transformer research, potentially causing more harm than it does good.

Pathfinder task

Could you please specify which pathfinder task is used in the paper? I'm assuming it's pathfinder32, but which difficulty?

Also, the task is broken. There is no way to specify the path to the data, and the pipeline code tries to reference _PATHFINER_TFDS_PATH (note the typo), which is never defined (even without the typo).

Text Classification Configuration: Paper vs Code

Hi, for the byte-level document classification task there seems to be a discrepancy between the paper (see Appendix 1.2) and the config file in the repository.

Paper

6 layers, 8 heads, 512 hidden dimensions, d=2048 for positional FFN

Code

config.emb_dim = 256
config.num_heads = 4
config.num_layers = 4
config.qkv_dim = 256
config.mlp_dim = 1024

Could you please resolve this?

This is also the case for other tasks, e.g. Image Classification

Linear Transformer code base

Hello!

Thank you for your work on long-range-arena it's impressive!
My name is Maksim Zubkov, and I am working on the improvement of the Linear Transformer proposed in the paper Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention (Angelos Katharopoulos et al, 2020). I want to compare the results of the vanilla Linear transformer with the one achieved by the model I proposed. In this regard, it would be very useful for my research to get access to the code base of the Linear transformer you used in the LRA. Due to this fact, I have the following questions:

  1. How soon do you plan to publish the code of the Linear Transformer used in the experiments?
  2. Have you somehow changed the mechanism proposed in the original repo? Or you basically change attention_fn in nn.SelfAttention?

Best regards, Maksim

Serious bugs in the ListOps task

I have discovered two serious bugs in the ListOps task, which unfortunately mean that the task is completely broken as far as I can tell.

  1. The input pipeline uses the tfds.deprecated.text.Tokenizer(), which ignores non-alphanumeric characters by default. This means an input like [MAX 4 3 [MIN 2 3 ] 1 0 [MEDIAN 1 5 8 9 2]] will actually get encoded as MAX 4 3 MIN 2 3 1 0 MEDIAN 1 5 8 9 2, making the task impossible to solve.
  2. The token counting in the data generation script is incorrect. With --max-length=2000, the generated sequences are clearly much longer than 2000 tokens (if encoded correctly), and will only get truncated to this length during training, leading again to an impossible task.

ValueError: Non-hashable static arguments are not supported

I am trying to train the Transformer on the listops task using the command from the readme, but I get the following error:

I1201 18:31:38.370242 23342722882432 input_pipeline.py:75] Finished processing vocab size=14                                                                                                                
Traceback (most recent call last):                                                                                                                                                                          
  File "lra_benchmarks/listops/train.py", line 288, in <module>                                                                                                                                             
    app.run(main)                                                                                                                                                                                           
  File "/ids-cluster-storage/storage/ocifka/venv/lra/lib/python3.7/site-packages/absl/app.py", line 300, in run                                                                                             
    _run_main(main, args)                                                                                                                                                                                   
  File "/ids-cluster-storage/storage/ocifka/venv/lra/lib/python3.7/site-packages/absl/app.py", line 251, in _run_main                                                                                       
    sys.exit(main(argv))                                                                                                                                                                                    
  File "lra_benchmarks/listops/train.py", line 193, in main                                                                                                                                                 
    model_kwargs)                                                                                                                                                                                           
jax._src.traceback_util.FilteredStackTrace: ValueError: Non-hashable static arguments are not supported. An error occured while trying to hash an object of type <class 'dict'>, {'vocab_size': 16, 'emb_dim
': 512, 'num_heads': 8, 'num_layers': 6, 'qkv_dim': 512, 'mlp_dim': 2048, 'max_len': 2000, 'classifier': True, 'num_classes': 10}. The error was:                                                           
TypeError: unhashable type: 'dict'                                                                                                                                                                          
                                                                                                                                                                                                            
The stack trace above excludes JAX-internal frames.                                                                                                                                                         
The following is the original exception that occurred, unmodified.                                                                                                                                          
                                                                                                                                                                                                            
--------------------                                                                                                                                                                                        
                                                                                                                                                                                                            
The above exception was the direct cause of the following exception:                                                                                                                                        
                                                                                                                                                                                                            
Traceback (most recent call last):                                                                                                                                                                          
  File "lra_benchmarks/listops/train.py", line 288, in <module>                                                                                                                                             
    app.run(main)                                                                                                                                                                                           
  File "/ids-cluster-storage/storage/ocifka/venv/lra/lib/python3.7/site-packages/absl/app.py", line 300, in run                                                                                             
    _run_main(main, args)                                                                                                                                                                                   
  File "/ids-cluster-storage/storage/ocifka/venv/lra/lib/python3.7/site-packages/absl/app.py", line 251, in _run_main                                                                                       
    sys.exit(main(argv))                                                                                                                                                                                    
  File "lra_benchmarks/listops/train.py", line 193, in main                                                                                                                                                 
    model_kwargs)                                                                                                                                                                                           
  File "/ids-cluster-storage/storage/ocifka/venv/lra/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 133, in reraise_with_filtered_traceback                                                  
    return fun(*args, **kwargs)                                                                                                                                                                             
  File "/ids-cluster-storage/storage/ocifka/venv/lra/lib/python3.7/site-packages/jax/api.py", line 371, in f_jitted                                                                                         
    return cpp_jitted_f(*args, **kwargs)                                                                                                                                                                    
ValueError: Non-hashable static arguments are not supported. An error occured while trying to hash an object of type <class 'dict'>, {'vocab_size': 16, 'emb_dim': 512, 'num_heads': 8, 'num_layers': 6, 'qk
v_dim': 512, 'mlp_dim': 2048, 'max_len': 2000, 'classifier': True, 'num_classes': 10}. The error was:
TypeError: unhashable type: 'dict'

I have jax==0.2.6 and jaxlib==0.1.57+cuda101.

ModuleNotFoundError: No module named 'flax.deprecated'

Hello! Thanks for offering this benchmark!
However, when I use the following commands

PYTHONPATH="$(pwd)":"$PYTHON_PATH" python lra_benchmarks/listops/train.py \
      --config=lra_benchmarks/listops/configs/transformer_base.py \
      --model_dir=/tmp/listops \
      --task_name=basic \
      --data_dir=$HOME/lra_data/listops/

I got this message

Traceback (most recent call last):
  File "lra_benchmarks/listops/train.py", line 28, in <module>
    from flax.deprecated import nn
ModuleNotFoundError: No module named 'flax.deprecated'

Is there any solutions?

Are you interested in publishing to huggingface/datasets ?

It's a little bit hard for Pytorch users to evaluate their models on the benchmark.

Are you willing to import your datasets to huggingface/datasets ?
There are detailed steps about how to add a dataset. (https://huggingface.co/docs/datasets/add_dataset.html), and it shouldn't be hard since you can refer to the processing scripts of other datasets.

If this benchmark can be imported to huggingface/datasets, which then provides use for Numpy/Pandas/PyTorch/TensorFlow/JAX, I believe it will become more accessible and prevailed.

Nested Structure in ListOps

Hi,

According to a previous issue #6, the bracketed parentheses '[', ']' should be preserved during tokenization, resulting to the following set of unique tokens

( ']', '1', '4', '7', '9', '8', '2', '5', '3', '0', '6', '[MIN', '[SM', '[MAX', '[MED']).

However, based on your code (input_pipeline.py), the following set of tokens is used to encode the input

('SM', '2', 'MAX', '8', '5', '0', '4', '1', '7', '6', '3', 'MIN', '9', 'MED').

Evidently, if you run

encoder.decode(x)

where x is an encoded input sample, the result does not contain the bracketed parentheses '[', ']. For instance, the following is the (truncated) outcome of decoding the first encoded sample in the training set

MIN 1 MED 8 0 1 3 0 6 1 7 MED SM 4 MED SM 2 9 MED 9 9 2 MIN 8 3 5 4 8 5 6 2 6 0 5 8 2 6 8 9 8 3 4 SM 3 SM MIN MIN MED 5 8 7 9 1 7 1 8 8 MAX 2 5 7 1 1 0 MAX 6 2 5 0 MIN 2 2 4 8 7 1 MIN 1 SM 3 6 4 0 5 5 MED 6 3 8 0 5 6 4 3 7 8 0 6 8 8 3 MAX 5 1 3 5 MED 3 SM 3 2 6 MAX 2 MIN 8 MAX 8 2 5 8 7 1 5 9 0 0 9 MAX 8 1 MAX 5 9 4 1 1 3 2 3 2 2 MIN 1 9 0 2 9 3 8 MAX ....

So I was wondering why the bracketed parentheses '[', ']' are ignored? This way the nested structure is no longer preserved.

Hyperparameters of each task to reproduce table 1 in paper

Hi,

I am looking for the hyperparameters of each task to reproduce table 1.
The hyperparameters in base_*_config.py do not match what is reported in Appendix, but when I tried using the hyperparameters reported in the paper, the result does not seem to match the scores in table 1.
Could you provide the hyperparameters you used to produce table 1?

Thanks!

bug in Pathfinder-128 dataset

Hi,

I trained a baseline CNN model on the 4 provided pathfinder datasets (32x32, 64x64, 128x128, 256x256) and it achieved good results on 32x32, 64x64, and 256x256, but only random guessing on 128x128. This seems to indicate that the 128x128 dataset has a bug. In more detail:

I extracted the provided .gz which is organized like this:

lra_release/
  listops-1000/
  pathfinder128/
  lra_release/
      listops-1000/
      tsv_data/
      pathfinder32/
      pathfinder64/
      pathfinder128/
      pathfinder256/

I created a dataset for each of the 4 pathfinderXX folders using the same processing. More specifically, for each dataset I used the 200k "curve length 14" examples with a 160k/20k/20k split.
I trained a basic ResNet18 model on each of these datasets; the pre-processing transforms and model were exactly the same in all cases.
This ResNet achieved 80+% validation/test accuracy on each of the 32x32, 64x64, and 256x256 datasets. On the 128x128 dataset, after training for many epochs it achieved over 95% train accuracy, but never achieved more than random guessing on validation/test.

I have no idea what is causing the issue, but this seems like the data for Pathfinder128 must be different somehow. Any leads on the issue will be greatly appreciated.

Computing required attention span

Hi,

I just have a general clarification question: for the required attention span mentioned in the Long Range Arena paper, do you calculate the distance from the attended keys to only the last query token in a sequence? In other words, the maximum possible distance is always 1K (2K, or 4K, respectively, depending on the tasks). Also, I am wondering how you deal with query tokens that are in the middle of a sequence.

Thank you.

Error when run document retrival

Hi, thanks for the great code,
I am having some issues when trying to run the document retrieval tasks.
I got following issue when trying to run matching/train.py using the base transformer network:

Traceback (most recent call last): File "lra_benchmarks/matching/train.py", line 320, in <module> app.run(main) File "/mnt/lustre/share/spring/conda_envs/miniconda3/envs/s0.3.4/lib/python3.6/site-packages/absl/app.py", line 303, in run _run_main(main, args) File "/mnt/lustre/share/spring/conda_envs/miniconda3/envs/s0.3.4/lib/python3.6/site-packages/absl/app.py", line 251, in _run_main sys.exit(main(argv)) File "lra_benchmarks/matching/train.py", line 197, in main init_rng, input_shape) File "/mnt/lustre/sunweixuan/long-range-arena/lra_benchmarks/utils/train_utils.py", line 52, in get_model *create_model_args) TypeError: create_model() missing 1 required positional argument: 'input2_shape'

It seems that it took two 'input_shape' then modified? see here:
093bfc6.

Then, when I input two 'input_shape1' and 'input_shape2', above issue is solved but I got a new error:

Traceback (most recent call last): File "lra_benchmarks/matching/train.py", line 321, in <module> app.run(main) File "/mnt/lustre/share/spring/conda_envs/miniconda3/envs/s0.3.4/lib/python3.6/site-packages/absl/app.py", line 303, in run _run_main(main, args) File "/mnt/lustre/share/spring/conda_envs/miniconda3/envs/s0.3.4/lib/python3.6/site-packages/absl/app.py", line 251, in _run_main sys.exit(main(argv)) File "lra_benchmarks/matching/train.py", line 198, in main init_rng, input_shape, input_shape) File "/mnt/lustre/sunweixuan/long-range-arena/lra_benchmarks/utils/train_utils.py", line 52, in get_model *create_model_args) File "lra_benchmarks/matching/train.py", line 71, in create_model return _create_model(key) File "lra_benchmarks/matching/train.py", line 67, in _create_model (input2_shape, jnp.float32)]) File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/nn/base.py", line 238, in wrapper return super_fn(*args, **kwargs) File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/nn/base.py", line 238, in wrapper return super_fn(*args, **kwargs) File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/nn/base.py", line 536, in init_by_shape return jax_utils.partial_eval_by_shape(lazy_init, input_specs) File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/jax_utils.py", line 116, in partial_eval_by_shape _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals) File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/jax_utils.py", line 110, in <lambda> f = lambda *inputs: fn(*inputs, *args, **kwargs) File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/nn/base.py", line 533, in lazy_init return init_fn() File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/nn/base.py", line 527, in init_fn return cls.init(_rng, *(inputs + args), name=name, **kwargs) File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/nn/base.py", line 238, in wrapper return super_fn(*args, **kwargs) File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/nn/base.py", line 238, in wrapper return super_fn(*args, **kwargs) File "/mnt/lustre/sunweixuan/.local/lib/python3.6/site-packages/flax/nn/base.py", line 489, in init y = instance.apply(*args, **kwargs) TypeError: apply() got multiple values for argument 'vocab_size'

Could you let me know what's the possible solution?

ListOps performance

On running the ListOps task as-is from the repo, I got a validation performance similar to that reported in the paper but the test performance on results.json is very low:

{"accuracy": 0.17500001192092896, "loss": 3.032956123352051, "perplexity": 20.758506774902344}

I saw that the code is saving the model from the last checkpoint as compared to the model with the best validation performance. Could you detail the evaluation setup used in the paper i.e. for the paper do you evaluate the model from the last checkpoint of from the best validation checkpoint?

Thank you very much! :-)

typo in readme code block

PYTHONPATH="$(pwd)":"$PYTHON_PATH" python lra_benchmarks/listops/train.py \ --config=lra_benchmarks/listops/configs/transformer_base.py \ --model_dir=/tmp/listops \ --task_name=basic \ --data_dir=$HOME/lra_data/listops/

PYTHON_PATH should be changed to PYTHONPATH

Different hyper-parameters used for different models in image task.

Hi,

I found that different hyper-parameters (number of layers, dimension, etc.) are used for different models.
Can you clarify how the baselines are compared?

For example,
https://github.com/google-research/long-range-arena/blob/main/lra_benchmarks/image/configs/cifar10/longformer_base.py

config.model_type = "longformer"
config.model.num_layers = 4
config.model.emb_dim = 128
config.model.qkv_dim = 64
config.model.mlp_dim = 128
config.model.num_heads = 4
config.model.classifier_pool = "MEAN"

https://github.com/google-research/long-range-arena/blob/main/lra_benchmarks/image/configs/cifar10/performer_base.py

config.model_type = "performer"
config.model.num_layers = 1
config.model.emb_dim = 128
config.model.qkv_dim = 64
config.model.mlp_dim = 128
config.model.num_heads = 8
config.model.classifier_pool = "CLS"

https://github.com/google-research/long-range-arena/blob/main/lra_benchmarks/image/configs/cifar10/reformer_base.py

config.model_type = "reformer"
config.model.num_layers = 4
config.model.emb_dim = 64
config.model.qkv_dim = 32
config.model.mlp_dim = 64
config.model.num_heads = 8
config.model.classifier_pool = "CLS"

Perceiver on LRA

Did anybody try out the Perceiver model on LRA?
It should be fairly easy to fit into this codebase, as both are written in jax
I'd be super curious to see the results it'll produce

Request about cuda version when using GPUs

Hi, Thanks for your code.

I try to run this code on GPUs. But I do suffer a lot in the environment setting.

Can you provide the version of CUDA, jaxlib, jax, flax, when using GPUs for training?

Text Classification Data

Couldn't find the text classification data in the gz file. Is it missing or goes under a different name?

Thanks,
Guy

Is it possible to include instructions on how to run it on GPUs

This code seems to be using 4x4 TPUs, but since I don't have access to TPUs, I wonder if you could release instructions on how to replicater the results on GPUs, which would make this code more accessible for people without abundant computation resources.

Pathfinder task cannot converge.

I try to run pathfinder32 based on this dataset and run 5 times, 3 out of 5 cannot converge, and the loss keep 0.6933 until the end, but 2 of them can converge normally, and get final acc of 75%(bigbird). It is pretty random. Then I try different models(performer), and it never converge again. But the cifar10 task, which using the same train code with pathfinder32, converge all the times. It's the problem of the dataset?

I0831 15:32:43.578928 140327465420608 train.py:276] eval in step: 16224, loss: 0.6931, acc: 0.5017
I0831 15:33:00.912956 140327465420608 train.py:242] train in step: 16536, loss: 0.6932, acc: 0.5011
I0831 15:33:02.813938 140327465420608 train.py:276] eval in step: 16536, loss: 0.6932, acc: 0.4983
I0831 15:33:21.293757 140327465420608 train.py:242] train in step: 16848, loss: 0.6931, acc: 0.5018
I0831 15:33:23.183998 140327465420608 train.py:276] eval in step: 16848, loss: 0.6932, acc: 0.4983
I0831 15:33:41.210031 140327465420608 train.py:242] train in step: 17160, loss: 0.6932, acc: 0.4997
I0831 15:33:43.294295 140327465420608 train.py:276] eval in step: 17160, loss: 0.6931, acc: 0.4983

Pretrained models

Have the pretrained models that were used to report the accuracy been released? If so, could someone please direct me to where I can find them?

Linear transformer performance

Hello again!

Thank you for your work and for open-sourcing the codebase! I tried to run a listops experiment with Linear transformer and got results on the test that did not correspond to results proposed in the paper:

{"accuracy": 0.27000001072883606, "loss": 2.579371690750122, "perplexity": 13.188848495483398}

The model config was absolutely identical to the one used in transformer and the only thing I changed in the lra_benchmark/listops/train.py, was the following lines:

  if model_type == 'transformer':
    model = create_model(init_rng, transformer.TransformerEncoder, input_shape,
                         model_kwargs)
  elif model_type == 'linear_transformer':
      model = create_model(init_rng, linear_transformer.LinearTransformerEncoder, input_shape,
                           model_kwargs)
  else:
    raise ValueError('Model type not supported')

The experiments were run inside a docker nvidia/cuda:11.0-cudnn8-devel-ubuntu18.04 on 4 Tesla T4 with requirements:

jax>=0.2.4
flax>=0.2.2
ml-collections>=0.1.0
tensorboard>=2.3.0
tensorflow>=2.3.1
tensorflow-datasets>=4.0.1

and

pip install --upgrade jaxlib==0.1.65+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

I can share the full Dockerfile if needed

AAN dataset crashing when loading .tsv file

Did anyone else have issues loading the AAN dataset into memory? In particular when I load the .tsv file into memory, it crashes :/ I used several different instances on Google Cloud, with varying amount of memory, up to 170G, 24 cpus, but it still crashed. I feel like I am missing something. Here's my snippet of code that crashes the instance every time.

from datasets import DatasetDict, Value, load_dataset
...

        dataset = load_dataset(
            "csv",
            data_files={
                "train": str(self.data_dir / "new_aan_pairs.train.tsv"),  # 8G file
                "val": str(self.data_dir / "new_aan_pairs.eval.tsv"),
                "test": str(self.data_dir / "new_aan_pairs.test.tsv"),
            },
            delimiter="\t",
            column_names=["label", "input1_id", "input2_id", "text1", "text2"],
            keep_in_memory=True,

Quadratic Longformer suspicion

Hi!
I am now checking the Longformer implementation and It seems that nn.attention.dot_product_attention() (with attention pattern passed through the bias parameter) does all the heavy lifting.
But in the nn.attention.dot_product_attention() or a newer version linen.attention.dot_product_attention() the multiplication goes first and you only use the mask after the multiplication. So you still have the quadratic computation.
Can you explain, please, how you bypass the quadratic computation?

jax report that "No GPU/TPU found, falling back to CPU"

in the requirement.txt it requires jax>=0.2.4
when incurring this problem as in the title, the jax github homepages says for both support of GPU and CPU,
one needs to install as

pip install --upgrade jax jaxlib==0.1.67+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
The jaxlib version must correspond to the version of the existing CUDA installation you want to use

but that jaxlib repo link does not contain any version >= 0.2.4,
actually the newest version there is 0.1.67

How do you install the environment ?

Cannot reproduce the results for cifar10

the same question is issued in
https://github.com/google-research/long-range-arena/issues/16

but for the vallina transformer, using the hyperparameters in the reply
i cannot reproduce the results in the LRA paper ( 2 V100/32GB card)
nor does the default settings in the code do
can you kindly share a ready version of hyperparameters

  • my result, test in step: 34999, loss: 1.8261, acc: 0.3633
  • paper result: 42.44 (same with the latest V2 release result)

my hyperparameters from the issue 16:
batch_size: 256
checkpoint_freq: 17500
eval_frequency: 175
factors: constant * linear_warmup * cosine_decay
grad_clip_norm: null
learning_rate: 0.0005
model:
attention_dropout_rate: 0.2
classifier_pool: CLS
dropout_rate: 0.3
emb_dim: 128
learn_pos_emb: true
mlp_dim: 128
num_heads: 1
num_layers: 1
qkv_dim: 64
model_type: transformer
num_eval_steps: 39
num_train_steps: 35000
random_seed: 0
restore_checkpoints: true
save_checkpoints: true
steps_per_cycle: 35000
trial: 0
warmup: 175
weight_decay: 0.0

validation on IMDB

Hi, could you tell me how you decide the best model trained on the IMDB dataset since there is no validation set? Do you use the last checkpoint for testing? Thanks!

Is there a pytorch equivalent of this implementation?

I searched for a PyTorch implementation of this benchmark and came across several papers that used it as a benchmark for their methods. However, I was unable to find a standardized version. If you have any leads, I would appreciate it. Thank you!

Q's on Performer & Text Classification

Thanks for the great work. I had a couple questions when trying to reproduce the Performer on the Byte Level Text Classification:

  1. What Kernel Function are you using? (Softmax approximation or Relu?)
  2. I found the training to be very instable. Do you take the final model after 20K steps or do you take the best checkpoint?
  3. With the learning rate scheduler you use, the learning rate is 0 if the first step is 0 isn't it? Shouldn't you instead start your training loop with for step in range(1, X) at https://github.com/google-research/long-range-arena/blob/main/lra_benchmarks/text_classification/train.py

Looking forward to the implementations of the other models, thanks!

Tied weights for ListOps Transformer

Hi,

According to the configuration file the transformer model (with vanilla softmax attention) used in the ListOps task is weight tied (for both self-attention module and the feedforward network). Does that also apply to the other transformer variants listed in the comparison? For example, is Linear Transformer also weight-tied? I couldn't find anything relevant in the paper.

Many thanks in advance!

Configs for Image Classification (cifar10)

Thanks for the great work!
I have a question regarding the hyperparams for training cifar10. I used the setting in this repo and replaced several hyperparams (eg n_layers n_heads etc) with the ones reported in the paper, but the best testing acc I got was 0.36:

import ml_collections

NUM_EPOCHS = 200
TRAIN_EXAMPLES = 45000
VALID_EXAMPLES = 10000

def get_config():
"""Get the default hyperparameter configuration."""
config = ml_collections.ConfigDict()
config.batch_size = 256
config.eval_frequency = TRAIN_EXAMPLES // config.batch_size
config.num_train_steps = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS
config.num_eval_steps = VALID_EXAMPLES // config.batch_size
config.weight_decay = 0.
config.grad_clip_norm = None

config.save_checkpoints = True
config.restore_checkpoints = False
config.checkpoint_freq = (TRAIN_EXAMPLES //
config.batch_size) * NUM_EPOCHS // 2
config.random_seed = 0

config.learning_rate = .0005 (if using 0.01 from the paper, the loss is not going down)
config.factors = 'constant * linear_warmup * cosine_decay'
config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1
config.steps_per_cycle = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS

model params

config.model = ml_collections.ConfigDict()
config.model.emb_dim = 32
config.model.num_heads = 4
config.model.num_layers = 3
config.model.qkv_dim = 64
config.model.mlp_dim = 128
config.model.dropout_rate = 0.3
config.model.attention_dropout_rate = 0.2
config.model.classifier_pool = 'CLS'
config.model.learn_pos_emb = True

config.trial = 0 # dummy for repeated runs.
return config

Could you point out which params I could adjust to match the accuracy (this is for full attention).

Script for computing memory consumption

Hello again!

I succeeded to run your code on my devices, but now I am struggling with computing memory consumption. Could you please opensource script for computing it?

How to run test

This might be a dumb question, but it seems that there's only a train.py which trains and prints validation stats. How to test the model to get numbers comparable to accuracy numbers in the table?

bracket in listops task?

Hello, thank you for sharing a great benchmark!
I'm focusing on 'listops' benchmark, with the provided codes and hyperparameters.

The paper says the maximum length of the sequence in this task is 2K, but it seems the code excluded all of the brackets in the sequence.

With the consideration of brackets ( '(', ')', '[', ']'), the maximum length becomes 6K which becomes larger sequence compared to the mentioned length in the paper.

Don't we have to consider such brackets as the input in this task?

Thank you.

The best checkpoint of Transformer

Hi,

Thanks for sharing this wonderful work! Could you please also share the finetuned checkpoint of the vanilla Transformer for all the tasks? Thank you very much!

Dataset for the matching task

Hi,

Is there available link to download the Document Retrieval dataset? The provided link in ReadMe seems broken.

Thanks

Problem training listops on GPU

Hello, I'm running into a strange issue training models (performer, bigbird, longformer) on listops.
The model works fine on CPU, but on GPU (either on one or multiple v100 16GB) it crashes with a strange error.
This doesn't happen with any other dataset I've tried in the benchmark. The error is:

E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_blas.cc:226] failed to create cublas handle: CUBLAS_STATUS_NOT_INITIALIZED
F external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc:113] Check failed: stream->parent()->GetBlasGemmAlgorithms(&algorithms)

The dataset was created with the included script:
PYTHONPATH="$(pwd)":"$PYTHON_PATH" python lra_benchmarks/data/listops.py --output_dir=lra_data/listops
Any idea what may be causing this?

Publish number of parameters for each task

Hello,

you mention: "The new model should be within at best 10% larger in terms of parameters compared to the base Transformer model in the provided config file"

Do you publish what those baseline number of params are respectively for each task?

Thanks

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.