Giter Club home page Giter Club logo

pytorch-msssim's Introduction

pytorch-msssim

Differentiable Multi-Scale Structural Similarity (SSIM) index

This small utiliy provides a differentiable MS-SSIM implementation for PyTorch based on Po Hsun Su's implementation of SSIM @ https://github.com/Po-Hsun-Su/pytorch-ssim. At the moment only the product method for MS-SSIM is supported.

Installation

Master branch now only supports PyTorch 0.4 or higher. All development occurs in the dev branch (git checkout dev after cloning the repository to get the latest development version).

To install the current version of pytorch_mssim:

  1. Clone this repo.
  2. Go to the repo directory.
  3. Run python setup.py install

or

  1. Clone this repo.
  2. Copy "pytorch_msssim" folder in your project.

To install a version of of pytorch_mssim that runs in PyTorch 0.3.1 or lower use the tag checkpoint-0.3. To do so, run the following commands after cloning the repository:

git fetch --all --tags
git checkout tags/checkpoint-0.3

Example

Basic usage

import pytorch_msssim
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
m = pytorch_msssim.MSSSIM()

img1 = torch.rand(1, 1, 256, 256)
img2 = torch.rand(1, 1, 256, 256)

print(pytorch_msssim.msssim(img1, img2))
print(m(img1, img2))

Training

For a detailed example on how to use msssim for optimization, take a look at the file max_ssim.py.

Stability and normalization

MS-SSIM is a particularly unstable metric when used for some architectures and may result in NaN values early on during the training. The msssim method provides a normalize attribute to help in these cases. There are three possible values. We recommend using the value normalize="relu" when training.

  • None : no normalization method is used and should be used for evaluation
  • "relu" : the ssimand mc values of each level during the calculation are rectified using a relu ensuring that negative values are zeroed
  • "simple" : the ssimresult of each iteration is averaged with 1 for an expected lower bound of 0.5 - should ONLY be used for the initial iterations of your training or when averaging below 0.6 normalized score

Currently and due to backward compability, a value of True will equal the "simple" normalization.

Reference

https://ece.uwaterloo.ca/~z70wang/research/ssim/

https://github.com/Po-Hsun-Su/pytorch-ssim

Thanks to z70wang for proposing MS-SSIM and providing the initial implementation, and Po-Hsun-Su for the initial differentiable SSIM implementation for Pytorch.

pytorch-msssim's People

Contributors

gregjohnso avatar henrych4 avatar ir1d avatar jorge-pessoa avatar maekawataiki avatar po-hsun-su avatar serkansulun avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

pytorch-msssim's Issues

Contrast Sensitivity (CS) for batch size > 1

Thanks for making this implementation.

I found the current implementation may have a small bug if I am not confused.
When throwing batches of the image with size > 1, cs returned from ssim function have size 1.
This may be expected behaviour when user wants average MS-SSIM score of images in batch (i.e. size_average=True), but not ideal when user wants cs for each image.
As a result, images within the same batch have a very similar MS-SSIM score even size_average=False since they all have the same mcs.

Here is some minimal reproduction on Google Colab.
https://colab.research.google.com/drive/1tNWb0QTqn3clnKcMlFeA8QOGJZDDRJe3?usp=sharing

I would appreciate if it is possible to specify whether to average cs like ret in implementation here

Proposed change:
Use size_average flag to specify the behaviour of mean operation on cs.

i.e.

    cs = v1 / v2  # contrast sensitivity
    ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)

    if size_average:
        cs = cs.mean()
        ret = ssim_map.mean()
    else:
        cs = cs.mean(1).mean(1).mean(1)
        ret = ssim_map.mean(1).mean(1).mean(1)

Additional Reference:
Tensorflow Implementation of MS-SSIM. They keep cs for each batch. (here, they take mean of width and height and take mean over each channel later)
https://github.com/tensorflow/tensorflow/blob/v2.3.1/tensorflow/python/ops/image_ops_impl.py#L3581

Arbitrary NAN for very low MS-SSIM comparisons

Currently the MS-SSIM calculation might return NAN when comparing two images with very low MS-SSIM scores, breaking the training process unless when accounted for.

Easy to reproduce, and can be avoided but the root cause should be discovered for fixing

Big in SSIM implementation, don't use this code for perceptual quality estimation

Hi
This code contains the same error as skimage, you can read full description here: scikit-image/scikit-image#5192

Shortly, when used for estimation of perceptual quality, authors of original paper proposed to downsample images first to make SSIM focus on major differences between reference and distorted inputs.

So what?
If you are using this implementation as a loss function for CNN, you're likely leading it in the wrong direction.

Alternatives
You can find correct implementation of SSIM, MS-SSIM and some other metrics here:
https://github.com/photosynthesis-team/piq

nan problem

How to deal with the problem of nan loss during training?

MS-SSIM Usage

Hi @jorge-pessoa,

Thanks a lot for sharing the code. I have some doubts about the usage of your implementation.

This is the way I am using the msssim in my code as a loss function.

loss_SSIM_A = pytorch_msssim.msssim(G_BA(real_A), real_A, normalize=True)
loss_SSIM_B = pytorch_msssim.msssim(G_AB(real_B), real_B, normalize=True)
loss_SSIM = (loss_SSIM_A + loss_SSIM_B) / 2

The result value until epoch 20 is betwwen 0.85 - 0.97 and then decrease to 0.65 until 0.81 , Is it fine or I need to define the threshold to maximize the value of MSSSIM as you did in max-ssim.py ?

  1. Am I using the msssim correctly? because in the max-ssim.py you add the minus sign; should I do the same or not and what is the reason of using the minus sign?

    msssim_out = -loss_func(img1, img2)

Thanks in advance!

Artifacts on color images

Running the test script generates corrupted results whereas pytorch-ssim yields no such issue.
Tested with

totensor = torchvision.transforms.ToTensor()
img1 = totensor(Image.open('kate.png'))
img1 = img1.reshape([1]+list(img1.shape))
...
torchvision.utils.save_image(img2, 'res.jpg')

kate
res

ssim on mnist

hello, I ran to problem while running your code,
Using mnist dataset and the input shape to ssim is two (batch_size, 1, 28 , 28)
seems like the variable "win" is, in default, only supporting 3d tensors
is there any work around?

thank you

p.s. following is my error

File "/home/*/anaconda3/lib/python3.7/site-packages/pytorch_msssim/ssim.py", line 37, in gaussian_filter
    out = F.conv2d(out, win.transpose(2, 3), stride=1, padding=0, groups=C)
RuntimeError: Given groups=1, weight of size [3, 1, 11, 1], expected input[128, 3, 28, 18] to have 1 channels, but got 3 channels instead

Inquiry about standard deviations

Based on the wiki and the paper to compute contrast you need $sigma_x*sigma_y$.

In this line you are using $sigma_{12} = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2$ instead of $sigma_x*sigma_y$.

$sigma_{12}=E(X*Y)-E(X)E(Y)$ while $sigma_x * sigma_y = [ E( X-(E(X) )^2 ) *E(Y- (E(Y))^2 ) ]^{.5}$.

Is the threshold necessary?

I have seen that you use while value < threshold: in your code - could you explain if it is necessary and why?

Thanks!

Expected object of scalar type Double but got scalar type Float for argument #2 'weight'

So I implemented this as follows:

import pytorch_msssim
[...]
lr_loss = pytorch_msssim.MSSSIM()
[...]
lr_tensor = torch.tensor(np.expand_dims(lr_img.astype(np.float32), axis=0)).type('torch.DoubleTensor').to(DEVICE)
in_tensor = torch.tensor(np.expand_dims(sr_img.astype(np.float32), axis=0)).type('torch.DoubleTensor').to(DEVICE)
[...]
ds_in_tensor = bds(in_tensor, nhwc=True)
lr_l = lr_loss(ds_in_tensor, lr_tensor)
l2_l = l2_loss(in_tensor, org_tensor)
l = lr_l + LAMBDA * l2_l
l.backward()

And I'm getting this error:

Traceback (most recent call last):
File "/usr/xtmp/superresoluter/superresolution/tester_msssim.py", line 137, in
lr_l = lr_loss(ds_in_tensor, lr_tensor)
File "/home/home5/abarnett/sr/lib/python3.5/site-packages/torch/nn/modules/module.py", line 489, in call
result = self.forward(*input, **kwargs)
File "/usr/project/xtmp/superresoluter/superresolution/pytorch_msssim/init.py", line 133, in forward
return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)
File "/usr/project/xtmp/superresoluter/superresolution/pytorch_msssim/init.py", line 78, in msssim
sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
File "/usr/project/xtmp/superresoluter/superresolution/pytorch_msssim/init.py", line 41, in ssim
mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
RuntimeError: Expected object of scalar type Double but got scalar type Float for argument #2 'weight'

Any ideas?

SSIM result is different from skimage.measure.compare_ssim

Hi,
Thanks for this tool. I use both pytorch_mssim.ssim and skimage.measure.compare_ssim to compute ssim, but the results are different. For example, ssim evaluation on an image sequence:
pytorch_msssim.ssim: [0.9655, 0.9500, 0.9324, 0.9229, 0.9191, 0.9154]
skimage.measure.compare_ssim: [0.97794482, 0.96226299, 0.948432, 0.9386946, 0.93113704, 0.92531453]

Why will this happen?

Typo in calculation?

Thanks for making this implementation.

I found the current implementation may have a small typo if I am not confused.

The current implementation of output calculation seems to have the wrong bracket location, changing the result of MS-SSIM score.
https://github.com/jorge-pessoa/pytorch-msssim/blob/master/pytorch_msssim/__init__.py#L104

I think it should be
output = torch.prod(pow1[:-1]) * pow2[-1]
instead of
output = torch.prod(pow1[:-1] * pow2[-1])


Reference:

The original Implementation of MS-SSIM in Matlab: https://ece.uwaterloo.ca/~z70wang/research/iwssim/
Their calculation take product over pow1[1:-1] first and multiply pow2[-1]
overall_msssim = prod(mcs_array(1:level-1).^weight(1:level-1))*(msssim_array(level).^weight(level));

The original paper of MS-SSIM: https://www.cns.nyu.edu/pub/eero/wang03b.pdf
The calculation is pow2[-1] multiply product over pow1
Screen Shot 2020-10-17 at 8 40 02 AM

Index map

Hi, wondering how can I compute the residual map for the MS-SSIM?

Bug in your MSSSIM implementation

From Tensorflow implementation line 182, I saw that they first apply an average pooling filter then downsample by slicing with step 2:

for _ in range(levels):
    ssim, cs = _SSIMForMultiScale(...)
    ...
    filtered = [convolve(im, downsample_filter, mode='reflect')
                for im in [im1, im2]]
    im1, im2 = [x[:, ::2, ::2, :] for x in filtered]

But in your implementation, you only apply a 2x2 AvgPool, so the images are not downsampled by 2 as in the paper.

for _ in range(levels):
    sim, cs = ssim(...)
    ...

    img1 = F.avg_pool2d(img1, (2, 2))
    img2 = F.avg_pool2d(img2, (2, 2))

Output is out of range

Sorry to bother you. But I used ssim as a loss function, and found some strange outputs. The gray values of the inputs are normalized to [0,1], but the output can be out of this range, like (-100,...), which is meaningless. In other words, the network is optimized toward a wrong direction. I do not know what is the reason, could you please give me some suggestions? btw, the network can work with other loss functions.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.