Giter Club home page Giter Club logo

Comments (12)

MostafaDehghani avatar MostafaDehghani commented on May 22, 2024 1

No problem at all!
As far as I remember, you will observe no improvement in the metrics we care about in almost all models after some number of training steps even if the loss is still changing (most of the time fluctuating). So we have chosen to fix the number of epochs to 200. I had runs with 1000 epochs but you don't see significant improvement in the "accuracy".

from long-range-arena.

MostafaDehghani avatar MostafaDehghani commented on May 22, 2024

@liuyang148 I think by "coverage" you mean "converge" (or please correct me if I'm wrong)?
In that case, I want to say that Pathfinder is a difficult task for transformers (and any other architecture that has no recurrence or an inductive bias for modeling transitivity). So what you're observing is simply the struggle of these models to pick up the task. That's actually one of the main reasons that we included the pathfinder in LRA.

from long-range-arena.

liuyang148 avatar liuyang148 commented on May 22, 2024

Yes, I mean 'converge', forgive my bad english.
Then, which result did the paper record. Only 'converge' one and ignore 'none-converge' results?

from long-range-arena.

liuyang148 avatar liuyang148 commented on May 22, 2024

OK, I got it. Thanks for your help.

from long-range-arena.

jnhwkim avatar jnhwkim commented on May 22, 2024

@MostafaDehghani I understand the task is difficult to converge and learn. I tried three times with different config.random_seed for Performer, but it keeps failing to converge and test accuracies are around 50%. How can I reproduce the number in the paper, i.e., 77.05 (the best score in Table 1)

from long-range-arena.

yinzhangyue avatar yinzhangyue commented on May 22, 2024

@jnhwkim I encountered the same situation as you.

from long-range-arena.

MostafaDehghani avatar MostafaDehghani commented on May 22, 2024

@jnhwkim @yinzhangyue
Can you point me to the exact config file you're using in LRA codebase?

from long-range-arena.

yinzhangyue avatar yinzhangyue commented on May 22, 2024

I don't change the config file, here is the base_pathfinder32_config.py.

# Copyright 2021 Google LLC

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base Configuration."""

import ml_collections

NUM_EPOCHS = 200
TRAIN_EXAMPLES = 160000
VALID_EXAMPLES = 20000


def get_config():
  """Get the default hyperparameter configuration."""
  config = ml_collections.ConfigDict()
  config.batch_size = 512
  config.eval_frequency = TRAIN_EXAMPLES // config.batch_size
  config.num_train_steps = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS
  config.num_eval_steps = VALID_EXAMPLES // config.batch_size
  config.weight_decay = 0.
  config.grad_clip_norm = None

  config.save_checkpoints = True
  config.restore_checkpoints = True
  config.checkpoint_freq = (TRAIN_EXAMPLES //
                            config.batch_size) * NUM_EPOCHS // 2
  config.random_seed = 0

  config.learning_rate = .001
  config.factors = 'constant * linear_warmup * cosine_decay'
  config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1
  config.steps_per_cycle = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS

  # model params
  config.model = ml_collections.ConfigDict()
  config.model.num_layers = 1
  config.model.num_heads = 2
  config.model.emb_dim = 32
  config.model.dropout_rate = 0.1

  config.model.qkv_dim = config.model.emb_dim // 2
  config.model.mlp_dim = config.model.qkv_dim * 2
  config.model.attention_dropout_rate = 0.1
  config.model.classifier_pool = 'MEAN'
  config.model.learn_pos_emb = False

  config.trial = 0  # dummy for repeated runs.
  return config

from long-range-arena.

yinzhangyue avatar yinzhangyue commented on May 22, 2024

My Run Script.

PYTHONPATH="$(pwd)":"$PYTHON_PATH" python lra_benchmarks/image/train.py \
      --config=lra_benchmarks/image/configs/pathfinder32/performer_base.py \
      --model_dir=./tmp/pathfinder_F \
      --task_name=pathfinder32_hard

from long-range-arena.

MostafaDehghani avatar MostafaDehghani commented on May 22, 2024

I just checked and seems the configs in the repo is not synced with the internal config that we have for getting the results in the paper. Not sure what went wrong, but sorry for that. I'll work on updating the repo, but in the meantime, here are the configs that you should use in the performer config file to be able to get the reported score:

def get_config():
  """Get the default hyperparameter configuration."""
  config = base_pathfinder32_config.get_config()
  config.model_type = "performer"

  config.model.num_layers = 1
  config.model.num_heads = 8
  config.model.emb_dim = 128
  config.model.dropout_rate = 0.2
  config.model.qkv_dim = 64
  config.model.mlp_dim = 128

  return config

from long-range-arena.

yinzhangyue avatar yinzhangyue commented on May 22, 2024

Thank you! I will try it immediately.

from long-range-arena.

yinzhangyue avatar yinzhangyue commented on May 22, 2024

It works! Thank you very much! ^o^

from long-range-arena.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.