Giter Club home page Giter Club logo

Comments (4)

anthony-wang avatar anthony-wang commented on June 18, 2024

from crabnet.

sgbaird avatar sgbaird commented on June 18, 2024

Good point. I think I've been a bit confused as to whether sometimes people use "validation data" to adjust hyperparameters of the model (in some sense I think we all do this, but maybe a bit more manually as we run models while debugging/developing), hence the idea of a third "test" set. Just wanted to make sure. I agree with your definition. Thanks Anthony!

from crabnet.

sgbaird avatar sgbaird commented on June 18, 2024

Ok, I think my suspicion is confirmed (validation data is used to improve CrabNet's training results). I had been looking and wondering where the validation set may have been used during the training process, and I found somewhere in train that appears to use the validation data to improve generalization.

CrabNet/crabnet/model.py

Lines 115 to 119 in e884482

if learning_time:
act_v, pred_v, _, _ = self.predict(self.data_loader)
mae_v = mean_absolute_error(act_v, pred_v)
self.optimizer.update_swa(mae_v)
minima.append(self.optimizer.minimum_found)

Near the end of training (not sure why this is on the 2nd to last epoch instead of the last epoch), it goes into L273:

CrabNet/crabnet/model.py

Lines 272 to 273 in e884482

if not (self.optimizer.discard_count >= self.discard_n):
self.optimizer.swap_swa_sgd()

See PyTorch 1.6 now includes Stochastic Weight Averaging (blog)

Recently, I came to an error where pred_v was all np.nan in one of the crabnet-hyperparameter combinations, and decided to check the size of pred_v using the original repo and example_materials_property. pred_v turned out to be the same size as the validation data (727), compared to training batch_size==256 and was consistent across consecutive jumps to the same breakpoint. The total # of training data points is 3433 in this case.

@anthony-wang @Kaaiian

from crabnet.

sgbaird avatar sgbaird commented on June 18, 2024

Empirical tests

Empirical tests also seem to confirm this (similar to my question about using "dummy" validation data). 42, 43, and 44 refer to the torch seeds used. The 3 numbers that follow the colon (:) are the train/val/test MAEs.

1. Using validation data to calculate mae for update_swa

seed: train MAE, val MAE, test MAE
42: 6.53, 10.7, 9.82
43: 6.38, 10.3, 9.63
44: 6.49, 10.5, 9.95

2. Shuffling validation predictions just prior to swa_update

Epoch 39 failed to improve.
Discarded: 1/3 weight updates ♻🗑️

42: 9.77, 11.9, 11.4
43: 9.19, 11.5, 10.7
44: 9.77, 11.9, 11.4

3. Using training data to calculate mae for update_swa

44: 6.35, 10.6, 9.85

4. Add val.csv data to train.csv and leave val.csv intact, use train.csv for update_swa

This is relevant to #19

42: 6.49, 6.96, 9.39
43: 6.58, 6.98, 9.38
44: 6.68, 6.89, 9.64

5. Add val.csv data to train.csv and drop all but 1 datapoint from val.csv, use train.csv for update_swa

Just to make sure that val.csv isn't affecting the test score in other ways than what I've described here. Since the results match up with those immediately above, I think SWA is the only place where val.csv affects the training process.

42: 6.49, 27.3, 9.39
43: 6.58, 26, 9.38
44: 6.68, 22.1, 9.64

Takeaways

  • The validation data affects the training process through SWA
  • The average test MAE is somewhat lower for (4) than for (1), (9.74 vs. 9.80), but the difference is marginal enough that it might be the opposite for other datasets. In general, I would think the current setup (1) is probably more robust to extrapolation. Adding the 25% validation data back into training data in the matbench benchmark may not necessarily improve model performance #19.

from crabnet.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.