Giter Club home page Giter Club logo

ssd-pytorch's Introduction

SSD:Single-Shot MultiBox Detector目标检测模型在Pytorch当中的实现


目录

  1. 仓库更新 Top News
  2. 性能情况 Performance
  3. 所需环境 Environment
  4. 文件下载 Download
  5. 训练步骤 How2train
  6. 预测步骤 How2predict
  7. 评估步骤 How2eval
  8. 参考资料 Reference

Top News

2022-03:进行了大幅度的更新,支持step、cos学习率下降法、支持adam、sgd优化器选择、支持学习率根据batch_size自适应调整、新增图片裁剪。
BiliBili视频中的原仓库地址为:https://github.com/bubbliiiing/ssd-pytorch/tree/bilibili

2021-10:进行了大幅度的更新,增加了mobilenetv2主干的选择、增加大量注释、增加了大量可调整参数、对代码的组成模块进行修改、增加fps、视频预测、批量预测等功能。

性能情况

训练数据集 权值文件名称 测试数据集 输入图片大小 mAP 0.5:0.95 mAP 0.5
VOC07+12 ssd_weights.pth VOC-Test07 300x300 - 78.55
VOC07+12 mobilenetv2_ssd_weights.pth VOC-Test07 300x300 - 71.32

所需环境

torch == 1.2.0

文件下载

训练所需的ssd_weights.pth和主干的权值可以在百度云下载。
链接: https://pan.baidu.com/s/1iUVE50oLkzqhtZbUL9el9w
提取码: jgn8

VOC数据集下载地址如下,里面已经包括了训练集、测试集、验证集(与测试集一样),无需再次划分:
链接: https://pan.baidu.com/s/1-1Ej6dayrx3g0iAA88uY5A
提取码: ph32

训练步骤

a、训练VOC07+12数据集

  1. 数据集的准备
    本文使用VOC格式进行训练,训练前需要下载好VOC07+12的数据集,解压后放在根目录

  2. 数据集的处理
    修改voc_annotation.py里面的annotation_mode=2,运行voc_annotation.py生成根目录下的2007_train.txt和2007_val.txt。

  3. 开始网络训练
    train.py的默认参数用于训练VOC数据集,直接运行train.py即可开始训练。

  4. 训练结果预测
    训练结果预测需要用到两个文件,分别是ssd.py和predict.py。我们首先需要去ssd.py里面修改model_path以及classes_path,这两个参数必须要修改。
    model_path指向训练好的权值文件,在logs文件夹里。
    classes_path指向检测类别所对应的txt。

    完成修改后就可以运行predict.py进行检测了。运行后输入图片路径即可检测。

b、训练自己的数据集

  1. 数据集的准备
    本文使用VOC格式进行训练,训练前需要自己制作好数据集,
    训练前将标签文件放在VOCdevkit文件夹下的VOC2007文件夹下的Annotation中。
    训练前将图片文件放在VOCdevkit文件夹下的VOC2007文件夹下的JPEGImages中。

  2. 数据集的处理
    在完成数据集的摆放之后,我们需要利用voc_annotation.py获得训练用的2007_train.txt和2007_val.txt。
    修改voc_annotation.py里面的参数。第一次训练可以仅修改classes_path,classes_path用于指向检测类别所对应的txt。
    训练自己的数据集时,可以自己建立一个cls_classes.txt,里面写自己所需要区分的类别。
    model_data/cls_classes.txt文件内容为:

cat
dog
...

修改voc_annotation.py中的classes_path,使其对应cls_classes.txt,并运行voc_annotation.py。

  1. 开始网络训练
    训练的参数较多,均在train.py中,大家可以在下载库后仔细看注释,其中最重要的部分依然是train.py里的classes_path。
    classes_path用于指向检测类别所对应的txt,这个txt和voc_annotation.py里面的txt一样!训练自己的数据集必须要修改!
    修改完classes_path后就可以运行train.py开始训练了,在训练多个epoch后,权值会生成在logs文件夹中。

  2. 训练结果预测
    训练结果预测需要用到两个文件,分别是ssd.py和predict.py。在ssd.py里面修改model_path以及classes_path。
    model_path指向训练好的权值文件,在logs文件夹里。
    classes_path指向检测类别所对应的txt。

    完成修改后就可以运行predict.py进行检测了。运行后输入图片路径即可检测。

预测步骤

a、使用预训练权重

  1. 下载完库后解压,在百度网盘下载,放入model_data,运行predict.py,输入
img/street.jpg
  1. 在predict.py里面进行设置可以进行fps测试和video视频检测。

b、使用自己训练的权重

  1. 按照训练步骤训练。
  2. 在ssd.py文件里面,在如下部分修改model_path和classes_path使其对应训练好的文件;model_path对应logs文件夹下面的权值文件,classes_path是model_path对应分的类
_defaults = {
    #--------------------------------------------------------------------------#
    #   使用自己训练好的模型进行预测一定要修改model_path和classes_path!
    #   model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt
    #   如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改
    #--------------------------------------------------------------------------#
    "model_path"        : 'model_data/ssd_weights.pth',
    "classes_path"      : 'model_data/voc_classes.txt',
    #---------------------------------------------------------------------#
    #   用于预测的图像大小,和train时使用同一个即可
    #---------------------------------------------------------------------#
    "input_shape"       : [300, 300],
    #-------------------------------#
    #   主干网络的选择
    #   vgg或者mobilenetv2
    #-------------------------------#
    "backbone"          : "vgg",
    #---------------------------------------------------------------------#
    #   只有得分大于置信度的预测框会被保留下来
    #---------------------------------------------------------------------#
    "confidence"        : 0.5,
    #---------------------------------------------------------------------#
    #   非极大抑制所用到的nms_iou大小
    #---------------------------------------------------------------------#
    "nms_iou"           : 0.45,
    #---------------------------------------------------------------------#
    #   用于指定先验框的大小
    #---------------------------------------------------------------------#
    'anchors_size'      : [30, 60, 111, 162, 213, 264, 315],
    #---------------------------------------------------------------------#
    #   该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize,
    #   在多次测试后,发现关闭letterbox_image直接resize的效果更好
    #---------------------------------------------------------------------#
    "letterbox_image"   : False,
    #-------------------------------#
    #   是否使用Cuda
    #   没有GPU可以设置成False
    #-------------------------------#
    "cuda"              : True,
}
  1. 运行predict.py,输入
img/street.jpg
  1. 在predict.py里面进行设置可以进行fps测试和video视频检测。

评估步骤

a、评估VOC07+12的测试集

  1. 本文使用VOC格式进行评估。VOC07+12已经划分好了测试集,无需利用voc_annotation.py生成ImageSets文件夹下的txt。
  2. 在ssd.py里面修改model_path以及classes_path。model_path指向训练好的权值文件,在logs文件夹里。classes_path指向检测类别所对应的txt。
  3. 运行get_map.py即可获得评估结果,评估结果会保存在map_out文件夹中。

b、评估自己的数据集

  1. 本文使用VOC格式进行评估。
  2. 如果在训练前已经运行过voc_annotation.py文件,代码会自动将数据集划分成训练集、验证集和测试集。如果想要修改测试集的比例,可以修改voc_annotation.py文件下的trainval_percent。trainval_percent用于指定(训练集+验证集)与测试集的比例,默认情况下 (训练集+验证集):测试集 = 9:1。train_percent用于指定(训练集+验证集)中训练集与验证集的比例,默认情况下 训练集:验证集 = 9:1。
  3. 利用voc_annotation.py划分测试集后,前往get_map.py文件修改classes_path,classes_path用于指向检测类别所对应的txt,这个txt和训练时的txt一样。评估自己的数据集必须要修改。
  4. 在ssd.py里面修改model_path以及classes_path。model_path指向训练好的权值文件,在logs文件夹里。classes_path指向检测类别所对应的txt。
  5. 运行get_map.py即可获得评估结果,评估结果会保存在map_out文件夹中。

Reference

https://github.com/pierluigiferrari/ssd_keras
https://github.com/kuhung/SSD_keras

ssd-pytorch's People

Contributors

bubbliiiing avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

ssd-pytorch's Issues

多卡训练时报错

Traceback (most recent call last): File "train.py", line 107, in <module> out = net(images) File "/home/walker2/anaconda3/envs/pytorch1.2/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/walker2/anaconda3/envs/pytorch1.2/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 153, in forward return self.gather(outputs, self.output_device) File "/home/walker2/anaconda3/envs/pytorch1.2/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 165, in gather return gather(outputs, output_device, dim=self.dim) File "/home/walker2/anaconda3/envs/pytorch1.2/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 68, in gather res = gather_map(outputs) File "/home/walker2/anaconda3/envs/pytorch1.2/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 63, in gather_map return type(out)(map(gather_map, zip(*outputs))) File "/home/walker2/anaconda3/envs/pytorch1.2/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 55, in gather_map return Gather.apply(target_device, dim, *outputs) File "/home/walker2/anaconda3/envs/pytorch1.2/lib/python3.6/site-packages/torch/nn/parallel/_functions.py", line 54, in forward assert all(map(lambda i: i.is_cuda, inputs)) AssertionError
在train.py里添加
import os os.environ["CUDA_VISIBLE_DEVICES"] = "0"
使用单卡时训练正常

bug

你好,我在读你的代码的时候,发现了个小问题,可能会造成死循环。
image
在dataloader文件中的SSDDataset中的__getitem__方法,你用While True应该是想循环找到符合要求的item,但是截图的这部分代码中,如果返回的标签数据y的长度是0的话,由于没有修改index,就会造成死循环。

'steps': [8, 16, 32, 64, 100, 300]这几个数的计算过程

您好,看了您的哔哩哔哩视频讲解之后有一个问题向请教一下。
'feature_maps': [38, 19, 10, 5, 3, 1],
'steps': [8, 16, 32, 64, 100, 300],
就是这个steps的计算过程,按照您的视频讲解是300/8≈38,前两个数(8,16)是符合的,但是第三个数32不符合啊,如果改成30,那么300/30=10了?为什么是32而不是30?

map计算及可视化

您好,我在仿写get_dr_txt.py过程中发现检测的结果中有的图像未检测到任何目标,全是背景,所以并不会生成对应的txt,那么get_map.py也会报错,请问您遇到过这个问题么?

不使用预训练权重训练

我直接将加载预训练权重的代码注释了,训练得到的模型计算的mAP指标都非常低,多训练几轮结果也不好,要怎么样训练无预训练的模型

大佬,我在运行get_dr_txt时,出现如下问题。所用的图片尺寸为320*320,您看看该怎么解决?谢谢!

Traceback (most recent call last):

File "C:\ssd-pytorch-master\get_dr_txt.py", line 77, in
ssd.detect_image(image_id,image)

File "C:\ssd-pytorch-master\get_dr_txt.py", line 25, in detect_image
preds = self.net(photo)

File "C:\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 532, in call
result = self.forward(*input, **kwargs)

File "C:\ssd-pytorch-master\nets\ssd.py", line 69, in forward
self.priors

File "C:\ssd-pytorch-master\nets\ssd_layers.py", line 32, in forward
self.num_classes).transpose(2, 1)

RuntimeError: shape '[1, 8732, 2]' is invalid for input of size 19180

大佬好,我想请问下k-means聚类生成的anchor box坐标放到yoloanchor里训练后,为啥map就只有百分之1点几了呢?

请问是坐标的放置问题吗?我是按照一行一行放入的
放入后的yolo_anchors文件内容是:8, 11, 12, 17, 12, 48, 27, 30, 27, 12, 47, 51, 65, 16, 111, 33, 226, 99
acc:69.40%
[[ 6.95440729 10.14634146]
[ 8.21884498 21.98373984]
[ 10.11550152 11.83739837]
[ 13.27659574 17.75609756]
[ 18.3343465 33.82113821]
[ 41.09422492 46.50406504]
[ 47.41641337 15.2195122 ]
[107.47720365 32.97560976]
[226.3343465 99.77235772]]

ValueError: invalid literal for int() with base 10

Traceback (most recent call last):
File "E:/python/Semantic_Segmentation/ssd/voc_annotation.py", line 30, in
convert_annotation(year, image_id, list_file)
File "E:/python/Semantic_Segmentation/ssd/voc_annotation.py", line 20, in convert_annotation
b = (int(xmlbox.find('xmin').text), int(xmlbox.find('ymin').text), int(xmlbox.find('xmax').text), int(xmlbox.find('ymax').text))
ValueError: invalid literal for int() with base 10: '45.70000076293945'

使用VOC2012数据集时报错,不知道引起问题的原因是?

训练自己数据集时,train.py报错?向大佬求助

Loading weights into state dict...
Finished!
Traceback (most recent call last):
File "/data/iiot-data/caixh/ssd-pytorch/train.py", line 108, in
out = net(images)
File "/home/amax/anaconda3/envs/iiot/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in call
result = self.forward(*input, **kwargs)
File "/home/amax/anaconda3/envs/iiot/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 153, in forward
return self.gather(outputs, self.output_device)
File "/home/amax/anaconda3/envs/iiot/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 165, in gather
return gather(outputs, output_device, dim=self.dim)
File "/home/amax/anaconda3/envs/iiot/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 68, in gather
res = gather_map(outputs)
File "/home/amax/anaconda3/envs/iiot/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 63, in gather_map
return type(out)(map(gather_map, zip(*outputs)))
File "/home/amax/anaconda3/envs/iiot/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 55, in gather_map
return Gather.apply(target_device, dim, *outputs)
File "/home/amax/anaconda3/envs/iiot/lib/python3.7/site-packages/torch/nn/parallel/_functions.py", line 54, in forward
assert all(map(lambda i: i.is_cuda, inputs))
AssertionError
不知道bug如何解决?

训练图片尺寸

训练时除了使用(300,300,3)规格,能不能使用(460, 460, 3)的图片呢?

小小的建议

建议不要将数据集放入仓库,不然容量太大,说实话下载有难度。谢谢

大佬,咨询几个问题

SSD和SSdlite 是同一个模型吗,只是主干网络不同? 还有SSD300 和 ssdlite320 后面的数字是啥意思

复现之后,mAP达不到那么高

‍博主大神,一直关注你的微博和视频学到现在,我复现了你的SSD—pytorch,训练了VOC2012数据集,但是怎么达不到你的模型效果,什么参数都没改,结果却不理想。Epoch=150,loss=1.4+0.25左右,哪些参数需要调整吗?真诚希望得到你的指导!多谢!‍

RuntimeError: Legacy autograd function with non-static forward method is deprecated.

环境问题

torch>=1.3

predict.py会报错

RuntimeError: Legacy autograd function with non-static forward method is deprecated.

在pytorch1.3及以后的版本需要规定forward方法为静态方法,对此修改以下片段就能work
before

if self.phase == "test":
  output = self.detect.forward(
      loc.view(loc.size(0), -1, 4),
      self.softmax(conf.view(conf.size(0), -1, self.num_classes)),
      self.priors              
  )

after

if self.phase == "test":
    output = self.detect.forward(
        loc.view(loc.size(0), -1, 4),
        self.softmax(conf.view(conf.size(0), -1, self.num_classes)),
        self.priors              
    )

关于预测报错

self.priors.type(type(x.data))  # default boxes

TypeError: forward() takes 4 positional arguments but 9 were given

predit.py出错

博主您好,训练自己数据集时,训练没有问题,但使用predit.py文件进行预测时,一张图片上遍满了预测框和类别,请问这个问题该如何解决。

运行train.py下载完权重之后一直就没反应了

昨天准备跑大佬的ssd,但是运行train.py后一直没动静了,换了实验室两台设备都是一模一样,Nvidia AGX 32G的GPU也是无动静,所以我肯定不是设备的原因,跑其余网络没问题,目前就SSD无法训练。

运行predict.py时报错

RuntimeError: Legacy autograd function with non-static forward method is deprecated. Please use new-style autograd function with static forward method.

get_random_data repeatedly?

Function get_random_data in utils/dataloader and function get_random_data in nets/ssd_training might do data augmentation repeatedly. I think adding parameter "random = False" in nets/ssd_training line 261 will avoid repitition.

关于fit_one_epoch函数

fit_one_epoch 的前两个参数model_train 和model

在train.py 中 model_train = model.train() 我理解为此时model和model_train是同一个东西
但是如果使用多卡,model_train = torch.nn.DataParallel(model) 这时候他们就不一样了吧

训练过程中更新的是model_train的参数,但是最后却保存了model的。不知这样做有什么目的

GPU内存不够

RuntimeError: CUDA out of memory. Tried to allocate 1.62 GiB (GPU 0; 6.00 GiB total capacity; 921.75 MiB already allocated; 3.54 GiB free; 4.25 MiB cached)

需要1.62G,明明有空闲的3.54G,还是报错,Batchsize调到1还是报错,搜了好多博客也没解决,咋回事呀

ssd_weight.pth的作用

自己对网络结构进行了一部分的修改,增加了几层网络。然后报错,请问ssd_weight.pth是什么呢?如果修改网络增加了一些层之后这个权重文件还能使用么?如果不能使用自己如何训练一个ssd_weight.pth呢或者不用这个如何训练

预测出错

up,我在预测的是出了这个问题,具体原因是什么呢
image

net中的SSD.py里面向前传播时x = self.vgg[k](x)是什么意思呢?

def forward(self, x):
#---------------------------#
# x是300,300,3
#---------------------------#
sources = list()
loc = list()
conf = list()

    #---------------------------#
    #   获得conv4_3的内容
    #   shape为38,38,512
    #---------------------------#
    if self.backbone_name == "vgg":
        for k in range(23):
            x = self.vgg[k](x)

请问这里的最后一行
x=self.vggk是什么意思呢

大佬,运行你的ssdpytorch代码在train.py训练时报错

错误提示:
Traceback (most recent call last):
File "D:/python_code/ssd网络/ssd-pytorch-master/train.py", line 111, in
images = batch[0]
TypeError: 'ExceptionWrapper' object does not support indexing

Process finished with exit code 1
附batch代码
for iteration, batch in enumerate(gen):
if Use_Data_Loader:
train_dataset = SSDDataset(lines[:num_train], (Config["min_dim"], Config["min_dim"]))
gen = DataLoader(train_dataset, batch_size=Batch_size, num_workers=8, pin_memory=True,
drop_last=True, collate_fn=ssd_dataset_collate)

训练时候会进去死循环,无法训练

这个项目存在严重BUG,这个我看到了#15楼一样反馈,再dataloader.py的106行,我打印输出,死循环了,希望作者修复一下
img, y = self.get_random_data(lines[index], self.image_size[0:2])
if len(y)==0:
print("--------------------------")
index = (index + 1) % n
continue

当我改为6分类时报错

Traceback (most recent call last):
File "D:/python project/study/CNN/ssd-pytorch-flower/train.py", line 37, in
model.load_state_dict(pretrained_dict)
File "D:\工作\python3.6\lib\site-packages\torch\nn\modules\module.py", line 830, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for SSD:
size mismatch for conf.0.weight: copying a param with shape torch.Size([84, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([28, 512, 3, 3]).
size mismatch for conf.0.bias: copying a param with shape torch.Size([84]) from checkpoint, the shape in current model is torch.Size([28]).
size mismatch for conf.1.weight: copying a param with shape torch.Size([126, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([42, 1024, 3, 3]).
size mismatch for conf.1.bias: copying a param with shape torch.Size([126]) from checkpoint, the shape in current model is torch.Size([42]).
size mismatch for conf.2.weight: copying a param with shape torch.Size([126, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([42, 512, 3, 3]).
size mismatch for conf.2.bias: copying a param with shape torch.Size([126]) from checkpoint, the shape in current model is torch.Size([42]).
size mismatch for conf.3.weight: copying a param with shape torch.Size([126, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([42, 256, 3, 3]).
size mismatch for conf.3.bias: copying a param with shape torch.Size([126]) from checkpoint, the shape in current model is torch.Size([42]).
size mismatch for conf.4.weight: copying a param with shape torch.Size([84, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([28, 256, 3, 3]).
size mismatch for conf.4.bias: copying a param with shape torch.Size([84]) from checkpoint, the shape in current model is torch.Size([28]).
size mismatch for conf.5.weight: copying a param with shape torch.Size([84, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([28, 256, 3, 3]).
size mismatch for conf.5.bias: copying a param with shape torch.Size([84]) from checkpoint, the shape in current model is torch.Size([28]).
当我改为6分类问题时报错如上,已修改config.py,voc_annotation.py,voc_classes.txt中的class类别,但仍然报错,请问是ssd_weights.pth的问题吗,应如何解决?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.