Giter Club home page Giter Club logo

xuelang_ai's Introduction

雪浪制造AI挑战赛

比赛;链接:https://tianchi.aliyun.com/competition/information.htm?spm=5176.100067.5678.2.37791756dvC7eR&raceId=231666 初赛73/2403,复赛22/2403,所有代码+介绍

初赛思路: 对于初赛,题目要求只需要将所有的图片分类为:正常or瑕疵就行,刚开始的思路就很简单,直接把正常图片放在一个文件夹,瑕疵图片放在另一个文件夹,(首先对所有图片进行resize,原图大小为2560 * 1920,GPU也跑不动,因此resize到320 * 320)然后选一个模型,用keras训练走了一遍……然后发现效果巨烂(acc大概0.8左右)。 然后分析了一下数据集,发现了很多问题:

  1. 数据集不平衡,正常图片的数量大约是瑕疵图片的两倍;

  2. 瑕疵在整张图片中非常小,resize更是损失了很多信息,但是不经过resize又不能训练;

  3. 模型怎么选合适?优化器怎么选合适?……

经过试验和摸索,选择了优化数据集的方式如下:

(1) 原图resize到499 * 499大小;

(2) 对于有瑕疵的图片,对瑕疵部位进行多次采样,将图片中与瑕疵所在的标注框的交并比大于一定阈值的图像区域resize到499 * 499大小,也作为瑕疵样本参与训练;

(3) 分别测试VGG系列模型以及Resnet、Inception等多种模型,选出最合适的模型构架;

(4) 尝试不同的优化器、加dropout、GAP层、BN等调参过程优化。

其中第1,2步处理的依据为:对于有些瑕疵,其具有全局特征,比如边扎洞等,需要用整张图像进行训练;但是对于毛洞等局部小瑕疵,resize并不能很好地保留特征,因此需要对瑕疵部位进行重点采样并训练,确保模型可以识别瑕疵特征。同时,对瑕疵的采样过程也充实了瑕疵样本,这样可以让瑕疵样本和正常样本数量达到平衡。 最后选用VGG-13模型为基础,在最终的参数调整之后,准确率达0.91。

初赛用到的所有代码都放在了初赛文件夹,可以通过main文件直接运行,使用方法参考初赛文件夹下的readme文件,数据可以在官网下载。

复赛思路:

虽然0.91的准确率让我们顺利进了复赛,但是复赛的难度也有所加大。 复赛要求将瑕疵分为10类,加上正常图片,需要将所有图片进行11分类,但是数据集并没有明显扩充(只是将初赛的两个测试集给了label,可以用于训练集的扩充,但是对于11分类问题,有的类别的样本依然是非常少,甚至只有几十张图片)。初赛时候的简单思路显然已经行不通,首要面对的问题就是如何扩充数据集。 受到初赛一个队伍的处理思路(源代码的github链接:https://github.com/lightfate/XueLang-YOLOhasst) 的启发,我们开始采用滑动窗的方式处理图片,最后的处理流程为:

(1) 将初赛的测试数据转换为训练数据所需的格式,即同名瑕疵图片和xml标签文件在同一路径下,便于后续一起处理;

(2) 将所有图片分为11类,并保存在'./data/raw'路径下的11个子文件夹中;

(3) 将所有图片进行滑动窗分割,正常图片滑动步长为256,滑动窗大小为512 * 512;而对于瑕疵样本,由于瑕疵所在的区域比较小,如果采用全局滑动窗裁剪,得到的图片里面有大部分图片不包含瑕疵,因此在裁剪结果中仅保留与瑕疵所在的标注框交并比大于某一个阈值的滑动窗作为得到的瑕疵图;由于瑕疵图片的样本数量不平衡,因此每一类瑕疵图片的滑动窗的滑动步长可以有所不同,以得到理想数量的瑕疵训练样本;同时由于不同的瑕疵的特点,对于小瑕疵可以设置交并比大一些的图进行保留,大瑕疵可以设置交并比小一些的图进行保留,以保留瑕疵的重要特征。

(4) 滑动窗裁剪之后得到的正常样本数量会远远超出瑕疵训练样本的数量,以一定比例随机剔除正常样本,使得每一类样本的数量相对平衡;

(5) 以一定的比例划分训练集合验证集;

(6) 利用keras的数据ImageDataGenerator对训练数据进行数据增强;

(7) 使用keras的现有模型InceptionResNetV2及其imagenet参数进行模型的初始化,修改全连接层,改为11分类输出,自定义11分类损失函数,并调整其他网络参数,进行训练。 (8) 由于训练数据使用的是原图裁剪得到的512 * 512小图,因此在预测的时候也要使用512的滑动窗对待预测图片进行滑动裁剪,滑动步长为256,裁剪得到((2560-512+256)/256) * ((1920-512+256)/256)=63张小图,每一张图都得到一个预测结果。最后,判断63个预测结果中有无瑕疵类的预测结果,如果没有,则该图为正常图片;如果有,则判断哪一类瑕疵的预测概率值最大,取最大值的瑕疵概率即为该图的瑕疵类别。

以上思路经过调参测试发现模型对瑕疵的分类准确率较高,但是对正常图片的识别率较低,容易将正常图片中的折叠、花纹等识别为瑕疵。考虑到最后的得分计算公式: AUC * 0.7+mAP * 0.3,其中mAP为Mean Average Precision 其中AUC是二分类模型的重要评估指标,也就是说,最后成绩的70%来源于模型能不能将正常样本和瑕疵样本区分清楚,而瑕疵内部的分类权重并不算很高。因此,最后的策略为:

(1) 选用一个较好的二分类模型,使得2分类的AUC达到0.93以上;

(2) 对于待遇测图片,先用二分类进行预测,如果二分类判断其为正常图片,则不再进行瑕疵预测;如果二分类判断其为瑕疵图片,将该图片送给11分类模型,预测得到该瑕疵的类别。

最终,模型得到0.73左右的评分,位于复赛20左右。

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.