Giter Club home page Giter Club logo

rgbd_semantic_segmentation_pytorch's Introduction

RGBD_Semantic_Segmentation_PyTorch

license PyTorch-1.0.0

Implement some state-of-the-art methods of RGBD Semantic Segmentation task in PyTorch.

Currently, we provide code of:

  • SA-Gate, ECCV 2020 [arXiv]
  • Malleable 2.5D Convolution, ECCV 2020 [arXiv]

News

  • 2020/08/16

Official code release for the paper Malleable 2.5D Convolution: Learning Receptive Fields along the Depth-axis for RGB-D Scene Parsing, ECCV 2020. [arXiv], [code]

Thanks aurora95 for his open source code!

  • 2020/07/20

Official code release for the paper Bi-directional Cross-Modality Feature Propagation with Separation-and-Aggregation Gate for RGB-D Semantic Segmentation, ECCV 2020. [arXiv], [code]

Main Results

Results on NYU Depth V2 Test Set with Multi-scale Inference

Method mIoU (%)
3DGNN 43.1
ACNet 48.3
RDFNet-101 49.1
PADNet 50.2
PAP 50.4
Malleable 2.5D 50.9
SA-Gate 52.4

Results on CityScapes Test Set with Multi-scale Inference (out method uses output stride=16 and does not use coarse-labeled data)

Method mIoU (%)
PADNet 80.3
DANet 81.5
GALD 81.8
ACFNet 81.8
SA-Gate 82.8

For more details, please refer to our paper.

Directory Tree

Your directory tree should look like this:

./
|-- furnace
|-- model
|-- DATA
-- |-- pytorch-weight
-- |-- NYUDepthv2
   |   |-- ColoredLabel
   |   |-- Depth
   |   |-- HHA
   |   |-- Label
   |   |-- RGB
   |   |-- test.txt
   |   |-- train.txt

Installation

The code is developed using Python 3.6 with PyTorch 1.0.0. The code is developed and tested using 4 or 8 NVIDIA TITAN V GPU cards. You can change the input size (image_height and image_width) or batch_size in the config.py according to your available resources.

  1. Clone this repo.

    $ git clone https://github.com/charlesCXK/RGBD_Semantic_Segmentation_PyTorch.git
    $ cd RGBD_Semantic_Segmentation_PyTorch
  2. Install dependencies.

    (1) Create a conda environment:

    $ conda env create -f rgbd.yaml
    $ conda activate rgbd

    (2) Install apex 0.1(needs CUDA)

    $ cd ./furnace/apex
    $ python setup.py install --cpp_ext --cuda_ext

Data preparation

Pretrained ResNet-101

Please download the pretrained ResNet-101 and then put it into ./DATA/pytorch-weight.

Source Link
BaiDu Cloud Link: https://pan.baidu.com/s/1Zc_ed9zdgzHiIkARp2tCcw Password: f3ew
Google Drive https://drive.google.com/drive/folders/1_1HpmoCsshNCMQdXhSNOq8Y-deIDcbKS?usp=sharing

NYU Depth V2

You could download the official NYU Depth V2 data here. After downloading the official data, you should modify them according to the structure of directories we provide. We also provide the processed data. We will delete the link at any time if the owner of NYU Depth V2 requests.

Source Link
BaiDu Cloud Link: https://pan.baidu.com/s/1iU8m20Jv9shG_wEvwpwSOQ Password: 27uj
Google Drive https://drive.google.com/drive/folders/1_1HpmoCsshNCMQdXhSNOq8Y-deIDcbKS?usp=sharing

How to generate HHA maps?

If you want to generate HHA maps from Depth maps, please refer to https://github.com/charlesCXK/Depth2HHA-python.

Training and Inference

We just take SA-Gate as an example. You could run other models in a similar way.

Training

Training on NYU Depth V2:

$ cd ./model/SA-Gate.nyu
$ export NGPUS=8
$ python -m torch.distributed.launch --nproc_per_node=$NGPUS train.py

If you only have 4 GPU cards, you could:

$ cd ./model/SA-Gate.nyu.432
$ export NGPUS=4
$ python -m torch.distributed.launch --nproc_per_node=$NGPUS train.py
  • Note that the only difference between SA-Gate.nyu/ and SA-Gate.nyu.432/ is the training/inference image crop size.
  • The tensorboard file is saved in log/tb/ directory.

Inference

Inference on NYU Depth V2:

$ cd ./model/SA-Gate.nyu
$ python eval.py -e 300-400 -d 0-7 --save_path results
  • Here, 300-400 means we evaluate on checkpoints whose ID is in [300, 400], such as epoch-300.pth, epoch-310.pth, etc.
  • The segmentation predictions will be saved in results/ and results_color/, the former stores the original predictions and the latter stores colored version. Performance in mIoU will be written to log/*.log. You will expect ~51.4% mIoU in SA-Gate.nyu and ~51.5% mIoU in SA-Gate.nyu.432. (single scale inference with no flip)
  • For multi-scale and flip inference, please set C.eval_flip = True and C.eval_scale_array = [1, 0.75, 1.25] in the config.py. Different eval_scale_array may have different performances.

Citation

Please consider citing this project in your publications if it helps your research.

@inproceedings{chen2020-SAGate,
  title={Bi-directional Cross-Modality Feature Propagation with Separation-and-Aggregation Gate for RGB-D Semantic Segmentation},
  author={Chen, Xiaokang and Lin, Kwan-Yee and Wang, Jingbo and Wu, Wayne and Qian, Chen and Li, Hongsheng and Zeng, Gang},
  booktitle={European Conference on Computer Vision (ECCV)},
  year={2020}
}
@inproceedings{xing2020-melleable,
  title={Malleable 2.5D Convolution: Learning Receptive Fields along the Depth-axis for RGB-D Scene Parsing
},
  author={Xing, Yajie and Wang, Jingbo and Zeng, Gang},
  booktitle={European Conference on Computer Vision (ECCV)},
  year={2020}
}

Acknowledgement

Thanks TorchSeg for their excellent project!

TODO

  • More encoders such as HRNet.
  • Code and data for Cityscapes.
  • More RGBD Semantic Segmentation models

rgbd_semantic_segmentation_pytorch's People

Contributors

charlescxk avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

rgbd_semantic_segmentation_pytorch's Issues

关于模型的保存

我想知道这个训练之后模型保存在哪,因为我并非本地训练,而是租的autodl来跑这个网络的,我想拿到这个模型到我本地电脑上运行

The requirement of cityscapes RGBD dataset

First thank you for sharing the excellent work for us. And for the RGBD segmentation research, we really need the cityscapes RGB-D dataset or the method to get depth map, would you share that in recent times?

Install apex failed

Hi, thank you for your work. I met a error when installing apex. Follwing are details

/root/miniconda3/envs/myconda/lib/python3.6/site-packages/torch/lib/include/ATen/cuda/CUDAContext.h:12:10: fatal error: cusparse.h: No such file or directory #include "cusparse.h" ^~~~~~~~~~~~ compilation terminated. error: command '/usr/local/cuda/bin/nvcc' failed with exit status 1

BTW, envs python3.6 pytorch1.0 cuda 10.0

Little question

Hi,
Thank you for your great work for us !
Have you used multigrid as in the code to get the 52.4mIOU on the NYUD dataset?

运行错误RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

File "train.py", line 173, in
loss.backward()
File "/home/anaconda3/envs/rgbd/lib/python3.6/site-packages/torch/tensor.py", line 102, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/home/anaconda3/envs/rgbd/lib/python3.6/site-packages/torch/autograd/init.py", line 90, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
作者您好,非常感谢您的开源贡献,我在调试您malleable2_5d代码的过程中出现了上述错误,有的说需要把nn.ReLU(inplace=True)中的inplace改为Flase,但是发现好像没用,请问需要怎么修改才能解决问题,非常感谢。

多卡训练改为单卡训练

作者您好,谢谢您开源代码!因为组里显卡有限,无法进行多GPU训练,想请教您代码是否有什么方法可以改为支持单卡训练?这方面有什么文档可以参考的吗?

绘图

陈师兄您好,我想请教一下,论文里面神经网络的图是用什么软件绘制出来的呢?

可否改为目标检测

作者您好,关于此代码能否训练多分类目标检测,该在您的代码哪一部分进行修改呢 ?

询问关于cityscapes的深度图disparity如何转为HHA格式?

作者,您好!非常感谢您的开源贡献。我目前对于将cityscapes的深度图disparity转为HHA格式十分困惑,因为我不知道使用disparity数据提供的相机参数通过如下代码计算是否正确?通过如下方式我得到的HHA格式的图片与您在论文提到完全不一样!如果您能开源对cityscapes中disparity的depth转换为HHA格式的代码,我将不甚感激!
SF )I(@1KJ3_M)4LLXTZH6F

为什么lable 0 是无效的?

代码中我发现在数据增强部分把所有标签都减了1,这是什么原理,可以解释一下吗?我尝试把网络改成二分类网络,标签只包含两个类别,分别是0 和 1,把类别数改成2,但是训练完成后测试并不能得到有效的输出,请问能指导一下吗?谢谢了!

niters_per_epoch

作者你好,感谢开源你的工作。我在复现过程中注意到niters_per_epoch的计算中还乘了(800 // C.nepochs),请问这样做的含义是什么?

特征图可视化

非常有启发的论文,请问您论文中特征图是如何可视化的呢,能否分享相应代码

关于RGB-D baseline

作者您好,感谢开源!想了解一下您文章中提到的RGB-D baseline使用了dual-branch deeplab v3 plus,想请问一下这个baseline是怎样的?

关于公式8

您好,在您的文章中的公式(8)使用了RGBin和HHAin生成Mij,请问为什么不使用RGBrec和HHArec?既然RGBrec和HHArec已经在细节上有所修复。

Error about data loading

root@47cc999396b9:/mnt/txf/codes/SA-Gate/RGBD_Semantic_Segmentation_PyTorch/model/SA-Gate.nyu.432# python -m torch.distributed.launch --nproc_per_node=$NGPUS train.py
19 04:15:06 PyTorch Version 1.0.0, Furnace Version 0.1.1
19 04:15:06 PyTorch Version 1.0.0, Furnace Version 0.1.1
19 04:15:07 WRN A exception occurred during Engine initialization, give up running process
Traceback (most recent call last):
File "train.py", line 64, in
norm_layer=BatchNorm2d)
File "/mnt/txf/codes/SA-Gate/RGBD_Semantic_Segmentation_PyTorch/model/SA-Gate.nyu.432/network.py", line 33, in init
deep_stem=True, stem_width=64)
File "/mnt/txf/codes/SA-Gate/RGBD_Semantic_Segmentation_PyTorch/model/SA-Gate.nyu.432/dual_resnet.py", line 395, in resnet101
model = load_dualpath_model(model, pretrained_model)
File "/mnt/txf/codes/SA-Gate/RGBD_Semantic_Segmentation_PyTorch/model/SA-Gate.nyu.432/dual_resnet.py", line 304, in load_dualpath_model
raw_state_dict = torch.load(model_file)
File "/usr/local/lib/python3.6/dist-packages/torch/serialization.py", line 367, in load
return _load(f, map_location, pickle_module)
File "/usr/local/lib/python3.6/dist-packages/torch/serialization.py", line 528, in _load
magic_number = pickle_module.load(f)
_pickle.UnpicklingError: pickle data was truncated
请问这个问题是数据有问题还是内存不足导致的?目前服务器只有两张卡,如果您有空的话,希望能够得到您的帮助,非常感谢!

Resnet50

你好,我使用TorchSeg提供的Resnet50预训练模型,将backbone改为Resnet50,但是训练两轮之后服务器会崩溃,而使用Resnet101则没有问题,请问下能提供Resnet50的预训练模型吗?

Pretrained weights

Hi

Thanks for sharing the code. Is it possible to obtain the pre-trained weights for the model on NYUD-v2 for academic purposes?

Thanks.

Where's DeepLabv3+ model Decoder?

Hello, I'm trying to study about your code. But when I analysis your code, I can't find DeepLabV3+ Decoder in your code. Could you please tell me where the DeepLabV3+ model's Decoder?

Some questions.

Hello. First of all, I appreciate the sharing this repository.
I am interested in Semantic Segmentation.
The data that I'm trying to segment has additional depth information (stereo or Lidar).
If I had this information, I thought it helps be better for the Boundary of Segment part.
I'm looking for ways to utilize it.
I have a question here.

  1. Is there any difference between Depth assisted segmentation and 3D segmentation? If there is a difference, where is it close to making the Boundary of Segmentation by adding the depth information that I want?
    I think it's similar in general, but 3D segmentation always provides visual data that looks like 3D reconstruction and segmentation...

  2. Models such as RDFnet, Malleable 2.5D Convolution, and SA-Gate require HHA format.
    It took a long time to change the depth image to HHA format. I'm not sure this is practical.
    Is this HHA format still useful?

I would appreciate it if you could answer me!

有人成功复现嘛?效果是多少?

我复现了之后,改了代码,因为只有一个gpu,不能用分布式,打算跑800epoch,跑了200epoch,miou效果还是只有0.多,所以想问有没有人能跑出原文的效果的?

想请教一下train.py中的代码问题

你好,我想问一下在train.py中是不是缺少参数的添加
因为代码一直报错,所以我去搜索了一下感觉是缺少参数
请问一下这个参数是如何添加的呢,谢谢谢谢!
image
image

询问关于安装包的问题

作者您好,非常感谢您的开源贡献,我在调试您代码的过程中出现了No module named 'engine'错误,后来安装上了pyttsx3包,但还是报相同的错误,我想请问一下您需要安装什么包才能解决问题。

Install apex error

作者你好!十分感谢你的开源工作,我在跑你的代码时遇到apex的安装错误,请你帮我看看!
/usr/local/lib/python3.6/dist-packages/torch/include/c10/util/Exception.h:354:3: error: expected ';' before 'do'
do {
^
csrc/scale_check_overflow.cpp:23:3: note: in expansion of macro 'AT_CHECK'
AT_CHECK(grads.numel() == downscaled_grads.numel(), "Input and output grads must be the same size.");
^
/usr/local/lib/python3.6/dist-packages/torch/include/c10/util/Exception.h:355:20: warning: 'void c10::detail::deprecated_AT_CHECK()' is deprecated [-Wdeprecated-declarations]
::c10::detail::deprecated_AT_CHECK();
^
csrc/scale_check_overflow.cpp:23:3: note: in expansion of macro 'AT_CHECK'
AT_CHECK(grads.numel() == downscaled_grads.numel(), "Input and output grads must be the same size.");
^
In file included from /usr/local/lib/python3.6/dist-packages/torch/include/c10/core/Device.h:5:0,
from /usr/local/lib/python3.6/dist-packages/torch/include/c10/core/Allocator.h:6,
from /usr/local/lib/python3.6/dist-packages/torch/include/ATen/ATen.h:3,
from /usr/local/lib/python3.6/dist-packages/torch/include/torch/csrc/api/include/torch/types.h:3,
from /usr/local/lib/python3.6/dist-packages/torch/include/torch/csrc/api/include/torch/data/dataloader_options.h:4,
from /usr/local/lib/python3.6/dist-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/base.h:3,
from /usr/local/lib/python3.6/dist-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateful.h:3,
from /usr/local/lib/python3.6/dist-packages/torch/include/torch/csrc/api/include/torch/data/dataloader.h:3,
from /usr/local/lib/python3.6/dist-packages/torch/include/torch/csrc/api/include/torch/data.h:3,
from /usr/local/lib/python3.6/dist-packages/torch/include/torch/csrc/api/include/torch/all.h:4,
from /usr/local/lib/python3.6/dist-packages/torch/include/torch/extension.h:4,
from csrc/scale_check_overflow.cpp:1:
/usr/local/lib/python3.6/dist-packages/torch/include/c10/util/Exception.h:330:13: note: declared here
inline void deprecated_AT_CHECK() {}
^~~~~~~~~~~~~~~~~~~
In file included from /usr/local/lib/python3.6/dist-packages/torch/include/c10/core/Device.h:5:0,
from /usr/local/lib/python3.6/dist-packages/torch/include/c10/core/Allocator.h:6,
from /usr/local/lib/python3.6/dist-packages/torch/include/ATen/ATen.h:3,
from /usr/local/lib/python3.6/dist-packages/torch/include/torch/csrc/api/include/torch/types.h:3,
from /usr/local/lib/python3.6/dist-packages/torch/include/torch/csrc/api/include/torch/data/dataloader_options.h:4,
from /usr/local/lib/python3.6/dist-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/base.h:3,
from /usr/local/lib/python3.6/dist-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateful.h:3,
from /usr/local/lib/python3.6/dist-packages/torch/include/torch/csrc/api/include/torch/data/dataloader.h:3,
from /usr/local/lib/python3.6/dist-packages/torch/include/torch/csrc/api/include/torch/data.h:3,
from /usr/local/lib/python3.6/dist-packages/torch/include/torch/csrc/api/include/torch/all.h:4,
from /usr/local/lib/python3.6/dist-packages/torch/include/torch/extension.h:4,
from csrc/scale_check_overflow.cpp:1:
/usr/local/lib/python3.6/dist-packages/torch/include/c10/util/Exception.h:355:40: warning: 'void c10::detail::deprecated_AT_CHECK()' is deprecated [-Wdeprecated-declarations]
::c10::detail::deprecated_AT_CHECK();
^
csrc/scale_check_overflow.cpp:23:3: note: in expansion of macro 'AT_CHECK'
AT_CHECK(grads.numel() == downscaled_grads.numel(), "Input and output grads must be the same size.");
^
In file included from /usr/local/lib/python3.6/dist-packages/torch/include/c10/core/Device.h:5:0,
from /usr/local/lib/python3.6/dist-packages/torch/include/c10/core/Allocator.h:6,
from /usr/local/lib/python3.6/dist-packages/torch/include/ATen/ATen.h:3,
from /usr/local/lib/python3.6/dist-packages/torch/include/torch/csrc/api/include/torch/types.h:3,
from /usr/local/lib/python3.6/dist-packages/torch/include/torch/csrc/api/include/torch/data/dataloader_options.h:4,
from /usr/local/lib/python3.6/dist-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/base.h:3,
from /usr/local/lib/python3.6/dist-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateful.h:3,
from /usr/local/lib/python3.6/dist-packages/torch/include/torch/csrc/api/include/torch/data/dataloader.h:3,
from /usr/local/lib/python3.6/dist-packages/torch/include/torch/csrc/api/include/torch/data.h:3,
from /usr/local/lib/python3.6/dist-packages/torch/include/torch/csrc/api/include/torch/all.h:4,
from /usr/local/lib/python3.6/dist-packages/torch/include/torch/extension.h:4,
from csrc/scale_check_overflow.cpp:1:
/usr/local/lib/python3.6/dist-packages/torch/include/c10/util/Exception.h:330:13: note: declared here
inline void deprecated_AT_CHECK() {}
^~~~~~~~~~~~~~~~~~~
error: command 'x86_64-linux-gnu-gcc' failed with exit status 1
我的编译环境
ubuntu 16.04
pytorch 1.4.0
cuda10.0
cudnn 7.6.5
gcc 7.5.0

Pretrained Resnet101 load failed

Hi author,a problem confused me that the resnet101_v1c.pth I download can't match the network size. The details are as following. Could you please tell me how to load the pretrained model into the resnet101~ Thanks
1

想问一下多卡训练的问题

如果一台机器上已经用了两张卡在跑,当我想用另外的卡去跑的时候会出现这个错误
image
请问您这边遇到过吗?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.