Giter Club home page Giter Club logo

vehicle-car-detection-and-multilabel-classification's Introduction

Vehicle-Car-detection-and-multilabel-classification 车辆检测和多标签属性识别

一个基于Pytorch精简的框架,使用YOLO_v3_tiny和B-CNN实现街头车辆的检测和车辆属性的多标签识别。
(A precise pytorch based framework for using yolo_v3_tiny to do vehicle or car detection and attribute's multilabel classification or recognize)

效果如下: Vehicle detection and recognition results are as follows:


使用方法 Usage

python Vehicle_DC -src_dir your_imgs_dir -dst_dir your_result_dir

训练好的模型文件(包括车辆检测模型和多标签分类模型) trained models on baidu drive

Tranied models-vehicle detection
Tranied models-vehicle classification
在运行Vehicle_DC脚本之前,先下载上面的模型文件或者使用自己预先训练好的模型文件,将car_540000.weights(用于检测)放在项目根目录,将epoch_39.pth(用于多标签识别)放在根目录下的checkpoints目录下,即可使用Vehicle_DC运行。
Before running Vehicle_DC, you should download provided model files provided above or use your own pretrained models. If using models provided, you need to place car_540000.weights on root directory of this project, and place epoch_39.pth on root/checkpoints/.

程序简介 brief introductions

(1). 程序包含两大模块:
The program consists of two parts: first, car detection(only provides model loading and inference code, if you need training code, you can refer to pytorch_yolo_v3); the car attributes classiyfing(provide both training and testing code, it will predict a vehicle's body color, body direction and car type)

<1>. 车辆检测模块: 只提供检测, 训练代码可以参考pytorch_yolo_v3;
<2>. 多标签识别模块:包含车辆颜色、车辆朝向、车辆类型

将这两个模块结合在一起,可以同时实现车辆的检测和识别。以此为基础,可以对室外智能交通信息,进行一定程度的结构化信息提取。
Combining these two modules together, you can do vehicle detection and multi-label recognization at the same time. Based on this info, some structured infos in outdoor traffic scenes can be extracted.

(2). 程序模块详解 modules detailed introduction

<1>. VehicleDC.py

此模块是车辆检测和车辆多标签识别接口的封装,需要指定测试源目录和结果输出目录。主类Car_DC, 函数__init__主要负责汽车检测、汽车识别两个模型的初始化。 函数detect_classify负责逐张对图像进行检测和识别:首先对输入图像进行预处理,统一输入格式,然后,输出该图像所有的车的检测框。通过函数process_predict做nms, 坐标系转换,得到所有最终的检测框。然后,程序调用函数cls_draw_bbox,在cls_draw_bbox中,逐一处理每个检测框。首先,取出原图像检测框区域检测框对应的的ROI(region of interest), 将ROI送入车辆多标签分类器。分类器调用B-CNN算法对ROI中的车辆进行多标签属性分类。参考paper link。B-CNN主要用于训练端到端的细粒度分类。本程序对论文中的网络结构做了一定的适应性修改:为了兼顾程序的推断速度和准确度,不同于论文中采用的Vgg-16,这里的B-CNN的基础网络采用Resnet-18。
This module is responsible for interface encapsulation of vehicle detection and multi-label classification. You need to specify source directory and result directory. The main class is Car_DC. The pretrained models are loaded and initiated in function init(). In function detect_classify, each input image is pre-processed to get uniformed format, then output the raw bounding boxes for further NMS calculation and coordinates tranformation. We do classification and bounding box drawing in function cls_draw_box based on bounding box ROIs. Bilinear CNN is used for fine-grained classification, and we use resnet-18 as backbone insted of vgg-16 for trade-off of accuracy and speed.

耗时统计耗时 Time consuming

车辆检测: 单张图像推断耗时,在单个GTX 1050TI GPU上约18ms。
车辆多标签识别:单张图像推断耗时,在单个GTX TITAN GPU上约7ms,在单个GTX 1050TI GPU上约10ms。
Vehicle detection: sigle image inference cost 18ms on single GTX1050TI.
Vehicle classification: single image inference cost 10ms on single GTX1050TI.

<2>. 车辆多标签数据模块(由于保密协议等原因暂时不能公开数据集) dataset.py

训练、测试数据类别按照子目录存放,子目录名即label,Color_Direction_type,如Yellow_Rear_suv。
Vehicle类重载了data.Dataset的init, getitem, len方法:
函数__init__负责初始化数据路径,数据标签,由于数据标签是多标签类型,故对输出向量分段计算交叉熵loss即可。
函数__getitem__负责迭代返回数据和标签,返回的数据需要经过标准化等预处理;函数__len__获取数据的总数量。

<3>. 车辆多标签训练、测试模块 train_vehicle_multilabel.py

此模块负责车辆多标签的训练和测试。训练过程选择交叉熵作为损失函数,需要注意的是,由于是多标签分类,故计算loss的时候需要累加各个标签的loss,其中loss = loss_color + loss_direction + 2.0 * loss_type,根据经验,将车辆类型的loss权重放到到2倍效果较好。
另一方面,训练分为两步:(1). 冻结除了Resnet-18除全连接层之外的所有层,Fine-tune训练到收敛为止;(2).打开第一步中冻结的所有层,进一步Fine-tune训练,调整所有层的权重,直至整个模型收敛为止。

vehicle-car-detection-and-multilabel-classification's People

Contributors

captaineven avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

vehicle-car-detection-and-multilabel-classification's Issues

换用yolov3.weights无法检测?

首先感谢您非常棒的作品!!
我换用其他的路况图片存在有的车辆未检测。想换用yolov3的代替tiny。但发现出现了无法检测。请问是什么原因?

车辆朝向

你好,请问车辆朝向检测的输出是什么?在脚本中哪里体现呢?

关于多标签分类的问题

你好,我看你在多标签分类中用到了B-CNN,在VehicleDC.py中这样写的:
“X = torch.bmm(X, torch.transpose(X, 1, 2)) / (1 ** 2) # Bi-linear CNN”
我看了B-CNN的代码,并没有看到这样的方法。
请问你这样做的根据和目的是什么?
也可能是我看漏掉了,还请作者指点。非常感谢

download link not working?

Tranied models-vehicle detection https://pan.baidu.com/s/1HwTCVGTmdqkeLnqnxfNL8Q
Tranied models-vehicle classification https://pan.baidu.com/s/1XmzjvCgOrrVv0NWTt4Fm3g
在运行Vehicle_DC脚本之前,先下载上面的模型文件或者使用自己预先训练好的模型文件,将car_540000.weights(用于检测)放在项目根目录,将epoch_39.pth(用于多标签识别)放在根目录下的checkpoints目录下,即可使用Vehicle_DC运行。
Before running Vehicle_DC,
I cannot download the datasets from the provided links. Is it possible an update, please?
Thank you

单个属性训练

只对车辆类型训练的时候,训练精度高而测试精度一直在70%左右呢?

检测识别速度

作者好,我在1080和1660ti的机器上跑你的程序,目标检测的速度和识别的速度都达不到你描述的10ms、18ms的速度,请问还有什么技巧吗?希望尽快回复,谢谢

识别率

请问在您的测试数据集上准确识别率有多少

其他图片无法识别的问题

Thank you very much for your outstanding contribution. I have a question. What are the requirements of this network for inputting pictures? Why do I input other pictures and report errors? The pictures in your test set will not report errors. The errors are as follows: tile cannot extend outside image. If you can reply, I will be very grateful!

conv.weight.data error

Hi Captain Even
I got the following error message while I run vehicleDC.py, I use the yolov3.cfg from pytorch yolo v3, and the weight is download from your network disk.
hope I can get your help. many thanks.
cheers
Edward

DR_model = Car_DC(src_dir=args.src_dir, dst_dir=args.dst_dir)

File "D:/Dev/Vehicle-Car-detection-and-multilabel-classification-master/VehicleDC.py", line 240, in init
self.detector.load_weights(car_det_weights_path)

File "D:\Dev\vehicle\Vehicle-Car-detection-and-multilabel-classification-master\darknet.py", line 440, in load_weights
conv_weights = conv_weights.view_as(conv.weight.data)

RuntimeError: invalid argument 2: size '[512 x 256 x 3 x 3]' is invalid for input with 209204 elements at ..\aten\src\TH\THStorage.cpp:84

Need some advice about training

Hi!
First of all, thanks for sharing these codes, especially the trained models. I've tested the detection part and the results are astonishingly good given the fact that yolov3-tiny has such a simple structure. So I'm really curious about your training approach.
1.What dataset did you use for detection? I've tried it on coco and voc but the results were way worse than yours.
2.Is there any pre-trained weights for yolov3-tiny? or you just trained from scratch?
Hope you can help me!

能否新增自己需要识别的属性?

作者目前的模型可以识别车身颜色、朝向和车辆类型3种属性,请问能否自己修改代码,新增自己需要的属性?比如车辆的品牌识别,这个网络能否完成车辆品牌这种细粒度分类任务呢?

中文交流贴

请问作者是怎么制作数据集的呢,主要想知道如何准备数据集的标签,我们现在也在做车辆检测和多属性识别

Dear, Sirs. Training my own data.

我對車輛標識檢測感興趣,我認為,YOLO是最好的方式。
但我是YOLO的新手,我想幫助你。
您是否有車輛標識檢測經驗?
由於我對培訓自己的數據的知識和計算機的性能非常差。
對不起我的可憐因為它不是我的母語。

I need your trained model

Hi
thanks for share your nice project.
but i can not download your model from Bidu website, its chines i think. please if possible for you , send me for my email adress
[email protected]

thanks again

RepNet-MDNet-VehicleReID中的模型在这个项目中不能使用?

将RepNet-MDNet-VehicleReID项目中的模型下载下来,用在此项目中,出现报错。
=> device: cuda:0
=> ./car_540000.weights loaded.
=> car detection model initiated.
Traceback (most recent call last):
File "VehicleDC.py", line 385, in
DR_model = Car_DC(src_dir=args.src_dir, dst_dir=args.dst_dir)
File "VehicleDC.py", line 246, in init
model_path=local_model_path)
File "VehicleDC.py", line 121, in init
self.net.load_state_dict(torch.load(model_path))
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 839, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Cls_Net:
Missing key(s) in state_dict: "features.0.weight", "features.1.weight", "features.1.bias", "features.1.running_mean", "features.1.running_var", "features.4.0.conv1.weight", "features.4.0.bn1.weight", "features.4.0.bn1.bias", "features.4.0.bn1.running_mean", "features.4.0.bn1.running_var", "features.4.0.conv2.weight", "features.4.0.bn2.weight", "features.4.0.bn2.bias", "features.4.0.bn2.running_mean", "features.4.0.bn2.running_var", "features.4.1.conv1.weight", "features.4.1.bn1.weight", "features.4.1.bn1.bias", "features.4.1.bn1.running_mean", "features.4.1.bn1.running_var", "features.4.1.conv2.weight", "features.4.1.bn2.weight", "features.4.1.bn2.bias", "features.4.1.bn2.running_mean", "features.4.1.bn2.running_var", "features.5.0.conv1.weight", "features.5.0.bn1.weight", "features.5.0.bn1.bias", "features.5.0.bn1.running_mean", "features.5.0.bn1.running_var", "features.5.0.conv2.weight", "features.5.0.bn2.weight", "features.5.0.bn2.bias", "features.5.0.bn2.running_mean", "features.5.0.bn2.running_var", "features.5.0.downsample.0.weight", "features.5.0.downsample.1.weight", "features.5.0.downsample.1.bias", "features.5.0.downsample.1.running_mean", "features.5.0.downsample.1.running_var", "features.5.1.conv1.weight", "features.5.1.bn1.weight", "features.5.1.bn1.bias", "features.5.1.bn1.running_mean", "features.5.1.bn1.running_var", "features.5.1.conv2.weight", "features.5.1.bn2.weight", "features.5.1.bn2.bias", "features.5.1.bn2.running_mean", "features.5.1.bn2.running_var", "features.6.0.conv1.weight", "features.6.0.bn1.weight", "features.6.0.bn1.bias", "features.6.0.bn1.running_mean", "features.6.0.bn1.running_var", "features.6.0.conv2.weight", "features.6.0.bn2.weight", "features.6.0.bn2.bias", "features.6.0.bn2.running_mean", "features.6.0.bn2.running_var", "features.6.0.downsample.0.weight", "features.6.0.downsample.1.weight", "features.6.0.downsample.1.bias", "features.6.0.downsample.1.running_mean", "features.6.0.downsample.1.running_var", "features.6.1.conv1.weight", "features.6.1.bn1.weight", "features.6.1.bn1.bias", "features.6.1.bn1.running_mean", "features.6.1.bn1.running_var", "features.6.1.conv2.weight", "features.6.1.bn2.weight", "features.6.1.bn2.bias", "features.6.1.bn2.running_mean", "features.6.1.bn2.running_var", "features.7.0.conv1.weight", "features.7.0.bn1.weight", "features.7.0.bn1.bias", "features.7.0.bn1.running_mean", "features.7.0.bn1.running_var", "features.7.0.conv2.weight", "features.7.0.bn2.weight", "features.7.0.bn2.bias", "features.7.0.bn2.running_mean", "features.7.0.bn2.running_var", "features.7.0.downsample.0.weight", "features.7.0.downsample.1.weight", "features.7.0.downsample.1.bias", "features.7.0.downsample.1.running_mean", "features.7.0.downsample.1.running_var", "features.7.1.conv1.weight", "features.7.1.bn1.weight", "features.7.1.bn1.bias", "features.7.1.bn1.running_mean", "features.7.1.bn1.running_var", "features.7.1.conv2.weight", "features.7.1.bn2.weight", "features.7.1.bn2.bias", "features.7.1.bn2.running_mean", "features.7.1.bn2.running_var", "fc.weight", "fc.bias".
Unexpected key(s) in state_dict: "conv1_1.weight", "conv1_1.bias", "conv1_3.weight", "conv1_3.bias", "conv1.0.weight", "conv1.0.bias", "conv1.2.weight", "conv1.2.bias", "conv2_1.weight", "conv2_1.bias", "conv2_3.weight", "conv2_3.bias", "conv2.0.weight", "conv2.0.bias", "conv2.2.weight", "conv2.2.bias", "conv3_1.weight", "conv3_1.bias", "conv3_3.weight", "conv3_3.bias", "conv3_5.weight", "conv3_5.bias", "conv3.0.weight", "conv3.0.bias", "conv3.2.weight", "conv3.2.bias", "conv3.4.weight", "conv3.4.bias", "conv4_1_1.weight", "conv4_1_1.bias", "conv4_1_3.weight", "conv4_1_3.bias", "conv4_1_5.weight", "conv4_1_5.bias", "conv4_1.0.weight", "conv4_1.0.bias", "conv4_1.2.weight", "conv4_1.2.bias", "conv4_1.4.weight", "conv4_1.4.bias", "conv4_2_1.weight", "conv4_2_1.bias", "conv4_2_3.weight", "conv4_2_3.bias", "conv4_2_5.weight", "conv4_2_5.bias", "conv4_2.0.weight", "conv4_2.0.bias", "conv4_2.2.weight", "conv4_2.2.bias", "conv4_2.4.weight", "conv4_2.4.bias", "conv5_1_1.weight", "conv5_1_1.bias", "conv5_1_3.weight", "conv5_1_3.bias", "conv5_1_5.weight", "conv5_1_5.bias", "conv5_1.0.weight", "conv5_1.0.bias", "conv5_1.2.weight", "conv5_1.2.bias", "conv5_1.4.weight", "conv5_1.4.bias", "conv5_2_1.weight", "conv5_2_1.bias", "conv5_2_3.weight", "conv5_2_3.bias", "conv5_2_5.weight", "conv5_2_5.bias", "conv5_2.0.weight", "conv5_2.0.bias", "conv5_2.2.weight", "conv5_2.2.bias", "conv5_2.4.weight", "conv5_2.4.bias", "FC6_1_1.weight", "FC6_1_1.bias", "FC6_1_4.weight", "FC6_1_4.bias", "FC6_1.0.weight", "FC6_1.0.bias", "FC6_1.3.weight", "FC6_1.3.bias", "FC6_2_1.weight", "FC6_2_1.bias", "FC6_2_4.weight", "FC6_2_4.bias", "FC6_2.0.weight", "FC6_2.0.bias", "FC6_2.3.weight", "FC6_2.3.bias", "FC7_1.weight", "FC7_1.bias", "FC7_2.weight", "FC7_2.bias", "FC_8.weight", "FC_8.bias", "attrib_classifier.weight", "attrib_classifier.bias", "arc_fc_br2.weight", "arc_fc_br3.weight", "shared_layers.0.0.weight", "shared_layers.0.0.bias", "shared_layers.0.2.weight", "shared_layers.0.2.bias", "shared_layers.1.0.weight", "shared_layers.1.0.bias", "shared_layers.1.2.weight", "shared_layers.1.2.bias", "shared_layers.2.0.weight", "shared_layers.2.0.bias", "shared_layers.2.2.weight", "shared_layers.2.2.bias", "shared_layers.2.4.weight", "shared_layers.2.4.bias", "branch_1_feats.0.0.0.weight", "branch_1_feats.0.0.0.bias", "branch_1_feats.0.0.2.weight", "branch_1_feats.0.0.2.bias", "branch_1_feats.0.1.0.weight", "branch_1_feats.0.1.0.bias", "branch_1_feats.0.1.2.weight", "branch_1_feats.0.1.2.bias", "branch_1_feats.0.2.0.weight", "branch_1_feats.0.2.0.bias", "branch_1_feats.0.2.2.weight", "branch_1_feats.0.2.2.bias", "branch_1_feats.0.2.4.weight", "branch_1_feats.0.2.4.bias", "branch_1_feats.1.0.weight", "branch_1_feats.1.0.bias", "branch_1_feats.1.2.weight", "branch_1_feats.1.2.bias", "branch_1_feats.1.4.weight", "branch_1_feats.1.4.bias", "branch_1_feats.2.0.weight", "branch_1_feats.2.0.bias", "branch_1_feats.2.2.weight", "branch_1_feats.2.2.bias", "branch_1_feats.2.4.weight", "branch_1_feats.2.4.bias", "branch_1_fc.0.0.weight", "branch_1_fc.0.0.bias", "branch_1_fc.0.3.weight", "branch_1_fc.0.3.bias", "branch_1_fc.1.weight", "branch_1_fc.1.bias", "branch_1.0.0.0.0.weight", "branch_1.0.0.0.0.bias", "branch_1.0.0.0.2.weight", "branch_1.0.0.0.2.bias", "branch_1.0.0.1.0.weight", "branch_1.0.0.1.0.bias", "branch_1.0.0.1.2.weight", "branch_1.0.0.1.2.bias", "branch_1.0.0.2.0.weight", "branch_1.0.0.2.0.bias", "branch_1.0.0.2.2.weight", "branch_1.0.0.2.2.bias", "branch_1.0.0.2.4.weight", "branch_1.0.0.2.4.bias", "branch_1.0.1.0.weight", "branch_1.0.1.0.bias", "branch_1.0.1.2.weight", "branch_1.0.1.2.bias", "branch_1.0.1.4.weight", "branch_1.0.1.4.bias", "branch_1.0.2.0.weight", "branch_1.0.2.0.bias", "branch_1.0.2.2.weight", "branch_1.0.2.2.bias", "branch_1.0.2.4.weight", "branch_1.0.2.4.bias", "branch_1.1.0.0.weight", "branch_1.1.0.0.bias", "branch_1.1.0.3.weight", "branch_1.1.0.3.bias", "branch_1.1.1.weight", "branch_1.1.1.bias", "branch_2_feats.0.0.0.weight", "branch_2_feats.0.0.0.bias", "branch_2_feats.0.0.2.weight", "branch_2_feats.0.0.2.bias", "branch_2_feats.0.1.0.weight", "branch_2_feats.0.1.0.bias", "branch_2_feats.0.1.2.weight", "branch_2_feats.0.1.2.bias", "branch_2_feats.0.2.0.weight", "branch_2_feats.0.2.0.bias", "branch_2_feats.0.2.2.weight", "branch_2_feats.0.2.2.bias", "branch_2_feats.0.2.4.weight", "branch_2_feats.0.2.4.bias", "branch_2_feats.1.0.weight", "branch_2_feats.1.0.bias", "branch_2_feats.1.2.weight", "branch_2_feats.1.2.bias", "branch_2_feats.1.4.weight", "branch_2_feats.1.4.bias", "branch_2_feats.2.0.weight", "branch_2_feats.2.0.bias", "branch_2_feats.2.2.weight", "branch_2_feats.2.2.bias", "branch_2_feats.2.4.weight", "branch_2_feats.2.4.bias", "branch_2_fc.0.0.weight", "branch_2_fc.0.0.bias", "branch_2_fc.0.3.weight", "branch_2_fc.0.3.bias", "branch_2_fc.1.weight", "branch_2_fc.1.bias", "branch_2.0.0.0.0.weight", "branch_2.0.0.0.0.bias", "branch_2.0.0.0.2.weight", "branch_2.0.0.0.2.bias", "branch_2.0.0.1.0.weight", "branch_2.0.0.1.0.bias", "branch_2.0.0.1.2.weight", "branch_2.0.0.1.2.bias", "branch_2.0.0.2.0.weight", "branch_2.0.0.2.0.bias", "branch_2.0.0.2.2.weight", "branch_2.0.0.2.2.bias", "branch_2.0.0.2.4.weight", "branch_2.0.0.2.4.bias", "branch_2.0.1.0.weight", "branch_2.0.1.0.bias", "branch_2.0.1.2.weight", "branch_2.0.1.2.bias", "branch_2.0.1.4.weight", "branch_2.0.1.4.bias", "branch_2.0.2.0.weight", "branch_2.0.2.0.bias", "branch_2.0.2.2.weight", "branch_2.0.2.2.bias", "branch_2.0.2.4.weight", "branch_2.0.2.4.bias", "branch_2.1.0.0.weight", "branch_2.1.0.0.bias", "branch_2.1.0.3.weight", "branch_2.1.0.3.bias", "branch_2.1.1.weight", "branch_2.1.1.bias".
请问我需要如何调整,才可以运行起来。

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.