Giter Club home page Giter Club logo

homura's Introduction

Homura CircleCI

document

Homura is a support tool for research experiments.

Requirements

minimal requirements

Python>=3.7
PyTorch>=1.0
torchvision>=0.2.1
tqdm # automatically installed

optional

tensorboardX
miniargs
colorlog
optuna

To enable distributed training using Synced BN and FP 16, install apex.

git clone https://github.com/NVIDIA/apex.git
cd apex
python setup.py install --cuda_ext --cpp_ext

test

pytest .

install

pip install git+https://github.com/AiMizuno/homura

or

git clone https://github.com/AiMizuno/homura
cd homura; pip install -e .

APIs

basics

  • Device Agnostic
  • Useful features
from homura import optim, lr_scheduler
from homura import trainers, callbacks, reporters
from torchvision.models import resnet50
from torch.nn import functional as F

# model will be registered in the trainer
resnet = resnet50()
# optimizer and scheduler will be registered in the trainer, too
optimizer = optim.SGD(lr=0.1, momentum=0.9)
scheduler = lr_scheduler.MultiStepLR(milestones=[30,80], gamma=0.1)

# list of callbacks or reporters can be registered in the trainer
with reporters.TensorboardReporter([callbacks.AccuracyCallback(), 
                                    callbacks.LossCallback()]) as reporter:
    trainer = trainers.SupervisedTrainer(resnet, optimizer, loss_f=F.cross_entropy, 
                                         callbacks=reporter, scheduler=scheduler)

Now iteration of trainer can be updated as follows,

from homura.utils.containers import Map

def iteration(trainer: Trainer, data: Tuple[torch.Tensor]) -> Mapping[torch.Tensor]:
    input, target = data
    output = trainer.model(input)
    loss = trainer.loss_f(output, target)
    results = Map(loss=loss, output=output)
    if trainer.is_train:
        trainer.optimizer.zero_grad()
        loss.backward()
        trainer.optimizer.step()
    # registered values can be called in callbacks
    results.user_value = user_value
    return results

SupervisedTrainer.iteration = iteration
# or   
trainer.update_iteration(iteration) 

Also, dict of models, optimizers, loss functions are supported.

trainer = CustomTrainer({"generator": generator, "discriminator": discriminator},
                        {"generator": gen_opt, "discriminator": dis_opt},
                        {"reconstruction": recon_loss, "generator": gen_loss},
                        **kwargs)

reproductivity

from homura.reproductivity import set_deterministic
set_deterministic(1)

debugger

>>> debug.module_debugger(nn.Sequential(nn.Linear(10, 5), 
                                        nn.Linear(5, 1)), 
                          torch.randn(4, 10))
[homura.debug|2019-02-25 17:57:06|DEBUG] Start forward calculation
[homura.debug|2019-02-25 17:57:06|DEBUG] forward> name=Sequential(1)
[homura.debug|2019-02-25 17:57:06|DEBUG] forward>   name=Linear(2)
[homura.debug|2019-02-25 17:57:06|DEBUG] forward>   name=Linear(3)
[homura.debug|2019-02-25 17:57:06|DEBUG] Start backward calculation
[homura.debug|2019-02-25 17:57:06|DEBUG] backward>   name=Linear(3)
[homura.debug|2019-02-25 17:57:06|DEBUG] backward> name=Sequential(1)
[homura.debug|2019-02-25 17:57:06|DEBUG] backward>   name=Linear(2)
[homura.debug|2019-02-25 17:57:06|INFO] Finish debugging mode

Examples

See examples.

  • cifar10.py: training ResNet-20 or WideResNet-28-10 with random crop on CIFAR10
  • imagenet.py: training a CNN on ImageNet on multi GPUs (single and multi process)
  • gap.py: better implementation of generative adversarial perturbation

homura's People

Contributors

moskomule avatar aimizuno avatar

Watchers

James Cloos avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.