Giter Club home page Giter Club logo

jaxdiversity's Introduction

jaxDiversity

CI Docs

This is an updated implementation for Neural networks embrace diversity paper

Authors

Anshul Choudhary, Anil Radhakrishnan, John F. Lindner, Sudeshna Sinha, and William L. Ditto

Link to paper

Key Results

  • We construct neural networks with learnable activation functions and sere that they quickly diversify from each other under training.
  • These activations subsequently outperform their pure counterparts on classification tasks.
  • The neuronal sub-networks instantiate the neurons and meta-learning adjusts their weights and biases to find efficient spanning sets of nonlinear activations.
  • These improved neural networks provide quantitative examples of the emergence of diversity and insight into its advantages.

Install

pip install jaxDiversity

How to use

The codebase has 4 main components:

  • dataloading: Contains tools for loading the datasets mentioned in the manuscript. We use pytorch dataloaders with a custom numpy collate function to use this data in jax.

  • losses: We handle both traditional mlps and hamiltonian neural networkss with minimal changes with our loss implementations.

  • mlp: Contains custom mlp that takes in multiple activations and uses them intralayer to create a diverse network. Also contains the activation neural networks.

  • loops: Contains the inner and outer loops for metalearning to optimize the activation functions in tandem with the supervised learning task

Minimum example

import jax
import optax

from jaxDiversity.utilclasses import (
    InnerConfig,
    OuterConfig,
)  # simple utility classes for configuration consistency
from jaxDiversity.dataloading import NumpyLoader, DummyDataset
from jaxDiversity.mlp import (
    mlp_afunc,
    MultiActMLP,
    init_linear_weight,
    xavier_normal_init,
    save,
)
from jaxDiversity.baseline import compute_loss as compute_loss_baseline
from jaxDiversity.hnn import compute_loss as compute_loss_hnn
from jaxDiversity.loops import inner_opt, outer_opt

inner optimzation or standard training loop with the baseline activation

dev_inner_config = InnerConfig(
    test_train_split=0.8,
    input_dim=2,
    output_dim=2,
    hidden_layer_sizes=[18],
    batch_size=64,
    epochs=2,
    lr=1e-3,
    mu=0.9,
    n_fns=2,
    l2_reg=1e-1,
    seed=42,
)
key = jax.random.PRNGKey(dev_inner_config.seed)
model_key, init_key = jax.random.split(key)
afuncs = [lambda x: x**2, lambda x: x]
train_dataset = DummyDataset(
    1000, dev_inner_config.input_dim, dev_inner_config.output_dim
)
test_dataset = DummyDataset(
    1000, dev_inner_config.input_dim, dev_inner_config.output_dim
)
train_dataloader = NumpyLoader(
    train_dataset, batch_size=dev_inner_config.batch_size, shuffle=True
)
test_dataloader = NumpyLoader(
    test_dataset, batch_size=dev_inner_config.batch_size, shuffle=True
)

opt = optax.rmsprop(
    learning_rate=dev_inner_config.lr,
    momentum=dev_inner_config.mu,
    decay=dev_inner_config.l2_reg,
)
model = MultiActMLP(
    dev_inner_config.input_dim,
    dev_inner_config.output_dim,
    dev_inner_config.hidden_layer_sizes,
    model_key,
    bias=False,
)
baselineNN, opt_state, inner_results = inner_opt(
    model=model,
    train_data=train_dataloader,
    test_data=test_dataloader,
    afuncs=afuncs,
    opt=opt,
    loss_fn=compute_loss_baseline,
    config=dev_inner_config,
    training=True,
    verbose=True,
)

metalearning with Hamiltonian Neural Networks

inner_config = InnerConfig(
    test_train_split=0.8,
    input_dim=2,
    output_dim=1,
    hidden_layer_sizes=[32],
    batch_size=64,
    epochs=5,
    lr=1e-3,
    mu=0.9,
    n_fns=2,
    l2_reg=1e-1,
    seed=42,
)
outer_config = OuterConfig(
    input_dim=1,
    output_dim=1,
    hidden_layer_sizes=[18],
    batch_size=1,
    steps=2,
    print_every=1,
    lr=1e-3,
    mu=0.9,
    seed=24,
)
train_dataset = DummyDataset(1000, inner_config.input_dim, 2)
test_dataset = DummyDataset(1000, inner_config.input_dim, 2)
train_dataloader = NumpyLoader(
    train_dataset, batch_size=inner_config.batch_size, shuffle=True
)
test_dataloader = NumpyLoader(
    test_dataset, batch_size=inner_config.batch_size, shuffle=True
)

opt = optax.rmsprop(
    learning_rate=inner_config.lr, momentum=inner_config.mu, decay=inner_config.l2_reg
)
meta_opt = optax.rmsprop(learning_rate=outer_config.lr, momentum=outer_config.mu)

HNN_acts, HNN_stats = outer_opt(
    train_dataloader,
    test_dataloader,
    compute_loss_hnn,
    inner_config,
    outer_config,
    opt,
    meta_opt,
    save_path=None,
)

Link to older pytorch codebase with classification problem: DiversityNN

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.