This repository contains training, validation, and inference code for object detection using the (YOLOX)[https://github.com/Megvii-BaseDetection/YOLOX] model. "YOLOX is an anchor-free version of YOLO, with a simpler design but better performance! It aims to bridge the gap between research and industrial communities. For more details, please refer to our report on Arxiv."
The PyTorch Lightning 2.x framework is used as a wrapper for the YOLOX model to enable reproducible training and validation, advanced metric logging, and overall flexibility and transparency during the training process. This repository is designed to be extensible; allowing object detection training on custom datasets with configurable network architectures.
Addition utilities in this repos:
- Ability to convert weights to ONNX, TensorRT, and OpenVINO format
Clone the repository into your working directory and install the dependencies.
git clone <repo_link>
cd <dir>
pip install -r requirements.txt # or pip install -v -e .
Download weights pretrained on COCO2017 for each YOLOX architecture from the original YOLOX repository and put them in /weights/pretrained_weights/
.
Verify correct lightning, pytorch, and other dependency installation by running python3 main.py fit --config config.yaml --print_config
YOLOX can be trained for custom object detection tasks if labeled data is provided.
Data should be separated into image and label directories. For example, the datasetd have the following generic structure
image_dir/
image1.png
image2.png
...
label_dir/
image1.txt
image2.txt
...
ATTENTION: Make sure that the albumentations bounding box transformations matches the format of the bounding boxes within the label files (i.e. yolo, albumentations, coco, or pascal_voc). See this link for clarification.
Data augmentations are defined in the constructor of the lightning data module. Change them as necessary here. They will be saved during training with the best set of weights and tensorboard metrics.
Read the tutorial for training on custom datasets here
This repo uses the Lightning CLI to maintain a high level of abstraction during model training and finetuning. All of the training hyperparameters are configurable from config.yaml
. See the Lightning Docs on more information about the various hooks, flags, and callbacks available in Lightning.
Training can be started with this command.
python3 main.py fit --config config.yaml # ./run_train.sh
python3 demo.py --mode dir --image_dir /path/to/images/ --ckpt /path/to/lightning/weights --conf 0.3 --nms 0.4
This command will do the following:
- Load the weights for YOLOX specified by
ckpt
- Iterate through all of the images from
image_dir
- Perform bounding box/class prediction on each image
- Draw bounding boxes on each image if there is a detection within the confidence and nms threshold
- Save the images with bounding boxes draw in folder called
vis_results/
in the same parent directory provided byckpt
.
After weights are converted to onnx
, rename them with the following convention <architecture>_<purpose>.onnx
and place them in weights/
for evaluation.
- Log the loss functions during validation
- Compute and log average and mean average precision during training and validation.
- Inference function (numpy input, numpy output with bounding boxes and labels draw).
- Visualize data augmentations from within the lightning data module (should be specifiable from
config.yaml
) - Ability to specify
yolox_custom
width/depth. - Ability to have warmup epochs right before training
- Move model creation outside of
lit_yolox.py
- Moved data augmentation creation within DroneNetDataModule
- Implement
predict
subcommand in Lightning CLI - Implement
BasePredictionWriter
as alternative to usingdemo.py
- Verify onnx conversion and correct inference within the new Lightning YOLOX framework
- Completed README and tutorials for how to train with new data
- Implement YOLOX C++ postprocessing
- Resetting optimizer/learning rate scheduler states after loading a checkpoint in Lightning CLI
- Mosaic augmentations??
- Further error checking as needed
- Add EMA capability
- Add distributed parallel training capability (DDP)
- Refactor Dataset classes to remove unnecessary outputs from
__getitem__()
The YOLOX architecture was created by its original authors. This Python wrapper repo was developed and maintained by Bassam Bikdash. For any inquiries or support, please contact Bassam Bikdash at [email protected].