Giter Club home page Giter Club logo

group_dro's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

group_dro's Issues

not being able to reproduce your results on CUB and CelebA

With the command in repo for MNLI, I am not able to reproduce your results.
Could you send me the command which reproduce your best results on CUB and CeleBA?
Especially for that with Large Weight Decay and Early stop
Thanks a lot.

Codalab worksheet link is broken

The link to the Codalab worksheet is broken. The page shows "Not found: '/worksheets/0x621811fe446b49bb818293bae2ef88c0'." Could you please update it? Thank you!

generating MNLI glue files

Hi
Could you provide the command to re-generate your cached_mnli_files. Also, Is this possible to have the code working from the raw text data of MNLI. thanks.

Where is the random group chosen?

First, thanks for the nice work!
In your paper you show the following:
image

I am having difficulty trying to find the part in your code corresponding to randomly picking a group: g~Uniform(1,...m)
Please could you tell me where can I find it?

Questions about the L2-penalties

It seems that there is no l2-penalities implementation in this code. Should I implement it myself to reproduce the results in the paper?

ERM should also save best model based on worst-group accuracy

Hi @kohpangwei and @ssagawa ,

Your paper mentioned that "All benchmark models were evaluated at the best early stopping epoch (as measured by robust validation accuracy)." However, your code

group_DRO/train.py

Lines 205 to 209 in ca58872

if args.save_best:
if args.robust or args.reweight_groups:
curr_val_acc = min(val_loss_computer.avg_group_acc)
else:
curr_val_acc = val_loss_computer.avg_acc

indicates that (i) for ERM, the best model is determined by the average validation accuracy (ii) for reweighting and GroupDRO, the best model is determined by the worst-group accuracy. I think for a fair comparison, the model selection rule of (i) should be changed to be identical to (ii), what do you think?

Did you use the model selection rule (i) for ERM in your paper's experiments (e.g., Table 3)? I'm trying to reproduce your results, but I'm not sure if your results on ERM are from (i) or not.

Algorithm in paper and random groups per batch

Following up on previous question #7
Please can you clarify whether we need to sample only one group at each iteration or it is OK to have multiple groups in a batch? In your algorithm, it seems to say that we need to sample only one group at each iteration, but this doesn't seem to be the case in the code.

image

Additionally, please can you comment on the following remark from this paper https://arxiv.org/abs/2010.12230 ?

image

Is this remark justified?

Why is reweight_groups flag set for DRO algorithm? Possible unfair comparison to ERM?

Dear authors, Thank you for sharing a well polished codebase!

For the results in table 1, I have noticed that DRO method is always run with "reweight_groups" flag set to "True", whereas the same flag is "False" for the ERM algorithm [1]. As per the code, the "reweight_groups" flag performs a weighted random sampling guaranteeing an equal count of each group in any given batch. On the other hand, the ERM algorithm receives a smaller count of the minority sample as there is no weighted random sampling. Such a difference in implementation between ERM and GroupDRO suggests for an unfair comparison between the two methods.

Surely, as pointed out in the comment [2], the loss function could be considered unaffected by the "reweight_groups" flag as the DRO method uses the mean of per-group losses. However, the empirical estimate of these means in a given batch would be highly noisy when the sample count of the minority group is very small. This makes me wonder (and I hope it's okay for me to ask), that the gains reported in the paper are attributed solely to the use of weighted random sampling procedure rather than DRO update rule? Please clarify

Do you have any comparisons of the DRO algorithm with "reweight_groups" flag set to "False"? How does ERM with "reweight_flag=True" compare to ERM with "reweight_flag=False"?

Thank you

[1] https://worksheets.codalab.org/worksheets/0x621811fe446b49bb818293bae2ef88c0
[2]

# since the minibatch is only used for mean gradient estimation for each group separately

not being able to reproduce your results on MNLI

Hi
With the command in repo for MNLI, I am not able to reproduce your results.
Could you send me the command which reproduce your best results on MNLI?
Basically seems some argument are missing from that command. thanks.

Setting of ERM baseline

Thanks for the good work,
Can you please tell me how to obtain three datasets' ERM baseline, e.g., the commands?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.