Giter Club home page Giter Club logo

Comments (15)

JingyunLiang avatar JingyunLiang commented on May 5, 2024 1

This line is for training. We initialize the mask as self.attn_mask so that we don't need to calculate the mask for every iteration. Note that the training image size is fixed, e.g., 64x64.

if self.input_resolution == x_size:

This line is for testing. We calculate the mask for a given testing image. Note that the testing image is generally not 64x64.

from swinir.

JingyunLiang avatar JingyunLiang commented on May 5, 2024

See #9 for a discussion

from swinir.

jiaaihhy avatar jiaaihhy commented on May 5, 2024

可是还是没有解决这个问题,我还是不理解如果input_resolution != x_size加的这个mask是什么 以及为什么要加这个mask
if self.input_resolution == x_size:
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nWB, window_sizewindow_size, C
else:
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))

attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))

from swinir.

jiaaihhy avatar jiaaihhy commented on May 5, 2024

verythanks

from swinir.

jiaaihhy avatar jiaaihhy commented on May 5, 2024

when test, what is the meaning of the mask? thankyou

from swinir.

JingyunLiang avatar JingyunLiang commented on May 5, 2024

Similar to training, we pad the image after shifting it. You can try not using padding during testing and share the results with me. Thank you.

from swinir.

paragon1234 avatar paragon1234 commented on May 5, 2024

Even I had the same issue. Can you please elaborate on why we require attn_mask?
From the code, it seems that it is required only for those transformers that operate on shifted window.

if self.shift_size > 0:
            attn_mask = self.calculate_mask(self.input_resolution)
        else:
            attn_mask = None 

from swinir.

JingyunLiang avatar JingyunLiang commented on May 5, 2024

Yes, attn_mask is only used for those transformers that operate on shifted window. Imagine that for a 64x64 input, after shifting 4x4 pixels towards top-left corner (by using torch.roll), pixels within [0:4,:] and [:,0:4] are shifted to [60:64,:] and [:,60:64], respectively. In such a case, pixels within [56:60,:] and [:,56:60] will be forced to attend to above unrelated pixels after the new window partition. This is not we want, so we use attn_mask to mask them out.

from swinir.

paragon1234 avatar paragon1234 commented on May 5, 2024

Thank you for the response. This is an interesting point, from an implementation perspective. I am wondering why should transformer operate only on [56:60,:] and [:,56:60], but not on [0:4,:] and [:,0:4]? Either:

  1. what if we simply not do self-attention on pixels within [56:64,:] and [:,56:64] for shifted_windows ? OR
  2. What if transformer operate on both, without mixing them ie even smaller window size for this corner-case?

In the curent implementation, top 4 rows and last 4 columns are operated only 50% of times.

from swinir.

JingyunLiang avatar JingyunLiang commented on May 5, 2024

I didn't test these two cases, but I guess the first case may lead to slightly worse performance (this part of data is discarded), while the second one may leads to slightly better performance (making full use of this part of data). The current implementation is just for simplicity and efficiency.

from swinir.

JingyunLiang avatar JingyunLiang commented on May 5, 2024

Feel free to open it if you have more questions.

from swinir.

paragon1234 avatar paragon1234 commented on May 5, 2024

Can you please explain why we need mask for testing, when input resolution is not 48x48?
I tried to change the attention mechanism in the code. However, my method did not require mask. I removed calculate_mask from code and completed the training phase. However, during testing it gave error.
My concern is: 1)Do I require mask in my attention? 2)How to change my attention mechanism to incorporate mask?

from swinir.

JingyunLiang avatar JingyunLiang commented on May 5, 2024

SwinIR needs mask for either input resolution of 48x48 or not. The difference is that we use pre-calculated mask for 48x48 images because we store it in the .pth file during training. For non 48x48 images, we need to recompute it.

For the second concern, what is your error in testing? If you don't need mask in your own attention, there is no need for masking in testing as well. Sorry that I cannot give more help because I don't you what is your changed attention mechanism.

from swinir.

paragon1234 avatar paragon1234 commented on May 5, 2024

I am using the model: efficient attention(https://github.com/cmsflash/efficient-attention).
I do not understand why we require mask? Even for 48x48 patch, as per your last reply?
Also, the training-time (using efficient attention) is more compared to transformer attention (of swinIR). Maybe the window_size of 8x8 is super efficient for low-level image processing.

from swinir.

JingyunLiang avatar JingyunLiang commented on May 5, 2024

It seems that the position encoding is not very important for SR from my experience. You can try to remove it and compare their results. Note that there are two problems you need to address for efficient attention(https://github.com/cmsflash/efficient-attention): 1) Using softmax for q and k separately may reduce the representation power of the attention matrix significantly as the rank is matrix is smaller. 2) It may be trick to apply masks for it (see cmsflash/efficient-attention#4)

from swinir.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.