Comments (4)
from crabnet.
Good point. I think I've been a bit confused as to whether sometimes people use "validation data" to adjust hyperparameters of the model (in some sense I think we all do this, but maybe a bit more manually as we run models while debugging/developing), hence the idea of a third "test" set. Just wanted to make sure. I agree with your definition. Thanks Anthony!
from crabnet.
Ok, I think my suspicion is confirmed (validation data is used to improve CrabNet's training results). I had been looking and wondering where the validation set may have been used during the training process, and I found somewhere in train
that appears to use the validation data to improve generalization.
Lines 115 to 119 in e884482
Near the end of training (not sure why this is on the 2nd to last epoch instead of the last epoch), it goes into L273:
Lines 272 to 273 in e884482
See PyTorch 1.6 now includes Stochastic Weight Averaging (blog)
Recently, I came to an error where pred_v
was all np.nan
in one of the crabnet-hyperparameter combinations, and decided to check the size of pred_v
using the original repo and example_materials_property
. pred_v
turned out to be the same size as the validation data (727
), compared to training batch_size==256
and was consistent across consecutive jumps to the same breakpoint. The total # of training data points is 3433
in this case.
from crabnet.
Empirical tests
Empirical tests also seem to confirm this (similar to my question about using "dummy" validation data). 42
, 43
, and 44
refer to the torch seeds used. The 3 numbers that follow the colon (:
) are the train/val/test MAEs.
1. Using validation data to calculate mae
for update_swa
seed: train MAE, val MAE, test MAE
42: 6.53, 10.7, 9.82
43: 6.38, 10.3, 9.63
44: 6.49, 10.5, 9.95
2. Shuffling validation predictions just prior to swa_update
Epoch 39 failed to improve.
Discarded: 1/3 weight updates ♻🗑️
42: 9.77, 11.9, 11.4
43: 9.19, 11.5, 10.7
44: 9.77, 11.9, 11.4
3. Using training data to calculate mae
for update_swa
44: 6.35, 10.6, 9.85
4. Add val.csv
data to train.csv
and leave val.csv
intact, use train.csv
for update_swa
This is relevant to #19
42: 6.49, 6.96, 9.39
43: 6.58, 6.98, 9.38
44: 6.68, 6.89, 9.64
5. Add val.csv
data to train.csv
and drop all but 1 datapoint from val.csv
, use train.csv
for update_swa
Just to make sure that val.csv
isn't affecting the test score in other ways than what I've described here. Since the results match up with those immediately above, I think SWA is the only place where val.csv
affects the training process.
42: 6.49, 27.3, 9.39
43: 6.58, 26, 9.38
44: 6.68, 22.1, 9.64
Takeaways
- The validation data affects the training process through SWA
- The average test MAE is somewhat lower for (4) than for (1), (
9.74
vs.9.80
), but the difference is marginal enough that it might be the opposite for other datasets. In general, I would think the current setup (1) is probably more robust to extrapolation. Adding the 25% validation data back into training data in thematbench
benchmark may not necessarily improve model performance #19.
from crabnet.
Related Issues (20)
- Multiclass classification HOT 5
- pip or conda install HOT 6
- Is Python 3.8+ a known requirement, or are earlier versions just untested? HOT 1
- CrabNet sometimes ignores/skips certain compounds. Why? How to keep track of compound IDs? HOT 4
- fit() and predict() methods
- Seems like an "extend_features" option for CrabNet could be useful for several people HOT 11
- CrabNet matbench results - possibly neglecting 25% of the training data it could have used HOT 2
- the classification criterion doesn't factor in the uncertainty - does this mean ignore the uncertainty for classification?
- pinned pytorch and cudatoolkit dependencies possibly defunct for GPU usage HOT 1
- attention-heads as samples from posterior distribution in a Bayesian sense
- CrabNet matbench data, possible mismatch between submission notebook and results HOT 2
- Reproducing RooSt results, error using RooSt Colab example is higher than what's reported in CrabNet paper
- Parameter used for the results in the published work
- Add skipatom featurizer to the repository
- Facing issue while importing model (from crabnet.model import Model) HOT 1
- The order of the two hyperlinks on the "Publications / How to cite" module in the readme file seems to be reversed HOT 1
- AttributeError: 'SWA' object has no attribute '_optimizer_step_pre_hooks'. Did you mean: '_optimizer_step_code'?
- Reason to use pe_scaler and ple_scaler
- Do any of the CSV files in data/element_properties/ need citations? (e.g. in a README in the element_properties folder) HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from crabnet.