“train_mi.py”文件是mi-Net模型。
“MI_net.py”文件是MI-net模型。
“MI_net_DS.py”文件是 MI-net with deep supervision模型。
"MI_Res.py"文件是MI-net with Res 模型。
"loader.py"文件是针对MUSK1/2数据集的导入方法。
"pre_.py"文件是针对fox/elephant/tiger数据集的导入方法,导入数据集的同时对数据集的顺序进行了打乱,由于没有找到文中的第三个数据集,故第三个数据集的导入函数尚未编写。
"main.py"文件是训练、测试文件,可以通过修改注释改变导入的数据集和模型。
CPU的实验manjaro系统下完成,利用anaconda建立虚拟环境,使用CPU版本的pytorch进行实验,python版本为3.7.11。GPU服务器的环境为ubuntu系统,CUDA11.1。
对每个网络的每个数据集作十折交叉验证,即将数据集分为十份,每次取9份做训练,取1份做测试,总共做十次,使得每份数据都有做过测试集,取十次结果的平均数。总共训练五次求平均数,每次都在训练集上做多次迭代训练再在测试集上做验证.迭代次数根据训练模型在测试集上的泛化性能决定,例如实验中发现MI-net_DS网络在fox数据集上迭代100次时泛化性能最好。可能网络的规模和深度比较小,用于传输数据的时间比例过高,利用GPU训练后速度反而有所下降,但是训练的效果提高了,可能是因为GPU的计算精度更高。下表中标粗的部分就是利用GPU重新训练后的结果。
训练过程采用随机梯度下降算法,学习率设置为0.01
MUSK1 | MUSK2 | fox | elepant | tiger | |
---|---|---|---|---|---|
mi-net | 81.957% | 79.216% | 62.200% | 85.000% | 81.700% |
MI-net | 87.174% | 85.098% | 62.400% | 86.400% | 81.400% |
MI-Net-DS | 92.826% | 86.078% | 64.500% | 87.200% | 82.400% |
MI-Net-RS | 88.913% | 85.686% | 62.700% | 87.800% | 82.600% |
优化器的lr=0.01,momentum=0.5
其中在MUSK2数据集上需要在优化器中添加weight-decay参数以进行权重衰减再训练才可以得到文中的预测效果
其中momentum=0.9,weight-decay=0.003 , DS。
momentum=0.9,weight-decay=0.003 , RS。
得到上述结果的迭代次数
MUSK1 | MUSK2 | fox | elephant | tiger | |
---|---|---|---|---|---|
mi-net | 120 | 300 | 200 | 150 | 200 |
MI-net | 200 | 60 | 200 | 200 | 200 |
MI-net-DS | 80 | 60 | 100 | 150 | 100 |
MI-net-RS | 150 | 60 | 150 | 100 | 100 |
对MI-net-DS网络分别应用max 、mean、lse池化函数,lse函数的r参数设置为2
MUSK1 | MUSK2 | fox | elephant | tiger | |
---|---|---|---|---|---|
max | 92.826% | 83.333% | 64.500% | 87.200% | 82.400% |
mean | 86.739% | 78.039% | 64.400% | 85.800% | 84.200% |
lse | 87.391% | 81.373% | 63.900% | 87.100% | 84.00% |
这篇文章的核心**就是在基于包分类的方法的基础上,对神经网络输出的最后一层的结果采用注意力机制,使池化方式也是可学习的,从而使最能代表包特征的instance在结果中占据的权重最大。
其中在benchmark数据集上的网络结构就是将MI-net的池化函数变为注意力机制
注意力机制的代码在MIL_pooling.py中,被写为attention类
MUSK1 | MUSK2 | fox | elephant | tiger | |
---|---|---|---|---|---|
Attention | 87.174% | 82.157% | 59.100% | 83.500% | 85.900% |
Attention-gate | 87.174% | 82.7451% | 59.100% | 84.700% | 83.900% |
其中fox、musk2数据集使用文中的参数无法得到最优结果
fox : lr = 0.005 , momentum=0.5,weight-decay=0
musk2: lr = 0.0005 , momentum=0.9, weight-decay=0.003