Giter Club home page Giter Club logo

enn's People

Contributors

dvikranth avatar hawkinsp avatar iosband avatar mehdijj avatar mibrahimi avatar mohammadasghari avatar rchen152 avatar romanngg avatar saran-t avatar superbobry avatar xlu0 avatar yilei avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

enn's Issues

No loss called: L2LossWithBootstrap

When trying to build my own enn, got the following error:

AttributeError                            Traceback (most recent call last)
Cell In[5], line 25
     17 enn = networks.MLPEnsembleMatchedPrior(
     18     output_sizes=[50, 50, 1],
     19     num_ensemble=10,
     20     dummy_input=np.zeros(50)
     21 )
     23 # Loss
     24 loss_fn = losses.average_single_index_loss(
---> 25     single_loss=losses.L2LossWithBootstrap(),
     26     num_index_samples=10
     27 )
     29 # Optimizer
     30 optimizer = optax.adam(1e-3)

AttributeError: module 'enn.losses' has no attribute 'L2LossWithBootstrap'

Here is the code that I created for the network:

from enn.loggers import TerminalLogger

from enn import losses
from enn import networks
from enn import supervised
from enn.supervised import regression_data
import optax
import numpy as np

# A small dummy dataset
dataset = regression_data.make_dataset()

# Logger
logger = TerminalLogger('supervised_regression')

# ENN
enn = networks.MLPEnsembleMatchedPrior(
    output_sizes=[50, 50, 1],
    num_ensemble=10,
    dummy_input=np.zeros(50)
)

# Loss
loss_fn = losses.average_single_index_loss(
    single_loss=losses.L2LossWithBootstrap(),
    num_index_samples=10
)

# Optimizer
optimizer = optax.adam(1e-3)

# Train the experiment
experiment = supervised.Experiment(
    enn, loss_fn, optimizer, dataset, seed=0, logger=logger)
experiment.train(FLAGS.num_batch)

Also not that to get the example to work, I had to add the line

    dummy_input=np.zeros(50)

otherwise I got an error that dummy_input was a required positional argument.

ENN Task generalizability?

Just came across this super interesting work! It seems like ENNs are a general module that can be adapted to any network but I was wondering if they only work for specific tasks like ImageNet or if it is possible to generalize them to other tasks such as segmentation, object detection, pose estimation, tracking etc? If so, an example would be amazing. Thanks!

Problem occured in enn_demo.ipynb

Hi!

I was trying the enn_demo.ipynb on google colab. Everything seems fine until I run this block of code.

# Train the experiment
experiment.train(FLAGS.num_batch)

and this error appears. Is there something wrong with the JAX version?

AttributeError                            Traceback (most recent call last)
[/usr/local/lib/python3.8/dist-packages/enn/networks/ensembles.py](https://kh9bbgsdon-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20221220-060108-RC02_496713401#) in apply(params, states, inputs, index)
     82       sub_states = jax.tree_map(particle_selector, states)
     83       out, new_sub_states = model.apply(sub_params, sub_states, inputs)
---> 84       new_states = jax.tree_multimap(
     85           lambda s, nss: s.at[index, ...].set(nss), states, new_sub_states)
     86       return out, new_states

AttributeError: module 'jax' has no attribute 'tree_multimap'

Thanks,
Adam

Colab tutorial link

Hi!

The google colab tutoral link does not work. There is a "notebook not found error" after clicking the link.

Colab error with haiku/jax dependency

A following error appears on executing the colab notebook:

AttributeError                            Traceback (most recent call last)
<ipython-input-7-4863fdbb2553> in <module>()
     13     num_ensemble=FLAGS.index_dim,
     14     prior_scale=FLAGS.prior_scale,
---> 15     seed=FLAGS.seed,
     16 )
     17 

4 frames
/usr/local/lib/python3.7/dist-packages/enn/networks/ensembles.py in __init__(self, output_sizes, dummy_input, num_ensemble, prior_scale, seed, w_init, b_init)
    137     """Ensemble of MLPs with matched prior functions."""
    138     mlp_priors = make_mlp_ensemble_prior_fns(
--> 139         output_sizes, dummy_input, num_ensemble, seed)
    140     enn = priors.EnnWithAdditivePrior(
    141         enn=MLPEnsembleEnn(

/usr/local/lib/python3.7/dist-packages/enn/networks/ensembles.py in make_mlp_ensemble_prior_fns(output_sizes, dummy_input, num_ensemble, seed, w_init, b_init)
     90     return hk.Sequential(layers)(x)
     91 
---> 92   transformed = hk.without_apply_rng(hk.transform(net_fn))
     93 
     94   prior_fns = []

/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py in transform(f, apply_rng)
    301         "Replace hk.transform(..., apply_rng=True) with hk.transform(...).")
    302 
--> 303   return without_state(transform_with_state(f))
    304 
    305 

/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py in transform_with_state(f)
    359   """
    360   analytics.log_once("transform_with_state")
--> 361   check_not_jax_transformed(f)
    362 
    363   unexpected_tracer_hint = (

/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py in check_not_jax_transformed(f)
    306 def check_not_jax_transformed(f):
    307   # TODO(tomhennigan): Consider `CompiledFunction = type(jax.jit(lambda: 0))`.
--> 308   if isinstance(f, (jax.xla.xe.CompiledFunction, jax.xla.xe.PmapFunction)):  # pytype: disable=name-error
    309     raise ValueError("A common error with Haiku is to pass an already jit "
    310                      "(or pmap) decorated function into hk.transform (e.g. "

AttributeError: module 'jaxlib.xla_extension' has no attribute 'PmapFunction'

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.