Giter Club home page Giter Club logo

torchdiffeq's People

Contributors

bamos avatar brettkoonce avatar cfinlay avatar emecas avatar haowggit avatar jambo6 avatar jamesallingham avatar lxuechen avatar maricelam avatar patrick-kidger avatar rafaelvalle avatar rajatvd avatar rtqichen avatar shivak avatar simitii avatar slishak avatar star-cold avatar stefoe avatar talesa avatar timudk avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

torchdiffeq's Issues

The function _select_initial_step() may output nan

Line 131 of "misc.py":
h0 = 0.01 * max(d0_ / d1_ for d0_, d1_ in zip(d0, d1))
When I test Dopri5 on my own network, I found that this line may output nan if d1_ = 0. Thus I think maybe it should be modified as
h0 = 0.01 * max(d0_ / (d1_ + 1e-5) for d0_, d1_ in zip(d0, d1))

Crash running dopri5

Hello! I'm having issues running the dopri5 method on a time sequence, very closely to what is in the latent example.
I am trying to run odeint using odeint_adjoint with a z0 batch of shape (batchsize, time_sequence, 1).
Whenever I try to run the following code, I get the error below:

print(times)
print(z0)
pred_z = odeint(func, z0, times, method="dopri5").permute(1, 0, 2) # odeint_adjoint
tensor([  0.,  60., 120., 180.], device='cuda:0')
tensor([[ 3.4961e-01, -1.3341e+00, -1.0196e+00, -5.6237e-01, -5.7106e-01,
         -2.0538e-01, -2.0877e-01,  1.6032e-01, -1.0014e-01,  8.7283e-01,
          2.0110e+00, -6.1008e-01,  2.5714e-01,  2.3386e+00,  2.7314e+00,
          2.4449e+00],
        [ 2.3964e-01, -3.9234e+00, -5.1578e-01, -1.1946e+00, -1.5457e+00,
         -7.2809e-01,  8.4074e-01,  2.4824e+00,  3.0488e-01,  3.8429e-01,
          3.0424e-01, -3.8435e-01,  1.7489e+00,  1.0455e+00, -2.3369e-01,
         -1.3098e+00],
        [-3.7673e-01, -6.8062e-01, -7.5301e-02, -4.2621e-01,  1.0845e+00,
         -4.3786e-01,  4.9334e-01,  1.7223e+00,  9.4618e-01,  9.6530e-01,
          2.8994e+00,  1.0563e+00,  8.6989e-01,  1.9997e+00, -1.1819e+00,
          2.2736e-01],
        [ 6.3911e-01, -2.0203e-01,  1.4277e+00, -2.0914e-01, -1.9965e+00,
          7.3284e-02, -5.9003e-01, -9.9907e-01,  1.2299e-01, -8.0105e-02,
          1.2543e+00,  8.7276e-01,  2.9519e-01,  6.1938e-01,  1.2256e+00,
         -2.1609e-01],
        [ 7.0405e-01,  9.3754e-02, -2.9164e-02, -2.3351e-02, -2.7254e-01,
         -6.6201e-01, -1.2737e+00,  6.7255e-01, -2.8363e-01, -6.3016e-01,
          2.9853e+00,  1.7805e+00,  2.4158e-01,  1.1367e+00,  1.4954e+00,
          1.8174e-01],
        [ 8.2383e-01, -8.8112e-01, -1.2737e+00,  6.5401e-02, -4.7465e-01,
          4.9482e-01,  7.0683e-01,  5.8325e-01, -9.1313e-01, -9.0717e-01,
          1.9697e+00, -8.7827e-01, -3.4570e-01,  2.8642e-01, -2.0495e+00,
          2.0563e+00],
        [ 9.1569e-01, -1.2349e+00, -1.5394e+00, -7.0736e-01, -5.5272e-01,
         -1.5898e+00, -6.0082e-01,  1.3908e+00,  2.7930e-01,  6.9085e-01,
          5.1686e-01,  6.5842e-01,  6.8905e-01,  9.7911e-01,  5.6687e-01,
         -6.5162e-01],
        [ 1.5812e-01, -1.2457e+00,  1.9778e+00,  5.4664e-02,  8.2410e-01,
         -2.6325e+00,  7.3439e-01, -1.0063e-02,  1.9677e-01,  2.4709e-01,
          1.4928e+00, -2.9880e-01,  2.1503e+00,  1.8539e+00,  8.1897e-02,
          4.5690e-01],
        [ 4.1733e-01, -1.3703e-01,  3.1289e-01, -1.1011e+00, -1.3120e+00,
         -2.0392e+00, -1.1899e+00,  6.1899e-01,  1.2533e+00, -1.9775e+00,
          8.3711e-01,  4.2185e-01,  5.4436e-01,  1.0219e+00, -9.9984e-01,
         -1.1049e+00],
        [ 2.3511e-01, -2.3007e+00, -7.5956e-01, -7.1586e-01, -4.7162e-01,
          1.9671e-01,  6.2589e-02,  9.0480e-01, -6.4017e-01, -1.6957e-01,
          7.3696e-01, -9.2881e-01, -5.9447e-01,  3.0396e-01,  6.7777e-03,
         -4.9320e-01],
        [ 4.3563e-01, -2.3411e-01,  1.1254e+00,  1.9592e-01, -1.8889e+00,
         -2.6352e+00,  7.4621e-01,  1.7433e+00, -5.3961e-01,  1.2617e+00,
          1.7898e+00,  2.7057e-01,  1.8180e+00,  1.7901e+00, -8.2793e-01,
          1.5555e+00],
        [ 9.7916e-02, -2.3332e+00, -2.9267e+00, -5.0806e-01, -7.4166e-01,
          4.6953e-01,  7.0740e-01,  9.7310e-01,  2.7399e-01,  6.2945e-01,
          1.6237e+00,  1.4580e+00, -2.0365e-01,  1.5745e+00, -1.2565e+00,
          1.0269e-01],
        [ 1.8389e-01, -3.0685e-01,  7.5610e-01, -1.2899e+00,  9.0083e-01,
         -8.9824e-01,  1.1224e+00,  9.1232e-01,  4.6421e-01, -2.4713e-01,
          1.5397e+00,  7.5857e-01,  1.2211e-01,  1.3789e+00, -1.3621e+00,
          1.6946e+00],
        [ 8.4649e-01, -3.3510e-01, -3.8854e-03,  4.8532e-01, -1.7783e+00,
         -1.0863e+00, -6.1646e-02,  1.9066e+00, -2.7216e-01,  8.5870e-01,
          1.2407e+00, -7.2768e-02,  1.1854e+00,  2.8140e+00,  4.0069e-01,
          7.5213e-01],
        [ 2.7837e-01, -1.2072e+00,  9.8222e-01, -1.0471e-01, -1.0825e+00,
         -9.6411e-04,  1.4368e+00,  1.3269e+00, -8.2408e-01, -2.9107e+00,
          2.6670e+00,  1.2564e-01,  3.1355e+00,  1.3564e+00, -1.1322e+00,
          1.7455e+00],
        [ 5.5862e-01,  6.3158e-02, -1.1971e+00, -6.8839e-01,  3.9886e-01,
         -9.8823e-01,  1.4100e+00,  3.3176e-01, -1.6303e+00,  1.1427e+00,
          4.9966e-01,  6.6720e-01,  1.2786e+00,  1.3973e+00, -5.7325e-01,
         -1.5150e+00]], device='cuda:0', grad_fn=<AddBackward0>)
Traceback (most recent call last):
  File "main.py", line 13, in <module>
    train_model("model", 10, gen)
  File "D:\github\Crispy\beta-predict\ode_torch.py", line 171, in train_model
    loss.backward()
  File "C:\Python36\lib\site-packages\torch\tensor.py", line 107, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "C:\Python36\lib\site-packages\torch\autograd\__init__.py", line 93, in backward
    allow_unreachable=True)  # allow_unreachable flag
  File "C:\Python36\lib\site-packages\torch\autograd\function.py", line 77, in apply
    return self._forward_cls.backward(self, *args)
  File "d:\github\crispy\torchdiffeq\torchdiffeq\_impl\adjoint.py", line 83, in backward
    torch.tensor([t[i], t[i - 1]]), rtol=rtol, atol=atol, method=method, options=options
  File "d:\github\crispy\torchdiffeq\torchdiffeq\_impl\odeint.py", line 72, in odeint
    solution = solver.integrate(t)
  File "d:\github\crispy\torchdiffeq\torchdiffeq\_impl\solvers.py", line 31, in integrate
    y = self.advance(t[i])
  File "d:\github\crispy\torchdiffeq\torchdiffeq\_impl\dopri5.py", line 90, in advance
    self.rk_state = self._adaptive_dopri5_step(self.rk_state)
  File "d:\github\crispy\torchdiffeq\torchdiffeq\_impl\dopri5.py", line 100, in _adaptive_dopri5_step
    assert t0 + dt > t0, 'underflow in dt {}'.format(dt.item())
AssertionError: underflow in dt 0.0

Also, if I try with adams it just hangs.
Any idea what is going on?
Thanks in advance!

Tests Fail (gradient + odeint)

Hello,

Thank you for your interesting work. I am eager to try out your code and experiment with neural ODEs.
However, when installing and testing your code, the following tests fail:

....F........................
======================================================================
FAIL: test_adams_adjoint_against_dopri5 (gradient_tests.TestCompareAdjointGradient)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/path/to/my/home/dir/torchdiffeq/tests/gradient_tests.py", line 134, in test_adams_adjoint_against_dopri5
    self.assertLess(max_abs(y0.grad - adj_y0_grad), 5e-2)
AssertionError: tensor(0.5438) not less than 0.05

======================================================================
FAIL: test_adams (odeint_tests.TestSolverError) (ode='linear')
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/path/to/my/home/dir/torchdiffeq/tests/odeint_tests.py", line 52, in test_adams
    self.assertLess(rel_error(sol, y), error_tol)
AssertionError: tensor(0.0005, grad_fn=<MaxBackward1>) not less than 0.0001

----------------------------------------------------------------------
Ran 30 tests in 19.674s

FAILED (failures=2)

I tested this with a fresh python 3.6.3 enviroment (with pyenv) on ubuntu 16.04. Installed numpy and scipy (with --no-binary option) before installing torch and your package. pip list returns

Package     Version     Location                                    
----------- ----------- --------------------------------------------
numpy       1.16.1      
Pillow      5.4.1       
pip         19.0.2      
scipy       1.2.1       
setuptools  28.8.0      
six         1.12.0      
torch       1.0.1.post2 
torchdiffeq 0.0.1       /path/to/my/home/dir/torchdiffeq
torchvision 0.2.1       

Do you have any idea what could be the issue there?

Thank you,
Best,
Max

Possible MacOSX visualization crash

When trying to run ode_demo.py with --viz on macosx I am getting the following crash.

(pytorch) โžœ  examples git:(master) python ode_demo.py --viz
2019-04-28 12:17:30.404 python[23732:2807027] -[NSApplication _setup:]: unrecognized selector sent to instance 0x7ffb17c3fd10
2019-04-28 12:17:30.406 python[23732:2807027] *** Terminating app due to uncaught exception 'NSInvalidArgumentException', reason: '-[NSApplication _setup:]: unrecognized selector sent to instance 0x7ffb17c3fd10'
*** First throw call stack:
(
	0   CoreFoundation                      0x00007fff4f8cf68b __exceptionPreprocess + 171
	1   libobjc.A.dylib                     0x00007fff76b60c76 objc_exception_throw + 48
	2   CoreFoundation                      0x00007fff4f9681e4 -[NSObject(NSObject) doesNotRecognizeSelector:] + 132
	3   CoreFoundation                      0x00007fff4f845b50 ___forwarding___ + 1456
	4   CoreFoundation                      0x00007fff4f845518 _CF_forwarding_prep_0 + 120
	5   libtk8.6.dylib                      0x0000000121a2e31d TkpInit + 413
	6   libtk8.6.dylib                      0x000000012198617e Initialize + 2622
	7   _tkinter.cpython-37m-darwin.so      0x000000011dfe7a0f _tkinter_create + 1183
	8   python                              0x0000000108639116 _PyMethodDef_RawFastCallKeywords + 230
	9   python                              0x0000000108775e42 call_function + 306
	10  python                              0x0000000108773aec _PyEval_EvalFrameDefault + 46092
	11  python                              0x000000010876749e _PyEval_EvalCodeWithName + 414
	12  python                              0x0000000108637de7 _PyFunction_FastCallDict + 231
	13  python                              0x00000001086ba381 slot_tp_init + 193
	14  python                              0x00000001086c4361 type_call + 241
	15  python                              0x0000000108638ae3 _PyObject_FastCallKeywords + 179
	16  python                              0x0000000108775ed5 call_function + 453
	17  python                              0x0000000108773be0 _PyEval_EvalFrameDefault + 46336
	18  python                              0x00000001086388d5 function_code_fastcall + 117
	19  python                              0x0000000108775dc7 call_function + 183
	20  python                              0x0000000108773aec _PyEval_EvalFrameDefault + 46092
	21  python                              0x000000010876749e _PyEval_EvalCodeWithName + 414
	22  python                              0x0000000108637de7 _PyFunction_FastCallDict + 231
	23  python                              0x000000010863bce2 method_call + 130
	24  python                              0x0000000108639752 PyObject_Call + 130
	25  python                              0x0000000108773d58 _PyEval_EvalFrameDefault + 46712
	26  python                              0x000000010876749e _PyEval_EvalCodeWithName + 414
	27  python                              0x0000000108638fe3 _PyFunction_FastCallKeywords + 195
	28  python                              0x0000000108775dc7 call_function + 183
	29  python                              0x0000000108773be0 _PyEval_EvalFrameDefault + 46336
	30  python                              0x000000010876749e _PyEval_EvalCodeWithName + 414
	31  python                              0x00000001087ca9a0 PyRun_FileExFlags + 256
	32  python                              0x00000001087c9e17 PyRun_SimpleFileExFlags + 391
	33  python                              0x00000001087f7d3f pymain_main + 9663
	34  python                              0x000000010860b66d main + 125
	35  libdyld.dylib                       0x00007fff7777a015 start + 1
	36  ???                                 0x0000000000000003 0x0 + 3

Following advice here I inserted the following above if args.viz at aprox ln 57

from sys import platform as sys_pf
if sys_pf == 'darwin':
    import matplotlib
    matplotlib.use("TkAgg")

And it ran fine.

Why num_blocks?

Thanks for releasing the code and I've learned a lot from the design. Just confused with the num_blocks argument. Could the authors explain why we need multiple CNF blocks rather than repeating on ODEnet block level?

Other examples

Thanks for the great work. Not an issue for sure, but I was wondering if you could also share the codes for the generative latent function time-series model?
Best
Vahid

problem running `ODEnet for MNIST `

Simple run python odenet_mnist.py --network odenet will cause following error:

Traceback (most recent call last):
  File "odenet_mnist.py", line 326, in <module>
    logits = model(x)
  File "/home/binjie/ENV/localENV/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/binjie/ENV/localENV/anaconda3/lib/python3.6/site-packages/torch/nn/modules/container.py", line 91, in forward
    input = module(input)
  File "/home/binjie/ENV/localENV/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "odenet_mnist.py", line 64, in forward
    out = self.relu(self.norm1(x))
  File "/home/binjie/ENV/localENV/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/binjie/ENV/localENV/anaconda3/lib/python3.6/site-packages/torch/nn/modules/normalization.py", line 217, in forward
    input, self.num_groups, self.weight, self.bias, self.eps)
  File "/home/binjie/ENV/localENV/anaconda3/lib/python3.6/site-packages/torch/nn/functional.py", line 1323, in group_norm
    torch.backends.cudnn.enabled)
RuntimeError: Expected number of channels in input to be divisible by num_groups, but got input of shape [128, 1, 28, 28] and num_groups=32

After setting --downsampling-method conv explicitly this problem is solved.

However, when trying to run python odenet_mnist.py --network odenet --adjoint True --downsampling-method conv, following error emerges:

Traceback (most recent call last):
  File "odenet_mnist.py", line 333, in <module>
    loss.backward()
  File "/home/binjie/ENV/localENV/anaconda3/lib/python3.6/site-packages/torch/tensor.py", line 93, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/binjie/ENV/localENV/anaconda3/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward
    allow_unreachable=True)  # allow_unreachable flag
  File "/home/binjie/ENV/localENV/anaconda3/lib/python3.6/site-packages/torch/autograd/function.py", line 76, in apply
    return self._forward_cls.backward(self, *args)
  File "/mfs/binjie/server_conf/ENV/shareENV/CONF/projects/torchdiffeq/torchdiffeq/_impl/adjoint.py", line 74, in backward
    adj_time = adj_time - dLd_cur_t
RuntimeError: Expected object of type torch.FloatTensor but found type torch.cuda.FloatTensor for argument #3 'other'

It seems like something is wrong with data type conversion inside torchdiffeq.

Publish to PyPI?

It would be convenient if this package could be published to PyPI, before thinking of merging into PyTorch.

odenet_mnist seems stuck, doesn't finish overnight

When I run it it prints out:

import os
import argparse
import logging
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

parser = argparse.ArgumentParser()
parser.add_argument('--network', type=str, choices=['resnet', 'odenet'], default='odenet')
parser.add_argument('--tol', type=float, default=1e-3)
parser.add_argument('--adjoint', type=eval, default=False, choices=[True, False])
parser.add_argument('--downsampling-method', type=str, default='conv', choices=['conv', 'res'])
parser.add_argument('--nepochs', type=int, default=160)
parser.add_argument('--data_aug', type=eval, default=True, choices=[True, False])
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--test_batch_size', type=int, default=1000)

parser.add_argument('--save', type=str, default='./experiment1')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--gpu', type=int, default=0)
args = parser.parse_args()

if args.adjoint:
    from torchdiffeq import odeint_adjoint as odeint
else:
    from torchdiffeq import odeint


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


def norm(dim):
    return nn.GroupNorm(min(32, dim), dim)


class ResBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(ResBlock, self).__init__()
        self.norm1 = norm(inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.norm2 = norm(planes)
        self.conv2 = conv3x3(planes, planes)

    def forward(self, x):
        shortcut = x

        out = self.relu(self.norm1(x))

        if self.downsample is not None:
            shortcut = self.downsample(out)

        out = self.conv1(out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)

        return out + shortcut


class ConcatConv2d(nn.Module):

    def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
        super(ConcatConv2d, self).__init__()
        module = nn.ConvTranspose2d if transpose else nn.Conv2d
        self._layer = module(
            dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
            bias=bias
        )

    def forward(self, t, x):
        tt = torch.ones_like(x[:, :1, :, :]) * t
        ttx = torch.cat([tt, x], 1)
        return self._layer(ttx)


class ODEfunc(nn.Module):

    def __init__(self, dim):
        super(ODEfunc, self).__init__()
        self.norm1 = norm(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm3 = norm(dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out


class ODEBlock(nn.Module):

    def __init__(self, odefunc):
        super(ODEBlock, self).__init__()
        self.odefunc = odefunc
        self.integration_time = torch.tensor([0, 1]).float()

    def forward(self, x):
        self.integration_time = self.integration_time.type_as(x)
        out = odeint(self.odefunc, x, self.integration_time, rtol=args.tol, atol=args.tol)
        return out[1]

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value


class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)


class RunningAverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, momentum=0.99):
        self.momentum = momentum
        self.reset()

    def reset(self):
        self.val = None
        self.avg = 0

    def update(self, val):
        if self.val is None:
            self.avg = val
        else:
            self.avg = self.avg * self.momentum + val * (1 - self.momentum)
        self.val = val


def get_mnist_loaders(data_aug=False, batch_size=128, test_batch_size=1000, perc=1.0):
    if data_aug:
        transform_train = transforms.Compose([
            transforms.RandomCrop(28, padding=4),
            transforms.ToTensor(),
        ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
        ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])

    train_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_train), batch_size=batch_size,
        shuffle=True, num_workers=2, drop_last=True
    )

    train_eval_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_test),
        batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
    )

    test_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=False, download=True, transform=transform_test),
        batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
    )

    return train_loader, test_loader, train_eval_loader


def inf_generator(iterable):
    """Allows training with DataLoaders in a single infinite loop:
        for i, (x, y) in enumerate(inf_generator(train_loader)):
    """
    iterator = iterable.__iter__()
    while True:
        try:
            yield iterator.__next__()
        except StopIteration:
            iterator = iterable.__iter__()


def learning_rate_with_decay(batch_size, batch_denom, batches_per_epoch, boundary_epochs, decay_rates):
    initial_learning_rate = args.lr * batch_size / batch_denom

    boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs]
    vals = [initial_learning_rate * decay for decay in decay_rates]

    def learning_rate_fn(itr):
        lt = [itr < b for b in boundaries] + [True]
        i = np.argmax(lt)
        return vals[i]

    return learning_rate_fn


def one_hot(x, K):
    return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)


def accuracy(model, dataset_loader):
    total_correct = 0
    for x, y in dataset_loader:
        x = x.to(device)
        y = one_hot(np.array(y.numpy()), 10)

        target_class = np.argmax(y, axis=1)
        predicted_class = np.argmax(model(x).cpu().detach().numpy(), axis=1)
        total_correct += np.sum(predicted_class == target_class)
    return total_correct / len(dataset_loader.dataset)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def makedirs(dirname):
    if not os.path.exists(dirname):
        os.makedirs(dirname)


def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False):
    logger = logging.getLogger()
    if debug:
        level = logging.DEBUG
    else:
        level = logging.INFO
    logger.setLevel(level)
    if saving:
        info_file_handler = logging.FileHandler(logpath, mode="a")
        info_file_handler.setLevel(level)
        logger.addHandler(info_file_handler)
    if displaying:
        console_handler = logging.StreamHandler()
        console_handler.setLevel(level)
        logger.addHandler(console_handler)
    logger.info(filepath)
    with open(filepath, "r") as f:
        logger.info(f.read())

    for f in package_files:
        logger.info(f)
        with open(f, "r") as package_f:
            logger.info(package_f.read())

    return logger


if __name__ == '__main__':

    makedirs(args.save)
    logger = get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__))
    logger.info(args)

    device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')

    is_odenet = args.network == 'odenet'

    if args.downsampling_method == 'conv':
        downsampling_layers = [
            nn.Conv2d(1, 64, 3, 1),
            norm(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, 2, 1),
            norm(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, 2, 1),
        ]
    elif args.downsampling_method == 'res':
        downsampling_layers = [
            nn.Conv2d(1, 64, 3, 1),
            ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),
            ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),
        ]

    feature_layers = [ODEBlock(ODEfunc(64))] if is_odenet else [ResBlock(64, 64) for _ in range(6)]
    fc_layers = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)]

    model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device)

    logger.info(model)
    logger.info('Number of parameters: {}'.format(count_parameters(model)))

    criterion = nn.CrossEntropyLoss().to(device)

    train_loader, test_loader, train_eval_loader = get_mnist_loaders(
        args.data_aug, args.batch_size, args.test_batch_size
    )

    data_gen = inf_generator(train_loader)
    batches_per_epoch = len(train_loader)

    lr_fn = learning_rate_with_decay(
        args.batch_size, batch_denom=128, batches_per_epoch=batches_per_epoch, boundary_epochs=[60, 100, 140],
        decay_rates=[1, 0.1, 0.01, 0.001]
    )

    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)

    best_acc = 0
    batch_time_meter = RunningAverageMeter()
    f_nfe_meter = RunningAverageMeter()
    b_nfe_meter = RunningAverageMeter()
    end = time.time()

    for itr in range(args.nepochs * batches_per_epoch):

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr_fn(itr)

        optimizer.zero_grad()
        x, y = data_gen.__next__()
        x = x.to(device)
        y = y.to(device)
        logits = model(x)
        loss = criterion(logits, y)

        if is_odenet:
            nfe_forward = feature_layers[0].nfe
            feature_layers[0].nfe = 0

        loss.backward()
        optimizer.step()

        if is_odenet:
            nfe_backward = feature_layers[0].nfe
            feature_layers[0].nfe = 0

        batch_time_meter.update(time.time() - end)
        if is_odenet:
            f_nfe_meter.update(nfe_forward)
            b_nfe_meter.update(nfe_backward)
        end = time.time()

        if itr % batches_per_epoch == 0:
            with torch.no_grad():
                train_acc = accuracy(model, train_eval_loader)
                val_acc = accuracy(model, test_loader)
                if val_acc > best_acc:
                    torch.save({'state_dict': model.state_dict(), 'args': args}, os.path.join(args.save, 'model.pth'))
                    best_acc = val_acc
                logger.info(
                    "Epoch {:04d} | Time {:.3f} ({:.3f}) | NFE-F {:.1f} | NFE-B {:.1f} | "
                    "Train Acc {:.4f} | Test Acc {:.4f}".format(
                        itr // batches_per_epoch, batch_time_meter.val, batch_time_meter.avg, f_nfe_meter.avg,
                        b_nfe_meter.avg, train_acc, val_acc
                    )
                )

Namespace(adjoint=False, batch_size=128, data_aug=True, debug=False, downsampling_method='conv', gpu=0, lr=0.1, nepochs=160, network='odenet', save='./experiment1', test_batch_size=1000, tol=0.001)
Sequential(
  (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
  (1): GroupNorm(32, 64, eps=1e-05, affine=True)
  (2): ReLU(inplace)
  (3): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (4): GroupNorm(32, 64, eps=1e-05, affine=True)
  (5): ReLU(inplace)
  (6): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (7): ODEBlock(
    (odefunc): ODEfunc(
      (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
      (relu): ReLU(inplace)
      (conv1): ConcatConv2d(
        (_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
      (conv2): ConcatConv2d(
        (_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (norm3): GroupNorm(32, 64, eps=1e-05, affine=True)
    )
  )
  (8): GroupNorm(32, 64, eps=1e-05, affine=True)
  (9): ReLU(inplace)
  (10): AdaptiveAvgPool2d(output_size=(1, 1))
  (11): Flatten()
  (12): Linear(in_features=64, out_features=10, bias=True)
)
Number of parameters: 208266
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!

and uses 100% cpu for hours - checking in morning and nothing seems to have happened.

It doesn't finish - Ctrl+C stops it at this point:

  File "~/torchdiffeq/torchdiffeq/_impl/dopri5.py", line 103, in _adaptive_dopri5_step
    y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=_DORMAND_PRINCE_SHAMPINE_TABLEAU)
  File "~/torchdiffeq/torchdiffeq/_impl/rk_common.py", line 52, in _runge_kutta_step
    tuple(k_.append(f_) for k_, f_ in zip(k, func(ti, yi)))
  File "~/torchdiffeq/torchdiffeq/_impl/misc.py", line 179, in <lambda>
    func = lambda t, y: (_base_nontuple_func_(t, y[0]),)
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "./examples/odenet_mnist.py", line 108, in forward
    out = self.conv1(t, out)
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "./examples/odenet_mnist.py", line 89, in forward
    return self._layer(ttx)
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/conv.py", line 320, in forward
    self.padding, self.dilation, self.groups)
KeyboardInterrupt

Training Odenet on Faces95

Can you help me with loading faces95 dataset and train it on odenet? I wanted to use this for face recognition task

OdeintAdjointMethod.apply throws RuntimeError given nan values.

Error below is obtained during loss.backward on code that works fine on 2d data/model and has been adapted to 3d.
Let me know if you have ideas on what might have caused this error and how to debug it.

  File "/torchdiffeq/torchdiffeq/_impl/adjoint.py", line 129, in odeint_adjoint
    ys = OdeintAdjointMethod.apply(*y0, func, t, flat_params, rtol, atol, method, options)

RuntimeError: Function 'torch::autograd::GraphRoot' returned nan values in its 0th output.

how to avoid underflow in dt in dopri5?

After some iterations, I see the underflow in dt exception coming from dopri5.py file. What should I look for in order to debug this problem?

Thanks in advance.

Can handle 2nd order ODE?

Hi, the code looks great! I'm wondering if it can handle cases where the differential equation contains second derivatives. For example

 d^2 y / dt^2 = g(t, y)
 dy/dt = f(t, y)
 y'(t_0) = yp_0
 y(t_0) = y_0

This would be very useful for many different applications.

CNF Implementation details

Hi, I'm trying to play with Neural ODEs and do some reimplementations in Tensorflow.
I was able to implement basic solver which works with the spiral problem in your examples.
However, I got stuck on CNF implementation :/

Here is my current version of planar flow implemented as a Keras Model:

class PlanarFlow(tf.keras.Model):
    def __init__(self, dim, scale=1.0, bias_scale=1.0, w_scale=1.0):
        super().__init__()
        self.weight = tf.Variable(w_scale * (w_initializer([1, dim])), name='weight')
        self.bias = tf.Variable(scale * w_initializer([1, 1]), name='bias')
        self.scale = tf.Variable(bias_scale * (w_initializer([1, dim])), name='scale')
        self.activation = tf.nn.tanh
      
    def linear(self, z):
        return tf.reduce_sum(z * self.weight, axis=-1, keepdims=True) + self.bias    
        
    def call(self, inputs, **kwargs):
        t, z, logdet = inputs

        with tf.GradientTape() as g:
            g.watch(z)
            logits = self.linear(z)  
            hfunc = self.activation(logits)        
        
        new_z = self.scale * hfunc  # (batch_size, 2)
        # compute gradient: df(z) / dz
        gradients = g.gradient(
            target=hfunc,
            sources=z,
        )
        # trace  - T (batch_size, 1)
        new_logdet = - tf.matmul(gradients, tf.transpose(self.scale))        
        # return dynamics gradients for z and log(p(z))
        return new_z, logdet + new_logdet

This implementation has missing gating mechanism (some NN as explained in paper), which is not described in the paper. How was the gating mechanism implemented in your case ?

Secondly, please correct me if I understand the algorithm correctly. Once I will have correct implementation of Planar flow I have to create a combination of them (Eq. 10), like this:

class MultipleFlow(tf.keras.Model):
    
    def __init__(self, num_flows, flow_factory=lambda: PlanarFlow(2)):
        super().__init__()
        self.flows = [flow_factory() for _ in range(num_flows)]
        
    def call(self, inputs, **kwargs):
        t, z, logdet = inputs
        for flow in self.flows:
            z_k, logdet_k = flow(inputs)
            z = z + z_k
            logdet = logdet + logdet_k            
        return t, z, logdet

# create CNF with M = 32 ???
cnf = MultipleFlow(num_flows=32)

And finally I want to maximize probability of energy function from Fig 4:

# pseudo code:
z_samples = tf.random_normal([512, 2])
# integrate dynamics
z_output, z_logdet = odeint(cnf, [z_samples, 0.0], tstart=1, tend=0, num_steps=100) 
# potential_energy => p(z_output) = exp(- U(z_output)) => - log(p(z_output)) = U(z_output)
loss = - potential_energy(z_output) - z_logdet
# maximize this 
loss = tf.reduce_mean(loss)

Is this approach correct?
Thank you in advance !

Training the network ? Conceptual Doubt.

Hi,

I have been checking out the paper and trying to utilize the code given in the examples for my own problem. I was also looking at the video where prof David K. Duvenaud, (NEURIPS conference) talks about why back propagating the gradients is bad. But in the examples , i see that loss.backward() is called , which if i am not mistaken computes the gradients the normal way for any nn module. Is the backward method overwritten anywhere else specially for odeBlock ?

more ODEBlock

if I use more ODEBlocks in odenet_mnist.py๏ผŒlike

feature_layers = [ODEBlock(ODEfunc(64)),  ODEBlock(ODEfunc(64))]

Is there anything should I change in odenet_mnist.py ?

        if is_odenet:
            nfe_forward = feature_layers[0].nfe
            feature_layers[0].nfe = 0

        loss.backward()
        optimizer.step()

        if is_odenet:
            nfe_backward = feature_layers[0].nfe
            feature_layers[0].nfe = 0

Receptive field in CNN-based architecture, and more about the usage of conv blocks

Thanks for the very interesting paper and the implemetation - fabulous work!

I just have a question about the usage of CNN blocks in the model. Please correct me if I'm wrong - it seems that a neural ODE with a single conv block will lead to infinite receptive field even for a very short integration time. For discrete ODE-based designs such as Euler Net and Runge-Kutta Net, the receptive fields are all finite and dependent on the number of layers/blocks. If so, (1) it seems that a conv block degenerates into a FC layer applied on the entire (flattened) input, since the concept of "receptive field" no loger holds in this case, and (2) it seems that there's no need to use a larger/deeper model for neural ode if the target is to cover a large, possibly global receptive field - a single block (perhaps together with HyperNet) should be enough for everything. I'm not sure if this assumption still holds for larger models such as ResNet50/ResNet101 on larger datasets, but my intuition is that a single conv-block ODE might be hard to hit on par performance with them (e.g. a comment in #32 about the performance on CIFAR10). So I'm also wondering if you have done any numerical experiments on larger datasets and compare neural ODE with larger, especially deeper models.

Thanks in advance!

possible issue in rk_common.py

First of all, thank you for your work and sharing your code. Very eye-opening stuff!
I was reading through the flow of the code and encountered a little bump when I arrived at "rk_common.py".

In line 52, you appear to be reorganizing the contents of the zip object "zip(k, func(ti, yi))" into tuple form, but this tuple isn't assigned to anything. Nonetheless, it appears this line is necessary for the code to run, as "ode_demo.py" fails to run without this line. So my question is what is actually happening in line 52.

My guess is that the tuple declared in line 52 ought to be assigned to "k".
In line 59, "f1" is defined to be a tuple consisting of the last entry of every element k_ in k. However, the only time "k" is modified up to this point is when "k" is initially defined in line 48, meaning that line 59 appears to be merely copying the contents of "k" and hence the contents of "f0" into "f1", which is supposed to be the derivative of the state at "t1".
Assigning the tuple declared in line 52 to "k" would resolve this, but I'm not sure as to how or whether this is done.

k = tuple(map(lambda x: [x], f0))
for alpha_i, beta_i in zip(tableau.alpha, tableau.beta):
ti = t0 + alpha_i * dt
yi = tuple(y0_ + _scaled_dot_product(dt, beta_i, k_) for y0_, k_ in zip(y0, k))
tuple(k_.append(f_) for k_, f_ in zip(k, func(ti, yi)))
if not (tableau.c_sol[-1] == 0 and tableau.c_sol[:-1] == tableau.beta[-1]):
# This property (true for Dormand-Prince) lets us save a few FLOPs.
yi = tuple(y0_ + _scaled_dot_product(dt, tableau.c_sol, k_) for y0_, k_ in zip(y0, k))
y1 = yi
f1 = tuple(k_[-1] for k_ in k)
y1_error = tuple(_scaled_dot_product(dt, tableau.c_error, k_) for k_ in k)

argument '--downsampling-method' in odenet_mnist.py?

Is this a typo?
Line 16 has hyphen in arg name:
parser.add_argument('--downsampling-method', type=str, default='conv', choices=['conv', 'res'])

Whereas line 287 has underscore in arg name:
if args.downsampling_method == 'conv':

In some cases `odeint` works on GPU but fails on CPU

Hi!
I've came across an example that runs on the GPU but doesn't run on the CPU.
The cause seems to be x.norm() returning a nan when on CPU and a correct result on a GPU.
(https://github.com/rtqichen/torchdiffeq/blob/master/torchdiffeq/_impl/misc.py#L74)
This is happening when setting initial step at
https://github.com/rtqichen/torchdiffeq/blob/master/torchdiffeq/_impl/misc.py#L126

Don't have time to follow through on this at the moment, leaving it here for now.

Syntax error when running odenet_mnist.py

I was trying to run
python odenet_mnist.py --network odenet --adjoint True
and had the following error message:
File "odenet_mnist.py", line 307 model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device) ^ SyntaxError: invalid syntax
I tried both python2 and python3, and neither worked.

NotImplementedError error when using odeint_adjoint

I've attache file with my code.

question.txt

If I use "odeint" everything is OK. But in the case of "odeint_adjoint" I have the next error:

C:\Users\mazhenir\AppData\Local\Continuum\anaconda3\python.exe D:/Projects/aira-gitlab/interaction_learning_potential/forum_question/question.py
Traceback (most recent call last):
File "/question.py", line 80, in
res = model(Z)
File "\AppData\Local\Continuum\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 489, in call
result = self.forward(*input, **kwargs)
File "/question.py", line 58, in forward
res = odeint_adjoint(self.fun, Z0, self.t)
File "\torchdiffeq\torchdiffeq_impl\adjoint.py", line 129, in odeint_adjoint
ys = OdeintAdjointMethod.apply(*y0, func, t, flat_params, rtol, atol, method, options)
File "\torchdiffeq\torchdiffeq_impl\adjoint.py", line 18, in forward
ans = odeint(func, y0, t, rtol=rtol, atol=atol, method=method, options=options)
File "\torchdiffeq\torchdiffeq_impl\odeint.py", line 72, in odeint
solution = solver.integrate(t)
File "\torchdiffeq\torchdiffeq_impl\solvers.py", line 29, in integrate
self.before_integrate(t)
File "\torchdiffeq\torchdiffeq_impl\dopri5.py", line 78, in before_integrate
f0 = self.func(t[0].type_as(self.y0[0]), self.y0)
File "\AppData\Local\Continuum\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 489, in call
result = self.forward(*input, **kwargs)
File "\torchdiffeq\torchdiffeq_impl\adjoint.py", line 122, in forward
return (self.base_func(t, y[0]),)
File "\AppData\Local\Continuum\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 489, in call
result = self.forward(*input, **kwargs)
File "\AppData\Local\Continuum\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 85, in forward
raise NotImplementedError
NotImplementedError

Process finished with exit code 1

Any idea?

examples/ode_demo.py throws SyntaxError

Traceback (most recent call last):
  File "examples/ode_demo.py", line 25, in <module>
    from torchdiffeq import odeint
  File "/Users/jarednielsen/Documents/papers/torchdiffeq/torchdiffeq/__init__.py", line 1, in <module>
    from ._impl import odeint
  File "/Users/jarednielsen/Documents/papers/torchdiffeq/torchdiffeq/_impl/__init__.py", line 1, in <module>
    from .odeint import odeint
  File "/Users/jarednielsen/Documents/papers/torchdiffeq/torchdiffeq/_impl/odeint.py", line 4, in <module>
    from .fixed_adams import AdamsBashforth, AdamsBashforthMoulton
  File "/Users/jarednielsen/Documents/papers/torchdiffeq/torchdiffeq/_impl/fixed_adams.py", line 198
    print('Warning: Functional iteration did not converge. Solution may be incorrect.', file=sys.stderr)
                                                                                            ^
SyntaxError: invalid syntax

Any solutions?

ODEsolve in BP

From the algorithm 1 presented in the paper, bp is done with the odesolve, but from the logger in the code, the value of NFE-B is always zero, and I cannot find out relevant code as well. Any suggestion on better understanding the paper? Thanks a lot.

Possible fault in your evaluation process

Dear writers,

Thanks for your work, it is very interesting and the code is also very well written.

It seems to me that there is an issue with your performance evaluation process.

In the paper you claim that for the problem of recognizing hand written digits, replacing a 6 layer ResNet module with a single ODE layer yields similar results. The paper therefore seems to suggest that a single ODE layer is equivalent to 6 layer ResNet module in some non-degenerate scenarios.

In this specific architecture however, replacing the 6 layer ResNet Module with a single ResNet layer yields the same result. Since the ODE layer itself is based on a single layer ResNet architecture, your results seem to fail at demonstrating any gain from using an ODE layer in this contest.

Thanks,

Yuval Frommer

MNIST: ODEBlock possibly redundant?

Hello,

Thank you for your work. It introduces a very interesting concept.

I have a question regarding your experimental section that acts as verification of the ODE model for MNIST classification.

Your ODE MNIST model in the paper is the following
model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device),
with the ODE block being in the middle as feature_layers and downsampling_method == conv. It has overall 208266 parameters and achieves test error of 0.42%.

However, if you get rid of the middle block altogether and construct the following instead
model = nn.Sequential(*downsampling_layers, *fc_layers).to(device),
with downsampling_layers and fc_layers exactly as in the case before, you get a model with 132874 that achieves a similar test error of under 0.6% after roughly 100 epochs.

Can it be that your experiment shows remarkable efficiency of your downsampling_layers rather than of the ODE block?

Thanks,

Simon

Can't pass further parameters to func?

Hi:)
Is it possible to pass other values b to func() where b is a constant parameter or something that does not have a "temporal" evolution. Sth. like odeint(func, y0, t, b)? cheers:)

Runs very slow on CUDA

I've been experimenting with torchdiffeq on google colab, but have found that running on CUDA is much slower on the GPU compared to the CPU. Colab apparently runs on a K80 so i expect it to be much faster.

I created a colab notebook that shows a simple benchmark, using modified code from ode_demo.py. Using the code I get an average of 1.72 seconds for a forward & backward pass on the GPU, compared to 0.32 seconds on the CPU. colab benchmark

Why Interpolate in solver?

Hi,

Many thanks for this great work. Could you please explain why you need to do a linear interpolation in the forward solver:

https://github.com/rtqichen/torchdiffeq/blob/master/torchdiffeq/_impl/solvers.py#L95:#L97

It seems that this should always return y1 but I am confused why you would need to perform an interpolation between y0 and y1? The only case this could happen is when the time grid elements differ from t[], but I am not sure why we would ever encounter such scenario

Some hyper-parameters and step acceptance

Dear authors,

I am a bit confused about the hyper-parameters _DORMAND_PRINCE_SHAMPINE_TABLEAU and DPS_C_MID chosen in the file torchdiffeq/_impl/dopri5.py, it seems over complicated. Do you have some intuitions about how to choose them?

Another question is the line 109 in the file torchdiffeq/_impl/dopri5.py, I have found accept_step is always True numerically. What if accept_step is False, it seems a dead-loop. I have set accept_step is False by force in this place and the code goes memory overflow.

Thanks in advance!

Merge into PyTorch?

Thanks for this useful library!

Would you consider submitting a pull request to PyTorch to merge it in? (I don't have any official affiliation with PyTorch/Facebook, just a user who would find it useful to have numerical ODE solvers without having to import another library.) If you don't have time, would you mind if I gave it a go?

Poisson Process Likelihoods

Hello, I'm trying to re-implement the part of the paper about modeling an inhomogeneous Poisson process, however I don't see how the compensator (negative part of the likelihood) can be computed "in a single call to an ODE solver". Do you just evaluate the intensity at random points in [t_start, t_end] and like so perform a Monte Carlo estimation of the integral ? Thanks !

The right way to integrate t

Hi. This github is really appreciated.

In your functions, the ODE-Net never really takes the t parameter into account.

If I get the article right, that's equivalent to saying, for the resnet case, that you repeat the same block with the exact same parameters anywhere you are on the flow map, or that you build a resnet using only one block and the inputs go multiple times into the same layers. In a recurrent case, with time series, I get why it's okay to do that, but in an image recognition task I'm not so sure...

It doesn't seem like an optimal way to represent the f(h, t) function. What's a good way to take t into account ? Shouldn't we ?

Definition of the dynamics of the ODENet for MNIST.

Hi, in odenet_mnist.py, the definition of the dynamics of the ODENet shares the same parameters across all t where saves the parameter numbers(dynamics function f only relates to Z(t): f(Z(t)) instead of f(Z(t), theta(t))). I wonder how to define the dynamics function in practical, and the reason why here only one ResBlock's parameters are enough for the ODENet. Thanks!

OOM during backward pass on a model with ~600k parameters

Hey Ricky,

I'm running out of memory during the backward pass on a 16gb gpu when running the adjoint method with rtol 1e-5, atol 1e-5, and a network with 631058 parameters.

I'm not sure why this happens given that the augmented_dynamics is within torch.no_grad() and the tensors saved during the forward pass should not be that large.

Any thoughts on what is happening and how to debug it?

The model network is a 3d unet (UNet) that goes into a few 3d conv(node_layers).

Model(
  (prenet_layers): PrenetLayer(
    (initval_layers): Identity()
    (image_layers): Sequential(
      (0): Conv3d(3, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): UNet(
        (in_norm): GroupNorm(8, 8, eps=1e-05, affine=True)
        (in_act): ReLU(inplace)
        (down1): down(
          (mpconv): Sequential(
            (0): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
            (1): double_conv(
              (conv): Sequential(
                (0): Conv3d(8, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (2): LeakyReLU(negative_slope=0.1, inplace)
                (3): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                (4): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (5): LeakyReLU(negative_slope=0.1, inplace)
              )
            )
          )
        )
        (down2): down(
          (mpconv): Sequential(
            (0): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
            (1): double_conv(
              (conv): Sequential(
                (0): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (2): LeakyReLU(negative_slope=0.1, inplace)
                (3): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                (4): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (5): LeakyReLU(negative_slope=0.1, inplace)
              )
            )
          )
        )
        (down3): down(
          (mpconv): Sequential(
            (0): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
            (1): double_conv(
              (conv): Sequential(
                (0): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (2): LeakyReLU(negative_slope=0.1, inplace)
                (3): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                (4): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (5): LeakyReLU(negative_slope=0.1, inplace)
              )
            )
          )
        )
        (down4): down(
          (mpconv): Sequential(
            (0): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
            (1): double_conv(
              (conv): Sequential(
                (0): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (2): LeakyReLU(negative_slope=0.1, inplace)
                (3): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                (4): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (5): LeakyReLU(negative_slope=0.1, inplace)
              )
            )
          )
        )
        (up0): up(
          (up): Upsample(scale_factor=2.0, mode=trilinear)
          (conv): double_conv(
            (conv): Sequential(
              (0): Conv3d(128, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
              (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): LeakyReLU(negative_slope=0.1, inplace)
              (3): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
              (4): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (5): LeakyReLU(negative_slope=0.1, inplace)
            )
          )
        )
        (up1): up(
          (up): Upsample(scale_factor=2.0, mode=trilinear)
          (conv): double_conv(
            (conv): Sequential(
              (0): Conv3d(64, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
              (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): LeakyReLU(negative_slope=0.1, inplace)
              (3): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
              (4): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (5): LeakyReLU(negative_slope=0.1, inplace)
            )
          )
        )
        (up2): up(
          (up): Upsample(scale_factor=2.0, mode=trilinear)
          (conv): double_conv(
            (conv): Sequential(
              (0): Conv3d(32, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
              (1): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): LeakyReLU(negative_slope=0.1, inplace)
              (3): Conv3d(8, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
              (4): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (5): LeakyReLU(negative_slope=0.1, inplace)
            )
          )
        )
        (up3): up(
          (up): Upsample(scale_factor=2.0, mode=trilinear)
          (conv): double_conv(
            (conv): Sequential(
              (0): Conv3d(16, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
              (1): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): LeakyReLU(negative_slope=0.1, inplace)
              (3): Conv3d(8, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
              (4): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (5): LeakyReLU(negative_slope=0.1, inplace)
            )
          )
        )
        (out): Conv3d(8, 8, kernel_size=(1, 1, 1), stride=(1, 1, 1))
        (out_norm): GroupNorm(8, 8, eps=1e-05, affine=True)
      )
    )
  )
  (node_layers): Sequential(
    (0): ODEBlock(
      (odefunc): ODEfunc(
        (tanh): Tanh()
        (conv_emb1): ConcatConv3d(
          (_layer): Conv3d(2, 4, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        )
        (norm_emb1): GroupNorm(4, 4, eps=1e-05, affine=True)
        (conv_emb2): ConcatConv3d(
          (_layer): Conv3d(5, 4, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        )
        (norm_emb2): GroupNorm(4, 4, eps=1e-05, affine=True)
        (norm_img_pre): GroupNorm(8, 8, eps=1e-05, affine=True)
        (conv_img): ConcatConv3d(
          (_layer): Conv3d(9, 4, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        )
        (norm_img): GroupNorm(4, 4, eps=1e-05, affine=True)
        (conv1): ConcatConv3d(
          (_layer): Conv3d(9, 4, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        )
        (norm_conv1): GroupNorm(4, 4, eps=1e-05, affine=True)
        (conv2): ConcatConv3d(
          (_layer): Conv3d(5, 4, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        )
        (norm_conv2): GroupNorm(4, 4, eps=1e-05, affine=True)
        (conv_out): ConcatConv3d(
          (_layer): Conv3d(5, 1, kernel_size=(1, 1, 1), stride=(1, 1, 1))
        )
      )
    )
  )
  (postnet_layers): Sequential(
    (0): Identity()
  )
)

Support higher order autodiff?

Thanks a lot for your work! However, it seems that backward after grad is not supported yet. Here is a minimal example:

# dy/dt = a b y, t = 0...1
# y1 = y0 exp(a b)
# dy1/da = b y0 exp(a b)
# dy1/dy0 = exp(a b)

import torch
from torch import nn
from torch.autograd import grad
from torchdiffeq import odeint_adjoint as odeint


class Func(nn.Module):
    def __init__(self):
        super(Func, self).__init__()
        self.a = nn.Parameter(torch.tensor(2.0))
        self.b = nn.Parameter(torch.tensor(3.0))

    def forward(self, t, y):
        return self.a * self.b * y


if __name__ == '__main__':
    func = Func()
    y0 = torch.tensor(4.0, requires_grad=True)
    t = torch.tensor([0.0, 1.0])
    y1 = odeint(func, y0, t)[1]
    print(y1)
    y1.backward(retain_graph=True)

    dy1_da = grad(y1, func.a, create_graph=True)[0]
    print(dy1_da)
    dy1_da.backward(retain_graph=True)

    dy1_dy0 = grad(y1, y0, create_graph=True)[0]
    print(dy1_dy0)
    dy1_dy0.backward(retain_graph=True)

Both dy1_da and dy1_dy0 do not have grad_fn, then dy1_da.backward and dy1_dy0.backward throw errors. It would be nice if you could support these operations, then we could build more complex applications on your package.

Error when trying to use gpu other than gpu 0

When I run the odenet for mnist script using a gpu other than gpu 0 as follows:

python odenet_mnist.py --network odenet --gpu 2

I get the following error:

Traceback (most recent call last):
  File "odenet_mnist.py", line 327, in <module>
    logits = model(x)
  File "/home/rajat/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/rajat/.local/lib/python3.6/site-packages/torch/nn/modules/container.py", line 92, in forward
    input = module(input)
  File "/home/rajat/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "odenet_mnist.py", line 109, in forward
    out = odeint(self.odefunc, x, self.integration_time, rtol=args.tol, atol=args.tol)
  File "/home/rajat/workspace/torchdiffeq/torchdiffeq/_impl/odeint.py", line 83, in odeint
    solution = solver.integrate(t)
  File "/home/rajat/workspace/torchdiffeq/torchdiffeq/_impl/solvers.py", line 31, in integrate
    y = self.advance(t[i])
  File "/home/rajat/workspace/torchdiffeq/torchdiffeq/_impl/dopri5.py", line 90, in advance
    self.rk_state = self._adaptive_dopri5_step(self.rk_state)
  File "/home/rajat/workspace/torchdiffeq/torchdiffeq/_impl/dopri5.py", line 119, in _adaptive_dopri5_step
    dt, mean_sq_error_ratio, safety=self.safety, ifactor=self.ifactor, dfactor=self.dfactor, order=5
  File "/home/rajat/workspace/torchdiffeq/torchdiffeq/_impl/misc.py", line 169, in _optimal_step_size
    factor = torch.max(1 / ifactor, torch.min(error_ratio**exponent / safety, 1 / dfactor))
RuntimeError: arguments are located on different GPUs at /pytorch/aten/src/THC/generic/THCTensorMathPointwise.cu:266

I'm using pytorch 1.0.0 and the gpu is an Nvidia 1080 ti.

Cannot run ffjord cnf examples on GPU

I am trying to run some examples in the https://github.com/rtqichen/ffjord and cannot get any of the cnf vae demos to run on GPU. The other variants of the CNFVAE also give the same bug.

ffjord$ python train_vae_flow.py --dataset mnist --flow cnf_rank --rank 64 --dims 512-512 --num_blocks 1

beta = 0.0100                                                                                                          
.../random/miniconda/lib/python3.7/site-packages/torch/nn/_reduction.py:49: UserWarning
: size_average and reduce args will be deprecated, please use reduction='sum' instead.                                 
  warnings.warn(warning.format(ret))                                                                                   
Traceback (most recent call last):                                                                                     
  File "train_vae_flow.py", line 358, in <module>                                                                      
    run(args, kwargs)                                                                                                  
  File "train_vae_flow.py", line 281, in run                                                                           
    tr_loss = train(epoch, train_loader, model, optimizer, args, logger)                                               
  File ".../ffjord/vae_lib/optimization/training.py", line 42, in train           
    loss.backward()                                                                                                    
  File ".../miniconda/lib/python3.7/site-packages/torch/tensor.py", line 102, in backward                                                                                                              torch.autograd.backward(self, gradient, retain_graph, create_graph)                                                
  File ".../miniconda/lib/python3.7/site-packages/torch/autograd/__init__.py", line 90, in backward                                                                                                    allow_unreachable=True)  # allow_unreachable flag                                                                  
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation           

I have replaced all the ReLU(..., inplace=True) with false, but not sure what else to try. I am wondering if the issue could be in this torchdiffeq library. Also the other examples in the ffjord repo seem to work on GPU, so maybe the issue is isolated to the odeint method of torchdiffeq (which only the cnfvae's use)?

I'm using pytorch 1.0.1 with cuda 9.2. Thanks in advance for any help!

Problem when rtol and atol are iterables

When I use fixed_adam solver and pass multiple rtol and atol, it causes error. The function _has_converged in misc.py cannot handle them. It would be better to handle them just like in _compute_error_ratio.

Can't use grad on the solution when using odeint_adjoint

Hello,
This is a very interesting module that perfectly suits my needs (currently solving ODEs manually using Euler's method, with the accumulation of gradients and huge memory footprint it imply).
The problem is simple, I can't use grad or backward of the solution given by odeint_adjoint.
Minimal example triggering the problem:

import torch
from torchdiffeq import odeint, odeint_adjoint

torch.set_default_tensor_type(torch.DoubleTensor)

class f(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, t, y):
        return y

x0 = torch.tensor([1.], requires_grad=True)
t = torch.linspace(0., 1., 20, requires_grad=True)

y = odeint_adjoint(f(), x0, t, method='dopri5').view(-1)

[dydt] = torch.autograd.grad(torch.sum(y), [t])

Error is:

Traceback (most recent call last):
  File "test_torchdifeq.py", line 41, in <module>
    [dydt] = torch.autograd.grad(torch.sum(y), [t])[0]
  File "/home//p3env/lib/python3.6/site-packages/torch/autograd/__init__.py", line 145, in grad
    inputs, allow_unused)
  File "/home//p3env/lib/python3.6/site-packages/torch/autograd/function.py", line 76, in apply
    return self._forward_cls.backward(self, *args)
  File "/home//boulot/lib/torchdiffeq/torchdiffeq/_impl/adjoint.py", line 78, in backward
    if len(adj_params) == 0:
  File "/home//p3env/lib/python3.6/site-packages/torch/tensor.py", line 411, in __len__
    raise TypeError("len() of a 0-d tensor")
TypeError: len() of a 0-d tensor

This works perfectly well simply using odeint, but then I don't get the O(1) memory adjoint.
I'm pretty sure the solution is simple but I can't get it to work.

Thank you for all future responses.
Leander

misc issue

I am attempting to implement a different dataset but am having an error from _select_initial_step(). f0 is not of the same length as scale which leaves me with the question of how your f0 is. The output of the ODE block isn't the same as the input so I am confused as to how this code works?

d0 = tuple(norm(y0 / scale_) for y0_, scale_ in zip(y0, scale))
d1 = tuple(norm(f0 / scale_) for f0_, scale_ in zip(f0, scale))

Multiple GPUS not supported for MNIST example

When running on a machine with multiple GPUs and using the adjoint method:

python odenet_mnist.py --gpu 2 --adjoint True

I get the following error running the MNIST example:

Traceback (most recent call last):
  File "odenet_mnist.py", line 350, in <module>
    loss.backward()
  File ".../lib/python3.6/site-packages/torch/tensor.py", line 102, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File ".../lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward
    allow_unreachable=True)  # allow_unreachable flag
  File ".../lib/python3.6/site-packages/torch/autograd/function.py", line 76, in apply
    return self._forward_cls.backward(self, *args)
  File ".../torchdiffeq/torchdiffeq/_impl/adjoint.py", line 67, in backward
    func_i = func(t[i], ans_i)
  File ".../lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File ".../torchdiffeq/torchdiffeq/_impl/adjoint.py", line 122, in forward
    return (self.base_func(t, y[0]),)
  File ".../lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "odenet_mnist.py", line 108, in forward
    out = self.conv1(t, out)
  File ".../lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "odenet_mnist.py", line 87, in forward
    tt = torch.ones_like(x[:, :1, :, :]) * t
RuntimeError: binary_op(): expected both inputs to be on same device, but input a is on cuda:2 and input b is on cuda:0

The problem appears to be fixed by changing https://github.com/rtqichen/torchdiffeq/blob/master/examples/odenet_mnist.py#L124 to:

self.integration_time = self.integration_time.type_as(x).to(x.device)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.