Giter Club home page Giter Club logo

torchtest's Introduction

torchtest

A Tiny Test Suite for pytorch based Machine Learning models, inspired by mltest. Chase Roberts lists out 4 basic tests in his medium post about mltest. torchtest is mostly a pytorch port of mltest(which was written for tensorflow).

Installation

pip install --upgrade torchtest

Tests

# imports for examples
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

Variables Change

from torchtest import assert_vars_change

inputs = Variable(torch.randn(20, 20))
targets = Variable(torch.randint(0, 2, (20,))).long()
batch = [inputs, targets]
model = nn.Linear(20, 2)

# what are the variables?
print('Our list of parameters', [ np[0] for np in model.named_parameters() ])

# do they change after a training step?
#  let's run a train step and see
assert_vars_change(
    model=model,
    loss_fn=F.cross_entropy,
    optim=torch.optim.Adam(model.parameters()),
    batch=batch)
""" FAILURE """
# let's try to break this, so the test fails
params_to_train = [ np[1] for np in model.named_parameters() if np[0] is not 'bias' ]
# run test now
assert_vars_change(
    model=model,
    loss_fn=F.cross_entropy,
    optim=torch.optim.Adam(params_to_train),
    batch=batch)

# YES! bias did not change

Variables Don’t Change

from torchtest import assert_vars_same

# What if bias is not supposed to change, by design?
#  test to see if bias remains the same after training
assert_vars_same(
    model=model,
    loss_fn=F.cross_entropy,
    optim=torch.optim.Adam(params_to_train),
    batch=batch,
    params=[('bias', model.bias)]
    )
# it does? good. let's move on

Output Range

from torchtest import test_suite

# NOTE : bias is fixed (not trainable)
optim = torch.optim.Adam(params_to_train)
loss_fn=F.cross_entropy

test_suite(model, loss_fn, optim, batch,
    output_range=(-2, 2),
    test_output_range=True
    )

# seems to work
""" FAILURE """
#  let's tweak the model to fail the test
model.bias = nn.Parameter(2 + torch.randn(2, ))

test_suite(
    model,
    loss_fn, optim, batch,
    output_range=(-1, 1),
    test_output_range=True
    )

# as expected, it fails; yay!

NaN Tensors

""" FAILURE """
model.bias = nn.Parameter(float('NaN') * torch.randn(2, ))

test_suite(
    model,
    loss_fn, optim, batch,
    test_nan_vals=True
    )

Inf Tensors

""" FAILURE """
model.bias = nn.Parameter(float('Inf') * torch.randn(2, ))

test_suite(
    model,
    loss_fn, optim, batch,
    test_inf_vals=True
    )

torchtest's People

Contributors

suriyadeepan avatar

Stargazers

 avatar

Watchers

James Cloos avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.