Giter Club home page Giter Club logo

unet's Introduction

U-Net: 使用 PyTorch 进行语义分割

快速开始

  1. 安装 CUDA

  2. 安装 PyTorch 1.12 或更新的版本

  3. 安装依赖

pip install -r requirements.txt
torch                          1.13.0
torchvision                    0.14.0
  1. Download the data and run training:
bash scripts/download_data.sh
python train.py --amp

使用

注意 : 请使用 Python 3.6 或更新的版本

训练

> python train.py -h
usage: train.py [-h] [--epochs E] [--batch-size B] [--learning-rate LR]
                [--load LOAD] [--scale SCALE] [--validation VAL] [--amp]

Train the UNet on images and target masks

optional arguments:
  -h, --help            show this help message and exit
  --epochs E, -e E      Number of epochs
  --batch-size B, -b B  Batch size
  --learning-rate LR, -l LR
                        Learning rate
  --load LOAD, -f LOAD  Load model from a .pth file
  --scale SCALE, -s SCALE
                        Downscaling factor of the images
  --validation VAL, -v VAL
                        Percent of the data that is used as validation (0-100)
  --amp                 Use mixed precision

默认情况下,scale为 0.5,因此如果您希望获得更好的结果(会被使用更多内存),请将其设置为 1。

--amp 代表使用自动混合精度。混合精度允许模型使用更少的内存,并通过使用 FP16 算法在 GPU 上更快,建议启用 AMP。

  • 支持训练的格式搭配:
imgs masks
jpg gif
tif gif
  • 实验结果:
model Validation Dice score
UNet 0.99

预测

训练模型并将其保存到MODEL.pth,您可以通过 CLI(命令行) 轻松测试图像上的输出掩码。

要预测单个图像并保存它:

python predict.py -i image.jpg -o output.jpg

要预测多个图像并在不保存的情况下显示它们:

python predict.py -i image1.jpg image2.jpg --viz --no-save

> python predict.py -h
usage: predict.py [-h] [--model FILE] --input INPUT [INPUT ...] 
                  [--output INPUT [INPUT ...]] [--viz] [--no-save]
                  [--mask-threshold MASK_THRESHOLD] [--scale SCALE]

Predict masks from input images

optional arguments:
  -h, --help            show this help message and exit
  --model FILE, -m FILE
                        Specify the file in which the model is stored
  --input INPUT [INPUT ...], -i INPUT [INPUT ...]
                        Filenames of input images
  --output INPUT [INPUT ...], -o INPUT [INPUT ...]
                        Filenames of output images
  --viz, -v             Visualize the images as they are processed
  --no-save, -n         Do not save the output masks
  --mask-threshold MASK_THRESHOLD, -t MASK_THRESHOLD
                        Minimum probability value to consider a mask pixel white
  --scale SCALE, -s SCALE
                        Scale factor for the input images

您可以使用 --model MODEL.pth 指定要使用的模型文件。

W&B可视化

可以使用Weights & Biases 实时可视化训练进度。损失曲线、验证曲线、权重和梯度直方图以及预测掩码都记录到平台上。

启动train时,控制台中会打印一个链接。单击它转到您的仪表板。如果您已有 W&B 帐户,则可以通过设置 WANDB_API_KEY 环境变量来链接它。如果没有,它将创建一个匿名运行,并在 7 天后自动删除。

预训练权重

在Carvana 数据集上训练的预训练模型 。也可以从 torch.hub 加载:

net = torch.hub.load('milesial/Pytorch-UNet', 'unet_carvana', pretrained=True, scale=0.5)

可用scales为 0.5 和 1.0。

数据集

Carvana 数据可在 Kaggle 上获得。

您也可以使用帮助脚本下载它:

bash scripts/download_data.sh

输入图像和目标蒙版应分别位于 data/imgsdata/masks 文件夹中(请注意,由于数据加载器,imgsmasks 文件夹不应包含任何子文件夹或任何其他文件)。对于 Carvana,图像是 RGB,蒙版是黑白的。

- imgs:.jpg
- masks:.gif

只要确保在 utils/data_loading.py中正确加载它,就可以使用自己的数据集。

参考

network architecture

unet's People

Contributors

uppez avatar wangrongsheng avatar

Watchers

 avatar  avatar  avatar  avatar

Forkers

haoyue-code

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.