Giter Club home page Giter Club logo

Comments (4)

patrick-kidger avatar patrick-kidger commented on June 4, 2024

That's not normal behaviour. Can you provide a MWE demonstrating this? [1] This usually occurs when you have a mistake in your training loop, nothing to do with torchsde.

The cache_size parameter doesn't have units like that -- it specifies the size of the LRU cache used in the Brownian Interval, which is measured in number of samples. That's pretty technical; you should almost certainly leave it alone. (If you're trying to reduce memory consumption then I'd look elsewhere -- the memory consumed by the Brownian Interval is generally negligible.)

[1] I'd note that if you're using adaptive step sizes without the adjoint method then this can plausibly occur because of the complexity changing during training -- typically NFEs go up over training -- but that's usually not substantial.

from torchsde.

lxuechen avatar lxuechen commented on June 4, 2024

Hi,

I see the chances of there being some subtle bug in torchsde regarding memory to be quite low. Though, I don't want to totally rule out that possibility. If you have a piece of minimum reproducible code, I'd also be happy to take a look and see if I can help solve the problem. It's very hard to reason about things without context.

from torchsde.

rubick1896 avatar rubick1896 commented on June 4, 2024

Thanks for the reply. It takes some time to work for a minimum reproducible code since the model also relies on some framework in the middle and I don't control the training loop directly. I want to provide more context here in words.

I am trying to model a rolling forecasting problem. Given x1,x2...xt, predict xt+1, and at the next time step, now you have x1,x2...xt+1, and predict xt+2. The training strategy is to find a random split point t, look back t time steps to get a single training case, and use t+1 as the label. So unlike standard NN traning, in each epoch, the training cases are not the same, they are random samples from some long time series.

If my description is not clear, maybe take a look at the code here.
https://github.com/zalandoresearch/pytorch-ts/blob/master/pts/transform/split.py

I am using an RNN to encode a training case and feed that to torchsde.

I just want to make sure that this is not a game-changer and I should still expect the memory usage to be constant between different epochs. If so, I will try to work for a minimum reproducible code.

from torchsde.

lxuechen avatar lxuechen commented on June 4, 2024

The expectation is that VRAM usage should roughly stay the same across different gradient updates. If you're using adaptive solvers, then yes, it's fairly possible that the learned dynamics become harder to solver, therefore requiring a lot more function evaluations.

Somewhat contrary to what Patrick has suggested, I've seen quite a few cases where NFE at the start of training could be much different than that at the end of training.

It seems your description doesn't provide the setting at the granularity that we could help with. You're mentioning data x1,x2...xt+1 and target xt+2. Does t change over different gradient updates?

I'm closing this issue for now since not enough background is given. Feel free to reopen if you could provide us with more context or at the very least create a small example that points towards a problem.

from torchsde.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.