Giter Club home page Giter Club logo

world_is_a_ball's Introduction

World is a ball 深度学习代码框架

本工程旨在降低进行分类、识别(定位+分类)、回归等任务时的工作难度。使研究生和开发人员集中精力在项目的核心部分,而不是花时间写重复的无意义重复代码。clone到本地后,按照实际需求进行调整的部分有:

  • dataloader文件夹下创建自己的数据加载

  • model文件夹下创建自己的模型类

  • Application中进行微调(可选)

下面将通过几个具体的例子来演示该项目的用法

1. 目标识别(定位+分类)以宠物分类数据集Oxford-IIIT Pet Dataset 为例

1.1 数据加载dataloader

数据加载函数需要继承自torch.utils.data.Dataset类,根据实际情况实现**__init()____len()__** 、__getitem()__ 三个方法即可。主要是 __getitem__() 方法,返回数据和标签。由于本例实现的是目标检测,这里返回三个要素:数据、边界、类别。

class OXFordData(Dataset):
    def __init__(self, data_path, mode='train'):
        self.data_path = data_path
        self.image_path = os.path.join(self.data_path, 'images')
        self.xml_path = os.path.join(data_path, 'annotations/xmls')
        if mode == 'train':
            self.txt_file = data_path + '/annotations/trainval.txt'
        elif mode == 'test':
            self.txt_file = data_path + '/annotations/test.txt'
        self.data_lines = self.get_list()
        self.transforms = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor()
        ])

    def get_list(self):
        with open(self.txt_file, 'r') as f:
            lines = f.readlines()
        return lines

    def __len__(self):
        return len(self.data_lines)

    def __getitem__(self, index):
        try:
            line = self.data_lines[index].strip()
            base_name = line.split(' ')[0]
            file_name, label_class = f"{base_name}.jpg", line.split(' ')[1]
            label_class = torch.tensor(int(label_class) - 1, dtype=torch.long)

            img_path = os.path.join(self.image_path, file_name)
            image = Image.open(img_path).convert('RGB')

            if self.transforms:
                image = self.transforms(image)

            xml_file = os.path.join(self.xml_path, f"{base_name}.xml")
            label_bd = torch.tensor(parse_xml_file(xml_file), dtype=torch.float)
        except Exception as e:
            # 如果你希望在出错时看到错误消息,可以取消下面这行的注释
            # print(f"Error loading data: {e}")
            return None, None, None

        return image, label_bd, label_class

返回值可以是列表、元祖或者字典,不同的数据结构需要再Applications.py中做出对应的修改。

1.2 网络设计

数据加载完成后,设计网络模型,网络模型应继承自torch.nn.Module类。为完成本例中的任务,修改resnet50的头部,修改为定位头和分类头:

class PetsNet(nn.Module):
    def __init__(self, num_classes=37):
        super(PetsNet, self).__init__()
        self.num_classes = num_classes
        self.backbone = resnet34(weights=ResNet34_Weights.DEFAULT)
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()    # 移除原有的全连接
        self.classifier_fc = nn.Linear(in_features, self.num_classes)
        self.bound_fc = nn.Linear(in_features, 4)

    def forward(self, x):
        features = self.backbone(x)
        class_logits = self.classifier_fc(features)
        bbox = self.bound_fc(features)

        return bbox, class_logits

可以对网络进行测试,确保处理后的图片数据经过forward可以得到正确的shape, 如:

net = PetsNet(37)
data = torch.randn(32, 3, 256, 256)
class_logits, bbox = net(data)
print(class_logits.shape, bbox.shape)

1.3 训练代码微调

在Application.py中进行代码的微调,由于当前版本就是基于本例做的,这部分可以不做,直接修改main.py中的超参数,开始训练

训练得到的参数变化过程

验证集的混淆矩阵

定位和识别结果

world_is_a_ball's People

Contributors

dlwangsan avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.