Contributers: Ang Li, Fenghao Yang, Yikai Mao, Jinghao Miao
First run get_data.sh
to get Food 101 dataset
To run the code, run the main.py
file and there are arguments that user can input for hyperparmeters, which includes:
- device_id (The id of the gpu to use;
type=int
;default=0
) - model (Model being used including baseline, custom, resnet18, and vgg16;
type=str
;default='custom'
) - pt_ft (Whether model is for partial fine-tune model;
type=int
;default=1
) - bz (Batch size;
type=int
;default=32
) - shuffle_data (Whether shuffle the data;
type=int
;default=1
) - normalization_mean (Mean value of z-scoring normalization for each channel in image;
type=tuple
;default=(0.485, 0.456, 0.406)
) - normalization_std (Standard deviation of z-scoring normalization for each channel in image;
type=tuple
;default=(0.229, 0.224, 0.225)
) - epoch (Number of epoch;
type=int
;default=30
) - criterion (Which loss function to use;
type=str
;default='cross_entropy'
) - optimizer (Which optimizer to use;
type=str
;default='adam'
) - lr (Learning rate;
type=float
;default=1e-4
) - weight_decay (weight decay;
type=float
;default=1e-4
) - lr_scheduling (Whether enable learning rate scheduling;
type=int
;default=0
) - lr_scheduler (Learning rate scheduler;
type=str
;default='steplr'
) - step_size (Period of learning rate decay;
type=int
;default=7
) - gamma (Multiplicative factor of learning rate decay;
type=float
;default=0.1
)
Directly run the code with
python3 main.py
will train the and test the performance of baseline model. The results will be saved in results.pkl
after training is finished, and can be used by visualization.ipynb
to visualize the loss/accuracy graph.
main.py
: file for the entire codeprepare_data.py
: file to load the datasetdata.py
: file to pre-processe data, split train, validation and test set, create dataloadermodel.py
: file with implementation of baseline, custom, resnet-18, vgg-16 modelengine.py
: file to prepare, train, test model, and save the results toresults.pkl
visualization.ipynb
: notebook to plot graphs, visualize weight maps and feature maps