Giter Club home page Giter Club logo

ccnet-pure-pytorch's Issues

将Pure-Pytorch的RCCA模块应用到视频任务中loss没有完全收敛

您好,我打算将您写的pytorch版本的RCCA模块应用到视频的不同帧之间,以获得帧与帧之间的注意力进而增强视频帧的特征表示。主要问题是loss没有完全收敛,维持在1-2中间。我想排除一下是不是我网络改的有问题,需要您的帮助!!!

主要任务是视频的显著性检测,取同一视频中任意两帧经过同一ResNet-101,获得 B x 256 x 47 x 47的特征,然后再输入到RCCA模块,先得到 Q_X , K_X , V_X , Q_Y, K_Y, V_Y,即得到两帧映射到Q,K,V空间的特征。然后再用 Q_X 和 K_Y 做相关性矩阵,作用到V_Y,然后是Q_Y 和 K_X 做相关性,作用到 V_X。 代码的实现如下,几乎没怎么改动,希望您能帮我看一眼,感谢!

`class RCCAModule(nn.Module):
def init(self, in_channels, out_channels = 256):
super(RCCAModule, self).init()

    #inter_channels = in_channels // 4


    self.cca = CrissCrossAttention(in_channels)

    self.convbX = nn.Sequential(nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=True),
                               nn.BatchNorm2d(in_channels))

    self.convbY = nn.Sequential(nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=True),
                                nn.BatchNorm2d(in_channels))

    self.bottleneckX = nn.Sequential(
        nn.Conv2d(in_channels * 2, out_channels, kernel_size=3, padding=1, dilation=1, bias=True),
        nn.BatchNorm2d(out_channels),
        #nn.Dropout2d(0.1),  # dropout在这也会有用吗??
        )

    self.bottleneckY = nn.Sequential(
        nn.Conv2d(in_channels * 2, out_channels, kernel_size=3, padding=1, dilation=1, bias=True),
        nn.BatchNorm2d(out_channels),
        #nn.Dropout2d(0.1),  # dropout在这也会有用吗??
        )



def forward(self, x, y, recurrence=2):
    #outputX = self.convaX(x)
    #outputY = self.convaY(y)
    outputX = x
    outputY = y
    for i in range(recurrence):
        outputX, outputY = self.cca(outputX, outputY)

    outputX = self.convbX(outputX)
    outputY = self.convbY(outputY)

    outputX = self.bottleneckX(torch.cat([x, outputX], 1))
    outputY = self.bottleneckY(torch.cat([y, outputY], 1))

    return outputX, outputY`

`class CrissCrossAttention(nn.Module):
""" Criss-Cross Attention Module"""
def init(self, in_dim):
super(CrissCrossAttention,self).init()
# 下面三个是转成Q,K,V之前的降维,V不变
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//2, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//2, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.softmax = Softmax(dim=3)
self.INF = INF
self.gamma1 = nn.Parameter(torch.zeros(1)) # 虽然初始化为0了,但是它是一个可以学习的参数,当插入在模型中时,最开始可以保证从
self.gamma2 = nn.Parameter(torch.zeros(1))
# self.gamma2 = torch.zeros(1).cuda().requires_grad_()

    # ImageNet上学来的特征,然后再慢慢学习,会得到一个值,这可以使得整个训练过程更加的平滑


def forward(self, x, y):

    m_batchsize, _, height, width = x.size()  # B x 2C x H x W ,m_batchsize = 2, _ = 256, height = 47, width = 47
    proj_query_X = self.query_conv(x) # 降维,我改成了128,即降维一半, B,C,H,W
    proj_query_X_H = proj_query_X.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1) # BW,H,C
    proj_query_X_W = proj_query_X.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1) # BH,W,C
    proj_key_X = self.key_conv(x) # 降维  B,C,H,W
    proj_key_X_H = proj_key_X.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height) # 12,8,5, BW,C,H
    proj_key_X_W = proj_key_X.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width) # 10,8,6, BH,C,W
    proj_value_X = self.value_conv(x)  # 2,64,5,6 就是没有降维而已
    proj_value_X_H = proj_value_X.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height) # 12,64,5 BW,2C,H
    proj_value_X_W = proj_value_X.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width) # 10,64,6 BH,2C,W


    proj_query_Y = self.query_conv(y) # 降维 B,C,W,H
    proj_query_Y_H = proj_query_Y.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1) # BW,H,C
    proj_query_Y_W = proj_query_Y.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1) # BH,W,C
    proj_key_Y = self.key_conv(y) # 降维  B,C,W,H
    proj_key_Y_H = proj_key_Y.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height) # 12,8,5, BW,C,H
    proj_key_Y_W = proj_key_Y.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width) # 10,8,6, BH,C,W
    proj_value_Y = self.value_conv(y)  # 2,64,5,6 就是没有降维而已
    proj_value_Y_H = proj_value_Y.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height) # 12,64,5 BW,2C,H
    proj_value_Y_W = proj_value_Y.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width) # 10,64,6 BH,2C,W

    A = torch.bmm(proj_query_X_H, proj_key_Y_H)
    B = self.INF(m_batchsize, height, width)
    C = A+B
    # BW,H,H的注意力图中每一列包含了查询帧中的每一个H信息,BH,W,W同理
    energy_X_H = (torch.bmm(proj_query_X_H, proj_key_Y_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3) # B,H,W,H
    energy_X_W = torch.bmm(proj_query_X_W, proj_key_Y_W).view(m_batchsize,height,width,width)  # B,H,W,W
    concateX = self.softmax(torch.cat([energy_X_H, energy_X_W], 3))  # B,H,W,H+W

    att_X_H = concateX[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)  # BW,H,H
    att_X_W = concateX[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)  # BH,W,W

    # 与X一样
    energy_Y_H = (torch.bmm(proj_query_Y_H, proj_key_X_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)
    energy_Y_W = torch.bmm(proj_query_Y_W, proj_key_X_W).view(m_batchsize,height,width,width)  
    concateY = self.softmax(torch.cat([energy_Y_H, energy_Y_W], 3)) 

    att_Y_H = concateY[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)  
    att_Y_W = concateY[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)  
    # 因为这边permute()相当于做了个转置,所以应当是每一行,包含了查询帧中的每一个H信息
    out_Y_H = torch.bmm(proj_value_Y_H, att_X_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)  
    out_Y_W = torch.bmm(proj_value_Y_W, att_X_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)  

    out_X_H = torch.bmm(proj_value_X_H, att_Y_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)  
    out_X_W = torch.bmm(proj_value_X_W, att_Y_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)  


    return (self.gamma1 * (out_X_H + out_X_W) + x), (self.gamma2 * (out_Y_H + out_Y_W) + y)

`
另外这部分的初试话,我是卷积权重 kaiming初始化,偏置0,BN层权重初始化为1,偏置0
@Serge-weihao

why self.gamma =0?

`class CC_module(nn.Module):
def init(self,in_dim):
super(CC_module, self).init()
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.softmax = Softmax(dim=3)
self.INF = INF
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
m_batchsize, _, height, width = x.size()
proj_query = self.query_conv(x)
proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsizewidth,-1,height).permute(0, 2, 1)
proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize
height,-1,width).permute(0, 2, 1)
proj_key = self.key_conv(x)
proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsizewidth,-1,height)
proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize
height,-1,width)
proj_value = self.value_conv(x)
proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsizewidth,-1,height)
proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize
height,-1,width)
energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)
energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width)
concate = self.softmax(torch.cat([energy_H, energy_W], 3))

    att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)
    #print(concate)
    #print(att_H) 
    att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)
    out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)
    out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)
    #print(out_H.size(),out_W.size())
    return self.gamma*(out_H + out_W) + x`

I am confused that why self.gamma = zero(1)

About 3D CC Attention module

can anyone explain what is different between CrissCrossAttention3D, CrissCrossAttention3D1, CrissCrossAttention3D2 ?? Thanks

关于网络最后的return[x, x_dsn]

您好,感谢您的代码实现。有些疑问想请教您一下:

  1. 网络最后的return[x, x_dsn]是什么意思,返回一个list?在论文中,最后不该是一个cat吗?
  2. 网络最后的输出,如果不经过上采样的话,如何恢复原图大小?

precision

have you compared the moiu between your cc.py and the original code on cityscapes ???
becauese i have test your cc.py,and .....

Can you offer the scores you got by the net?

Thanks for your insightful work. The original repository confused me for a long time to reproduce the scores. As described in the title, can you offer the Miou and other metrics produced by the net? Thank you so much.

inplace_abn usage

Hi,
thank you a lot for your implementation. I'd like to clarify one moment.

In README file you state that you decided not to use Cuda inplace-abn. However, there're number of imports from inplace_abn (e.g. in networks/ccnet.py) which are used further. What should I do with them?

Thank you in advance!

Apex error

It is stated that if we do not need either Cuda extension or apex. And we can run only from python 3 and pytorch.
But, when I began used "python train.py --data-dir /data/datasets/Cityscapes/ --random-mirror --random-scale --restore-from ./dataset/resnet101-imagenet.pth --gpu 0,1,2,3 --learning-rate 0.01 --input-size 769,769 --weight-decay 0.0001 --batch-size 4 --num-steps 60000 --recurrence 2 --ohem 1 --ohem-thres 0.7 --ohem-keep 100000 --model ccnet
"

It gave this error stating apex is needed.
Traceback (most recent call last):
File "/home/shrutisth/PycharmProjects/CCNet-Pure-Pytorch/engine.py", line 19, in
from apex.parallel import DistributedDataParallel, SyncBatchNorm
ModuleNotFoundError: No module named 'apex'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "train.py", line 26, in
from engine import Engine
File "/home/shrutisth/PycharmProjects/CCNet-Pure-Pytorch/engine.py", line 22, in
"Please install apex from https://www.github.com/nvidia/apex .")
ImportError: Please install apex from https://www.github.com/nvidia/apex .

How can we use this repo without CUDA extension and apex?

关于类别一致损失CCL

您好,感谢您的代码!
大致浏览了下,没有发现CCL的代码,可以指出其位置所在吗?
感谢!

about_CC.py

hello i met this problem and please u to see .
1 x = torch.randn(2,64,5,6)
2
----> 3 y = model(x)

1 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
475 result = self._slow_forward(*input, **kwargs)
476 else:
--> 477 result = self.forward(*input, **kwargs)
478 for hook in self._forward_hooks.values():
479 hook_result = hook(self, input, result)

in forward(self, x)
24 proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsizewidth,-1,height)
25 proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize
height,-1,width)
---> 26 energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)
27 energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width)
28 concate = self.softmax(torch.cat([energy_H, energy_W], 3))

RuntimeError: Expected object of type torch.FloatTensor but found type torch.cuda.FloatTensor for argument #3 'other'

Not an issue - question on maths / paper / relation to this line

def INF(B,H,W):

def INF(B,H,W):
return -torch.diag(torch.tensor(float("inf")).cuda().repeat(H),0).unsqueeze(0).repeat(B*W,1,1)

is this related to the white squares that are not criss crossed?
Screen Shot 2021-02-07 at 10 06 36 pm

does the blue dot in question - presumably - it can't go any further left or right? how does algorithm handle this 'edge' case?
Are these the 'residual connections'? how does the code handle this?

Screen Shot 2021-02-07 at 10 07 58 pm

Were the efforts to change the length of the cross?

If I had to comment the code

// Dense Attention Map - green parts
        proj_query = self.query_conv(x)
        proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1)
        proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1)
        proj_key = self.key_conv(x)
        proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
        proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
        proj_value = self.value_conv(x)
        proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
        proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)

what is 0,3,1,2 related to ?

In the paper it mentions a 3d criss cross implementation - with a T / temporarl parameter introduced - does this exist in this code?
Screen Shot 2021-02-07 at 10 48 15 pm

Where is H prime? Is that connected to the energy?

Sorry - all these noob questions - thanks for any help you can shed light on.

Num classes

I saw there are 21 classes including a background in the pyt_utils. why did you use num classes=19? any specific reason?

Calculating loss

In the ccnet code. At the end of resnet there is out which gives an output of 33X33 sized image. But it is then passed into the self.criterion(outs, labels). But the labels are of original size i.e. 769*769. So, how can it be computed? I am getting the "CUDNN_STATUS_INTERNAL_ERROR" error. I have attached the same in the snapshot below.

ccnet

关于评估代码与原始代码有些不同,请大佬帮忙解释一下疑惑?

在原始代码的评估.py代码中此处:
confusion_matrix = torch.from_numpy(confusion_matrix).contiguous().cuda()
confusion_matrix = engine.all_reduce_tensor(confusion_matrix, norm=False).cpu().numpy()
pos = confusion_matrix.sum(1)
res = confusion_matrix.sum(0)
tp = np.diag(confusion_matrix)
而你的:
#confusion_matrix = torch.from_numpy(confusion_matrix).contiguous().cuda()
#confusion_matrix = engine.all_reduce_tensor(confusion_matrix, norm=False).cpu().numpy()
被注释掉了,能不能解释一下啊

can this work with cuda toolkit 11.2 + for nvidia 3090 ?

(torch) ➜ CCNet git:(master) ✗ ./run_local.sh
Linux pop-os 5.8.0-7630-generic #32160919370720.10~781bb80-Ubuntu SMP Tue Jan 5 21:29:56 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
Fri 05 Feb 2021 21:04:13 AEDT
Traceback (most recent call last):
File "train.py", line 16, in
from networks.ccnet import Res_Deeplab
File "/home/jp/Documents/gitWorkspace/CCNet/networks/ccnet.py", line 15, in
from libs import InPlaceABN, InPlaceABNSync
File "/home/jp/Documents/gitWorkspace/CCNet/libs/init.py", line 1, in
from .bn import ABN, InPlaceABN, InPlaceABNWrappe
r, InPlaceABNSync, InPlaceABNSyncWrapper
File "/home/jp/Documents/gitWorkspace/CCNet/libs/bn.py", line 15, in
from .functions import inplace_abn, inplace_abn_sync
File "/home/jp/Documents/gitWorkspace/CCNet/libs/functions.py", line 5, in
from . import _ext
File "/home/jp/Documents/gitWorkspace/CCNet/libs/_ext/init.py", line 2, in
from torch.utils.ffi import _wrap_function
File "/home/jp/miniconda3/envs/torch/lib/python3.8/site-packages/torch/utils/ffi/init.py", line 1, in
raise ImportError("torch.utils.ffi is deprecated. Please use cpp extensions instead.")
ImportError: torch.utils.ffi is deprecated. Please use cpp extensions instead.
Traceback (most recent call last):
File "evaluate.py", line 14, in
from networks.ccnet import Res_Deeplab
File "/home/jp/Documents/gitWorkspace/CCNet/networks/ccnet.py", line 15, in
from libs import InPlaceABN, InPlaceABNSync
File "/home/jp/Documents/gitWorkspace/CCNet/libs/init.py", line 1, in
from .bn import ABN, InPlaceABN, InPlaceABNWrapper, InPlaceABNSync, InPlaceABNSyncWrapper
File "/home/jp/Documents/gitWorkspace/CCNet/libs/bn.py", line 15, in
from .functions import inplace_abn, inplace_abn_sync
File "/home/jp/Documents/gitWorkspace/CCNet/libs/functions.py", line 5, in
from . import _ext
File "/home/jp/Documents/gitWorkspace/CCNet/libs/_ext/init.py", line 2, in
from torch.utils.ffi import _wrap_function
File "/home/jp/miniconda3/envs/torch/lib/python3.8/site-packages/torch/utils/ffi/init.py", line 1, in
raise ImportError("torch.utils.ffi is deprecated. Please use cpp extensions instead.")
ImportError: torch.utils.ffi is deprecated. Please use cpp extensions instead.

rcca compile error

File "/home/zhangli/anaconda3/envs/py36/lib/python3.7/site-packages/torch/utils/cpp_extension.py", line 1538, in _run_ninja_build
raise RuntimeError(message) from e
RuntimeError: Error building extension 'rcca'

About 3d version

Hi,

I am concerned about CrissCrossAttention3D, CrissCrossAttention3D1 and CrissCrossAttention3D2 in cc3d.py. What is the difference between them?

Specifically, I am not sure what INF3D will affect them?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.