Giter Club home page Giter Club logo

locate-globally-segment-locally-a-progressive-architecture-with-knowledge-review-network-for-sod's Introduction

Locate-Globally-Segment-locally-A-Progressive-Architecture-With-Knowledge-Review-Network-for-SOD

!!!2021-7-3. We have corrected some errors. The pre-trained SGL-KRN model and PA-KRN model will be released soon...

!!!2021-8-12. The pre-trained SGL-KRN model and PA-KRN model have been released.

This repository is the official implementation of PA-KRN and SGL-KRN, which is proposed in "Locate Globally, Segment locally: A Progressive Architecture With Knowledge Review Network for Salient Object Detection." PDF

image

Prerequisites

  • Python 3.6
  • PyTorch 1.0.0
  • torchvision
  • Opencv
  • numpy
  • scipy

Usage

1. Install body-atttention sampler related tools (MobulaOP)

# Clone the project
git clone https://github.com/wkcn/MobulaOP

# Enter the directory
cd MobulaOP

# Install MobulaOP
pip install -v -e .

2. Clone the repository

git clone https://github.com/bradleybin/Locate-Globally-Segment-locally-A-Progressive-Architecture-With-Knowledge-Review-Network-for-SOD

The directory shows as follow:

├─demo
│   ├── attention_sampler
│   ├── data
│   ├── dataset
│   ├── networks
│   ├── results
│   ├── KRN.py
│   ├── KRN_edge.py
│   ├── main_clm.py
│   ├── main_fsm.py
│   ├── main_joint.py
│   ├── main_SGL_KRN.py
│   ├── Solver_clm.py
│   ├── Solver_fsm.py
│   └── Solver_joint.py
├── MobulaOP

3. Download datasets

Download the DUTS and other datasets and unzip them into demo/data folder. (Refer to PoolNet repository)

The directory shows as follow:

├─DUTS
│        └── DUTS-TR
│                  ├── DUTS-TR-Image
│                  ├── DUTS-TR-Mask
│                  └── DUTS-TR-Edge
├─DUTS-TE
│        ├── Imgs
│        └── test.lst
├─PASCALS
│        ├── Imgs
│        └── test.lst
├─DUTOMRON
│        ├── Imgs
│        └── test.lst
├─HKU-IS
│        ├── Imgs
│        └── test.lst
└─ECSSD
         ├── Imgs
         └── test.lst

4. Download Pretrained ResNet-50 Model for backbone

Download ResNet-50 pretrained models Google Drive and save it into demo/dataset/pretrained folder.

5. Train

5.1 SGL-KRN

cd demo
python main_SGL_KRN.py

After training, the resulting model will be stored under results/sgl_krn/run-* folder.

5.2 PA-KRN

The whole system can be trained in an end-to-end manner. To get finer results, we first train CLM and FSM sequentially and then combine them to fine-tune.

cd demo
  1. Train CLM.
python main_clm.py

After training, the resulting model will be stored under results/clm/run-* folder.

  1. Train FSM.
python main_fsm.py  --clm_model path/to/pretrained/clm/folder/

After training, the resulting model will be stored under results/fsm/run-* folder, and * changes accordingly. 'path/to/pretrained/clm/folder/' is the path to pretrained clm folder.

  1. Train PA-KRN.
python main_joint.py  --clm_model path/to/pretrained/clm/folder/  --fsm_model path/to/pretrained/fsm/folder/

After training, the resulting model will be stored under results/joint/run-* folder. 'net_*.pth' is the parameter of CLM model and '.pth' is the parameter of FSM model.

6. Test

Download pretrained SGL-KRN and PA-KRN models Google Drive.

6.1 SGL-KRN

For DUTS-TE dataset testing.

python main_SGL_KRN.py --mode test --test_model path/to/pretrained/SGL_KRN/folder/ --test_fold path/to/test/folder/ --sal_mode t

'sal_mode' of ECSSD, PASCALS, DUT-OMRON, and HKU-IS are 'e', 'p', 'd', and 'h', respectively.

6.2 PA-KRN

For DUTS-TE dataset testing.

python main_joint.py --mode test --clm_model path/to/pretrained/clm/folder/  --fsm_model path/to/pretrained/fsm/folder/ --test_fold path/to/test/folder/ --sal_mode t

'sal_mode' of ECSSD, PASCALS, DUT-OMRON, and HKU-IS are 'e', 'p', 'd', and 'h', respectively.

7. Saliency maps

We provide the pre-computed saliency maps from our paper Google Drive | Baidu Disk (pwd: 9wxg).

Thanks to PoolNet repository and AttentionSampler repository.

Citing PAKRN

Please cite with the following Bibtex code:

@inproceedings{xu2021locate,
  title={Locate globally, segment locally: A progressive architecture with knowledge review network for salient object detection},
  author={Xu, Binwei and Liang, Haoran and Liang, Ronghua and Chen, Peng},
  booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
  volume={35},
  number={4},
  pages={3004--3012},
  year={2021}
}

locate-globally-segment-locally-a-progressive-architecture-with-knowledge-review-network-for-sod's People

Contributors

bradleybin avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

locate-globally-segment-locally-a-progressive-architecture-with-knowledge-review-network-for-sod's Issues

RuntimeError: Given input size: (512x2x3). Calculated output size: (512x0x0). Output size is too small

I get an Error when executing your Testexample:

python main_joint.py --mode test --clm_model ./model/joint/clm_final.pth --fsm_model ./model/joint/fsm_final.pth --test_fold ./testRes --sal_mode t

I am running Windows 10 with Cuda 10.1 and Torch 1.7.1

Traceback (most recent call last):
  File "main_joint.py", line 282, in <module>
    main(config)
  File "main_joint.py", line 229, in main
    test.test()
  File "C:\Users\Seppi\Desktop\MobulaOP\Locate-Globally-Segment-locally-A-Progressive-Architecture-With-Knowledge-Review-Network-for-SOD\Solver_joint.py", line 84, in test
    feasum_out, merge_solid, out_merge_solid1, out_merge_solid2, out_merge_solid3, out_merge_solid4 = self.net(images)
  File "C:\Users\Seppi\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\Users\Seppi\Desktop\MobulaOP\Locate-Globally-Segment-locally-A-Progressive-Architecture-With-Knowledge-Review-Network-for-SOD\KRN.py", line 155, in forward
    merge_solid1 = self.DeepPool_solid1(conv2merge[0], conv2merge[1])
  File "C:\Users\Seppi\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\Users\Seppi\Desktop\MobulaOP\Locate-Globally-Segment-locally-A-Progressive-Architecture-With-Knowledge-Review-Network-for-SOD\KRN.py", line 54, in forward
    y = self.convs[i](self.pools[i](y))
  File "C:\Users\Seppi\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\Users\Seppi\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\nn\modules\pooling.py", line 594, in forward
    return F.avg_pool2d(input, self.kernel_size, self.stride,
RuntimeError: Given input size: (512x2x3). Calculated output size: (512x0x0). Output size is too small

AttributeError: module 'mobula.op' has no attribute 'AttSamplerGrid'

Hi! I follow your instructions to install this software. When I run the code, I get this error:

File "/home/xxx/xxx/attention_sampler/attsampler_th.py", line 19, in forward
    grid = mobula.op.AttSamplerGrid(data.detach(),
AttributeError: module 'mobula.op' has no attribute 'AttSamplerGrid'

Can you kindly give me some suggestions? Any comments will be highly appreciately!

ValueError: cannot reshape array of size 86400 into shape (360,355,1)

where:
label = np.reshape(label, (h, w, 1))
edge = np.reshape(edge, (h, w, 1))

question:
Traceback (most recent call last):
File "/home/PAKRN/main_SGL_KRN.py", line 556, in
main(config)
File "/home/PAKRN/main_SGL_KRN.py", line 498, in main
train.train()
File "/home/main_SGL_KRN.py", line 388, in train
for i, data_batch in enumerate(self.train_loader):
File "/home/utils/data/dataloader.py", line 521, in next
data = self._next_data()
File "/home/utils/data/dataloader.py", line 1203, in _next_data
return self._process_data(data)
File "/home//utils/data/dataloader.py", line 1229, in _process_data
data.reraise()
File "/home//anaconda3//_utils.py", line 434, in reraise
raise exception
ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home///utils/data/_utils/worker.py", line 287, in _worker_loop
data = fetcher.fetch(index)
File "/home//a/fetch.py", line 49, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home//a/_utils/fetch.py", line 49, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home//PAKRN/dataset/dataset_edge_augment.py", line 48, in getitem
sal_image, sal_label, sal_edge = generate_scale_label(sal_image, sal_label, sal_edge)
File "/home/PAKRN/dataset/dataset_edge_augment.py", line 151, in generate_scale_label
label = np.reshape(label, (h, w, 1))
File "<array_function internals>", line 5, in reshape
File "/home//fromnumeric.py", line 298, in reshape
return _wrapfunc(a, 'reshape', newshape, order=order)
File "/home//numpy/core/fromnumeric.py", line 57, in _wrapfunc
return bound(*args, **kwds)
ValueError: cannot reshape array of size 86400 into shape (360,355,1)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.