Giter Club home page Giter Club logo

Comments (3)

BenjaminBossan avatar BenjaminBossan commented on May 28, 2024 1

Great explanation @ottonemo. To add a tiny thing, the call to self.double() looks suspicious, you should try working with float32. I assume you ran into a dtype error because model and data had different dtypes, but it's better to cast the input data to float32 instead of casting the model to float64.

from skorch.

ottonemo avatar ottonemo commented on May 28, 2024

Hey there,

thanks for the question(s).

I stepped through your code and there were some things I noticed.

  1. You add an extra dimension in the .forward call and specify in the LSTM layer that batch_first=True. I'm not sure if that is correct, I would have expected the dimension to be added at dimension 2 (last) and not first (0). I would expect the shape of x to have (batch, time, 1) for single value time-series.

  2. The input_size parameter of the LSTM layer is set to 10 but you are saying that you have scalar input values so the input size should be 1 (you only have one number per time-step). This parameter does not define the length of the input time-series but the amount of features/variables it has. Since you define the semantics for the output_size parameter yourself it is fine as it is but needs some modifications (see next point). I would strongly advice to name it appropriately (e.g., output_sequence_length), though.

  3. The fully-connected output layer is your way to generate a fixed length sequence (you will always generate output_size number of elements). You have to make sure to generate a batched sequence, though. The output of your module should match your target sequence. So when your y_true is of shape (batch_size, sequence_length) then your module should also produce (batch_size, sequence_length) - in your case (10, 2). Since your current code produces a tensor with shape (2,) the loss function needs to broadcast this output to match the (10, 2) shape which will lead to wrong results (you are using the same output for all different target values).

  4. You don't need to concatenate anything to get the last output of the LSTM. See the section return values for LSTM in PyTorch to see that in case of bidirectional output the concatenated version is returned to you. Since you name the variable last hidden state I assume that you don't want to use the predictions (shape: B x T x lstm_output_size) but the hidden state (second return value of LSTM.forward, shape: (2 * num_layers, batch_size, lstm_hidden_size)). So you need to change your code to use the second return value. The rest of the code however, looks as though you want to use the LSTM's output of the last layer of the last time-step, so using lstm(x)[0][:, -1,:] should be good.

Additionally (and personally) I would refrain from using the terms train/test for anything else than indications for train/test splits, e.g. X_train or y_test. The things you pass to .fit() are almost certainly both training data. Standard notation is X for feature data and y for target values. This avoids a lot of confusion when it comes to evaluations.

Here's a modified version of your code in which I applied the above comments. This is not checked for correctness as I don't understand your task fully but I hope it helps you nevertheless:

Modified code
import torch
from skorch import NeuralNet
from skorch.callbacks import EarlyStopping

torch.manual_seed(69420)

time_series_data = torch.tensor(range(1, 501), dtype=torch.float64)
sequence_length = 10
output_length = 2


def get_train_test(input_seq, seq_len, output_len):
    # this value avoids going outside the boundaries of the input seq
    num_seqs = len(input_seq) - seq_len - output_len + 1

    train = [input_seq[x : x + seq_len] for x in range(num_seqs)]
    test = [input_seq[x + seq_len : x + seq_len + output_len] for x in range(num_seqs)]

    return torch.stack(train, dim=0), torch.stack(test, dim=0)


train_seq, test_seq = get_train_test(time_series_data, sequence_length, output_length)

print(test_seq[-1], train_seq[-1])
print(test_seq[0], train_seq[0])


class TimeSeriesPredictor(torch.nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.output_size = output_size

        # Bidirectional LSTM layers
        self.lstm = torch.nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
        )

        # Fully connected layer for output
        # hidden_size is doubled since the lstm is bidirectional
        self.fc = torch.nn.Linear(hidden_size * 2, output_size)

        self.double()

    def forward(self, x):
        x = x.unsqueeze(dim=2)

        # LSTM layer
        outputs, (_last_hidden_state, _last_cell_state) = self.lstm(x)

        # Extract the last output from both directions
        last_output = outputs[:, -1, :]

        # Fully connected layer for output
        output = self.fc(last_output)

        return output


train_early_stop = EarlyStopping(
    monitor="train_loss",
    # patience=5,
    # threshold=1e-4,
    threshold_mode="abs",
    lower_is_better=True,
    load_best=False,
)

model = NeuralNet(
    module=TimeSeriesPredictor,
    module__input_size=1,
    module__hidden_size=200,
    module__num_layers=6,
    module__output_size=output_length,
    max_epochs=200,
    batch_size=10,
    callbacks=[train_early_stop],
    criterion=torch.nn.MSELoss,
    optimizer=torch.optim.Adam,
    verbose=1,
    warm_start=True,
)

model.fit(X=train_seq[:-10], y=test_seq[:-10])

print(model.predict(train_seq[-10:]))

from skorch.

chris-fj avatar chris-fj commented on May 28, 2024

Hello, everyone. Thanks for your answers, specially @ottonemo. I greatly appreciate that you took the time to go through my code and explain with details the improvable things you have seen. I have understood now why the problem I had regarding the warning and now I know how to solve it in future cases. Also thanks @BenjaminBossan because that's exactly why I added self.double(), in a moment where I was not thinking clearly. Downcasting to float32 has also improved runtime

from skorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.