Giter Club home page Giter Club logo

drbayes's People

Contributors

andrewgordonwilson avatar dependabot[bot] avatar izmailovpavel avatar wjmaddox avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

drbayes's Issues

Dependencies

Hello!

I was trying to run the visualisation jupyter notebook and had a couple of issues (I installed the repo as a package as mentioned):

  • In setup.py one of the requirements is scikit_learn>=0.20.2. However , currently version 0.24 gets installed which actually throws some errors (I think some things have been refactored in the newest versions). I had to downgrade to 0.20.2 to get it to work, so I believe this should be changed to scikit_learn==0.20.2 (I actually just noticed that this is specified correctly in requirements.txt, just not in setup.py)
  • It would maybe be helpful if the correct Python version needed was specified? setup.py says "3.6" but 3.6.0 and 3.6.1 don't actually work (sorry I can't remember exactly now what the issues were, but I ended up using 3.6.9 to get it to work)
  • Additionally, things that I had to install manually, which were not included in setup.py: seaborn and pyro-ppl

And one more thing, in the curve subspaces section of the notebook it says "Note that for this to work you need to add the repo https://github.com/timgaripov/dnn-mode-connectivity to your Python path". However, I think the version of the curves.py file in that repo might not be the same as that used in the notebook. In fact, the notebook calls

model = curves.CurveNet(curve, architecture, 3, fix_end=True, fix_start=True, 
                        architecture_kwargs=model_cfg.kwargs)

but in the original repo that function is written as:

class CurveNet(Module):
    def __init__(self, num_classes, curve, architecture, num_bends, fix_start=True, fix_end=True,
                 architecture_kwargs={}):

so I believe the num_classes argument is missing in the function call in the notebook.

Thanks!

Failed to reproduce the results presented in the paper

Hello,

The approach and results presented in the paper "Subspace Inference for Bayesian Deep Learning" impress me a lot and I would like to reproduce the experiments, specifically the UCI experiments, using this repository.

I followed the instructions and process described in the README and managed to get things run. However, I failed to obtain similar numbers for certain datasets.

To be more precise:

  • versions of libraries match.

  • the datasets are downloaded from Google Drive as suggested in the README under drbayes/experiments/uci_exps.

  • all experiments on small UCI datasets are launched by the script run_ucismall.sh under drbayes/experiments/uci_exps/bayesian_benchmarks/tasks.

  • the reproduced unnormalized test likelihoods for the yacht dataset are significantly different from those reported in the paper as shown below.

paper result reproduce result
PCA+ESS (SI) -0.225 ± 0.400 -2.493 ± 0.067
SWAG -0.404 ± 0.418 -2.545 ± 0.053

I would be extremely grateful if anyone could suggest what I could've done wrong in my reproduction experiments.

Thank you in advance!

Bug in SWAG implementation?

Hello @wjmaddox,

Thanks for sharing your code. I found your paper very inspiring.

I wonder if the sampling of SWAG is correct. In particular, it appears that the standard Gaussian sample is multiplied by the vector of variances for the SWAG-Diagonal part:

z += variance * torch.randn_like(variance)

Whereas in the original SWAG paper and its implementation the standard Gaussian is multiplied by the vector of standard deviations to sample from N(0, diag(variances)):

https://github.com/wjmaddox/swa_gaussian/blob/b172d93278fdb92522c8fccb7c6bfdd6f710e4f0/swag/posteriors/swag.py#L121

Am I missing something? Thanks for your time.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.