Giter Club home page Giter Club logo

torchutils's Introduction

TorchUtils

TorchUtils is a pytorch lib with several useful tools and some state-of-the-art training methods or tricks. (Work In Progress)

  • Rewirte the repo using pytorch 1.6 (because many tool functions or tricks now natively supported in PyTorch 1.6)

Import

import torch_utils as tu

Seed All

SEED = 42
tu.tools.seed_everything(SEED)

Data Augmentation

TODO:

Model

recommanded pretrained models:

from github repos:

fast build models with torch_utils:

import timm

model = timm.create_model('tresnet_m', pretrained=True)
model.global_pool = tu.layers.FastGlobalConcatPool2d(flatten=True)
model.head = tu.layers.get_attention_fc(2048*2, 1) 
model.cuda()
from pytorchcv.model_provider import get_model as ptcv_get_model

model = ptcv_get_model('seresnext50_32x4d', pretrained=True)
model.features.final_pool = tu.layers.GeM() 
model.output = tu.layers.get_simple_fc(2048, 1)   
model.cuda()

model utils:

# model summary
tu.models.summary(model, (3,224,224))

# 3 channels pretrained weights to 1 channel
weight_rgb = model.conv1.weight
weight_grey = weight_rgb.sum(dim=1, keepdim=True)
model.conv1 = nn.Conv2d(1, 64, kernel_size=xxx, stride=xxx, padding=xxx, bias=False)
model.conv1.weight = torch.nn.Parameter(weight_grey)

# 2D models to 3d models using ACSConv (advanced)
## using code in this repo: https://github.com/M3DV/ACSConv

Optimizer

optimizer_ranger = tu.Ranger(model_conv.parameters(), lr=LR)

# optimizer = torch.optim.AdamW(model_conv.parameters(), lr=LR, weight_decay=2e-4)

Criterion

TODO:

  • Criterions

Find LR

lr_finder = tu.LRFinder(model, optimizer, criterion, device="cuda")
lr_finder.range_test(train_loader, end_lr=10, num_iter=100)
lr_finder.plot() # to inspect the loss-learning rate graph
lr_finder.reset() # to reset the model and optimizer to their initial state

LR Scheduler

scheduler = tu.CosineAnnealingWarmUpRestarts(optimizer, T_0=T, T_mult=1, eta_max=LR, T_up=0, gamma=0.05)

# torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1)

# torch.optim.lr_scheduler.OneCycleLR
# tu.OneCycleScheduler

TTA:

TODO :

AMP

TODO: In pytorch 1.6 https://pytorch.org/docs/master/notes/amp_examples.html

TODO

  1. clean code using pytorch 1.6.0
  2. cutmix : https://github.com/ildoonet/cutmix
  3. randaug: https://github.com/ildoonet/pytorch-randaugment
  4. fast-autoaug: https://github.com/kakaobrain/fast-autoaugment
  5. SupContrast: https://github.com/HobbitLong/SupContrast
  6. metric learning: https://github.com/KevinMusgrave/pytorch-metric-learning

torchutils's People

Contributors

seefun avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.